feat: add lewm-conditioned imf training and sigreg loss

This commit is contained in:
Logic
2026-04-17 18:46:02 +08:00
parent ff7c9c1f2a
commit 74f4963613
14 changed files with 1634 additions and 24 deletions

View File

@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_can_emit_lewm_history_and_future_observations(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
lewm_history_horizon=3,
lewm_query_offsets=[1, 2],
)
sample = dataset[1]
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))