feat: add IMF AttnRes policy training path

This commit is contained in:
Logic
2026-04-01 23:35:31 +08:00
parent 8d6060224a
commit c2000b5533
10 changed files with 1566 additions and 11 deletions

View File

@@ -0,0 +1,196 @@
import contextlib
import importlib
import inspect
import subprocess
import sys
import types
import unittest
from pathlib import Path
import torch
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
_EXTERNAL_COMMIT = '185ed659'
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
_MISSING = object()
def _find_external_checkout_root() -> Path | None:
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
candidate = ancestor / 'diffusion_policy'
if (candidate / '.git').exists():
return candidate
return None
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
_EXTERNAL_MODULE_PATHS = {
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
}
@contextlib.contextmanager
def _temporary_registered_modules():
previous_modules = {}
def remember(name: str) -> None:
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
def ensure_package(name: str) -> None:
if not name or name in sys.modules:
return
remember(name)
package = types.ModuleType(name)
package.__path__ = []
sys.modules[name] = package
def load(name: str, source: str, origin: str):
package_parts = name.split('.')[:-1]
for idx in range(1, len(package_parts) + 1):
ensure_package('.'.join(package_parts[:idx]))
remember(name)
module = types.ModuleType(name)
module.__file__ = origin
module.__package__ = name.rpartition('.')[0]
sys.modules[name] = module
exec(compile(source, origin, 'exec'), module.__dict__)
return module
try:
yield load
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
result = subprocess.run(
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
check=True,
capture_output=True,
text=True,
)
return result.stdout
@contextlib.contextmanager
def _load_external_module_or_skip(test_case: unittest.TestCase):
if _EXTERNAL_CHECKOUT_ROOT is None:
test_case.skipTest('external diffusion_policy checkout unavailable')
try:
sources = {
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
}
except subprocess.CalledProcessError as exc:
test_case.skipTest(
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
)
with _temporary_registered_modules() as load_external:
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
load_external(
name,
sources[name],
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
)
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
def _load_local_module():
importlib.invalidate_caches()
sys.modules.pop(_LOCAL_MODULE_NAME, None)
return importlib.import_module(_LOCAL_MODULE_NAME)
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
def _optim_group_names(self, model, groups):
names_by_param = {id(param): name for name, param in model.named_parameters()}
return [
{names_by_param[id(param)] for param in group['params']}
for group in groups
]
def test_local_defaults_preserve_supported_attnres_config(self):
local_module = _load_local_module()
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
self.assertEqual(ctor['n_head'].default, 1)
self.assertEqual(ctor['n_kv_head'].default, 1)
self.assertEqual(ctor['n_cond_layers'].default, 0)
self.assertTrue(ctor['time_as_cond'].default)
self.assertFalse(ctor['causal_attn'].default)
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
local_module = _load_local_module()
with _load_external_module_or_skip(self) as external_module:
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=5,
n_layer=2,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=True,
n_cond_layers=0,
backbone_type='attnres_full',
n_kv_head=1,
)
torch.manual_seed(7)
external_model = external_module.IMFTransformerForDiffusion(**config)
local_model = local_module.IMFTransformer1D(**config)
external_model.eval()
local_model.eval()
external_state_dict = external_model.state_dict()
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
local_model.load_state_dict(external_state_dict, strict=True)
batch_size = 2
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
with torch.no_grad():
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
self.assertEqual(local_out.shape, external_out.shape)
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
weight_decay = 0.123
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
self.assertEqual(len(local_groups), len(external_groups))
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
self.assertEqual(
self._optim_group_names(local_model, local_groups),
self._optim_group_names(external_model, external_groups),
)
if __name__ == '__main__':
unittest.main()

427
tests/test_imf_vla_agent.py Normal file
View File

