feat: add rollout trajectory image artifacts and swanlab logging

This commit is contained in:
Logic
2026-04-03 09:39:16 +08:00
parent 48f0eb8dd0
commit 0586a6e6c7
8 changed files with 626 additions and 21 deletions

View File

@@ -0,0 +1,79 @@
# IMF Rollout Trajectory Images and Short-Horizon Training Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add training-time rollout front trajectory image export plus SwanLab image logging, then start a new local IMF training run with `emb=384`, `layer=12`, `pred_horizon=8`, `num_action_steps=4`, `max_steps=50000`.
**Architecture:** Extend `eval_vla.py` so a rollout can emit one per-episode static front-view image with red EE trajectory overlay. Extend `train_vla.py` so rollout validation forces image export, forces video off, and uploads those per-episode images to SwanLab. Launch the requested new run through explicit command-line overrides rather than branch-default config changes.
**Tech Stack:** Python, PyTorch, Hydra/OmegaConf, MuJoCo, OpenCV, SwanLab.
---
### Task 1: Add and validate rollout image tests
**Files:**
- Modify: `tests/test_eval_vla_rollout_artifacts.py`
- Modify: `tests/test_train_vla_swanlab_logging.py`
- Modify: `tests/test_train_vla_rollout_validation.py`
- [ ] Add/adjust eval tests so they assert per-episode trajectory image paths are produced without requiring video export.
- [ ] Add/adjust training tests so they assert training-time rollout validation forces `record_video=false`.
- [ ] Add/adjust training tests so they assert trajectory image paths flow from eval summary into SwanLab media logging.
- [ ] Add/adjust training tests so they assert image media is logged, not only scalar reward metrics.
### Task 2: Implement per-episode front trajectory image export in eval
**Files:**
- Modify: `roboimi/demos/vla_scripts/eval_vla.py`
- Reuse/Read: `roboimi/utils/raw_action_trajectory_viewer.py`
- Modify: `roboimi/vla/conf/eval/eval.yaml`
- [ ] Add config plumbing for `save_trajectory_image` and `trajectory_image_camera_name`.
- [ ] Ensure the default training-time camera resolution path is pinned to `front`.
- [ ] Implement distinct per-episode image naming so 5 rollout episodes create 5 distinct PNGs.
- [ ] Reuse the existing red trajectory representation logic when composing the PNG.
- [ ] Ensure headless eval works under EGL even on machines with `DISPLAY` set.
### Task 3: Implement SwanLab rollout image logging in training
**Files:**
- Modify: `roboimi/demos/vla_scripts/train_vla.py`
- Modify: `tests/test_train_vla_swanlab_logging.py`
- Modify: `tests/test_train_vla_rollout_validation.py`
- [ ] Make `run_rollout_validation()` force `record_video=false`.
- [ ] Make `run_rollout_validation()` force `save_trajectory_image=true` and `trajectory_image_camera_name=front`.
- [ ] Ensure rollout validation still uses 5 episodes per validation event for the requested run.
- [ ] Add a best-effort helper that converts per-episode image paths into SwanLab image media payloads.
- [ ] Keep image-upload failures non-fatal and warning-only.
### Task 4: Verify action-chunk semantics for the new run
**Files:**
- Verify: `roboimi/vla/agent.py`
- Verify: `roboimi/vla/agent_imf.py`
- Test: `tests/test_imf_vla_agent.py`
- [ ] Confirm the existing queue logic still means “predict 8, execute first 4”.
- [ ] Do not change branch defaults unless strictly necessary; prefer launch-time overrides.
### Task 5: Verify and launch the requested local training run
**Files:**
- Use: `roboimi/demos/vla_scripts/train_vla.py`
- Use: `roboimi/demos/vla_scripts/eval_vla.py`
- [ ] Run the targeted verification suite.
- [ ] Run one real headless smoke eval and confirm a front trajectory PNG is produced while `video_mp4` stays null.
- [ ] Launch the new local training run with explicit overrides including:
- `agent=resnet_imf_attnres`
- `agent.head.n_emb=384`
- `agent.head.n_layer=12`
- `agent.pred_horizon=8`
- `agent.num_action_steps=4`
- `train.max_steps=50000`
- `train.rollout_num_episodes=5`
- `train.use_swanlab=true`
- current local baseline dataset/camera/CUDA/batch/lr/num_workers/backbone settings
- [ ] Verify PID, GPU allocation, log tail, and SwanLab run URL.

