feat: add held-out validation and dual-decoder lewm imf

This commit is contained in:
Logic
2026-04-17 19:26:56 +08:00
parent 74f4963613
commit 395f5a1645
8 changed files with 693 additions and 86 deletions

View File

@@ -388,6 +388,26 @@ class _StubFutureTokenPredictor(nn.Module):
return summary.repeat(1, self.num_future_tokens, 1)
class _RecordingDirectFutureDecoder(nn.Module):
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.tensor(0.5))
self.calls = []
def forward(self, sample, r, t, cond=None):
record = {
'sample': sample.detach().clone(),
'r': r.detach().clone(),
't': t.detach().clone(),
'cond': None if cond is None else cond.detach().clone(),
}
self.calls.append(record)
cond_term = 0.0
if cond is not None:
cond_term = cond.mean(dim=1, keepdim=True)
return self.scale * sample + cond_term
class _RecordingSigReg(nn.Module):
def __init__(self, value=0.5):
super().__init__()
@@ -687,6 +707,148 @@ class IMFVLAAgentTest(unittest.TestCase):
)
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_loss_weight=1.0,
)
agent.infer_scheduler = _ForbiddenScheduler()
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(images, qpos)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
self.assertEqual(len(future_decoder.calls), 0)
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_decoder = _RecordingDirectFutureDecoder()
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
future_decoder=future_decoder,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=1.0,
)
with torch.no_grad():
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertEqual(len(head.calls), 2)
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
self.assertEqual(len(future_decoder.calls), 1)
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + metrics['lewm_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = {
@@ -1077,6 +1239,36 @@ class IMFVLAAgentTest(unittest.TestCase):
)
self.assertIsNotNone(agent.lewm_sigreg)
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_dual_decoder_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.future_decoder.n_layer=1',
'agent.future_decoder.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
self.assertEqual(
cfg.agent.future_decoder._target_,
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg(

View File

@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None:
episode_path = dataset_dir / "episode_0.hdf5"
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
root.create_dataset(
"action",
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
)
root.create_dataset(
"observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4),
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
)
root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset(
"observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
)
def test_getitem_only_resizes_observation_horizon_images(self):
@@ -100,3 +103,25 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
episode_indices=[1],
)
sample = dataset[0]
self.assertEqual(len(dataset.hdf5_files), 1)
self.assertEqual(dataset.available_episode_indices, [1])
self.assertEqual(len(dataset), 4)
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))

View File

@@ -41,6 +41,19 @@ class FakeDataset:
return 4
class SplitAwareFakeDataset(FakeDataset):
def __init__(self, episode_indices=None):
self.episode_indices = None if episode_indices is None else list(episode_indices)
if self.episode_indices is None:
self.episodes = {0: [0], 1: [1], 2: [2]}
else:
self.episodes = {idx: [idx] for idx in self.episode_indices}
@property
def available_episode_indices(self):
return sorted(self.episodes.keys())
class FakeLoader:
def __init__(self, batch):
self.batch = batch
@@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
def predict_action_chunk(self, agent_input):
self.seen_inputs.append({'predict_action_chunk': agent_input})
return torch.ones_like(agent_input['action'])
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
@@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
batch_size=2,
num_workers=0,
val_split=0.25,
val_episode_indices=None,
action_mse_val_freq_epochs=0,
seed=0,
lr=1e-3,
max_steps=2,
@@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'val_episode_indices': None,
'action_mse_val_freq_epochs': 0,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
@@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
cfg.train.save_freq = 100
cfg.train.val_split = 0.0
cfg.train.val_episode_indices = [2]
cfg.train.action_mse_val_freq_epochs = 1
agent = RecordingAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **kwargs):
if config_node is cfg.data:
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
action_value = 0.0 if shuffle else 2.0
batch = {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.full((1, 1, 2), action_value),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
return FakeLoader(batch)
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
self.assertIn('val/action_mse', logged_keys)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)