import contextlib import importlib.util import inspect import sys import types import unittest import warnings from pathlib import Path import torch _REPO_ROOT = Path(__file__).resolve().parents[1] _LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py' _EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy' _TRANSFORMER_WARNING_MESSAGE = ( r'enable_nested_tensor is True, but self.use_nested_tensor is False ' r'because encoder_layer\.norm_first was True' ) _MISSING = object() def _load_module_from_path(name: str, path: Path, *, register: bool = False): spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) assert spec.loader is not None if register: sys.modules[name] = module spec.loader.exec_module(module) return module def _resolve_external_module_paths(external_checkout_root: Path): diffusion_policy_root = external_checkout_root / 'diffusion_policy' paths = { 'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py', 'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py', 'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py', } if not all(path.exists() for path in paths.values()): return None return paths @contextlib.contextmanager def _temporary_registered_modules(): previous_modules = {} def remember(name: str) -> None: if name not in previous_modules: previous_modules[name] = sys.modules.get(name, _MISSING) def ensure_package(name: str) -> None: if not name or name in sys.modules: return remember(name) package = types.ModuleType(name) package.__path__ = [] sys.modules[name] = package def load(name: str, path: Path): package_parts = name.split('.')[:-1] for idx in range(1, len(package_parts) + 1): ensure_package('.'.join(package_parts[:idx])) remember(name) return _load_module_from_path(name, path, register=True) try: yield load finally: for name, previous in reversed(list(previous_modules.items())): if previous is _MISSING: sys.modules.pop(name, None) else: sys.modules[name] = previous @contextlib.contextmanager def _suppress_nested_tensor_warning(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message=_TRANSFORMER_WARNING_MESSAGE, category=UserWarning, module=r'torch\.nn\.modules\.transformer', ) yield def _load_local_module(): return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH) class Transformer1DExternalAlignmentTest(unittest.TestCase): def _load_transformer_classes_or_skip(self): external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT) if external_paths is None: self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}') local_module = _load_local_module() with _temporary_registered_modules() as load_external: load_external( 'diffusion_policy.model.diffusion.positional_embedding', external_paths['positional_embedding'], ) load_external( 'diffusion_policy.model.common.module_attr_mixin', external_paths['module_attr_mixin'], ) external_module = load_external( 'diffusion_policy.model.diffusion.transformer_for_diffusion', external_paths['transformer_for_diffusion'], ) return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion def _optim_group_names(self, model, groups): names_by_param = {id(param): name for name, param in model.named_parameters()} return [ {names_by_param[id(param)] for param in group['params']} for group in groups ] def test_missing_external_checkout_resolution_returns_none(self): self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__')) def test_external_loader_restores_injected_sys_modules(self): external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT) if external_paths is None: self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}') watched_names = [ 'diffusion_policy', 'diffusion_policy.model', 'diffusion_policy.model.common', 'diffusion_policy.model.common.module_attr_mixin', 'diffusion_policy.model.diffusion', 'diffusion_policy.model.diffusion.positional_embedding', 'diffusion_policy.model.diffusion.transformer_for_diffusion', ] before = {name: sys.modules.get(name, _MISSING) for name in watched_names} with _temporary_registered_modules() as load_external: load_external( 'diffusion_policy.model.diffusion.positional_embedding', external_paths['positional_embedding'], ) load_external( 'diffusion_policy.model.common.module_attr_mixin', external_paths['module_attr_mixin'], ) load_external( 'diffusion_policy.model.diffusion.transformer_for_diffusion', external_paths['transformer_for_diffusion'], ) after = {name: sys.modules.get(name, _MISSING) for name in watched_names} self.assertEqual(after, before) def test_transformer1d_preserves_local_direct_call_defaults(self): local_module = _load_local_module() ctor = inspect.signature(local_module.Transformer1D.__init__).parameters helper = inspect.signature(local_module.create_transformer1d).parameters self.assertEqual(ctor['n_layer'].default, 8) self.assertEqual(ctor['n_head'].default, 8) self.assertEqual(ctor['n_emb'].default, 256) self.assertEqual(helper['n_layer'].default, 8) self.assertEqual(helper['n_head'].default, 8) self.assertEqual(helper['n_emb'].default, 256) def test_time_as_cond_false_token_accounting_matches_external(self): Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip() self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters) config = dict( input_dim=4, output_dim=4, horizon=6, n_obs_steps=3, cond_dim=0, n_layer=2, n_head=2, n_emb=8, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=False, time_as_cond=False, obs_as_cond=False, n_cond_layers=0, ) torch.manual_seed(5) with _suppress_nested_tensor_warning(): external_model = TransformerForDiffusion(**config) local_model = Transformer1D(**config) external_model.eval() local_model.eval() self.assertEqual(local_model.T, external_model.T) self.assertEqual(local_model.T_cond, external_model.T_cond) self.assertEqual(local_model.time_as_cond, external_model.time_as_cond) self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond) self.assertEqual(local_model.encoder_only, external_model.encoder_only) def test_nocausal_state_dict_forward_and_optim_groups_match_external(self): Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip() config = dict( input_dim=4, output_dim=4, horizon=6, n_obs_steps=3, cond_dim=5, n_layer=2, n_head=2, n_emb=8, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=False, obs_as_cond=True, n_cond_layers=1, ) torch.manual_seed(7) with _suppress_nested_tensor_warning(): external_model = TransformerForDiffusion(**config) local_model = Transformer1D(**config) external_model.eval() local_model.eval() external_state_dict = external_model.state_dict() self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys())) local_model.load_state_dict(external_state_dict, strict=True) batch_size = 2 sample = torch.randn(batch_size, config['horizon'], config['input_dim']) cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim']) timestep = torch.tensor([11, 17], dtype=torch.long) with torch.no_grad(): external_out = external_model(sample=sample, timestep=timestep, cond=cond) local_out = local_model(sample=sample, timestep=timestep, cond=cond) self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim'])) self.assertEqual(local_out.shape, external_out.shape) self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5)) weight_decay = 0.123 external_groups = external_model.get_optim_groups(weight_decay=weight_decay) local_groups = local_model.get_optim_groups(weight_decay=weight_decay) self.assertEqual(len(local_groups), len(external_groups)) self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0]) self.assertEqual( self._optim_group_names(local_model, local_groups), self._optim_group_names(external_model, external_groups), ) if __name__ == '__main__': unittest.main()