merge: imf attnres policy
# Conflicts: # roboimi/demos/vla_scripts/eval_vla.py # roboimi/envs/double_base.py
This commit is contained in:
26
tests/test_attnres_resnet2d_backbone.py
Normal file
26
tests/test_attnres_resnet2d_backbone.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
||||
def test_backbone_preserves_resnet_like_stage_contract(self):
|
||||
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||
|
||||
backbone = AttnResResNetLikeBackbone2D(
|
||||
input_channels=3,
|
||||
stem_dim=16,
|
||||
stage_dims=(16, 32, 64, 128),
|
||||
stage_depths=(1, 1, 1, 1),
|
||||
stage_heads=(2, 4, 4, 8),
|
||||
stage_kv_heads=(1, 1, 1, 1),
|
||||
stage_window_sizes=(7, 7, 7, 7),
|
||||
dropout=0.0,
|
||||
)
|
||||
x = torch.randn(2, 3, 56, 56)
|
||||
y = backbone(x)
|
||||
self.assertEqual(y.shape, (2, 128, 2, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -90,6 +90,36 @@ class _FakeRenderer:
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_prepare_observation_skips_resize_when_image_resize_shape_is_none(self):
|
||||
obs = {
|
||||
"images": {
|
||||
"front": np.arange(8 * 8 * 3, dtype=np.uint8).reshape(8, 8, 3),
|
||||
},
|
||||
"qpos": np.zeros(16, dtype=np.float32),
|
||||
}
|
||||
|
||||
with mock.patch("cv2.resize", side_effect=AssertionError("resize should be skipped")):
|
||||
prepared = eval_vla.prepare_observation(
|
||||
obs,
|
||||
["front"],
|
||||
image_resize_shape=None,
|
||||
)
|
||||
|
||||
self.assertEqual(tuple(prepared["images"]["front"].shape), (3, 8, 8))
|
||||
self.assertEqual(tuple(prepared["qpos"].shape), (16,))
|
||||
|
||||
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
|
||||
|
||||
def test_headless_eval_preserves_existing_mujoco_gl(self):
|
||||
cfg = OmegaConf.create({"eval": {"headless": True}})
|
||||
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
|
||||
eval_vla._configure_headless_mujoco_gl(cfg.eval)
|
||||
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
|
||||
|
||||
def test_eval_config_exposes_headless_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
@@ -117,6 +147,49 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
cam_view="left_side",
|
||||
)
|
||||
|
||||
def test_headless_sync_camera_capture_populates_images_without_gui_calls(self):
|
||||
env = DualDianaMed.__new__(DualDianaMed)
|
||||
env.mj_model = object()
|
||||
env.mj_data = object()
|
||||
env.exit_flag = False
|
||||
env.is_render = False
|
||||
env.cam = 'angle'
|
||||
env.r_vis = None
|
||||
env.l_vis = None
|
||||
env.top = None
|
||||
env.angle = None
|
||||
env.front = None
|
||||
env._offscreen_renderer = None
|
||||
|
||||
with mock.patch(
|
||||
'roboimi.envs.double_base.mj.Renderer',
|
||||
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
|
||||
) as renderer_cls, mock.patch('roboimi.envs.double_base.cv2.namedWindow') as named_window, mock.patch(
|
||||
'roboimi.envs.double_base.cv2.imshow'
|
||||
) as imshow, mock.patch('roboimi.envs.double_base.cv2.waitKey') as wait_key:
|
||||
env._update_camera_images_sync()
|
||||
|
||||
renderer_cls.assert_called_once()
|
||||
named_window.assert_not_called()
|
||||
imshow.assert_not_called()
|
||||
wait_key.assert_not_called()
|
||||
self.assertIsNotNone(env.r_vis)
|
||||
self.assertIsNotNone(env.l_vis)
|
||||
self.assertIsNotNone(env.top)
|
||||
self.assertIsNotNone(env.angle)
|
||||
self.assertIsNotNone(env.front)
|
||||
|
||||
def test_cam_start_skips_background_thread_when_headless(self):
|
||||
env = DualDianaMed.__new__(DualDianaMed)
|
||||
env.is_render = False
|
||||
env.cam_thread = None
|
||||
|
||||
with mock.patch('roboimi.envs.double_base.threading.Thread') as thread_cls:
|
||||
env.cam_start()
|
||||
|
||||
thread_cls.assert_not_called()
|
||||
self.assertIsNone(env.cam_thread)
|
||||
|
||||
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||
env = DualDianaMed.__new__(DualDianaMed)
|
||||
env.mj_model = object()
|
||||
|
||||
26
tests/test_eval_vla_headless_import.py
Normal file
26
tests/test_eval_vla_headless_import.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_eval_vla_import_does_not_import_mujoco_early_when_headless_backend_not_set():
|
||||
env = os.environ.copy()
|
||||
env.pop('MUJOCO_GL', None)
|
||||
proc = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
'-c',
|
||||
(
|
||||
'import json, sys; '
|
||||
'from roboimi.demos.vla_scripts import eval_vla; '
|
||||
'print(json.dumps({"mujoco_in_sys_modules": "mujoco" in sys.modules}))'
|
||||
),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
payload = json.loads(proc.stdout.strip())
|
||||
assert payload['mujoco_in_sys_modules'] is False
|
||||
@@ -102,8 +102,10 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
self.assertIn('artifact_dir', eval_cfg)
|
||||
self.assertFalse(eval_cfg.save_summary_json)
|
||||
self.assertFalse(eval_cfg.save_trajectory_npz)
|
||||
self.assertFalse(eval_cfg.save_trajectory_image)
|
||||
self.assertFalse(eval_cfg.record_video)
|
||||
self.assertIsNone(eval_cfg.artifact_dir)
|
||||
self.assertIsNone(eval_cfg.trajectory_image_camera_name)
|
||||
self.assertIsNone(eval_cfg.video_camera_name)
|
||||
self.assertEqual(eval_cfg.video_fps, 30)
|
||||
|
||||
@@ -133,6 +135,8 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
'artifact_dir': tmpdir,
|
||||
'save_summary_json': True,
|
||||
'save_trajectory_npz': True,
|
||||
'save_trajectory_image': True,
|
||||
'trajectory_image_camera_name': 'front',
|
||||
'record_video': True,
|
||||
'video_camera_name': 'front',
|
||||
'video_fps': 12,
|
||||
@@ -176,12 +180,14 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
trajectory_path = Path(artifacts['trajectory_npz'])
|
||||
summary_path = Path(artifacts['summary_json'])
|
||||
video_path = Path(artifacts['video_mp4'])
|
||||
trajectory_image_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||
|
||||
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
||||
self.assertEqual(artifacts['video_camera_name'], 'front')
|
||||
self.assertTrue(trajectory_path.exists())
|
||||
self.assertTrue(summary_path.exists())
|
||||
self.assertTrue(video_path.exists())
|
||||
self.assertTrue(trajectory_image_path.exists())
|
||||
|
||||
rollout_npz = np.load(trajectory_path)
|
||||
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
||||
@@ -218,11 +224,121 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
saved_summary = json.load(fh)
|
||||
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
||||
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
||||
self.assertEqual(
|
||||
saved_summary['episodes'][0]['artifact_paths']['trajectory_image'],
|
||||
str(trajectory_image_path),
|
||||
)
|
||||
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
||||
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
||||
self.assertIn('avg_obs_read_time_ms', summary)
|
||||
self.assertIn('avg_env_step_time_ms', summary)
|
||||
|
||||
def test_run_eval_exports_front_trajectory_images_without_video_dependency(self):
|
||||
actions = [
|
||||
np.arange(16, dtype=np.float32),
|
||||
np.arange(16, dtype=np.float32) + 10.0,
|
||||
np.arange(16, dtype=np.float32) + 100.0,
|
||||
np.arange(16, dtype=np.float32) + 110.0,
|
||||
]
|
||||
fake_agent = _FakeAgent(actions)
|
||||
fake_env = _FakeEnv()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {},
|
||||
'eval': {
|
||||
'ckpt_path': 'checkpoints/vla_model_best.pt',
|
||||
'num_episodes': 2,
|
||||
'max_timesteps': 2,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['top', 'front'],
|
||||
'use_smoothing': True,
|
||||
'smooth_alpha': 0.5,
|
||||
'verbose_action': False,
|
||||
'headless': True,
|
||||
'artifact_dir': tmpdir,
|
||||
'save_trajectory_image': True,
|
||||
'record_video': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
trajectory_image_calls = []
|
||||
|
||||
def fake_save_rollout_trajectory_image(
|
||||
env,
|
||||
output_path,
|
||||
raw_actions,
|
||||
camera_name,
|
||||
*,
|
||||
line_radius=0.004,
|
||||
max_markers=1500,
|
||||
):
|
||||
del env, line_radius, max_markers
|
||||
trajectory_image_calls.append(
|
||||
{
|
||||
'output_path': output_path,
|
||||
'camera_name': camera_name,
|
||||
'raw_actions': [np.array(action, copy=True) for action in raw_actions],
|
||||
}
|
||||
)
|
||||
if output_path is None:
|
||||
return None
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_bytes(b'fake-png')
|
||||
return str(output_path)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
'load_checkpoint',
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'make_sim_env',
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'sample_transfer_pose',
|
||||
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'tqdm',
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'_save_rollout_trajectory_image',
|
||||
side_effect=fake_save_rollout_trajectory_image,
|
||||
) as save_trajectory_image_mock, mock.patch.object(
|
||||
eval_vla,
|
||||
'_open_video_writer',
|
||||
) as open_video_writer_mock:
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(save_trajectory_image_mock.call_count, 2)
|
||||
open_video_writer_mock.assert_not_called()
|
||||
self.assertIsNone(summary['artifacts']['video_mp4'])
|
||||
self.assertEqual(summary['artifacts']['trajectory_image_camera_name'], 'front')
|
||||
self.assertEqual(
|
||||
[call['camera_name'] for call in trajectory_image_calls],
|
||||
['front', 'front'],
|
||||
)
|
||||
|
||||
first_episode_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
|
||||
second_episode_path = Path(summary['episodes'][1]['artifact_paths']['trajectory_image'])
|
||||
self.assertTrue(first_episode_path.exists())
|
||||
self.assertTrue(second_episode_path.exists())
|
||||
self.assertNotEqual(first_episode_path, second_episode_path)
|
||||
self.assertEqual(first_episode_path.parent, Path(tmpdir))
|
||||
self.assertEqual(second_episode_path.parent, Path(tmpdir))
|
||||
|
||||
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][0], actions[0])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][1], actions[1])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][0], actions[2])
|
||||
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][1], actions[3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
|
||||
_EXTERNAL_COMMIT = '185ed659'
|
||||
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _find_external_checkout_root() -> Path | None:
|
||||
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
|
||||
candidate = ancestor / 'diffusion_policy'
|
||||
if (candidate / '.git').exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
|
||||
_EXTERNAL_MODULE_PATHS = {
|
||||
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
|
||||
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
|
||||
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
|
||||
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_registered_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def remember(name: str) -> None:
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
|
||||
def ensure_package(name: str) -> None:
|
||||
if not name or name in sys.modules:
|
||||
return
|
||||
remember(name)
|
||||
package = types.ModuleType(name)
|
||||
package.__path__ = []
|
||||
sys.modules[name] = package
|
||||
|
||||
def load(name: str, source: str, origin: str):
|
||||
package_parts = name.split('.')[:-1]
|
||||
for idx in range(1, len(package_parts) + 1):
|
||||
ensure_package('.'.join(package_parts[:idx]))
|
||||
|
||||
remember(name)
|
||||
module = types.ModuleType(name)
|
||||
module.__file__ = origin
|
||||
module.__package__ = name.rpartition('.')[0]
|
||||
sys.modules[name] = module
|
||||
exec(compile(source, origin, 'exec'), module.__dict__)
|
||||
return module
|
||||
|
||||
try:
|
||||
yield load
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
|
||||
result = subprocess.run(
|
||||
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _load_external_module_or_skip(test_case: unittest.TestCase):
|
||||
if _EXTERNAL_CHECKOUT_ROOT is None:
|
||||
test_case.skipTest('external diffusion_policy checkout unavailable')
|
||||
|
||||
try:
|
||||
sources = {
|
||||
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
|
||||
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
|
||||
}
|
||||
except subprocess.CalledProcessError as exc:
|
||||
test_case.skipTest(
|
||||
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
|
||||
)
|
||||
|
||||
with _temporary_registered_modules() as load_external:
|
||||
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
|
||||
load_external(
|
||||
name,
|
||||
sources[name],
|
||||
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
|
||||
)
|
||||
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
|
||||
|
||||
|
||||
def _load_local_module():
|
||||
importlib.invalidate_caches()
|
||||
sys.modules.pop(_LOCAL_MODULE_NAME, None)
|
||||
return importlib.import_module(_LOCAL_MODULE_NAME)
|
||||
|
||||
|
||||
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
|
||||
def _optim_group_names(self, model, groups):
|
||||
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||
return [
|
||||
{names_by_param[id(param)] for param in group['params']}
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def test_local_defaults_preserve_supported_attnres_config(self):
|
||||
local_module = _load_local_module()
|
||||
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
|
||||
|
||||
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
|
||||
self.assertEqual(ctor['n_head'].default, 1)
|
||||
self.assertEqual(ctor['n_kv_head'].default, 1)
|
||||
self.assertEqual(ctor['n_cond_layers'].default, 0)
|
||||
self.assertTrue(ctor['time_as_cond'].default)
|
||||
self.assertFalse(ctor['causal_attn'].default)
|
||||
|
||||
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
|
||||
local_module = _load_local_module()
|
||||
with _load_external_module_or_skip(self) as external_module:
|
||||
config = dict(
|
||||
input_dim=4,
|
||||
output_dim=4,
|
||||
horizon=6,
|
||||
n_obs_steps=3,
|
||||
cond_dim=5,
|
||||
n_layer=2,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=False,
|
||||
time_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
backbone_type='attnres_full',
|
||||
n_kv_head=1,
|
||||
)
|
||||
|
||||
torch.manual_seed(7)
|
||||
external_model = external_module.IMFTransformerForDiffusion(**config)
|
||||
local_model = local_module.IMFTransformer1D(**config)
|
||||
external_model.eval()
|
||||
local_model.eval()
|
||||
|
||||
external_state_dict = external_model.state_dict()
|
||||
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||
local_model.load_state_dict(external_state_dict, strict=True)
|
||||
|
||||
batch_size = 2
|
||||
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
|
||||
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
|
||||
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||
|
||||
with torch.no_grad():
|
||||
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
||||
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
||||
|
||||
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||
self.assertEqual(local_out.shape, external_out.shape)
|
||||
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||
|
||||
weight_decay = 0.123
|
||||
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||
|
||||
self.assertEqual(len(local_groups), len(external_groups))
|
||||
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||
self.assertEqual(
|
||||
self._optim_group_names(local_model, local_groups),
|
||||
self._optim_group_names(external_model, external_groups),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
889
tests/test_imf_vla_agent.py
Normal file
889
tests/test_imf_vla_agent.py
Normal file
@@ -0,0 +1,889 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||
_MISSING = object()
|
||||
_CAMERA_NAMES = ('r_vis', 'top', 'front')
|
||||
|
||||
|
||||
class _FakeScheduler:
|
||||
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||
self.timesteps = []
|
||||
|
||||
def add_noise(self, sample, noise, timestep):
|
||||
return sample + noise
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||
|
||||
def step(self, noise_pred, timestep, sample):
|
||||
return types.SimpleNamespace(prev_sample=sample)
|
||||
|
||||
|
||||
class _IdentityCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeResNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.conv1(x))
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class _FakeRearrange(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeViTConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _FakeViTModel(nn.Module):
|
||||
def __init__(self, config, add_pooling_layer=False):
|
||||
super().__init__()
|
||||
del add_pooling_layer
|
||||
self.config = config
|
||||
hidden_size = int(getattr(config, 'hidden_size', 192))
|
||||
self.proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
|
||||
del interpolate_pos_encoding, kwargs
|
||||
batch_size = pixel_values.shape[0]
|
||||
hidden_size = int(getattr(self.config, 'hidden_size', 192))
|
||||
seq_len = 2
|
||||
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
load_calls = []
|
||||
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = cls()
|
||||
cls.load_calls.append({
|
||||
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
})
|
||||
return model
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
'pixel_values': pixel_values.detach().clone(),
|
||||
'kwargs': dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class _StubIMFHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
horizon,
|
||||
n_obs_steps,
|
||||
cond_dim,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.constructor_kwargs = {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'horizon': horizon,
|
||||
'n_obs_steps': n_obs_steps,
|
||||
'cond_dim': cond_dim,
|
||||
**kwargs,
|
||||
}
|
||||
self.proj = nn.Linear(input_dim, output_dim)
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1))
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
return torch.zeros_like(sample)
|
||||
|
||||
def get_optim_groups(self, weight_decay):
|
||||
return [
|
||||
{'params': [self.proj.weight], 'weight_decay': weight_decay},
|
||||
{'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0},
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _stub_optional_modules(include_imf_head=False):
|
||||
previous_modules = {}
|
||||
|
||||
def remember_and_remove(name):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
def inject(name, module):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules[name] = module
|
||||
|
||||
diffusers_module = types.ModuleType('diffusers')
|
||||
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||
ddim_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.schedulers = schedulers_module
|
||||
schedulers_module.scheduling_ddpm = ddpm_module
|
||||
schedulers_module.scheduling_ddim = ddim_module
|
||||
|
||||
torchvision_module = types.ModuleType('torchvision')
|
||||
models_module = types.ModuleType('torchvision.models')
|
||||
transforms_module = types.ModuleType('torchvision.transforms')
|
||||
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
|
||||
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
|
||||
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
|
||||
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||
transforms_module.CenterCrop = _IdentityCrop
|
||||
transforms_module.RandomCrop = _IdentityCrop
|
||||
torchvision_module.models = models_module
|
||||
torchvision_module.transforms = transforms_module
|
||||
|
||||
einops_module = types.ModuleType('einops')
|
||||
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||
einops_layers_module = types.ModuleType('einops.layers')
|
||||
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||
einops_module.layers = einops_layers_module
|
||||
einops_layers_module.torch = einops_layers_torch_module
|
||||
|
||||
transformers_module = types.ModuleType('transformers')
|
||||
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
|
||||
transformers_module.ViTConfig = _FakeViTConfig
|
||||
transformers_module.ViTModel = _FakeViTModel
|
||||
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
|
||||
|
||||
try:
|
||||
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
|
||||
inject('diffusers', diffusers_module)
|
||||
inject('diffusers.schedulers', schedulers_module)
|
||||
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||
inject('torchvision', torchvision_module)
|
||||
inject('torchvision.models', models_module)
|
||||
inject('torchvision.transforms', transforms_module)
|
||||
inject('einops', einops_module)
|
||||
inject('einops.layers', einops_layers_module)
|
||||
inject('einops.layers.torch', einops_layers_torch_module)
|
||||
inject('transformers', transformers_module)
|
||||
|
||||
if include_imf_head:
|
||||
import roboimi.vla.models.heads as heads_package
|
||||
|
||||
imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d')
|
||||
imf_head_module.IMFTransformer1D = _StubIMFHead
|
||||
inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module)
|
||||
setattr(heads_package, 'imf_transformer1d', imf_head_module)
|
||||
|
||||
yield
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
def _compose_cfg(overrides=None):
|
||||
if not OmegaConf.has_resolver('len'):
|
||||
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||
|
||||
GlobalHydra.instance().clear()
|
||||
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||
return compose(config_name='config', overrides=list(overrides or []))
|
||||
|
||||
|
||||
def _load_imf_agent_class():
|
||||
with _stub_optional_modules():
|
||||
sys.modules.pop('roboimi.vla.agent_imf', None)
|
||||
module = importlib.import_module('roboimi.vla.agent_imf')
|
||||
return module.IMFVLAAgent, module
|
||||
|
||||
|
||||
class _StubVisionBackbone(nn.Module):
|
||||
output_dim = 1
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
per_camera_features = []
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||
return torch.cat(per_camera_features, dim=-1)
|
||||
|
||||
|
||||
class _StubJointVisionBackbone(nn.Module):
|
||||
joint_output_dim = 5
|
||||
output_dim = 5
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
for camera_name in ('front', 'top', 'r_vis'):
|
||||
image_batch = images[camera_name]
|
||||
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||
joint_features = torch.cat(features, dim=-1)
|
||||
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
|
||||
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
|
||||
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
|
||||
time_marker = time_marker.expand(batch_size, -1, -1)
|
||||
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
|
||||
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
|
||||
return torch.stack(features, dim=2)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
per_camera = []
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
|
||||
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
|
||||
return torch.stack(per_camera, dim=2)
|
||||
|
||||
|
||||
class _RecordingLinearIMFHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(0.5))
|
||||
self.calls = []
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value, reference):
|
||||
while value.ndim < reference.ndim:
|
||||
value = value.unsqueeze(-1)
|
||||
return value
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
record = {
|
||||
'sample': sample.detach().clone(),
|
||||
'r': r.detach().clone(),
|
||||
't': t.detach().clone(),
|
||||
'cond': None if cond is None else cond.detach().clone(),
|
||||
}
|
||||
self.calls.append(record)
|
||||
cond_term = 0.0
|
||||
if cond is not None:
|
||||
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||
r_b = self._broadcast_batch_time(r, sample)
|
||||
t_b = self._broadcast_batch_time(t, sample)
|
||||
return self.scale * sample + r_b + 2.0 * t_b + cond_term
|
||||
|
||||
|
||||
class _ForbiddenScheduler:
|
||||
def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||
raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps')
|
||||
|
||||
def step(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||
return {
|
||||
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||
for name, value in per_camera_fill.items()
|
||||
}
|
||||
|
||||
|
||||
class IMFVLAAgentTest(unittest.TestCase):
|
||||
def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=pred_horizon,
|
||||
obs_horizon=obs_horizon,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=num_action_steps,
|
||||
head_type='transformer',
|
||||
)
|
||||
return agent, head, agent_module
|
||||
|
||||
def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self):
|
||||
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||
actions = torch.tensor(
|
||||
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_is_pad = torch.tensor([[False, False, True]])
|
||||
noise = torch.tensor(
|
||||
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||
loss = agent.compute_loss(
|
||||
{
|
||||
'images': images,
|
||||
'qpos': qpos,
|
||||
'action': actions,
|
||||
'action_is_pad': action_is_pad,
|
||||
}
|
||||
)
|
||||
|
||||
cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32)
|
||||
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||
t = t_sample
|
||||
r = r_sample
|
||||
z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise
|
||||
scale = head.scale.detach()
|
||||
u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term
|
||||
v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term
|
||||
du_dt = scale * v + 2.0
|
||||
compound_velocity = u + (t - r).view(1, 1, 1) * du_dt
|
||||
target = noise - actions
|
||||
elementwise_loss = (compound_velocity - target) ** 2
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype)
|
||||
expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1])
|
||||
|
||||
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||
self.assertEqual(len(head.calls), 2)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], cond))
|
||||
|
||||
def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self):
|
||||
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=2,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor(
|
||||
[
|
||||
[[1.0], [2.0]],
|
||||
[[3.0], [4.0]],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
initial_noise = torch.tensor(
|
||||
[
|
||||
[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]],
|
||||
[[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
expected_cond = torch.tensor(
|
||||
[
|
||||
[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]],
|
||||
[[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
cond_term = expected_cond.mean(dim=(1, 2), keepdim=True)
|
||||
expected_actions = 0.5 * initial_noise - 2.0 - cond_term
|
||||
|
||||
self.assertEqual(predicted_actions.shape, (2, 3, 2))
|
||||
self.assertTrue(torch.allclose(predicted_actions, expected_actions))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2)))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||
observation = {
|
||||
'qpos': torch.tensor([0.25], dtype=torch.float32),
|
||||
'images': {
|
||||
'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32),
|
||||
'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32),
|
||||
'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32),
|
||||
},
|
||||
}
|
||||
first_chunk = torch.tensor(
|
||||
[[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
second_chunk = torch.tensor(
|
||||
[[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk:
|
||||
first_action = agent.select_action(observation)
|
||||
second_action = agent.select_action(observation)
|
||||
third_action = agent.select_action(observation)
|
||||
|
||||
self.assertTrue(torch.equal(first_action, first_chunk[0, 1]))
|
||||
self.assertTrue(torch.equal(second_action, first_chunk[0, 2]))
|
||||
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
||||
self.assertEqual(mock_predict_chunk.call_count, 2)
|
||||
|
||||
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
vision_backbone = _StubJointVisionBackbone()
|
||||
agent = agent_cls(
|
||||
vision_backbone=vision_backbone,
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(
|
||||
agent.global_cond_dim,
|
||||
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
expected_cond = torch.tensor(
|
||||
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
projector.weight.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=projector,
|
||||
)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
cond = agent._build_cond(images, qpos)
|
||||
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(cond.shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(cond, expected))
|
||||
|
||||
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
cond_projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
cond_projector.weight.copy_(torch.tensor([
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
], dtype=torch.float32))
|
||||
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=cond_projector,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D')
|
||||
self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full')
|
||||
self.assertEqual(cfg.agent.head.n_head, 1)
|
||||
self.assertEqual(cfg.agent.head.n_kv_head, 1)
|
||||
self.assertEqual(cfg.agent.head.n_cond_layers, 0)
|
||||
self.assertTrue(cfg.agent.head.time_as_cond)
|
||||
self.assertFalse(cfg.agent.head.causal_attn)
|
||||
self.assertEqual(cfg.agent.inference_steps, 1)
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.head_type, 'transformer')
|
||||
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres',
|
||||
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,56,56]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.vision_backbone.attnres_stem_dim=16',
|
||||
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
||||
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
||||
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
||||
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
|
||||
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_imf_attnres',
|
||||
'agent.vision_backbone.checkpoint_path=null',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 208)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 208)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
|
||||
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=32',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 32)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 32)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=siglip2_imf_attnres',
|
||||
'agent.vision_backbone.per_view_output_dim=96',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.cond_projector.output_dim=384',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
|
||||
)
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 384)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 384)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 16)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 16)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
|
||||
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
220
tests/test_lewm_vit_backbone.py
Normal file
220
tests/test_lewm_vit_backbone.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import ViTConfig, ViTModel
|
||||
|
||||
|
||||
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
|
||||
|
||||
|
||||
class _ReferenceProjector(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(192, 2048),
|
||||
nn.BatchNorm1d(2048),
|
||||
nn.GELU(),
|
||||
nn.Linear(2048, 192),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def _build_reference_encoder() -> ViTModel:
|
||||
return ViTModel(
|
||||
ViTConfig(
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
num_channels=3,
|
||||
hidden_size=192,
|
||||
intermediate_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=3,
|
||||
qkv_bias=True,
|
||||
),
|
||||
add_pooling_layer=False,
|
||||
)
|
||||
|
||||
|
||||
def _write_synthetic_lightning_ckpt(path: Path):
|
||||
torch.manual_seed(7)
|
||||
encoder = _build_reference_encoder()
|
||||
projector = _ReferenceProjector()
|
||||
lightning_state_dict = {}
|
||||
for key, value in encoder.state_dict().items():
|
||||
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
|
||||
for key, value in projector.state_dict().items():
|
||||
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
|
||||
torch.save({"state_dict": lightning_state_dict}, path)
|
||||
return encoder.state_dict(), projector.state_dict()
|
||||
|
||||
|
||||
class LEWMViTBackboneTest(unittest.TestCase):
|
||||
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
|
||||
ckpt_path
|
||||
)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.joint_output_dim, 192)
|
||||
self.assertEqual(backbone.output_dim, 192)
|
||||
self.assertEqual(backbone.encoder.config.hidden_size, 192)
|
||||
self.assertEqual(backbone.encoder.config.patch_size, 14)
|
||||
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
|
||||
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
|
||||
|
||||
for key, value in reference_encoder_state.items():
|
||||
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
|
||||
for key, value in reference_projector_state.items():
|
||||
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
|
||||
|
||||
images = {
|
||||
cam_name: torch.rand(1, 1, 3, 224, 224)
|
||||
for cam_name in _INPUT_CAMERA_NAMES
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertFalse(output.requires_grad)
|
||||
|
||||
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
_write_synthetic_lightning_ckpt(ckpt_path)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
|
||||
del module, kwargs
|
||||
captured["pixel_values"] = pixel_values.detach().clone()
|
||||
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
|
||||
batch = pixel_values.shape[0]
|
||||
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
|
||||
cls = (
|
||||
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch, -1)
|
||||
)
|
||||
last_hidden_state = torch.zeros(
|
||||
batch,
|
||||
patch_tokens + 1,
|
||||
192,
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
last_hidden_state[:, 0] = cls
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
|
||||
|
||||
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
|
||||
top = torch.full((1, 1, 3, 256, 256), 0.20)
|
||||
front = torch.full((1, 1, 3, 256, 256), 0.10)
|
||||
bn = backbone.projector.net[1]
|
||||
running_mean_before = bn.running_mean.detach().clone()
|
||||
running_var_before = bn.running_var.detach().clone()
|
||||
|
||||
backbone.train()
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
self.assertFalse(backbone.projector.training)
|
||||
|
||||
output = backbone({"r_vis": r_vis, "top": top, "front": front})
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
|
||||
self.assertTrue(captured["interpolate_pos_encoding"])
|
||||
|
||||
normalized_views = [
|
||||
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
|
||||
for view in (front, top, r_vis)
|
||||
]
|
||||
expected_fuse_then_resize = F.interpolate(
|
||||
torch.cat(normalized_views, dim=-2),
|
||||
size=(672, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
expected_pre_resize_then_fuse = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
view,
|
||||
size=(224, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
for view in normalized_views
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
expected_fuse_then_resize,
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
captured["pixel_values"],
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 223, :],
|
||||
expected_fuse_then_resize[0, :, 223, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 447, :],
|
||||
expected_fuse_then_resize[0, :, 447, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
|
||||
self.assertTrue(torch.equal(bn.running_var, running_var_before))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -180,6 +180,14 @@ def _extract_camera_markers(cond, feature_dim, num_cams):
|
||||
return camera_block[:, 0]
|
||||
|
||||
|
||||
def _extract_token_camera_markers(tokens):
|
||||
return tokens[0, 0, :, 0]
|
||||
|
||||
|
||||
def _extract_token_markers(token_sequence):
|
||||
return token_sequence[0, 0, :, 0]
|
||||
|
||||
|
||||
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||
cfg = _compose_cfg(
|
||||
@@ -246,6 +254,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||
agent.predict_action(missing_images, proprioception)
|
||||
|
||||
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
'left_wrist': 99.0,
|
||||
},
|
||||
)
|
||||
tokens = backbone(images)
|
||||
|
||||
self.assertEqual(tokens.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||
self.assertEqual(backbone.tokens_per_step, 3)
|
||||
camera_markers = _extract_token_camera_markers(tokens)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
@@ -382,6 +420,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=32',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
|
||||
token_markers = _extract_token_markers(output)
|
||||
self.assertTrue(torch.allclose(token_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
121
tests/test_siglip2_diffusion_backbone.py
Normal file
121
tests/test_siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import types
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
del args, kwargs
|
||||
return cls()
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
"pixel_values": pixel_values.detach().clone(),
|
||||
"kwargs": dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
|
||||
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
fake_model = _FakeSiglipVisionModel(hidden_size=3)
|
||||
with mock.patch(
|
||||
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
|
||||
return_value=fake_model,
|
||||
) as mock_from_pretrained:
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
model_name="google/siglip2-base-patch16-256",
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
per_view_output_dim=2,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.output_dim, 2)
|
||||
self.assertEqual(backbone.joint_output_dim, 6)
|
||||
self.assertIsNone(backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
|
||||
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
|
||||
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
|
||||
with torch.no_grad():
|
||||
backbone.view_projector.weight.zero_()
|
||||
backbone.view_projector.bias.zero_()
|
||||
backbone.view_projector.weight[0, 0] = 1.0
|
||||
backbone.view_projector.weight[1, 1] = 1.0
|
||||
|
||||
images = {
|
||||
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
|
||||
"top": torch.full((1, 2, 3, 256, 256), 0.50),
|
||||
"front": torch.full((1, 2, 3, 256, 256), 0.75),
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 2, 6))
|
||||
self.assertEqual(len(fake_model.forward_calls), 3)
|
||||
|
||||
expected_per_camera = []
|
||||
for cam_name in _CAMERA_NAMES:
|
||||
img = images[cam_name].reshape(2, 3, 256, 256)
|
||||
normalized = (img - 0.5) / 0.5
|
||||
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
|
||||
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
|
||||
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
|
||||
|
||||
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
|
||||
pixels = call["pixel_values"]
|
||||
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
pixels,
|
||||
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_forward_rejects_missing_required_camera(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
vision_model=_FakeSiglipVisionModel(hidden_size=4),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "missing"):
|
||||
backbone({
|
||||
"r_vis": torch.rand(1, 1, 3, 256, 256),
|
||||
"top": torch.rand(1, 1, 3, 256, 256),
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -56,3 +56,26 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(resize_calls), 2)
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
)
|
||||
|
||||
fake_cv2 = types.SimpleNamespace(
|
||||
INTER_LINEAR=1,
|
||||
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
|
||||
)
|
||||
|
||||
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||
sample = dataset[1]
|
||||
|
||||
fake_cv2.resize.assert_not_called()
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
@@ -159,6 +159,92 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {
|
||||
'vision_backbone': {
|
||||
'dataset_image_resize_shape': None,
|
||||
},
|
||||
'normalization_type': 'min_max',
|
||||
},
|
||||
'data': {
|
||||
'dataset_dir': 'unused',
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'train': {
|
||||
'batch_size': 2,
|
||||
'lr': 1e-4,
|
||||
'max_steps': 0,
|
||||
'device': 'cpu',
|
||||
'disable_cudnn': False,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 42,
|
||||
'log_freq': 1,
|
||||
'save_freq': 10,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 0,
|
||||
'rollout_validate_on_checkpoint': False,
|
||||
'rollout_num_episodes': 1,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 1e-6,
|
||||
'weight_decay': 1e-5,
|
||||
'grad_clip': 1.0,
|
||||
'pretrained_ckpt': None,
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 1,
|
||||
'headless': True,
|
||||
'device': 'cpu',
|
||||
'verbose_action': False,
|
||||
},
|
||||
'experiment': {},
|
||||
}
|
||||
)
|
||||
captured_dataset_kwargs = {}
|
||||
|
||||
def fake_instantiate(config_node, **kwargs):
|
||||
if config_node is cfg.data:
|
||||
captured_dataset_kwargs.update(kwargs)
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla, '_init_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertIn('image_resize_shape', captured_dataset_kwargs)
|
||||
self.assertIsNone(captured_dataset_kwargs['image_resize_shape'])
|
||||
|
||||
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
@@ -234,7 +320,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
agent = _FakeAgent()
|
||||
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
||||
rollout_mock = mock.Mock(
|
||||
side_effect=[
|
||||
{
|
||||
'avg_reward': 2.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
'avg_reward': 1.0,
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
swanlab_log_mock = mock.Mock()
|
||||
saved_checkpoints = []
|
||||
|
||||
@@ -281,17 +388,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertEqual(rollout_mock.call_count, 2)
|
||||
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
||||
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
||||
self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
|
||||
self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
|
||||
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||
self.assertFalse(first_rollout_cfg.eval.record_video)
|
||||
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
|
||||
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
|
||||
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||
self.assertFalse(cfg.eval.headless)
|
||||
self.assertEqual(cfg.eval.device, 'cpu')
|
||||
self.assertFalse(cfg.eval.verbose_action)
|
||||
self.assertNotIn('save_trajectory_image', cfg.eval)
|
||||
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
|
||||
|
||||
rollout_reward_logs = [
|
||||
call.args[1]['rollout/avg_reward']
|
||||
@@ -769,10 +881,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
'dataset_len': 1,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
[path for path, _payload in saved_checkpoints],
|
||||
['checkpoints/vla_model_final.pt'],
|
||||
)
|
||||
self.assertEqual(len(saved_checkpoints), 1)
|
||||
self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
self.log_errors = list(log_errors or [])
|
||||
self.finish_error = finish_error
|
||||
self.image_errors = list(image_errors or [])
|
||||
self.init_calls = []
|
||||
self.log_calls = []
|
||||
self.finish_calls = 0
|
||||
self.image_calls = []
|
||||
|
||||
def init(self, project, experiment_name=None, config=None):
|
||||
self.init_calls.append({
|
||||
@@ -138,6 +140,18 @@ class FakeSwanLab:
|
||||
if self.log_errors:
|
||||
raise self.log_errors.pop(0)
|
||||
|
||||
def Image(self, path, caption=None):
|
||||
self.image_calls.append({
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
})
|
||||
if self.image_errors:
|
||||
raise self.image_errors.pop(0)
|
||||
return {
|
||||
'path': path,
|
||||
'caption': caption,
|
||||
}
|
||||
|
||||
def finish(self):
|
||||
self.finish_calls += 1
|
||||
if self.finish_error is not None:
|
||||
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||
self.assertIn('use_swanlab: false', config_text)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 3,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 7,
|
||||
'artifact_paths': {'trajectory_image': None},
|
||||
},
|
||||
{
|
||||
'episode_index': 8,
|
||||
'artifact_paths': {},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=12,
|
||||
context_label='epoch 1 rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_0': {
|
||||
'path': 'artifacts/episode_0_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
'rollout/trajectory_image_episode_3': {
|
||||
'path': 'artifacts/episode_3_front.png',
|
||||
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
|
||||
},
|
||||
},
|
||||
12,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
|
||||
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
|
||||
module = self._load_train_vla_module()
|
||||
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
|
||||
|
||||
with mock.patch.object(module.log, 'warning') as warning_mock:
|
||||
module._log_rollout_trajectory_images_to_swanlab(
|
||||
fake_swanlab,
|
||||
{
|
||||
'episodes': [
|
||||
{
|
||||
'episode_index': 0,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
|
||||
},
|
||||
{
|
||||
'episode_index': 1,
|
||||
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
|
||||
},
|
||||
],
|
||||
},
|
||||
step=7,
|
||||
context_label='checkpoint rollout',
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
fake_swanlab.image_calls,
|
||||
[
|
||||
{
|
||||
'path': 'artifacts/bad_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
|
||||
},
|
||||
{
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
{
|
||||
'rollout/trajectory_image_episode_1': {
|
||||
'path': 'artifacts/good_episode.png',
|
||||
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
|
||||
},
|
||||
},
|
||||
7,
|
||||
),
|
||||
fake_swanlab.log_calls,
|
||||
)
|
||||
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||
self.assertTrue(
|
||||
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
|
||||
)
|
||||
|
||||
def _load_train_vla_module(self):
|
||||
hydra_module = types.ModuleType('hydra')
|
||||
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||
self.assertEqual(final_step, cfg.train.max_steps)
|
||||
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
path = Path(path)
|
||||
if path == resume_path:
|
||||
path = Path(path).resolve()
|
||||
if path == resume_path.resolve():
|
||||
return resume_checkpoint_state
|
||||
if path == best_path:
|
||||
if path == best_path.resolve():
|
||||
return stale_best_checkpoint_state
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
|
||||
|
||||
@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
|
||||
]
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
class FakeIMFAgent(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.head_type = 'transformer'
|
||||
self.head_type = 'imf_transformer'
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
|
||||
|
||||
class FakeTransformerAgent(nn.Module):
|
||||
def __init__(self, *, head_type='transformer'):
|
||||
super().__init__()
|
||||
self.head_type = head_type
|
||||
self.noise_pred_net = RecordingTransformerHead()
|
||||
self.backbone = nn.Linear(4, 3)
|
||||
self.adapter = nn.Linear(3, 2, bias=False)
|
||||
@@ -205,6 +214,95 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
for group in optimizer.param_groups
|
||||
]
|
||||
|
||||
def test_clean_ld_preload_value_removes_problematic_nxegl_entry(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
cleaned, changed = module._clean_ld_preload_value(
|
||||
'/usr/lib/libfoo.so /usr/NX/lib/libnxegl.so /usr/lib/libbar.so'
|
||||
)
|
||||
|
||||
self.assertTrue(changed)
|
||||
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||
|
||||
def test_clean_ld_preload_value_leaves_safe_entries_unchanged(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
cleaned, changed = module._clean_ld_preload_value('/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||
|
||||
self.assertFalse(changed)
|
||||
self.assertEqual(cleaned, '/usr/lib/libfoo.so /usr/lib/libbar.so')
|
||||
|
||||
|
||||
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
||||
module = self._load_train_vla_module()
|
||||
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
||||
|
||||
original = module.torch.backends.cudnn.enabled
|
||||
try:
|
||||
module.torch.backends.cudnn.enabled = True
|
||||
module._configure_cuda_runtime(cfg)
|
||||
self.assertFalse(module.torch.backends.cudnn.enabled)
|
||||
finally:
|
||||
module.torch.backends.cudnn.enabled = original
|
||||
|
||||
|
||||
def test_resolve_run_output_dir_prefers_hydra_runtime_output_dir(self):
|
||||
module = self._load_train_vla_module()
|
||||
hydra_core_module = types.ModuleType('hydra.core')
|
||||
hydra_hydra_config_module = types.ModuleType('hydra.core.hydra_config')
|
||||
|
||||
class _Runtime:
|
||||
output_dir = '/tmp/hydra-output'
|
||||
|
||||
class _Cfg:
|
||||
runtime = _Runtime()
|
||||
|
||||
class HydraConfigStub:
|
||||
@staticmethod
|
||||
def initialized():
|
||||
return True
|
||||
@staticmethod
|
||||
def get():
|
||||
return _Cfg()
|
||||
|
||||
hydra_hydra_config_module.HydraConfig = HydraConfigStub
|
||||
with mock.patch.dict(sys.modules, {
|
||||
'hydra.core': hydra_core_module,
|
||||
'hydra.core.hydra_config': hydra_hydra_config_module,
|
||||
}):
|
||||
output_dir = module._resolve_run_output_dir()
|
||||
|
||||
self.assertEqual(Path(output_dir).resolve(), Path('/tmp/hydra-output').resolve())
|
||||
|
||||
|
||||
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||
module = self._load_train_vla_module()
|
||||
|
||||
fake_sys_path = ['/tmp/site-packages', '/another/path']
|
||||
with mock.patch.object(module.sys, 'path', fake_sys_path):
|
||||
repo_root = module._ensure_repo_root_on_syspath()
|
||||
|
||||
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
|
||||
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
|
||||
|
||||
|
||||
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeIMFAgent()
|
||||
|
||||
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||
group_names = self._group_names(agent, optimizer)
|
||||
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(group_names[1], {
|
||||
'noise_pred_net.proj.bias',
|
||||
'noise_pred_net.norm.weight',
|
||||
'noise_pred_net.norm.bias',
|
||||
})
|
||||
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||
|
||||
|
||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
@@ -268,6 +366,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||
self.assertNotIn('frozen.weight', optimizer_names)
|
||||
self.assertNotIn('frozen.bias', optimizer_names)
|
||||
|
||||
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent(head_type='imf')
|
||||
|
||||
with mock.patch.object(module, 'AdamW', RecordingAdamW):
|
||||
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||
|
||||
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||
grouped_names = self._group_names(agent, optimizer)
|
||||
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||
self.assertEqual(
|
||||
grouped_names[1],
|
||||
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
|
||||
)
|
||||
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||
|
||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||
module = self._load_train_vla_module()
|
||||
agent = FakeTransformerAgent()
|
||||
|
||||
Reference in New Issue
Block a user