263 lines
9.8 KiB
Python
263 lines
9.8 KiB
Python
import contextlib
|
|
import importlib.util
|
|
import inspect
|
|
import sys
|
|
import types
|
|
import unittest
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
|
|
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
|
|
_TRANSFORMER_WARNING_MESSAGE = (
|
|
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
|
|
r'because encoder_layer\.norm_first was True'
|
|
)
|
|
_MISSING = object()
|
|
|
|
|
|
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec.loader is not None
|
|
if register:
|
|
sys.modules[name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def _resolve_external_module_paths(external_checkout_root: Path):
|
|
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
|
|
paths = {
|
|
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
|
|
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
|
|
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
|
|
}
|
|
if not all(path.exists() for path in paths.values()):
|
|
return None
|
|
return paths
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _temporary_registered_modules():
|
|
previous_modules = {}
|
|
|
|
def remember(name: str) -> None:
|
|
if name not in previous_modules:
|
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
|
|
|
def ensure_package(name: str) -> None:
|
|
if not name or name in sys.modules:
|
|
return
|
|
remember(name)
|
|
package = types.ModuleType(name)
|
|
package.__path__ = []
|
|
sys.modules[name] = package
|
|
|
|
def load(name: str, path: Path):
|
|
package_parts = name.split('.')[:-1]
|
|
for idx in range(1, len(package_parts) + 1):
|
|
ensure_package('.'.join(package_parts[:idx]))
|
|
|
|
remember(name)
|
|
return _load_module_from_path(name, path, register=True)
|
|
|
|
try:
|
|
yield load
|
|
finally:
|
|
for name, previous in reversed(list(previous_modules.items())):
|
|
if previous is _MISSING:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = previous
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _suppress_nested_tensor_warning():
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
'ignore',
|
|
message=_TRANSFORMER_WARNING_MESSAGE,
|
|
category=UserWarning,
|
|
module=r'torch\.nn\.modules\.transformer',
|
|
)
|
|
yield
|
|
|
|
|
|
def _load_local_module():
|
|
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
|
|
|
|
|
|
class Transformer1DExternalAlignmentTest(unittest.TestCase):
|
|
def _load_transformer_classes_or_skip(self):
|
|
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
|
if external_paths is None:
|
|
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
|
|
|
local_module = _load_local_module()
|
|
with _temporary_registered_modules() as load_external:
|
|
load_external(
|
|
'diffusion_policy.model.diffusion.positional_embedding',
|
|
external_paths['positional_embedding'],
|
|
)
|
|
load_external(
|
|
'diffusion_policy.model.common.module_attr_mixin',
|
|
external_paths['module_attr_mixin'],
|
|
)
|
|
external_module = load_external(
|
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
|
external_paths['transformer_for_diffusion'],
|
|
)
|
|
|
|
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
|
|
|
|
def _optim_group_names(self, model, groups):
|
|
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
|
return [
|
|
{names_by_param[id(param)] for param in group['params']}
|
|
for group in groups
|
|
]
|
|
|
|
def test_missing_external_checkout_resolution_returns_none(self):
|
|
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
|
|
|
|
def test_external_loader_restores_injected_sys_modules(self):
|
|
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
|
if external_paths is None:
|
|
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
|
|
|
watched_names = [
|
|
'diffusion_policy',
|
|
'diffusion_policy.model',
|
|
'diffusion_policy.model.common',
|
|
'diffusion_policy.model.common.module_attr_mixin',
|
|
'diffusion_policy.model.diffusion',
|
|
'diffusion_policy.model.diffusion.positional_embedding',
|
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
|
]
|
|
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
|
|
|
with _temporary_registered_modules() as load_external:
|
|
load_external(
|
|
'diffusion_policy.model.diffusion.positional_embedding',
|
|
external_paths['positional_embedding'],
|
|
)
|
|
load_external(
|
|
'diffusion_policy.model.common.module_attr_mixin',
|
|
external_paths['module_attr_mixin'],
|
|
)
|
|
load_external(
|
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
|
external_paths['transformer_for_diffusion'],
|
|
)
|
|
|
|
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
|
self.assertEqual(after, before)
|
|
|
|
def test_transformer1d_preserves_local_direct_call_defaults(self):
|
|
local_module = _load_local_module()
|
|
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
|
|
helper = inspect.signature(local_module.create_transformer1d).parameters
|
|
|
|
self.assertEqual(ctor['n_layer'].default, 8)
|
|
self.assertEqual(ctor['n_head'].default, 8)
|
|
self.assertEqual(ctor['n_emb'].default, 256)
|
|
self.assertEqual(helper['n_layer'].default, 8)
|
|
self.assertEqual(helper['n_head'].default, 8)
|
|
self.assertEqual(helper['n_emb'].default, 256)
|
|
|
|
def test_time_as_cond_false_token_accounting_matches_external(self):
|
|
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
|
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
|
|
|
|
config = dict(
|
|
input_dim=4,
|
|
output_dim=4,
|
|
horizon=6,
|
|
n_obs_steps=3,
|
|
cond_dim=0,
|
|
n_layer=2,
|
|
n_head=2,
|
|
n_emb=8,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=False,
|
|
time_as_cond=False,
|
|
obs_as_cond=False,
|
|
n_cond_layers=0,
|
|
)
|
|
|
|
torch.manual_seed(5)
|
|
with _suppress_nested_tensor_warning():
|
|
external_model = TransformerForDiffusion(**config)
|
|
local_model = Transformer1D(**config)
|
|
external_model.eval()
|
|
local_model.eval()
|
|
|
|
self.assertEqual(local_model.T, external_model.T)
|
|
self.assertEqual(local_model.T_cond, external_model.T_cond)
|
|
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
|
|
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
|
|
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
|
|
|
|
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
|
|
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
|
config = dict(
|
|
input_dim=4,
|
|
output_dim=4,
|
|
horizon=6,
|
|
n_obs_steps=3,
|
|
cond_dim=5,
|
|
n_layer=2,
|
|
n_head=2,
|
|
n_emb=8,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=False,
|
|
obs_as_cond=True,
|
|
n_cond_layers=1,
|
|
)
|
|
|
|
torch.manual_seed(7)
|
|
with _suppress_nested_tensor_warning():
|
|
external_model = TransformerForDiffusion(**config)
|
|
local_model = Transformer1D(**config)
|
|
external_model.eval()
|
|
local_model.eval()
|
|
|
|
external_state_dict = external_model.state_dict()
|
|
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
|
local_model.load_state_dict(external_state_dict, strict=True)
|
|
|
|
batch_size = 2
|
|
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
|
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
|
timestep = torch.tensor([11, 17], dtype=torch.long)
|
|
|
|
with torch.no_grad():
|
|
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
|
|
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
|
|
|
|
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
|
self.assertEqual(local_out.shape, external_out.shape)
|
|
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
|
|
|
weight_decay = 0.123
|
|
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
|
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
|
|
|
self.assertEqual(len(local_groups), len(external_groups))
|
|
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
|
self.assertEqual(
|
|
self._optim_group_names(local_model, local_groups),
|
|
self._optim_group_names(external_model, external_groups),
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|