feat: add IMF AttnRes policy training path
This commit is contained in:
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
|
||||
_EXTERNAL_COMMIT = '185ed659'
|
||||
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _find_external_checkout_root() -> Path | None:
|
||||
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
|
||||
candidate = ancestor / 'diffusion_policy'
|
||||
if (candidate / '.git').exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
|
||||
_EXTERNAL_MODULE_PATHS = {
|
||||
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
|
||||
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
|
||||
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
|
||||
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_registered_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def remember(name: str) -> None:
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
|
||||
def ensure_package(name: str) -> None:
|
||||
if not name or name in sys.modules:
|
||||
return
|
||||
remember(name)
|
||||
package = types.ModuleType(name)
|
||||
package.__path__ = []
|
||||
sys.modules[name] = package
|
||||
|
||||
def load(name: str, source: str, origin: str):
|
||||
package_parts = name.split('.')[:-1]
|
||||
for idx in range(1, len(package_parts) + 1):
|
||||
ensure_package('.'.join(package_parts[:idx]))
|
||||
|
||||
remember(name)
|
||||
module = types.ModuleType(name)
|
||||
module.__file__ = origin
|
||||
module.__package__ = name.rpartition('.')[0]
|
||||
sys.modules[name] = module
|
||||
exec(compile(source, origin, 'exec'), module.__dict__)
|
||||
return module
|
||||
|
||||
try:
|
||||
yield load
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
|
||||
result = subprocess.run(
|
||||
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _load_external_module_or_skip(test_case: unittest.TestCase):
|
||||
if _EXTERNAL_CHECKOUT_ROOT is None:
|
||||
test_case.skipTest('external diffusion_policy checkout unavailable')
|
||||
|
||||
try:
|
||||
sources = {
|
||||
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
|
||||
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
|
||||
}
|
||||
except subprocess.CalledProcessError as exc:
|
||||
test_case.skipTest(
|
||||
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
|
||||
)
|
||||
|
||||
with _temporary_registered_modules() as load_external:
|
||||
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
|
||||
load_external(
|
||||
name,
|
||||
sources[name],
|
||||
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
|
||||
)
|
||||
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
|
||||
|
||||
|
||||
def _load_local_module():
|
||||
importlib.invalidate_caches()
|
||||
sys.modules.pop(_LOCAL_MODULE_NAME, None)
|
||||
return importlib.import_module(_LOCAL_MODULE_NAME)
|
||||
|
||||
|
||||
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
|
||||
def _optim_group_names(self, model, groups):
|
||||
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||
return [
|
||||
{names_by_param[id(param)] for param in group['params']}
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def test_local_defaults_preserve_supported_attnres_config(self):
|
||||
local_module = _load_local_module()
|
||||
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
|
||||
|
||||
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
|
||||
self.assertEqual(ctor['n_head'].default, 1)
|
||||
self.assertEqual(ctor['n_kv_head'].default, 1)
|
||||
self.assertEqual(ctor['n_cond_layers'].default, 0)
|
||||
self.assertTrue(ctor['time_as_cond'].default)
|
||||
self.assertFalse(ctor['causal_attn'].default)
|
||||
|
||||
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
|
||||
local_module = _load_local_module()
|
||||
with _load_external_module_or_skip(self) as external_module:
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=5,
|
||||
n_layer=2,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
backbone_type='attnres_full',
|
||||
n_kv_head=1,
|
||||
)
|
||||
|
||||
torch.manual_seed(7)
|
||||
external_model = external_module.IMFTransformerForDiffusion(**config)
|
||||
local_model = local_module.IMFTransformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
external_state_dict = external_model.state_dict()
|
||||
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||
local_model.load_state_dict(external_state_dict, strict=True)
|
||||
|
||||
batch_size = 2
|
||||
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
|
||||
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
|
||||
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||
|
||||
with torch.no_grad():
|
||||
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
||||
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
||||
|
||||
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||
self.assertEqual(local_out.shape, external_out.shape)
|
||||
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||
|
||||
weight_decay = 0.123
|
||||
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||
|
||||
self.assertEqual(len(local_groups), len(external_groups))
|
||||
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||
self.assertEqual(
|
||||
self._optim_group_names(local_model, local_groups),
|
||||
self._optim_group_names(external_model, external_groups),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user