feat: add held-out validation and dual-decoder lewm imf

This commit is contained in:
Logic
2026-04-17 19:26:56 +08:00
parent 74f4963613
commit 395f5a1645
8 changed files with 693 additions and 86 deletions

View File

@@ -12,18 +12,21 @@ from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None:
episode_path = dataset_dir / "episode_0.hdf5"
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
root.create_dataset(
"action",
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
)
root.create_dataset(
"observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4),
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
)
root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset(
"observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
)
def test_getitem_only_resizes_observation_horizon_images(self):
@@ -100,3 +103,25 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
episode_indices=[1],
)
sample = dataset[0]
self.assertEqual(len(dataset.hdf5_files), 1)
self.assertEqual(dataset.available_episode_indices, [1])
self.assertEqual(len(dataset), 4)
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))