feat: add lewm-conditioned imf training and sigreg loss
This commit is contained in:
@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
|
||||
fake_cv2.resize.assert_not_called()
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_getitem_can_emit_lewm_history_and_future_observations(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[1, 2],
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
|
||||
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
|
||||
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
||||
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
||||
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
Reference in New Issue
Block a user