feat(vla): align transformer training stack and rollout validation
This commit is contained in:
262
tests/test_transformer1d_external_alignment.py
Normal file
262
tests/test_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import contextlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
|
||||
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
|
||||
_TRANSFORMER_WARNING_MESSAGE = (
|
||||
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
|
||||
r'because encoder_layer\.norm_first was True'
|
||||
)
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
if register:
|
||||
sys.modules[name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _resolve_external_module_paths(external_checkout_root: Path):
|
||||
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
|
||||
paths = {
|
||||
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
|
||||
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
|
||||
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
|
||||
}
|
||||
if not all(path.exists() for path in paths.values()):
|
||||
return None
|
||||
return paths
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_registered_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def remember(name: str) -> None:
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
|
||||
def ensure_package(name: str) -> None:
|
||||
if not name or name in sys.modules:
|
||||
return
|
||||
remember(name)
|
||||
package = types.ModuleType(name)
|
||||
package.__path__ = []
|
||||
sys.modules[name] = package
|
||||
|
||||
def load(name: str, path: Path):
|
||||
package_parts = name.split('.')[:-1]
|
||||
for idx in range(1, len(package_parts) + 1):
|
||||
ensure_package('.'.join(package_parts[:idx]))
|
||||
|
||||
remember(name)
|
||||
return _load_module_from_path(name, path, register=True)
|
||||
|
||||
try:
|
||||
yield load
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _suppress_nested_tensor_warning():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
'ignore',
|
||||
message=_TRANSFORMER_WARNING_MESSAGE,
|
||||
category=UserWarning,
|
||||
module=r'torch\.nn\.modules\.transformer',
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def _load_local_module():
|
||||
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
|
||||
|
||||
|
||||
class Transformer1DExternalAlignmentTest(unittest.TestCase):
|
||||
def _load_transformer_classes_or_skip(self):
|
||||
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||
if external_paths is None:
|
||||
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||
|
||||
local_module = _load_local_module()
|
||||
with _temporary_registered_modules() as load_external:
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
external_paths['positional_embedding'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
external_paths['module_attr_mixin'],
|
||||
)
|
||||
external_module = load_external(
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
external_paths['transformer_for_diffusion'],
|
||||
)
|
||||
|
||||
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
|
||||
|
||||
def _optim_group_names(self, model, groups):
|
||||
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||
return [
|
||||
{names_by_param[id(param)] for param in group['params']}
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def test_missing_external_checkout_resolution_returns_none(self):
|
||||
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
|
||||
|
||||
def test_external_loader_restores_injected_sys_modules(self):
|
||||
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||
if external_paths is None:
|
||||
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||
|
||||
watched_names = [
|
||||
'diffusion_policy',
|
||||
'diffusion_policy.model',
|
||||
'diffusion_policy.model.common',
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
'diffusion_policy.model.diffusion',
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
]
|
||||
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||
|
||||
with _temporary_registered_modules() as load_external:
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.positional_embedding',
|
||||
external_paths['positional_embedding'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.common.module_attr_mixin',
|
||||
external_paths['module_attr_mixin'],
|
||||
)
|
||||
load_external(
|
||||
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||
external_paths['transformer_for_diffusion'],
|
||||
)
|
||||
|
||||
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||
self.assertEqual(after, before)
|
||||
|
||||
def test_transformer1d_preserves_local_direct_call_defaults(self):
|
||||
local_module = _load_local_module()
|
||||
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
|
||||
helper = inspect.signature(local_module.create_transformer1d).parameters
|
||||
|
||||
self.assertEqual(ctor['n_layer'].default, 8)
|
||||
self.assertEqual(ctor['n_head'].default, 8)
|
||||
self.assertEqual(ctor['n_emb'].default, 256)
|
||||
self.assertEqual(helper['n_layer'].default, 8)
|
||||
self.assertEqual(helper['n_head'].default, 8)
|
||||
self.assertEqual(helper['n_emb'].default, 256)
|
||||
|
||||
def test_time_as_cond_false_token_accounting_matches_external(self):
|
||||
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
|
||||
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=0,
|
||||
n_layer=2,
|
||||
n_head=2,
|
||||
n_emb=8,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=False,
|
||||
obs_as_cond=False,
|
||||
n_cond_layers=0,
|
||||
)
|
||||
|
||||
torch.manual_seed(5)
|
||||
with _suppress_nested_tensor_warning():
|
||||
external_model = TransformerForDiffusion(**config)
|
||||
local_model = Transformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
self.assertEqual(local_model.T, external_model.T)
|
||||
self.assertEqual(local_model.T_cond, external_model.T_cond)
|
||||
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
|
||||
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
|
||||
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
|
||||
|
||||
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
|
||||
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=5,
|
||||
n_layer=2,
|
||||
n_head=2,
|
||||
n_emb=8,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=1,
|
||||
)
|
||||
|
||||
torch.manual_seed(7)
|
||||
with _suppress_nested_tensor_warning():
|
||||
external_model = TransformerForDiffusion(**config)
|
||||
local_model = Transformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
external_state_dict = external_model.state_dict()
|
||||
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||
local_model.load_state_dict(external_state_dict, strict=True)
|
||||
|
||||
batch_size = 2
|
||||
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||
timestep = torch.tensor([11, 17], dtype=torch.long)
|
||||
|
||||
with torch.no_grad():
|
||||
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
|
||||
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
|
||||
|
||||
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||
self.assertEqual(local_out.shape, external_out.shape)
|
||||
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||
|
||||
weight_decay = 0.123
|
||||
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||
|
||||
self.assertEqual(len(local_groups), len(external_groups))
|
||||
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||
self.assertEqual(
|
||||
self._optim_group_names(local_model, local_groups),
|
||||
self._optim_group_names(external_model, external_groups),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user