View File

@@ -0,0 +1,75 @@
# IMF Rollout Trajectory Images + Short-Horizon Training Design
## Background
The current RoboIMI IMF training flow can perform rollout validation and log scalar reward metrics to SwanLab, but it does not yet emit the qualitative rollout artifacts now required for analysis. The user wants training-time rollout validation to save front-view trajectory images with the model-generated trajectory drawn in red, upload those images to SwanLab, and then start a new local short-horizon IMF training run.
## Goals
1. During training-time rollout validation, save one **front-camera** trajectory image per rollout episode.
2. The image must show the rollout EE trajectory in red.
3. Reuse the existing repository trajectory visualization logic as much as practical, especially the existing red capsule-marker trajectory representation.
4. Save 5 rollout images locally for each validation event and upload the same 5 images to SwanLab.
5. Do **not** record rollout videos for this training-time validation flow.
6. Start a new local IMF-AttnRes training run with:
- `agent.head.n_emb=384`
- `agent.head.n_layer=12`
- `agent.pred_horizon=8`
- `agent.num_action_steps=4`
- `train.max_steps=50000`
- `train.rollout_num_episodes=5`
- `train.use_swanlab=true`
## Non-Goals
- No IMF architecture or loss-function change.
- No dataset schema change.
- No rollout video generation for the new training flow.
- No interactive viewer requirement.
## Existing Relevant Code
- `roboimi/demos/vla_scripts/eval_vla.py`
- already supports rollout summaries, optional trajectory export, and optional video export.
- `roboimi/utils/raw_action_trajectory_viewer.py`
- already contains the red trajectory capsule-marker construction logic.
- `roboimi/demos/vla_scripts/train_vla.py`
- already performs periodic rollout validation and scalar SwanLab logging.
- `roboimi/vla/agent.py`
- already implements “predict pred_horizon, execute first num_action_steps” queue semantics.
## Design Decisions
### 1. Artifact contract
Each rollout episode will emit one distinct PNG file under the eval artifact directory. The file naming/path contract must be per-episode, not shared, so a 5-episode validation event yields 5 stable image paths without overwriting.
### 2. Trajectory definition
The red trajectory corresponds to the **actually executed model action sequence** over the rollout loop: the raw EE actions returned and consumed step-by-step by the policy loop. For the requested short-horizon run, this means the visualization reflects repeated execution of the first 4 actions from each predicted 8-action chunk, not every discarded future prediction from replanning.
### 3. Camera choice
The training-time image export path is explicitly pinned to the repos concrete `front` camera key. It must not silently use `camera_names[0]` if that is not `front`.
### 4. Rendering path
`eval_vla.py` will add a lightweight headless image-export path that:
- renders the `front` camera frame,
- overlays the trajectory using the existing red trajectory representation,
- saves a static PNG per episode.
The implementation may reuse the existing marker-construction logic directly and add a minimal helper for final image composition/export.
### 5. Training-time behavior
`train_vla.py` rollout validation must explicitly:
- request/save trajectory images,
- keep `record_video=false`,
- return the 5 per-episode image paths in the rollout summary payload,
- upload those 5 images to SwanLab,
- keep image-upload failures non-fatal.
## Expected User-Visible Outcome
For each scheduled validation event in the new training run:
- 5 rollout episodes execute,
- 5 front-view PNG trajectory images are saved locally,
- the same 5 images are uploaded to SwanLab,
- scalar reward metrics continue to be logged,
- no rollout videos are generated.
## Risks and Mitigations
- **Headless rendering conflicts from desktop env vars**: force headless eval onto EGL when `headless=true`.
- **Image overwrite risk**: use explicit per-episode artifact paths.
- **SwanLab media API mismatch**: isolate media logging in a small best-effort helper.

View File