@@ -0,0 +1,427 @@
import contextlib
import importlib
import sys
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
_MISSING = object()
_CAMERA_NAMES = ('r_vis', 'top', 'front')
class _FakeScheduler:
def __init__(self, num_train_timesteps=100, **kwargs):
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
self.timesteps = []
def add_noise(self, sample, noise, timestep):
return sample + noise
def set_timesteps(self, num_inference_steps):
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
def step(self, noise_pred, timestep, sample):
return types.SimpleNamespace(prev_sample=sample)
class _IdentityCrop:
def __init__(self, size):
self.size = size
def __call__(self, x):
return x
class _FakeResNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
self.relu2 = nn.ReLU()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 16)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.avgpool(x)
x = torch.flatten(x, start_dim=1)
return self.fc(x)
class _FakeRearrange(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x
class _StubIMFHead(nn.Module):
def __init__(
self,
input_dim,
output_dim,
horizon,
n_obs_steps,
cond_dim,
**kwargs,
):
super().__init__()
self.constructor_kwargs = {
'input_dim': input_dim,
'output_dim': output_dim,
'horizon': horizon,
'n_obs_steps': n_obs_steps,
'cond_dim': cond_dim,
**kwargs,
}
self.proj = nn.Linear(input_dim, output_dim)
self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1))
def forward(self, sample, r, t, cond=None):
return torch.zeros_like(sample)
def get_optim_groups(self, weight_decay):
return [
{'params': [self.proj.weight], 'weight_decay': weight_decay},
{'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0},
]
@contextlib.contextmanager
def _stub_optional_modules(include_imf_head=False):
previous_modules = {}
def inject(name, module):
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
sys.modules[name] = module
diffusers_module = types.ModuleType('diffusers')
schedulers_module = types.ModuleType('diffusers.schedulers')
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
ddpm_module.DDPMScheduler = _FakeScheduler
ddim_module.DDIMScheduler = _FakeScheduler
diffusers_module.DDPMScheduler = _FakeScheduler
diffusers_module.DDIMScheduler = _FakeScheduler
diffusers_module.schedulers = schedulers_module
schedulers_module.scheduling_ddpm = ddpm_module
schedulers_module.scheduling_ddim = ddim_module
torchvision_module = types.ModuleType('torchvision')
models_module = types.ModuleType('torchvision.models')
transforms_module = types.ModuleType('torchvision.transforms')
models_module.resnet18 = lambda weights=None: _FakeResNet()
transforms_module.CenterCrop = _IdentityCrop
transforms_module.RandomCrop = _IdentityCrop
torchvision_module.models = models_module
torchvision_module.transforms = transforms_module
einops_module = types.ModuleType('einops')
einops_module.rearrange = lambda x, *args, **kwargs: x
einops_layers_module = types.ModuleType('einops.layers')
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
einops_layers_torch_module.Rearrange = _FakeRearrange
einops_module.layers = einops_layers_module
einops_layers_module.torch = einops_layers_torch_module
try:
inject('diffusers', diffusers_module)
inject('diffusers.schedulers', schedulers_module)
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
inject('torchvision', torchvision_module)
inject('torchvision.models', models_module)
inject('torchvision.transforms', transforms_module)
inject('einops', einops_module)
inject('einops.layers', einops_layers_module)
inject('einops.layers.torch', einops_layers_torch_module)
if include_imf_head:
import roboimi.vla.models.heads as heads_package
imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d')
imf_head_module.IMFTransformer1D = _StubIMFHead
inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module)
setattr(heads_package, 'imf_transformer1d', imf_head_module)
yield
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
def _compose_cfg(overrides=None):
if not OmegaConf.has_resolver('len'):
OmegaConf.register_new_resolver('len', lambda x: len(x))
GlobalHydra.instance().clear()
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
return compose(config_name='config', overrides=list(overrides or []))
def _load_imf_agent_class():
with _stub_optional_modules():
sys.modules.pop('roboimi.vla.agent_imf', None)
module = importlib.import_module('roboimi.vla.agent_imf')
return module.IMFVLAAgent, module
class _StubVisionBackbone(nn.Module):
output_dim = 1
def __init__(self, camera_names=_CAMERA_NAMES):
super().__init__()
self.camera_names = tuple(camera_names)
self.num_cameras = len(self.camera_names)
def forward(self, images):
per_camera_features = []
for camera_name in self.camera_names:
image_batch = images[camera_name]
per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
return torch.cat(per_camera_features, dim=-1)
class _RecordingLinearIMFHead(nn.Module):
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.tensor(0.5))
self.calls = []
@staticmethod
def _broadcast_batch_time(value, reference):
while value.ndim < reference.ndim:
value = value.unsqueeze(-1)
return value
def forward(self, sample, r, t, cond=None):
record = {
'sample': sample.detach().clone(),
'r': r.detach().clone(),
't': t.detach().clone(),
'cond': None if cond is None else cond.detach().clone(),
}
self.calls.append(record)
cond_term = 0.0
if cond is not None:
cond_term = cond.mean(dim=(1, 2), keepdim=True)
r_b = self._broadcast_batch_time(r, sample)
t_b = self._broadcast_batch_time(t, sample)
return self.scale * sample + r_b + 2.0 * t_b + cond_term
class _ForbiddenScheduler:
def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression
raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps')
def step(self, *args, **kwargs): # pragma: no cover - only runs on regression
raise AssertionError('IMF inference should not use DDIM scheduler step')
def _make_images(batch_size, obs_horizon, per_camera_fill):
return {
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
for name, value in per_camera_fill.items()
}
class IMFVLAAgentTest(unittest.TestCase):
def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=pred_horizon,
obs_horizon=obs_horizon,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=num_action_steps,
head_type='transformer',
)
return agent, head, agent_module
def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self):
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
action_is_pad = torch.tensor([[False, False, True]])
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'action_is_pad': action_is_pad,
}
)
cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32)
cond_term = cond.mean(dim=(1, 2), keepdim=True)
t = t_sample
r = r_sample
z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise
scale = head.scale.detach()
u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term
v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term
du_dt = scale * v + 2.0
compound_velocity = u + (t - r).view(1, 1, 1) * du_dt
target = noise - actions
elementwise_loss = (compound_velocity - target) ** 2
mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype)
expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1])
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
self.assertEqual(len(head.calls), 2)
self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample))
self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample))
self.assertTrue(torch.allclose(head.calls[0]['cond'], cond))
def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self):
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
agent.infer_scheduler = _ForbiddenScheduler()
images = _make_images(
batch_size=2,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor(
[
[[1.0], [2.0]],
[[3.0], [4.0]],
],
dtype=torch.float32,
)
initial_noise = torch.tensor(
[
[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]],
[[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]],
],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
predicted_actions = agent.predict_action(images, qpos)
expected_cond = torch.tensor(
[
[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]],
[[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]],
],
dtype=torch.float32,
)
cond_term = expected_cond.mean(dim=(1, 2), keepdim=True)
expected_actions = 0.5 * initial_noise - 2.0 - cond_term
self.assertEqual(predicted_actions.shape, (2, 3, 2))
self.assertTrue(torch.allclose(predicted_actions, expected_actions))
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2)))
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = {
'qpos': torch.tensor([0.25], dtype=torch.float32),
'images': {
'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32),
'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32),
'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32),
},
}
first_chunk = torch.tensor(
[[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]],
dtype=torch.float32,
)
second_chunk = torch.tensor(
[[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk:
first_action = agent.select_action(observation)
second_action = agent.select_action(observation)
third_action = agent.select_action(observation)
self.assertTrue(torch.equal(first_action, first_chunk[0, 1]))
self.assertTrue(torch.equal(second_action, first_chunk[0, 2]))
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
self.assertEqual(mock_predict_chunk.call_count, 2)
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.vision_backbone.freeze_backbone=false',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D')
self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full')
self.assertEqual(cfg.agent.head.n_head, 1)
self.assertEqual(cfg.agent.head.n_kv_head, 1)
self.assertEqual(cfg.agent.head.n_cond_layers, 0)
self.assertTrue(cfg.agent.head.time_as_cond)
self.assertFalse(cfg.agent.head.causal_attn)
self.assertEqual(cfg.agent.inference_steps, 1)
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.head_type, 'transformer')
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim)
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
if __name__ == '__main__':
unittest.main()

