import contextlib import importlib import inspect import subprocess import sys import types import unittest from pathlib import Path import torch _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) _EXTERNAL_COMMIT = '185ed659' _LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d' _MISSING = object() def _find_external_checkout_root() -> Path | None: for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents): candidate = ancestor / 'diffusion_policy' if (candidate / '.git').exists(): return candidate return None _EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root() _EXTERNAL_MODULE_PATHS = { 'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py', 'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py', 'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py', 'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py', } @contextlib.contextmanager def _temporary_registered_modules(): previous_modules = {} def remember(name: str) -> None: if name not in previous_modules: previous_modules[name] = sys.modules.get(name, _MISSING) def ensure_package(name: str) -> None: if not name or name in sys.modules: return remember(name) package = types.ModuleType(name) package.__path__ = [] sys.modules[name] = package def load(name: str, source: str, origin: str): package_parts = name.split('.')[:-1] for idx in range(1, len(package_parts) + 1): ensure_package('.'.join(package_parts[:idx])) remember(name) module = types.ModuleType(name) module.__file__ = origin module.__package__ = name.rpartition('.')[0] sys.modules[name] = module exec(compile(source, origin, 'exec'), module.__dict__) return module try: yield load finally: for name, previous in reversed(list(previous_modules.items())): if previous is _MISSING: sys.modules.pop(name, None) else: sys.modules[name] = previous def _git_show(repo_root: Path, commit: str, relative_path: str) -> str: result = subprocess.run( ['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'], check=True, capture_output=True, text=True, ) return result.stdout @contextlib.contextmanager def _load_external_module_or_skip(test_case: unittest.TestCase): if _EXTERNAL_CHECKOUT_ROOT is None: test_case.skipTest('external diffusion_policy checkout unavailable') try: sources = { name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path) for name, relative_path in _EXTERNAL_MODULE_PATHS.items() } except subprocess.CalledProcessError as exc: test_case.skipTest( f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}' ) with _temporary_registered_modules() as load_external: for name, relative_path in _EXTERNAL_MODULE_PATHS.items(): load_external( name, sources[name], origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}', ) yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion'] def _load_local_module(): importlib.invalidate_caches() sys.modules.pop(_LOCAL_MODULE_NAME, None) return importlib.import_module(_LOCAL_MODULE_NAME) class IMFTransformer1DExternalAlignmentTest(unittest.TestCase): def _optim_group_names(self, model, groups): names_by_param = {id(param): name for name, param in model.named_parameters()} return [ {names_by_param[id(param)] for param in group['params']} for group in groups ] def test_local_defaults_preserve_supported_attnres_config(self): local_module = _load_local_module() ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters self.assertEqual(ctor['backbone_type'].default, 'attnres_full') self.assertEqual(ctor['n_head'].default, 1) self.assertEqual(ctor['n_kv_head'].default, 1) self.assertEqual(ctor['n_cond_layers'].default, 0) self.assertTrue(ctor['time_as_cond'].default) self.assertFalse(ctor['causal_attn'].default) def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self): local_module = _load_local_module() with _load_external_module_or_skip(self) as external_module: config = dict( input_dim=4, output_dim=4, horizon=6, n_obs_steps=3, cond_dim=5, n_layer=2, n_head=1, n_emb=16, p_drop_emb=0.0, p_drop_attn=0.0, causal_attn=False, time_as_cond=True, n_cond_layers=0, backbone_type='attnres_full', n_kv_head=1, ) torch.manual_seed(7) external_model = external_module.IMFTransformerForDiffusion(**config) local_model = local_module.IMFTransformer1D(**config) external_model.eval() local_model.eval() external_state_dict = external_model.state_dict() self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys())) local_model.load_state_dict(external_state_dict, strict=True) batch_size = 2 sample = torch.randn(batch_size, config['horizon'], config['input_dim']) r = torch.tensor([0.1, 0.4], dtype=torch.float32) t = torch.tensor([0.7, 0.9], dtype=torch.float32) cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim']) with torch.no_grad(): external_out = external_model(sample=sample, r=r, t=t, cond=cond) local_out = local_model(sample=sample, r=r, t=t, cond=cond) self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim'])) self.assertEqual(local_out.shape, external_out.shape) self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5)) weight_decay = 0.123 external_groups = external_model.get_optim_groups(weight_decay=weight_decay) local_groups = local_model.get_optim_groups(weight_decay=weight_decay) self.assertEqual(len(local_groups), len(external_groups)) self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0]) self.assertEqual( self._optim_group_names(local_model, local_groups), self._optim_group_names(external_model, external_groups), ) if __name__ == '__main__': unittest.main()