@@ -26,6 +26,7 @@ from hydra.utils import instantiate
from einops import rearrange from einops import rearrange
from roboimi.utils.act_ex_utils import sample_transfer_pose from roboimi.utils.act_ex_utils import sample_transfer_pose
from roboimi.utils.raw_action_trajectory_viewer import build_trajectory_capsule_markers
from roboimi.vla.eval_utils import execute_policy_action from roboimi.vla.eval_utils import execute_policy_action
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
@@ -41,10 +42,8 @@ def _configure_headless_mujoco_gl(eval_cfg: DictConfig) -> None:
return return
if os.environ.get('MUJOCO_GL'): if os.environ.get('MUJOCO_GL'):
return return
if os.environ.get('DISPLAY'):
return
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = 'egl'
log.info('headless eval detected without DISPLAY; set MUJOCO_GL=egl') log.info('headless eval detected; set MUJOCO_GL=egl')
def make_sim_env(task_name: str, headless: bool = False): def make_sim_env(task_name: str, headless: bool = False):
@@ -204,10 +203,12 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
save_trajectory = bool( save_trajectory = bool(
eval_cfg.get('save_trajectory', False) or eval_cfg.get('save_trajectory_npz', False) eval_cfg.get('save_trajectory', False) or eval_cfg.get('save_trajectory_npz', False)
) )
save_trajectory_image = bool(eval_cfg.get('save_trajectory_image', False))
wants_artifacts = any([ wants_artifacts = any([
bool(eval_cfg.get('save_artifacts', False)), bool(eval_cfg.get('save_artifacts', False)),
save_timing, save_timing,
save_trajectory, save_trajectory,
save_trajectory_image,
bool(eval_cfg.get('record_video', False)), bool(eval_cfg.get('record_video', False)),
]) ])
output_dir: Optional[Path] = None output_dir: Optional[Path] = None
@@ -233,6 +234,22 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
else: else:
raise ValueError('record_video=true requires eval.video_camera_name or a non-empty eval.camera_names') raise ValueError('record_video=true requires eval.video_camera_name or a non-empty eval.camera_names')
trajectory_image_camera_name = None
if save_trajectory_image:
configured_camera_name = eval_cfg.get('trajectory_image_camera_name', None)
if configured_camera_name is None:
configured_camera_name = eval_cfg.get('trajectory_image_camera', None)
if configured_camera_name is not None:
trajectory_image_camera_name = str(configured_camera_name)
elif eval_cfg.get('camera_names'):
camera_names = [str(name) for name in eval_cfg.camera_names]
trajectory_image_camera_name = 'front' if 'front' in camera_names else camera_names[0]
else:
raise ValueError(
'save_trajectory_image=true requires eval.trajectory_image_camera_name '
'or a non-empty eval.camera_names'
)
return { return {
'output_dir': str(output_dir) if output_dir is not None else None, 'output_dir': str(output_dir) if output_dir is not None else None,
'summary_json': ( 'summary_json': (
@@ -257,6 +274,7 @@ def _resolve_artifact_paths(eval_cfg: DictConfig) -> dict[str, Optional[str]]:
else None else None
), ),
'video_camera_name': video_camera_name, 'video_camera_name': video_camera_name,
'trajectory_image_camera_name': trajectory_image_camera_name,
} }
@@ -285,6 +303,109 @@ def _open_video_writer(output_path: str, frame_size: tuple[int, int], fps: int):
return writer return writer
def _episode_trajectory_image_path(
artifact_paths: dict[str, Optional[str]],
episode_idx: int,
) -> Optional[str]:
output_dir = artifact_paths.get('output_dir')
camera_name = artifact_paths.get('trajectory_image_camera_name')
if output_dir is None or camera_name is None:
return None
return str(Path(output_dir) / f'rollout_{camera_name}_ep{episode_idx + 1:02d}_trajectory.png')
def _build_action_trajectory_positions(raw_actions: list[np.ndarray]) -> dict[str, np.ndarray]:
if not raw_actions:
empty = np.zeros((0, 3), dtype=np.float32)
return {'left': empty, 'right': empty}
raw_action_array = np.asarray(raw_actions, dtype=np.float32)
return {
'left': raw_action_array[:, :3].astype(np.float32, copy=True),
'right': raw_action_array[:, 7:10].astype(np.float32, copy=True),
}
def _append_capsule_markers_to_scene(scene, markers: list[dict]) -> None:
import mujoco
for marker in markers:
if scene.ngeom >= scene.maxgeom:
break
geom = scene.geoms[scene.ngeom]
mujoco.mjv_initGeom(
geom,
mujoco.mjtGeom.mjGEOM_CAPSULE,
np.zeros(3, dtype=np.float64),
np.zeros(3, dtype=np.float64),
np.eye(3, dtype=np.float64).reshape(-1),
np.asarray(marker['rgba'], dtype=np.float32),
)
mujoco.mjv_connector(
geom,
mujoco.mjtGeom.mjGEOM_CAPSULE,
float(marker['radius']),
np.asarray(marker['from'], dtype=np.float64),
np.asarray(marker['to'], dtype=np.float64),
)
scene.ngeom += 1
def _save_rollout_trajectory_image(
env,
output_path: Optional[str],
raw_actions: list[np.ndarray],
camera_name: Optional[str],
*,
line_radius: float = 0.004,
max_markers: int = 1500,
) -> Optional[str]:
if output_path is None or camera_name is None:
return None
output_path = str(output_path)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
frame = None
owned_renderer = None
positions = _build_action_trajectory_positions(raw_actions)
markers = build_trajectory_capsule_markers(
positions,
max_markers=max_markers,
radius=line_radius,
)
try:
renderer = None
if callable(getattr(env, '_get_or_create_offscreen_renderer', None)):
renderer = env._get_or_create_offscreen_renderer()
elif hasattr(env, 'mj_model') and hasattr(env, 'mj_data'):
import mujoco
renderer = mujoco.Renderer(env.mj_model, height=480, width=640)
owned_renderer = renderer
if renderer is not None and hasattr(env, 'mj_data'):
renderer.update_scene(env.mj_data, camera=str(camera_name))
if markers:
_append_capsule_markers_to_scene(renderer.scene, markers)
frame = renderer.render()[:, :, ::-1]
finally:
if owned_renderer is not None:
owned_renderer.close()
if frame is None and callable(getattr(env, '_get_image_obs', None)):
obs = env._get_image_obs()
frame = _get_video_frame(obs, str(camera_name))
if frame is None:
return None
import cv2
cv2.imwrite(output_path, frame)
return output_path
class _RolloutVideoRecorder: class _RolloutVideoRecorder:
def __init__(self, output_path: Optional[str], fps: int): def __init__(self, output_path: Optional[str], fps: int):
self.output_path = output_path self.output_path = output_path
@@ -582,6 +703,7 @@ def _run_eval(cfg: DictConfig):
model_forward_flags = [] model_forward_flags = []
episode_reward = 0.0 episode_reward = 0.0
episode_max_reward = float('-inf') episode_max_reward = float('-inf')
episode_raw_actions: list[np.ndarray] = []
with torch.inference_mode(): with torch.inference_mode():
for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"): for t in tqdm(range(eval_cfg.max_timesteps), desc=f"回合 {episode_idx + 1}"):
@@ -612,6 +734,7 @@ def _run_eval(cfg: DictConfig):
# 转换为 numpy # 转换为 numpy
raw_action = _to_numpy_action(action) raw_action = _to_numpy_action(action)
episode_raw_actions.append(raw_action.astype(np.float32, copy=True))
# 调试:打印当前时间步的动作(由配置控制) # 调试:打印当前时间步的动作(由配置控制)
if eval_cfg.get('verbose_action', False): if eval_cfg.get('verbose_action', False):
@@ -696,6 +819,12 @@ def _run_eval(cfg: DictConfig):
episode_artifact_paths = { episode_artifact_paths = {
'video': artifact_paths['video_mp4'], 'video': artifact_paths['video_mp4'],
'trajectory': artifact_paths['trajectory_npz'], 'trajectory': artifact_paths['trajectory_npz'],
'trajectory_image': _save_rollout_trajectory_image(
env,
_episode_trajectory_image_path(artifact_paths, episode_idx),
episode_raw_actions,
artifact_paths['trajectory_image_camera_name'],
),
'timing': artifact_paths['timing_json'] or artifact_paths['summary_json'], 'timing': artifact_paths['timing_json'] or artifact_paths['summary_json'],
} }

