Files
roboimi/tests/test_imf_transformer1d_external_alignment.py
2026-04-01 23:35:31 +08:00

197 lines
7.2 KiB
Python

import contextlib
import importlib
import inspect
import subprocess
import sys
import types
import unittest
from pathlib import Path
import torch
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
_EXTERNAL_COMMIT = '185ed659'
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
_MISSING = object()
def _find_external_checkout_root() -> Path | None:
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
candidate = ancestor / 'diffusion_policy'
if (candidate / '.git').exists():
return candidate
return None
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
_EXTERNAL_MODULE_PATHS = {
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
}
@contextlib.contextmanager
def _temporary_registered_modules():
previous_modules = {}
def remember(name: str) -> None:
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
def ensure_package(name: str) -> None:
if not name or name in sys.modules:
return
remember(name)
package = types.ModuleType(name)
package.__path__ = []
sys.modules[name] = package
def load(name: str, source: str, origin: str):
package_parts = name.split('.')[:-1]
for idx in range(1, len(package_parts) + 1):
ensure_package('.'.join(package_parts[:idx]))
remember(name)
module = types.ModuleType(name)
module.__file__ = origin
module.__package__ = name.rpartition('.')[0]
sys.modules[name] = module
exec(compile(source, origin, 'exec'), module.__dict__)
return module
try:
yield load
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
result = subprocess.run(
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
check=True,
capture_output=True,
text=True,
)
return result.stdout
@contextlib.contextmanager
def _load_external_module_or_skip(test_case: unittest.TestCase):
if _EXTERNAL_CHECKOUT_ROOT is None:
test_case.skipTest('external diffusion_policy checkout unavailable')
try:
sources = {
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
}
except subprocess.CalledProcessError as exc:
test_case.skipTest(
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
)
with _temporary_registered_modules() as load_external:
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
load_external(
name,
sources[name],
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
)
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
def _load_local_module():
importlib.invalidate_caches()
sys.modules.pop(_LOCAL_MODULE_NAME, None)
return importlib.import_module(_LOCAL_MODULE_NAME)
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
def _optim_group_names(self, model, groups):
names_by_param = {id(param): name for name, param in model.named_parameters()}
return [
{names_by_param[id(param)] for param in group['params']}
for group in groups
]
def test_local_defaults_preserve_supported_attnres_config(self):
local_module = _load_local_module()
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
self.assertEqual(ctor['n_head'].default, 1)
self.assertEqual(ctor['n_kv_head'].default, 1)
self.assertEqual(ctor['n_cond_layers'].default, 0)
self.assertTrue(ctor['time_as_cond'].default)
self.assertFalse(ctor['causal_attn'].default)
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
local_module = _load_local_module()
with _load_external_module_or_skip(self) as external_module:
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=5,
n_layer=2,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=True,
n_cond_layers=0,
backbone_type='attnres_full',
n_kv_head=1,
)
torch.manual_seed(7)
external_model = external_module.IMFTransformerForDiffusion(**config)
local_model = local_module.IMFTransformer1D(**config)
external_model.eval()
local_model.eval()
external_state_dict = external_model.state_dict()
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
local_model.load_state_dict(external_state_dict, strict=True)
batch_size = 2
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
with torch.no_grad():
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
self.assertEqual(local_out.shape, external_out.shape)
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
weight_decay = 0.123
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
self.assertEqual(len(local_groups), len(external_groups))
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
self.assertEqual(
self._optim_group_names(local_model, local_groups),
self._optim_group_names(external_model, external_groups),
)
if __name__ == '__main__':
unittest.main()