feat: add IMF AttnRes policy training path

This commit is contained in:
Logic
2026-04-01 23:35:31 +08:00
parent 8d6060224a
commit c2000b5533
10 changed files with 1566 additions and 11 deletions

View File

@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
]
class FakeTransformerAgent(nn.Module):
class FakeIMFAgent(nn.Module):
def __init__(self):
super().__init__()
self.head_type = 'transformer'
self.head_type = 'imf_transformer'
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
class FakeTransformerAgent(nn.Module):
def __init__(self, *, head_type='transformer'):
super().__init__()
self.head_type = head_type
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
@@ -205,6 +214,47 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
for group in optimizer.param_groups
]
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
module = self._load_train_vla_module()
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
original = module.torch.backends.cudnn.enabled
try:
module.torch.backends.cudnn.enabled = True
module._configure_cuda_runtime(cfg)
self.assertFalse(module.torch.backends.cudnn.enabled)
finally:
module.torch.backends.cudnn.enabled = original
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
module = self._load_train_vla_module()
fake_sys_path = ['/tmp/site-packages', '/another/path']
with mock.patch.object(module.sys, 'path', fake_sys_path):
repo_root = module._ensure_repo_root_on_syspath()
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
module = self._load_train_vla_module()
agent = FakeIMFAgent()
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
group_names = self._group_names(agent, optimizer)
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(group_names[1], {
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
})
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
@@ -268,6 +318,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
self.assertNotIn('frozen.weight', optimizer_names)
self.assertNotIn('frozen.bias', optimizer_names)
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent(head_type='imf')
with mock.patch.object(module, 'AdamW', RecordingAdamW):
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
grouped_names = self._group_names(agent, optimizer)
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(
grouped_names[1],
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
)
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()