View File

@@ -299,6 +299,45 @@ def _log_to_swanlab(swanlab_module, payload, step=None):
log.warning(f"SwanLab log failed at step {step}: {exc}") log.warning(f"SwanLab log failed at step {step}: {exc}")
def _log_rollout_trajectory_images_to_swanlab(
swanlab_module,
rollout_stats,
step=None,
context_label: str = 'rollout',
):
if swanlab_module is None or not rollout_stats:
return
image_factory = getattr(swanlab_module, 'Image', None)
if image_factory is None:
return
payload = {}
for fallback_episode_index, episode in enumerate(rollout_stats.get('episodes', [])):
if not isinstance(episode, dict):
continue
artifact_paths = episode.get('artifact_paths', {})
if not isinstance(artifact_paths, dict):
continue
trajectory_image = artifact_paths.get('trajectory_image')
if not trajectory_image:
continue
episode_index = int(episode.get('episode_index', fallback_episode_index))
caption = f'{context_label} trajectory image - episode {episode_index} (front)'
try:
payload[f'rollout/trajectory_image_episode_{episode_index}'] = image_factory(
str(trajectory_image),
caption=caption,
)
except Exception as exc:
log.warning(
f"SwanLab rollout trajectory image upload prep failed at step {step}: {exc}"
)
if payload:
_log_to_swanlab(swanlab_module, payload, step=step)
def _finish_swanlab(swanlab_module): def _finish_swanlab(swanlab_module):
if swanlab_module is None: if swanlab_module is None:
return return
@@ -661,6 +700,13 @@ def _run_training(cfg: DictConfig):
rollout_cfg.eval.headless = True rollout_cfg.eval.headless = True
rollout_cfg.eval.device = 'cpu' rollout_cfg.eval.device = 'cpu'
rollout_cfg.eval.verbose_action = False rollout_cfg.eval.verbose_action = False
rollout_cfg.eval.record_video = False
rollout_cfg.eval.save_trajectory_image = True
rollout_cfg.eval.trajectory_image_camera_name = 'front'
rollout_cfg.eval.save_summary_json = True
rollout_cfg.eval.artifact_dir = str(
(run_output_dir / 'rollout_artifacts' / checkpoint_path.stem).resolve()
)
log.info( log.info(
"🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)", "🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)",
@@ -867,6 +913,12 @@ def _run_training(cfg: DictConfig):
}, },
step=step, step=step,
) )
_log_rollout_trajectory_images_to_swanlab(
swanlab_module,
rollout_stats,
step=step,
context_label=f'epoch {completed_epoch} rollout',
)
if rollout_avg_reward > best_rollout_reward: if rollout_avg_reward > best_rollout_reward:
best_rollout_reward = rollout_avg_reward best_rollout_reward = rollout_avg_reward
best_model_path = default_best_model_path best_model_path = default_best_model_path

