Files
roboimi/tests/test_transformer1d_external_alignment.py

263 lines
9.8 KiB
Python

import contextlib
import importlib.util
import inspect
import sys
import types
import unittest
import warnings
from pathlib import Path
import torch
_REPO_ROOT = Path(__file__).resolve().parents[1]
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
_TRANSFORMER_WARNING_MESSAGE = (
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
r'because encoder_layer\.norm_first was True'
)
_MISSING = object()
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
if register:
sys.modules[name] = module
spec.loader.exec_module(module)
return module
def _resolve_external_module_paths(external_checkout_root: Path):
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
paths = {
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
}
if not all(path.exists() for path in paths.values()):
return None
return paths
@contextlib.contextmanager
def _temporary_registered_modules():
previous_modules = {}
def remember(name: str) -> None:
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
def ensure_package(name: str) -> None:
if not name or name in sys.modules:
return
remember(name)
package = types.ModuleType(name)
package.__path__ = []
sys.modules[name] = package
def load(name: str, path: Path):
package_parts = name.split('.')[:-1]
for idx in range(1, len(package_parts) + 1):
ensure_package('.'.join(package_parts[:idx]))
remember(name)
return _load_module_from_path(name, path, register=True)
try:
yield load
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
@contextlib.contextmanager
def _suppress_nested_tensor_warning():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore',
message=_TRANSFORMER_WARNING_MESSAGE,
category=UserWarning,
module=r'torch\.nn\.modules\.transformer',
)
yield
def _load_local_module():
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
class Transformer1DExternalAlignmentTest(unittest.TestCase):
def _load_transformer_classes_or_skip(self):
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
if external_paths is None:
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
local_module = _load_local_module()
with _temporary_registered_modules() as load_external:
load_external(
'diffusion_policy.model.diffusion.positional_embedding',
external_paths['positional_embedding'],
)
load_external(
'diffusion_policy.model.common.module_attr_mixin',
external_paths['module_attr_mixin'],
)
external_module = load_external(
'diffusion_policy.model.diffusion.transformer_for_diffusion',
external_paths['transformer_for_diffusion'],
)
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
def _optim_group_names(self, model, groups):
names_by_param = {id(param): name for name, param in model.named_parameters()}
return [
{names_by_param[id(param)] for param in group['params']}
for group in groups
]
def test_missing_external_checkout_resolution_returns_none(self):
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
def test_external_loader_restores_injected_sys_modules(self):
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
if external_paths is None:
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
watched_names = [
'diffusion_policy',
'diffusion_policy.model',
'diffusion_policy.model.common',
'diffusion_policy.model.common.module_attr_mixin',
'diffusion_policy.model.diffusion',
'diffusion_policy.model.diffusion.positional_embedding',
'diffusion_policy.model.diffusion.transformer_for_diffusion',
]
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
with _temporary_registered_modules() as load_external:
load_external(
'diffusion_policy.model.diffusion.positional_embedding',
external_paths['positional_embedding'],
)
load_external(
'diffusion_policy.model.common.module_attr_mixin',
external_paths['module_attr_mixin'],
)
load_external(
'diffusion_policy.model.diffusion.transformer_for_diffusion',
external_paths['transformer_for_diffusion'],
)
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
self.assertEqual(after, before)
def test_transformer1d_preserves_local_direct_call_defaults(self):
local_module = _load_local_module()
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
helper = inspect.signature(local_module.create_transformer1d).parameters
self.assertEqual(ctor['n_layer'].default, 8)
self.assertEqual(ctor['n_head'].default, 8)
self.assertEqual(ctor['n_emb'].default, 256)
self.assertEqual(helper['n_layer'].default, 8)
self.assertEqual(helper['n_head'].default, 8)
self.assertEqual(helper['n_emb'].default, 256)
def test_time_as_cond_false_token_accounting_matches_external(self):
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=0,
n_layer=2,
n_head=2,
n_emb=8,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=False,
obs_as_cond=False,
n_cond_layers=0,
)
torch.manual_seed(5)
with _suppress_nested_tensor_warning():
external_model = TransformerForDiffusion(**config)
local_model = Transformer1D(**config)
external_model.eval()
local_model.eval()
self.assertEqual(local_model.T, external_model.T)
self.assertEqual(local_model.T_cond, external_model.T_cond)
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=5,
n_layer=2,
n_head=2,
n_emb=8,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
obs_as_cond=True,
n_cond_layers=1,
)
torch.manual_seed(7)
with _suppress_nested_tensor_warning():
external_model = TransformerForDiffusion(**config)
local_model = Transformer1D(**config)
external_model.eval()
local_model.eval()
external_state_dict = external_model.state_dict()
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
local_model.load_state_dict(external_state_dict, strict=True)
batch_size = 2
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
timestep = torch.tensor([11, 17], dtype=torch.long)
with torch.no_grad():
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
self.assertEqual(local_out.shape, external_out.shape)
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
weight_decay = 0.123
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
self.assertEqual(len(local_groups), len(external_groups))
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
self.assertEqual(
self._optim_group_names(local_model, local_groups),
self._optim_group_names(external_model, external_groups),
)
if __name__ == '__main__':
unittest.main()