197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
import contextlib
|
|
import importlib
|
|
import inspect
|
|
import subprocess
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(_REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(_REPO_ROOT))
|
|
|
|
_EXTERNAL_COMMIT = '185ed659'
|
|
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
|
|
_MISSING = object()
|
|
|
|
|
|
def _find_external_checkout_root() -> Path | None:
|
|
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
|
|
candidate = ancestor / 'diffusion_policy'
|
|
if (candidate / '.git').exists():
|
|
return candidate
|
|
return None
|
|
|
|
|
|
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
|
|
_EXTERNAL_MODULE_PATHS = {
|
|
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
|
|
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
|
|
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
|
|
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
|
|
}
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _temporary_registered_modules():
|
|
previous_modules = {}
|
|
|
|
def remember(name: str) -> None:
|
|
if name not in previous_modules:
|
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
|
|
|
def ensure_package(name: str) -> None:
|
|
if not name or name in sys.modules:
|
|
return
|
|
remember(name)
|
|
package = types.ModuleType(name)
|
|
package.__path__ = []
|
|
sys.modules[name] = package
|
|
|
|
def load(name: str, source: str, origin: str):
|
|
package_parts = name.split('.')[:-1]
|
|
for idx in range(1, len(package_parts) + 1):
|
|
ensure_package('.'.join(package_parts[:idx]))
|
|
|
|
remember(name)
|
|
module = types.ModuleType(name)
|
|
module.__file__ = origin
|
|
module.__package__ = name.rpartition('.')[0]
|
|
sys.modules[name] = module
|
|
exec(compile(source, origin, 'exec'), module.__dict__)
|
|
return module
|
|
|
|
try:
|
|
yield load
|
|
finally:
|
|
for name, previous in reversed(list(previous_modules.items())):
|
|
if previous is _MISSING:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = previous
|
|
|
|
|
|
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
|
|
result = subprocess.run(
|
|
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
return result.stdout
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _load_external_module_or_skip(test_case: unittest.TestCase):
|
|
if _EXTERNAL_CHECKOUT_ROOT is None:
|
|
test_case.skipTest('external diffusion_policy checkout unavailable')
|
|
|
|
try:
|
|
sources = {
|
|
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
|
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
|
|
}
|
|
except subprocess.CalledProcessError as exc:
|
|
test_case.skipTest(
|
|
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
|
|
)
|
|
|
|
with _temporary_registered_modules() as load_external:
|
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
|
|
load_external(
|
|
name,
|
|
sources[name],
|
|
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
|
|
)
|
|
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
|
|
|
|
|
|
def _load_local_module():
|
|
importlib.invalidate_caches()
|
|
sys.modules.pop(_LOCAL_MODULE_NAME, None)
|
|
return importlib.import_module(_LOCAL_MODULE_NAME)
|
|
|
|
|
|
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
|
|
def _optim_group_names(self, model, groups):
|
|
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
|
return [
|
|
{names_by_param[id(param)] for param in group['params']}
|
|
for group in groups
|
|
]
|
|
|
|
def test_local_defaults_preserve_supported_attnres_config(self):
|
|
local_module = _load_local_module()
|
|
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
|
|
|
|
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
|
|
self.assertEqual(ctor['n_head'].default, 1)
|
|
self.assertEqual(ctor['n_kv_head'].default, 1)
|
|
self.assertEqual(ctor['n_cond_layers'].default, 0)
|
|
self.assertTrue(ctor['time_as_cond'].default)
|
|
self.assertFalse(ctor['causal_attn'].default)
|
|
|
|
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
|
|
local_module = _load_local_module()
|
|
with _load_external_module_or_skip(self) as external_module:
|
|
config = dict(
|
|
input_dim=4,
|
|
output_dim=4,
|
|
horizon=6,
|
|
n_obs_steps=3,
|
|
cond_dim=5,
|
|
n_layer=2,
|
|
n_head=1,
|
|
n_emb=16,
|
|
p_drop_emb=0.0,
|
|
p_drop_attn=0.0,
|
|
causal_attn=False,
|
|
time_as_cond=True,
|
|
n_cond_layers=0,
|
|
backbone_type='attnres_full',
|
|
n_kv_head=1,
|
|
)
|
|
|
|
torch.manual_seed(7)
|
|
external_model = external_module.IMFTransformerForDiffusion(**config)
|
|
local_model = local_module.IMFTransformer1D(**config)
|
|
external_model.eval()
|
|
local_model.eval()
|
|
|
|
external_state_dict = external_model.state_dict()
|
|
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
|
local_model.load_state_dict(external_state_dict, strict=True)
|
|
|
|
batch_size = 2
|
|
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
|
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
|
|
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
|
|
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
|
|
|
with torch.no_grad():
|
|
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
|
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
|
|
|
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
|
self.assertEqual(local_out.shape, external_out.shape)
|
|
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
|
|
|
weight_decay = 0.123
|
|
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
|
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
|
|
|
self.assertEqual(len(local_groups), len(external_groups))
|
|
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
|
self.assertEqual(
|
|
self._optim_group_names(local_model, local_groups),
|
|
self._optim_group_names(external_model, external_groups),
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|