View File

@@ -41,6 +41,9 @@ save_timing: false # 是否保存 timing.json包含各阶段耗时
save_trajectory: false # 是否保存 trajectory.npz原始 EE action + 执行后 EE pose save_trajectory: false # 是否保存 trajectory.npz原始 EE action + 执行后 EE pose
save_summary_json: false # 是否保存 JSON-friendly rollout summary save_summary_json: false # 是否保存 JSON-friendly rollout summary
save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ save_trajectory_npz: false # 是否保存每步轨迹/时序/EE pose 为 NPZ
save_trajectory_image: false # 是否保存带红色 EE 轨迹覆盖的静态 PNG
trajectory_image_camera: null # trajectory_image_camera_name 的别名
trajectory_image_camera_name: null # 导出轨迹图片使用的相机名;为空时默认取 camera_names[0]
record_video: false # 是否从单个相机流录制 rollout mp4 record_video: false # 是否从单个相机流录制 rollout mp4
video_camera: null # video_camera_name 的别名 video_camera: null # video_camera_name 的别名
video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0] video_camera_name: null # 录制视频使用的相机名;为空时默认取 camera_names[0]

View File

@@ -102,8 +102,10 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
self.assertIn('artifact_dir', eval_cfg) self.assertIn('artifact_dir', eval_cfg)
self.assertFalse(eval_cfg.save_summary_json) self.assertFalse(eval_cfg.save_summary_json)
self.assertFalse(eval_cfg.save_trajectory_npz) self.assertFalse(eval_cfg.save_trajectory_npz)
self.assertFalse(eval_cfg.save_trajectory_image)
self.assertFalse(eval_cfg.record_video) self.assertFalse(eval_cfg.record_video)
self.assertIsNone(eval_cfg.artifact_dir) self.assertIsNone(eval_cfg.artifact_dir)
self.assertIsNone(eval_cfg.trajectory_image_camera_name)
self.assertIsNone(eval_cfg.video_camera_name) self.assertIsNone(eval_cfg.video_camera_name)
self.assertEqual(eval_cfg.video_fps, 30) self.assertEqual(eval_cfg.video_fps, 30)
@@ -133,6 +135,8 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
'artifact_dir': tmpdir, 'artifact_dir': tmpdir,
'save_summary_json': True, 'save_summary_json': True,
'save_trajectory_npz': True, 'save_trajectory_npz': True,
'save_trajectory_image': True,
'trajectory_image_camera_name': 'front',
'record_video': True, 'record_video': True,
'video_camera_name': 'front', 'video_camera_name': 'front',
'video_fps': 12, 'video_fps': 12,
@@ -176,12 +180,14 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
trajectory_path = Path(artifacts['trajectory_npz']) trajectory_path = Path(artifacts['trajectory_npz'])
summary_path = Path(artifacts['summary_json']) summary_path = Path(artifacts['summary_json'])
video_path = Path(artifacts['video_mp4']) video_path = Path(artifacts['video_mp4'])
trajectory_image_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir)) self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
self.assertEqual(artifacts['video_camera_name'], 'front') self.assertEqual(artifacts['video_camera_name'], 'front')
self.assertTrue(trajectory_path.exists()) self.assertTrue(trajectory_path.exists())
self.assertTrue(summary_path.exists()) self.assertTrue(summary_path.exists())
self.assertTrue(video_path.exists()) self.assertTrue(video_path.exists())
self.assertTrue(trajectory_image_path.exists())
rollout_npz = np.load(trajectory_path) rollout_npz = np.load(trajectory_path)
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0])) np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
@@ -218,11 +224,121 @@ class EvalVLARolloutArtifactsTest(unittest.TestCase):
saved_summary = json.load(fh) saved_summary = json.load(fh)
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path)) self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path)) self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
self.assertEqual(
saved_summary['episodes'][0]['artifact_paths']['trajectory_image'],
str(trajectory_image_path),
)
self.assertEqual(saved_summary['episode_rewards'], [3.0]) self.assertEqual(saved_summary['episode_rewards'], [3.0])
self.assertAlmostEqual(summary['avg_reward'], 3.0) self.assertAlmostEqual(summary['avg_reward'], 3.0)
self.assertIn('avg_obs_read_time_ms', summary) self.assertIn('avg_obs_read_time_ms', summary)
self.assertIn('avg_env_step_time_ms', summary) self.assertIn('avg_env_step_time_ms', summary)
def test_run_eval_exports_front_trajectory_images_without_video_dependency(self):
actions = [
np.arange(16, dtype=np.float32),
np.arange(16, dtype=np.float32) + 10.0,
np.arange(16, dtype=np.float32) + 100.0,
np.arange(16, dtype=np.float32) + 110.0,
]
fake_agent = _FakeAgent(actions)
fake_env = _FakeEnv()
with tempfile.TemporaryDirectory() as tmpdir:
cfg = OmegaConf.create(
{
'agent': {},
'eval': {
'ckpt_path': 'checkpoints/vla_model_best.pt',
'num_episodes': 2,
'max_timesteps': 2,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['top', 'front'],
'use_smoothing': True,
'smooth_alpha': 0.5,
'verbose_action': False,
'headless': True,
'artifact_dir': tmpdir,
'save_trajectory_image': True,
'record_video': False,
},
}
)
trajectory_image_calls = []
def fake_save_rollout_trajectory_image(
env,
output_path,
raw_actions,
camera_name,
*,
line_radius=0.004,
max_markers=1500,
):
del env, line_radius, max_markers
trajectory_image_calls.append(
{
'output_path': output_path,
'camera_name': camera_name,
'raw_actions': [np.array(action, copy=True) for action in raw_actions],
}
)
if output_path is None:
return None
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_bytes(b'fake-png')
return str(output_path)
with mock.patch.object(
eval_vla,
'load_checkpoint',
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
'make_sim_env',
return_value=fake_env,
), mock.patch.object(
eval_vla,
'sample_transfer_pose',
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
), mock.patch.object(
eval_vla,
'tqdm',
side_effect=lambda iterable, **kwargs: iterable,
), mock.patch.object(
eval_vla,
'_save_rollout_trajectory_image',
side_effect=fake_save_rollout_trajectory_image,
) as save_trajectory_image_mock, mock.patch.object(
eval_vla,
'_open_video_writer',
) as open_video_writer_mock:
summary = eval_vla._run_eval(cfg)
self.assertEqual(save_trajectory_image_mock.call_count, 2)
open_video_writer_mock.assert_not_called()
self.assertIsNone(summary['artifacts']['video_mp4'])
self.assertEqual(summary['artifacts']['trajectory_image_camera_name'], 'front')
self.assertEqual(
[call['camera_name'] for call in trajectory_image_calls],
['front', 'front'],
)
first_episode_path = Path(summary['episodes'][0]['artifact_paths']['trajectory_image'])
second_episode_path = Path(summary['episodes'][1]['artifact_paths']['trajectory_image'])
self.assertTrue(first_episode_path.exists())
self.assertTrue(second_episode_path.exists())
self.assertNotEqual(first_episode_path, second_episode_path)
self.assertEqual(first_episode_path.parent, Path(tmpdir))
self.assertEqual(second_episode_path.parent, Path(tmpdir))
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][0], actions[0])
np.testing.assert_array_equal(trajectory_image_calls[0]['raw_actions'][1], actions[1])
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][0], actions[2])
np.testing.assert_array_equal(trajectory_image_calls[1]['raw_actions'][1], actions[3])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -234,7 +234,28 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
} }
) )
agent = _FakeAgent() agent = _FakeAgent()
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}]) rollout_mock = mock.Mock(
side_effect=[
{
'avg_reward': 2.0,
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/epoch_49_front.png'},
},
],
},
{
'avg_reward': 1.0,
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/epoch_99_front.png'},
},
],
},
]
)
swanlab_log_mock = mock.Mock() swanlab_log_mock = mock.Mock()
saved_checkpoints = [] saved_checkpoints = []
@@ -281,17 +302,22 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
self.assertEqual(rollout_mock.call_count, 2) self.assertEqual(rollout_mock.call_count, 2)
first_rollout_cfg = rollout_mock.call_args_list[0].args[0] first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
second_rollout_cfg = rollout_mock.call_args_list[1].args[0] second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt') self.assertTrue(first_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_49.pt'))
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt') self.assertTrue(second_rollout_cfg.eval.ckpt_path.endswith('checkpoints/vla_model_step_99.pt'))
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3) self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
self.assertTrue(first_rollout_cfg.eval.headless) self.assertTrue(first_rollout_cfg.eval.headless)
self.assertEqual(first_rollout_cfg.eval.device, 'cpu') self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
self.assertFalse(first_rollout_cfg.eval.verbose_action) self.assertFalse(first_rollout_cfg.eval.verbose_action)
self.assertFalse(first_rollout_cfg.eval.record_video)
self.assertTrue(first_rollout_cfg.eval.save_trajectory_image)
self.assertEqual(first_rollout_cfg.eval.trajectory_image_camera_name, 'front')
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt') self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
self.assertEqual(cfg.eval.num_episodes, 99) self.assertEqual(cfg.eval.num_episodes, 99)
self.assertFalse(cfg.eval.headless) self.assertFalse(cfg.eval.headless)
self.assertEqual(cfg.eval.device, 'cpu') self.assertEqual(cfg.eval.device, 'cpu')
self.assertFalse(cfg.eval.verbose_action) self.assertFalse(cfg.eval.verbose_action)
self.assertNotIn('save_trajectory_image', cfg.eval)
self.assertNotIn('trajectory_image_camera_name', cfg.eval)
rollout_reward_logs = [ rollout_reward_logs = [
call.args[1]['rollout/avg_reward'] call.args[1]['rollout/avg_reward']
@@ -769,10 +795,8 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
'dataset_len': 1, 'dataset_len': 1,
}, },
) )
self.assertEqual( self.assertEqual(len(saved_checkpoints), 1)
[path for path, _payload in saved_checkpoints], self.assertTrue(saved_checkpoints[0][0].endswith('checkpoints/vla_model_final.pt'))
['checkpoints/vla_model_final.pt'],
)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -115,13 +115,15 @@ class FakeAgent(nn.Module):
class FakeSwanLab: class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None): def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error self.init_error = init_error
self.log_errors = list(log_errors or []) self.log_errors = list(log_errors or [])
self.finish_error = finish_error self.finish_error = finish_error
self.image_errors = list(image_errors or [])
self.init_calls = [] self.init_calls = []
self.log_calls = [] self.log_calls = []
self.finish_calls = 0 self.finish_calls = 0
self.image_calls = []
def init(self, project, experiment_name=None, config=None): def init(self, project, experiment_name=None, config=None):
self.init_calls.append({ self.init_calls.append({
@@ -138,6 +140,18 @@ class FakeSwanLab:
if self.log_errors: if self.log_errors:
raise self.log_errors.pop(0) raise self.log_errors.pop(0)
def Image(self, path, caption=None):
self.image_calls.append({
'path': path,
'caption': caption,
})
if self.image_errors:
raise self.image_errors.pop(0)
return {
'path': path,
'caption': caption,
}
def finish(self): def finish(self):
self.finish_calls += 1 self.finish_calls += 1
if self.finish_error is not None: if self.finish_error is not None:
@@ -149,6 +163,119 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
config_text = _CONFIG_PATH.read_text(encoding='utf-8') config_text = _CONFIG_PATH.read_text(encoding='utf-8')
self.assertIn('use_swanlab: false', config_text) self.assertIn('use_swanlab: false', config_text)
def test_log_rollout_trajectory_images_to_swanlab_uploads_episode_artifacts(self):
module = self._load_train_vla_module()
fake_swanlab = FakeSwanLab()
module._log_rollout_trajectory_images_to_swanlab(
fake_swanlab,
{
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/episode_0_front.png'},
},
{
'episode_index': 3,
'artifact_paths': {'trajectory_image': 'artifacts/episode_3_front.png'},
},
{
'episode_index': 7,
'artifact_paths': {'trajectory_image': None},
},
{
'episode_index': 8,
'artifact_paths': {},
},
],
},
step=12,
context_label='epoch 1 rollout',
)
self.assertEqual(
fake_swanlab.image_calls,
[
{
'path': 'artifacts/episode_0_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
},
{
'path': 'artifacts/episode_3_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
},
],
)
self.assertIn(
(
{
'rollout/trajectory_image_episode_0': {
'path': 'artifacts/episode_0_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 0 (front)',
},
'rollout/trajectory_image_episode_3': {
'path': 'artifacts/episode_3_front.png',
'caption': 'epoch 1 rollout trajectory image - episode 3 (front)',
},
},
12,
),
fake_swanlab.log_calls,
)
def test_log_rollout_trajectory_images_to_swanlab_is_best_effort(self):
module = self._load_train_vla_module()
fake_swanlab = FakeSwanLab(image_errors=[RuntimeError('decode failed')])
with mock.patch.object(module.log, 'warning') as warning_mock:
module._log_rollout_trajectory_images_to_swanlab(
fake_swanlab,
{
'episodes': [
{
'episode_index': 0,
'artifact_paths': {'trajectory_image': 'artifacts/bad_episode.png'},
},
{
'episode_index': 1,
'artifact_paths': {'trajectory_image': 'artifacts/good_episode.png'},
},
],
},
step=7,
context_label='checkpoint rollout',
)
self.assertEqual(
fake_swanlab.image_calls,
[
{
'path': 'artifacts/bad_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 0 (front)',
},
{
'path': 'artifacts/good_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
},
],
)
self.assertIn(
(
{
'rollout/trajectory_image_episode_1': {
'path': 'artifacts/good_episode.png',
'caption': 'checkpoint rollout trajectory image - episode 1 (front)',
},
},
7,
),
fake_swanlab.log_calls,
)
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
self.assertTrue(
any('SwanLab rollout trajectory image upload prep failed' in message for message in warning_messages)
)
def _load_train_vla_module(self): def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra') hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils') hydra_utils_module = types.ModuleType('hydra.utils')
@@ -356,8 +483,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
final_payload, final_step = fake_swanlab.log_calls[-1] final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps) self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt') self.assertTrue(final_payload['final/checkpoint_path'].endswith('checkpoints/vla_model_final.pt'))
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt') self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1) self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_skips_swanlab_when_disabled(self): def test_run_training_skips_swanlab_when_disabled(self):
@@ -512,10 +639,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
def fake_torch_load(path, map_location=None): def fake_torch_load(path, map_location=None):
del map_location del map_location
path = Path(path) path = Path(path).resolve()
if path == resume_path: if path == resume_path.resolve():
return resume_checkpoint_state return resume_checkpoint_state
if path == best_path: if path == best_path.resolve():
return best_checkpoint_state return best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}') raise AssertionError(f'unexpected load path: {path}')
@@ -538,8 +665,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
final_payload, final_step = fake_swanlab.log_calls[-1] final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps) self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt') self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths) self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self): def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module() module = self._load_train_vla_module()
@@ -594,10 +721,10 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
def fake_torch_load(path, map_location=None): def fake_torch_load(path, map_location=None):
del map_location del map_location
path = Path(path) path = Path(path).resolve()
if path == resume_path: if path == resume_path.resolve():
return resume_checkpoint_state return resume_checkpoint_state
if path == best_path: if path == best_path.resolve():
return stale_best_checkpoint_state return stale_best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}') raise AssertionError(f'unexpected load path: {path}')