View File

@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
]
class FakeTransformerAgent(nn.Module):
class FakeIMFAgent(nn.Module):
def __init__(self):
super().__init__()
self.head_type = 'transformer'
self.head_type = 'imf_transformer'
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
class FakeTransformerAgent(nn.Module):
def __init__(self, *, head_type='transformer'):
super().__init__()
self.head_type = head_type
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
@@ -205,6 +214,47 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
for group in optimizer.param_groups
]
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
module = self._load_train_vla_module()
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
original = module.torch.backends.cudnn.enabled
try:
module.torch.backends.cudnn.enabled = True
module._configure_cuda_runtime(cfg)
self.assertFalse(module.torch.backends.cudnn.enabled)
finally:
module.torch.backends.cudnn.enabled = original
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
module = self._load_train_vla_module()
fake_sys_path = ['/tmp/site-packages', '/another/path']
with mock.patch.object(module.sys, 'path', fake_sys_path):
repo_root = module._ensure_repo_root_on_syspath()
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
module = self._load_train_vla_module()
agent = FakeIMFAgent()
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
group_names = self._group_names(agent, optimizer)
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(group_names[1], {
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
})
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
@@ -268,6 +318,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
self.assertNotIn('frozen.weight', optimizer_names)
self.assertNotIn('frozen.bias', optimizer_names)
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent(head_type='imf')
with mock.patch.object(module, 'AdamW', RecordingAdamW):
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
grouped_names = self._group_names(agent, optimizer)
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(
grouped_names[1],
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
)
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()