feat: add IMF AttnRes policy training path
This commit is contained in:
@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
|
||||
]
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
class FakeIMFAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.head_type = 'transformer'
|
||||
self.head_type = 'imf_transformer'
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
def __init__(self, *, head_type='transformer'):
|
||||
super().__init__()
|
||||
self.head_type = head_type
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
@@ -205,6 +214,47 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
||||
module = self._load_train_vla_module()
|
||||
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
||||
|
||||
original = module.torch.backends.cudnn.enabled
|
||||
try:
|
||||
module.torch.backends.cudnn.enabled = True
|
||||
module._configure_cuda_runtime(cfg)
|
||||
self.assertFalse(module.torch.backends.cudnn.enabled)
|
||||
finally:
|
||||
module.torch.backends.cudnn.enabled = original
|
||||
|
||||
|
||||
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
fake_sys_path = ['/tmp/site-packages', '/another/path']
|
||||
with mock.patch.object(module.sys, 'path', fake_sys_path):
|
||||
repo_root = module._ensure_repo_root_on_syspath()
|
||||
|
||||
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
|
||||
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
|
||||
|
||||
|
||||
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeIMFAgent()
|
||||
|
||||
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||
group_names = self._group_names(agent, optimizer)
|
||||
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(group_names[1], {
|
||||
'noise_pred_net.proj.bias',
|
||||
'noise_pred_net.norm.weight',
|
||||
'noise_pred_net.norm.bias',
|
||||
})
|
||||
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||
|
||||
|
||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
@@ -268,6 +318,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
self.assertNotIn('frozen.weight', optimizer_names)
|
||||
self.assertNotIn('frozen.bias', optimizer_names)
|
||||
|
||||
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent(head_type='imf')
|
||||
|
||||
with mock.patch.object(module, 'AdamW', RecordingAdamW):
|
||||
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||
grouped_names = self._group_names(agent, optimizer)
|
||||
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(
|
||||
grouped_names[1],
|
||||
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
|
||||
)
|
||||
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||
|
||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
|
||||
Reference in New Issue
Block a user