From 74f4963613aa5621a99433cfce8de9bf1b57eb97 Mon Sep 17 00:00:00 2001 From: Logic Date: Fri, 17 Apr 2026 18:46:02 +0800 Subject: [PATCH] feat: add lewm-conditioned imf training and sigreg loss --- roboimi/demos/vla_scripts/train_vla.py | 78 +++- .../refresh_experiment_suite_status.py | 267 ++++++++++++ roboimi/vla/agent.py | 17 +- roboimi/vla/agent_imf.py | 322 +++++++++++++- .../agent/lewm_resnet_query_imf_attnres.yaml | 77 ++++ .../backbone/lewm_resnet_query_fusion.yaml | 7 + .../vla/conf/modules/lewm_state_encoder.yaml | 5 + roboimi/vla/data/simpe_robot_dataset.py | 63 +++ roboimi/vla/models/backbones/__init__.py | 24 +- .../backbones/lewm_resnet_query_fusion.py | 409 ++++++++++++++++++ roboimi/vla/modules/encoders.py | 22 +- tests/test_imf_vla_agent.py | 226 ++++++++++ ...test_simple_robot_dataset_image_loading.py | 21 + tests/test_train_vla_swanlab_logging.py | 120 +++++ 14 files changed, 1634 insertions(+), 24 deletions(-) create mode 100755 roboimi/scripts/refresh_experiment_suite_status.py create mode 100644 roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml create mode 100644 roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml create mode 100644 roboimi/vla/conf/modules/lewm_state_encoder.yaml create mode 100644 roboimi/vla/models/backbones/lewm_resnet_query_fusion.py diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index e4a063c..afbd23f 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -237,6 +237,32 @@ def build_training_optimizer(agent, lr, weight_decay): return AdamW(optim_groups, lr=lr, weight_decay=weight_decay) +def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict): + """Load only checkpoint tensors whose keys exist locally and whose shapes match.""" + current_state_dict = module.state_dict() + compatible_state_dict = {} + mismatched_keys = [] + missing_keys = [] + + for key, value in incoming_state_dict.items(): + if key not in current_state_dict: + missing_keys.append(key) + continue + if current_state_dict[key].shape != value.shape: + mismatched_keys.append(key) + continue + compatible_state_dict[key] = value + + merged_state_dict = dict(current_state_dict) + merged_state_dict.update(compatible_state_dict) + module.load_state_dict(merged_state_dict, strict=True) + return { + 'loaded_keys': sorted(compatible_state_dict.keys()), + 'missing_keys': sorted(missing_keys), + 'mismatched_keys': sorted(mismatched_keys), + } + + def _init_swanlab(cfg): """按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。""" if not bool(cfg.train.get('use_swanlab', False)): @@ -509,18 +535,23 @@ def _run_training(cfg: DictConfig): try: checkpoint = torch.load(ckpt_path, map_location=cfg.train.device) - # 只加载模型权重(不加载 optimizer、scheduler) - missing_keys, unexpected_keys = agent.load_state_dict( + load_info = load_state_dict_ignoring_shape_mismatches( + agent, checkpoint['model_state_dict'], - strict=False # 允许部分加载(结构不完全匹配时) ) log.info(f"✅ [Finetune] 模型权重加载成功") - if missing_keys: - log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...") - if unexpected_keys: - log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...") + if load_info['missing_keys']: + log.warning( + f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): " + f"{load_info['missing_keys'][:5]}..." + ) + if load_info['mismatched_keys']: + log.warning( + f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): " + f"{load_info['mismatched_keys'][:5]}..." + ) log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}") log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})") @@ -652,13 +683,35 @@ def _run_training(cfg: DictConfig): if key in batch_data: images[cam_name] = batch_data[key] - return { + agent_input = { 'images': images, 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state 'action': batch_data['action'], 'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask } + lewm_images = {} + lewm_future_images = {} + for cam_name in cfg.data.camera_names: + lewm_obs_key = f"lewm.observation.{cam_name}" + if lewm_obs_key in batch_data: + lewm_images[cam_name] = batch_data[lewm_obs_key] + + lewm_future_key = f"lewm.future.{cam_name}" + if lewm_future_key in batch_data: + lewm_future_images[cam_name] = batch_data[lewm_future_key] + + if 'lewm.observation.state' in batch_data: + agent_input['lewm_qpos'] = batch_data['lewm.observation.state'] + if lewm_images: + agent_input['lewm_images'] = lewm_images + if 'lewm.future.state' in batch_data: + agent_input['lewm_future_qpos'] = batch_data['lewm.future.state'] + if lewm_future_images: + agent_input['lewm_future_images'] = lewm_future_images + + return agent_input + def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None): agent_stats = agent.get_normalization_stats() torch.save({ @@ -809,6 +862,15 @@ def _run_training(cfg: DictConfig): }, step=step, ) + if hasattr(agent, 'get_last_loss_breakdown'): + loss_breakdown = agent.get_last_loss_breakdown() + extra_train_metrics = { + f"train/{key}": value + for key, value in loss_breakdown.items() + if value is not None and key != 'loss' + } + if extra_train_metrics: + _log_to_swanlab(swanlab_module, extra_train_metrics, step=step) # ===================================================================== # 检查点保存与验证 diff --git a/roboimi/scripts/refresh_experiment_suite_status.py b/roboimi/scripts/refresh_experiment_suite_status.py new file mode 100755 index 0000000..0ddbb9d --- /dev/null +++ b/roboimi/scripts/refresh_experiment_suite_status.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import pathlib +import re +import shlex +import subprocess +from collections import defaultdict +from typing import Any + + +STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)") +BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)") + + +def normalize_chunks(text: str): + for part in re.split(r"[\r\n]+", text): + part = part.strip() + if part: + yield part + + +def parse_latest_line(text: str) -> tuple[str, int | None]: + latest_line = "" + latest_step = None + for line in normalize_chunks(text): + if "步骤" not in line and "训练中:" not in line: + continue + latest_line = line + match = STEP_PAT.search(line) or BAR_PAT.search(line) + if match: + latest_step = int(match.group(1)) + return latest_line, latest_step + + +def now_iso() -> str: + return dt.datetime.now( + dt.timezone(dt.timedelta(hours=8)), + ).isoformat(timespec="seconds") + + +def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]: + return subprocess.run(cmd, capture_output=True, text=True, check=check) + + +def probe_local(run: dict[str, Any]) -> dict[str, Any]: + pid = str(run["pid"]) + ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False) + log_path = pathlib.Path(run["log_path"]) + latest_line = "" + latest_step = None + if log_path.exists(): + latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace")) + return { + "alive": bool(ps.stdout.strip()), + "ps": ps.stdout.strip(), + "log_exists": log_path.exists(), + "latest_line": latest_line, + "latest_step": latest_step, + } + + +def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + payload = [ + { + "run_id": run["run_id"], + "pid": str(run["pid"]), + "log_path": run["log_path"], + } + for run in runs + ] + remote_py = r""" +import json +import pathlib +import re +import subprocess +import sys + +payload = json.loads(sys.argv[1]) +step_pat = re.compile(r"步骤\s+(\d+)/(\d+)") +bar_pat = re.compile(r"\|\s*(\d+)/(\d+)") + +def normalize_chunks(text): + for part in re.split(r"[\r\n]+", text): + part = part.strip() + if part: + yield part + +def parse_latest_line(text): + latest_line = "" + latest_step = None + for line in normalize_chunks(text): + if "步骤" not in line and "训练中:" not in line: + continue + latest_line = line + match = step_pat.search(line) or bar_pat.search(line) + if match: + latest_step = int(match.group(1)) + return latest_line, latest_step + +out = {} +for item in payload: + try: + ps = subprocess.run( + ["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="], + capture_output=True, + text=True, + check=False, + ) + log_path = pathlib.Path(item["log_path"]) + latest_line = "" + latest_step = None + if log_path.exists(): + latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace")) + out[item["run_id"]] = { + "alive": bool(ps.stdout.strip()), + "ps": ps.stdout.strip(), + "log_exists": log_path.exists(), + "latest_line": latest_line, + "latest_step": latest_step, + } + except Exception as exc: + out[item["run_id"]] = { + "alive": False, + "ps": "", + "log_exists": False, + "latest_line": "", + "latest_step": None, + "error": str(exc), + } +print(json.dumps(out, ensure_ascii=False)) +""" + remote_target = host if "@" in host else f"{remote_user}@{host}" + remote_cmd = ( + f"python3 -c {shlex.quote(remote_py)} " + f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}" + ) + try: + res = run_cmd( + [ + "ssh", + "-F", + "/dev/null", + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=accept-new", + remote_target, + remote_cmd, + ] + ) + return json.loads(res.stdout) + except subprocess.CalledProcessError as exc: + error = (exc.stderr or exc.stdout or str(exc)).strip() + return { + run["run_id"]: { + "alive": False, + "ps": "", + "log_exists": False, + "latest_line": "", + "latest_step": None, + "error": f"ssh_failed: {error}", + } + for run in runs + } + + +def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None: + lines = [f"\n## Status snapshot {snapshot_at}"] + for run in runs: + lines.append( + ( + f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} " + f"alive={run.get('alive', False)} step={run.get('latest_step')} " + f"pid={run['pid']}" + ) + ) + if run.get("latest_line"): + lines.append(f" - latest_line: `{run['latest_line']}`") + if run.get("error"): + lines.append(f" - error: `{run['error']}`") + with notes_path.open("a", encoding="utf-8") as f: + f.write("\n".join(lines) + "\n") + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("suite_dir", type=pathlib.Path) + parser.add_argument("--remote-user", default="droid") + parser.add_argument("--append-notes", action="store_true") + args = parser.parse_args() + + suite_dir = args.suite_dir.resolve() + status_path = suite_dir / "status.json" + notes_path = suite_dir / "notes.md" + monitor_dir = suite_dir / "monitor_logs" + monitor_dir.mkdir(parents=True, exist_ok=True) + + status = json.loads(status_path.read_text(encoding="utf-8")) + runs: list[dict[str, Any]] = status["runs"] + snapshot_at = now_iso() + + by_host: dict[str, list[dict[str, Any]]] = defaultdict(list) + for run in runs: + by_host[run["host"]].append(run) + + results: dict[str, dict[str, Any]] = {} + for host, host_runs in by_host.items(): + if host == "local": + for run in host_runs: + results[run["run_id"]] = probe_local(run) + else: + results.update(remote_probe(host, args.remote_user, host_runs)) + + alive_count = 0 + for run in runs: + result = results[run["run_id"]] + run["alive"] = result["alive"] + run["ps"] = result["ps"] + run["log_exists"] = result["log_exists"] + run["latest_line"] = result["latest_line"] + run["latest_step"] = result["latest_step"] + run["last_verified_at"] = snapshot_at + if "error" in result: + run["error"] = result["error"] + else: + run.pop("error", None) + run["status"] = "running" if result["alive"] else "stopped" + alive_count += int(result["alive"]) + + status["last_verified_at"] = snapshot_at + status["alive_count"] = alive_count + status["total_runs"] = len(runs) + + status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + + snapshot_payload = { + "suite_name": status.get("suite_name"), + "snapshot_at": snapshot_at, + "alive_count": alive_count, + "total_runs": len(runs), + "runs": {run["run_id"]: results[run["run_id"]] for run in runs}, + } + timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "") + snapshot_path = monitor_dir / f"status-{timestamp_slug}.json" + snapshot_path.write_text( + json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + + if args.append_notes: + append_notes(notes_path, snapshot_at, runs) + + print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2)) + print(f"\nstatus_json={status_path}") + print(f"snapshot_json={snapshot_path}") + if args.append_notes: + print(f"notes_md={notes_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index 3578ac3..77a7cde 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -28,6 +28,7 @@ class VLAAgent(nn.Module): num_action_steps=8, # 每次推理实际执行多少步动作 head_type='unet', # Policy head类型: 'unet' 或 'transformer' cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度 + extra_condition_tokens: int = 0, # 可选:额外条件token数量(例如未来预测embedding) ): super().__init__() # 保存参数 @@ -39,6 +40,9 @@ class VLAAgent(nn.Module): self.num_action_steps = num_action_steps self.inference_steps = inference_steps self.head_type = head_type # 'unet' 或 'transformer' + self.extra_condition_tokens = int(extra_condition_tokens) + if self.extra_condition_tokens < 0: + raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}") agent_camera_names = tuple(camera_names) if camera_names is not None else None backbone_camera_names = getattr(vision_backbone, 'camera_names', None) backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None @@ -71,11 +75,14 @@ class VLAAgent(nn.Module): stats=dataset_stats, normalization_type=normalization_type ) + self.dataset_stats = dataset_stats self.vision_encoder = vision_backbone + self.state_encoder = state_encoder if self.camera_names is not None: self.vision_encoder.camera_names = self.camera_names self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1)) + self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim)) joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None) if joint_vision_dim is not None: per_token_vision_dim = int(joint_vision_dim) @@ -87,8 +94,11 @@ class VLAAgent(nn.Module): else: per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams) - self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step - self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim + self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step + self.condition_sequence_length = ( + self.history_condition_sequence_length + self.extra_condition_tokens + ) + self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim if cond_projector is None: self.cond_projector = None self.per_step_cond_dim = self.raw_per_step_cond_dim @@ -139,7 +149,6 @@ class VLAAgent(nn.Module): global_cond_dim=self.global_cond_dim ) - self.state_encoder = state_encoder self.action_encoder = action_encoder # 初始化队列(用于在线推理) @@ -220,7 +229,7 @@ class VLAAgent(nn.Module): f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" ) cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim) - expected_length = self.condition_sequence_length + expected_length = self.history_condition_sequence_length if cond.shape[1] != expected_length: raise RuntimeError( f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}" diff --git a/roboimi/vla/agent_imf.py b/roboimi/vla/agent_imf.py index 6dfc307..557de7d 100644 --- a/roboimi/vla/agent_imf.py +++ b/roboimi/vla/agent_imf.py @@ -1,9 +1,12 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Dict, Optional +from collections import deque +from pathlib import Path +from typing import Any, Dict, Mapping, Optional, Sequence import torch +import torch.nn as nn import torch.nn.functional as F from roboimi.vla.agent import VLAAgent @@ -15,14 +18,59 @@ except ImportError: # pragma: no cover class IMFVLAAgent(VLAAgent): - def __init__(self, *args, inference_steps: int = 1, **kwargs): + def __init__( + self, + *args, + inference_steps: int = 1, + lewm_history_horizon: Optional[int] = None, + lewm_query_offsets: Optional[Sequence[int]] = None, + lewm_predictor: Optional[nn.Module] = None, + lewm_pred_projector: Optional[nn.Module] = None, + lewm_sigreg: Optional[nn.Module] = None, + lewm_sigreg_weight: float = 0.09, + lewm_loss_weight: float = 0.0, + lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None, + **kwargs, + ): if inference_steps != 1: raise ValueError( 'IMFVLAAgent only supports one-step inference; ' f'inference_steps must be 1, got {inference_steps}.' ) + lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ())) + inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0 + kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens) + self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1)) + self.__dict__['lewm_query_offsets'] = lewm_query_offsets + self.__dict__['lewm_predictor'] = lewm_predictor + self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity() + self.__dict__['lewm_sigreg'] = lewm_sigreg + self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight) + self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight) + self.__dict__['_last_loss_breakdown'] = { + 'action_loss': 0.0, + 'lewm_pred_loss': 0.0, + 'lewm_sigreg_loss': 0.0, + 'lewm_loss': 0.0, + 'loss': 0.0, + } super().__init__(*args, inference_steps=inference_steps, **kwargs) self.inference_steps = 1 + self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon) + self.lewm_predictor = lewm_predictor + self.lewm_pred_projector = lewm_pred_projector or nn.Identity() + self.lewm_sigreg = lewm_sigreg + self.lewm_sigreg_weight = float(lewm_sigreg_weight) + if self.lewm_predictor is None and self.extra_condition_tokens > 0: + raise ValueError( + 'extra_condition_tokens > 0 requires lewm_predictor to be provided' + ) + if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens: + raise ValueError( + 'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled' + ) + if lewm_pretrained_ckpt is not None: + self.load_lewm_pretrained_components(lewm_pretrained_ckpt) @staticmethod def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor: @@ -119,14 +167,181 @@ class IMFVLAAgent(VLAAgent): delta = self._broadcast_batch_time(t - r, z_t) return z_t - delta * u + def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor: + if not self.normalization.enabled: + return qpos + + qpos_mean = getattr(self.normalization, 'qpos_mean', None) + qpos_std = getattr(self.normalization, 'qpos_std', None) + if qpos_mean is not None and qpos_std is not None: + return (qpos - qpos_mean) / qpos_std + if isinstance(self.dataset_stats, dict): + mean = self.dataset_stats.get('qpos_mean', None) + std = self.dataset_stats.get('qpos_std', None) + if mean is not None and std is not None: + mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device) + std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device) + return (qpos - mean) / std + return self.normalization.normalize_qpos(qpos) + + def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor: + if predicted_tokens.ndim != 3: + raise ValueError( + f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}" + ) + batch_size, token_count, token_dim = predicted_tokens.shape + flattened = predicted_tokens.reshape(batch_size * token_count, token_dim) + projected = self.lewm_pred_projector(flattened) + if projected.ndim != 2: + raise ValueError( + f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}" + ) + return projected.reshape(batch_size, token_count, projected.shape[-1]) + + @staticmethod + def _load_checkpoint_payload( + checkpoint_or_path: str | Path | Mapping[str, Any], + ) -> Mapping[str, torch.Tensor]: + if isinstance(checkpoint_or_path, (str, Path)): + payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False) + else: + payload = checkpoint_or_path + state_dict = payload.get('state_dict', payload) + if not isinstance(state_dict, Mapping): + raise TypeError('checkpoint payload must contain a mapping state_dict') + return state_dict + + @staticmethod + def _extract_prefixed_state_dict( + state_dict: Mapping[str, torch.Tensor], + prefix: str, + ) -> Dict[str, torch.Tensor]: + extracted = { + key[len(prefix):]: value + for key, value in state_dict.items() + if key.startswith(prefix) + } + if not extracted: + raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}") + return extracted + + @staticmethod + def _adapt_and_load_state_dict( + module: nn.Module, + incoming_state_dict: Mapping[str, torch.Tensor], + *, + query_key: str = 'query_tokens', + pos_key: str = 'pos_embedding', + ) -> None: + current_state_dict = module.state_dict() + adapted_state_dict = dict(current_state_dict) + for key, current_tensor in current_state_dict.items(): + if key not in incoming_state_dict: + continue + source_tensor = incoming_state_dict[key] + if source_tensor.shape == current_tensor.shape: + adapted_state_dict[key] = source_tensor + continue + + if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim: + patched = current_tensor.clone() + if key == query_key: + copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) + patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...] + if copy_count < current_tensor.shape[1] and copy_count > 0: + patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] + else: + copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) + patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...] + if copy_count < current_tensor.shape[1] and copy_count > 0: + patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] + adapted_state_dict[key] = patched + + module.load_state_dict(adapted_state_dict, strict=True) + + def load_lewm_pretrained_components( + self, + checkpoint_or_path: str | Path | Mapping[str, Any], + ) -> None: + state_dict = self._load_checkpoint_payload(checkpoint_or_path) + + if hasattr(self.vision_encoder, 'load_lewm_checkpoint'): + self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict}) + else: + vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.') + self.vision_encoder.load_state_dict(vision_state_dict, strict=True) + + state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.') + self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True) + + projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.') + mapped_projector_state_dict = { + f'linear.{key}': value + for key, value in projector_state_dict.items() + } + self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True) + + if self.lewm_predictor is not None: + predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.') + self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict) + + if self.lewm_pred_projector is not None: + pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.') + self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True) + + def _build_full_condition( + self, + images, + proprioception, + *, + lewm_images=None, + lewm_proprioception=None, + ): + normalized_proprioception = self.normalization.normalize_qpos(proprioception) + history_cond = self._build_cond(images, normalized_proprioception) + predicted_future_tokens = None + lewm_history_cond = None + + if self.lewm_predictor is None: + return history_cond, predicted_future_tokens, lewm_history_cond + + lewm_images = lewm_images if lewm_images is not None else images + lewm_proprioception = ( + lewm_proprioception if lewm_proprioception is not None else proprioception + ) + lewm_history_cond = self._build_cond( + lewm_images, + self._normalize_qpos_for_lewm(lewm_proprioception), + ) + predicted_future_tokens = self.lewm_predictor(lewm_history_cond) + predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens) + cond = torch.cat([history_cond, predicted_future_tokens], dim=1) + if cond.shape[1] != self.condition_sequence_length: + raise RuntimeError( + f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}" + ) + if cond.shape[-1] != self.per_step_cond_dim: + raise RuntimeError( + f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" + ) + return cond, predicted_future_tokens, lewm_history_cond + + @staticmethod + def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return F.mse_loss(pred, target) + def compute_loss(self, batch): actions, states, images = batch['action'], batch['qpos'], batch['images'] action_is_pad = batch.get('action_is_pad', None) batch_size = actions.shape[0] - states = self.normalization.normalize_qpos(states) actions = self.normalization.normalize_action(actions) - cond = self._build_cond(images, states) + cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition( + images, + states, + lewm_images=batch.get('lewm_images', None), + lewm_proprioception=batch.get('lewm_qpos', None), + ) x = actions e = torch.randn_like(x) @@ -146,16 +361,103 @@ class IMFVLAAgent(VLAAgent): if action_is_pad is not None: mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) valid_count = mask.sum() * loss.shape[-1] - loss = (loss * mask).sum() / valid_count.clamp_min(1.0) + action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0) else: - loss = loss.mean() - return loss + action_loss = loss.mean() + + lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype) + lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype) + if predicted_future_tokens is not None: + lewm_future_images = batch.get('lewm_future_images', None) + lewm_future_qpos = batch.get('lewm_future_qpos', None) + if lewm_future_images is not None and lewm_future_qpos is not None: + future_target = self._build_cond( + lewm_future_images, + self._normalize_qpos_for_lewm(lewm_future_qpos), + ) + lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target) + if self.lewm_sigreg is not None and lewm_history_cond is not None: + lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1)) + + lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss + total_loss = action_loss + self.lewm_loss_weight * lewm_loss + self._last_loss_breakdown = { + 'action_loss': float(action_loss.detach().item()), + 'lewm_pred_loss': float(lewm_pred_loss.detach().item()), + 'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()), + 'lewm_loss': float(lewm_loss.detach().item()), + 'loss': float(total_loss.detach().item()), + } + return total_loss + + def get_last_loss_breakdown(self) -> Dict[str, float]: + return dict(self._last_loss_breakdown) + + def reset(self): + super().reset() + if self.lewm_predictor is not None: + self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon) + self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon) + + def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None: + super()._populate_queues(observation) + if self.lewm_predictor is None: + return + if 'qpos' in observation: + self._queues['lewm_qpos'].append(observation['qpos'].clone()) + if 'images' in observation: + ordered_images = self._order_images(observation['images']) + self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()}) + + def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]: + batch = super()._prepare_observation_batch() + if self.lewm_predictor is None: + return batch + + qpos_list = list(self._queues['lewm_qpos']) + images_list = list(self._queues['lewm_images']) + if len(qpos_list) == 0 or len(images_list) == 0: + raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测") + while len(qpos_list) < self.lewm_history_horizon: + qpos_list.append(qpos_list[-1]) + while len(images_list) < self.lewm_history_horizon: + images_list.append(images_list[-1]) + + batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0) + batch['lewm_images'] = {} + camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys())) + for cam_name in camera_names: + batch['lewm_images'][cam_name] = torch.stack( + [img[cam_name] for img in images_list], + dim=0, + ).unsqueeze(0) + return batch @torch.no_grad() - def predict_action(self, images, proprioception): + def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + return self.predict_action( + batch['images'], + batch['qpos'], + lewm_images=batch.get('lewm_images', None), + lewm_proprioception=batch.get('lewm_qpos', None), + ) + + @torch.no_grad() + def predict_action( + self, + images, + proprioception, + *, + lewm_images=None, + lewm_proprioception=None, + ): batch_size = proprioception.shape[0] - proprioception = self.normalization.normalize_qpos(proprioception) - cond = self._build_cond(images, proprioception) + cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition( + images, + proprioception, + lewm_images=lewm_images, + lewm_proprioception=lewm_proprioception, + ) z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype) action = self._sample_one_step(z_t, cond=cond) return self.normalization.denormalize_action(action) diff --git a/roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml b/roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml new file mode 100644 index 0000000..d436719 --- /dev/null +++ b/roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml @@ -0,0 +1,77 @@ +# @package agent +defaults: + - /backbone@vision_backbone: lewm_resnet_query_fusion + - /modules@state_encoder: lewm_state_encoder + - /modules@action_encoder: identity_action_encoder + - /modules@cond_projector: linear_condition_projector + - /head: imf_transformer1d + - _self_ + +_target_: roboimi.vla.agent_imf.IMFVLAAgent + +action_dim: 16 +obs_dim: 16 +normalization_type: "min_max" +pred_horizon: 8 +obs_horizon: 2 +num_action_steps: 8 +camera_names: ${data.camera_names} +num_cams: 3 + +vision_backbone: + camera_names: ${agent.camera_names} + num_views: ${agent.num_cams} + +cond_projector: + output_dim: 288 + +lewm_history_horizon: 3 +lewm_query_offsets: [8] +extra_condition_tokens: ${len:${agent.lewm_query_offsets}} +lewm_loss_weight: 1.0 +lewm_sigreg_weight: 0.09 +lewm_pretrained_ckpt: null + +lewm_sigreg: + _target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg + knots: 17 + num_proj: 1024 + +lewm_predictor: + _target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor + num_frames: ${agent.lewm_history_horizon} + query_offsets: ${agent.lewm_query_offsets} + input_dim: ${agent.cond_projector.output_dim} + hidden_dim: ${agent.cond_projector.output_dim} + output_dim: ${agent.cond_projector.output_dim} + depth: 6 + heads: 16 + mlp_dim: 2048 + dim_head: 64 + dropout: 0.1 + emb_dropout: 0.0 + +lewm_pred_projector: + _target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP + input_dim: ${agent.cond_projector.output_dim} + hidden_dim: 2048 + output_dim: ${agent.cond_projector.output_dim} + +diffusion_steps: 100 +inference_steps: 1 +head_type: "transformer" + +head: + input_dim: ${agent.action_dim} + output_dim: ${agent.action_dim} + horizon: ${agent.pred_horizon} + n_obs_steps: ${agent.obs_horizon} + cond_dim: 288 + n_emb: 384 + causal_attn: false + time_as_cond: true + obs_as_cond: true + n_cond_layers: 0 + backbone_type: attnres_full + n_head: 1 + n_kv_head: 1 diff --git a/roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml b/roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml new file mode 100644 index 0000000..481ca23 --- /dev/null +++ b/roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml @@ -0,0 +1,7 @@ +_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone + +view_feature_dim: 96 +num_views: ${agent.num_cams} +view_encoder_mode: separate +camera_names: ${agent.camera_names} +checkpoint_path: null diff --git a/roboimi/vla/conf/modules/lewm_state_encoder.yaml b/roboimi/vla/conf/modules/lewm_state_encoder.yaml new file mode 100644 index 0000000..8e8fd01 --- /dev/null +++ b/roboimi/vla/conf/modules/lewm_state_encoder.yaml @@ -0,0 +1,5 @@ +_target_: roboimi.vla.modules.encoders.LeWMStateEncoder + +input_dim: ${agent.obs_dim} +hidden_dim: 256 +output_dim: 64 diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py index c1c5a93..94156fc 100644 --- a/roboimi/vla/data/simpe_robot_dataset.py +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -24,6 +24,8 @@ class SimpleRobotDataset(Dataset): camera_names: List[str] = None, image_resize_shape: Optional[Sequence[int]] = (224, 224), max_open_files: int = 64, + lewm_history_horizon: Optional[int] = None, + lewm_query_offsets: Optional[Sequence[int]] = None, ): """ Args: @@ -42,6 +44,13 @@ class SimpleRobotDataset(Dataset): self.obs_horizon = obs_horizon self.pred_horizon = pred_horizon self.camera_names = camera_names or [] + self.lewm_history_horizon = ( + int(lewm_history_horizon) if lewm_history_horizon is not None else None + ) + self.lewm_query_offsets = ( + tuple(int(offset) for offset in lewm_query_offsets) + if lewm_query_offsets is not None else () + ) self.image_resize_shape = ( tuple(int(v) for v in image_resize_shape) if image_resize_shape is not None else None @@ -220,6 +229,60 @@ class SimpleRobotDataset(Dataset): for cam_name in self.camera_names: result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"]) + if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0: + lewm_observations = { + "state": [], + } + for cam_name in self.camera_names: + lewm_observations[f"observation.{cam_name}"] = [] + + for delta in range(-self.lewm_history_horizon + 1, 1): + target_idx = idx + delta + if ep_start <= target_idx <= ep_end: + target_frame = self._load_frame(target_idx) + else: + boundary_idx = ep_start if target_idx < ep_start else ep_end + target_frame = self._load_frame(boundary_idx) + + lewm_observations["state"].append(target_frame["observation.state"]) + for cam_name in self.camera_names: + lewm_observations[f"observation.{cam_name}"].append( + target_frame[f"observation.{cam_name}"] + ) + + result["lewm.observation.state"] = torch.stack(lewm_observations["state"]) + for cam_name in self.camera_names: + result[f"lewm.observation.{cam_name}"] = torch.stack( + lewm_observations[f"observation.{cam_name}"] + ) + + if self.lewm_query_offsets: + lewm_future = { + "state": [], + } + for cam_name in self.camera_names: + lewm_future[f"observation.{cam_name}"] = [] + + for offset in self.lewm_query_offsets: + target_idx = idx + offset + if ep_start <= target_idx <= ep_end: + target_frame = self._load_frame(target_idx) + else: + boundary_idx = ep_start if target_idx < ep_start else ep_end + target_frame = self._load_frame(boundary_idx) + + lewm_future["state"].append(target_frame["observation.state"]) + for cam_name in self.camera_names: + lewm_future[f"observation.{cam_name}"].append( + target_frame[f"observation.{cam_name}"] + ) + + result["lewm.future.state"] = torch.stack(lewm_future["state"]) + for cam_name in self.camera_names: + result[f"lewm.future.{cam_name}"] = torch.stack( + lewm_future[f"observation.{cam_name}"] + ) + return result @property diff --git a/roboimi/vla/models/backbones/__init__.py b/roboimi/vla/models/backbones/__init__.py index c5544b5..2df41ec 100644 --- a/roboimi/vla/models/backbones/__init__.py +++ b/roboimi/vla/models/backbones/__init__.py @@ -1,5 +1,14 @@ # Backbone models -__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"] +__all__ = [ + "LEWMViTBackbone", + "LeWMMultiViewResNetBackbone", + "QueryTokenPredictor", + "LeWMProjectorMLP", + "SIGReg", + "ResNetBackbone", + "ResNetDiffusionBackbone", + "SigLIP2DiffusionBackbone", +] def __getattr__(name): @@ -9,6 +18,19 @@ def __getattr__(name): if name == "SigLIP2DiffusionBackbone": from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone return SigLIP2DiffusionBackbone + if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}: + from .lewm_resnet_query_fusion import ( + LeWMMultiViewResNetBackbone, + QueryTokenPredictor, + LeWMProjectorMLP, + SIGReg, + ) + return { + "LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone, + "QueryTokenPredictor": QueryTokenPredictor, + "LeWMProjectorMLP": LeWMProjectorMLP, + "SIGReg": SIGReg, + }[name] if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}: from .resnet_diffusion import ResNetDiffusionBackbone return ResNetDiffusionBackbone diff --git a/roboimi/vla/models/backbones/lewm_resnet_query_fusion.py b/roboimi/vla/models/backbones/lewm_resnet_query_fusion.py new file mode 100644 index 0000000..bd6c0ea --- /dev/null +++ b/roboimi/vla/models/backbones/lewm_resnet_query_fusion.py @@ -0,0 +1,409 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Mapping, Optional, Sequence + +import torch +from einops import rearrange +from torch import nn +import torch.nn.functional as F +from torchvision import models + +from roboimi.vla.core.interfaces import VLABackbone + + +class SpatialSoftmax2D(nn.Module): + """Convert a feature map into expected 2D keypoint coordinates per channel.""" + + def forward(self, feature_map): + if feature_map.ndim != 4: + raise ValueError( + f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}" + ) + + batch, channels, height, width = feature_map.shape + scores = feature_map.reshape(batch, channels, height * width) + attention = F.softmax(scores, dim=-1) + + ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype) + xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype) + grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") + grid_x = grid_x.reshape(1, 1, height * width) + grid_y = grid_y.reshape(1, 1, height * width) + + expected_x = (attention * grid_x).sum(dim=-1) + expected_y = (attention * grid_y).sum(dim=-1) + return torch.cat([expected_x, expected_y], dim=-1) + + +class ResNet18SpatialEncoder(nn.Module): + """Encode one camera view into a fixed-dimensional spatial-softmax embedding.""" + + def __init__(self, view_feature_dim=96): + super().__init__() + if view_feature_dim % 2 != 0: + raise ValueError("view_feature_dim must be even for spatial softmax features") + + backbone = models.resnet18(weights=None) + if all( + hasattr(backbone, name) + for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4") + ): + self.backbone = nn.Sequential( + backbone.conv1, + backbone.bn1, + backbone.relu, + backbone.maxpool, + backbone.layer1, + backbone.layer2, + backbone.layer3, + backbone.layer4, + ) + feature_channels = 512 + else: + children = list(backbone.children()) + if len(children) < 1: + raise ValueError("resnet18 backbone must expose child modules") + truncated = children[:-2] if len(children) > 2 else children + self.backbone = nn.Sequential(*truncated) + with torch.no_grad(): + dummy = torch.zeros(1, 3, 16, 16) + feature_channels = int(self.backbone(dummy).shape[1]) + + self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1) + self.spatial_softmax = SpatialSoftmax2D() + self.output_dim = int(view_feature_dim) + + def forward(self, pixels): + if pixels.ndim not in (4, 5): + raise ValueError( + f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}" + ) + + needs_unflatten = pixels.ndim == 5 + if needs_unflatten: + batch, steps, channels, height, width = pixels.shape + pixels = rearrange(pixels, "b t c h w -> (b t) c h w") + + features = self.backbone(pixels.float()) + features = self.proj(features) + embeddings = self.spatial_softmax(features) + + if needs_unflatten: + embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps) + return embeddings + + +class LeWMMultiViewResNetBackbone(VLABackbone): + """RoboIMI-side LeWM multiview ResNet spatial-softmax encoder.""" + + def __init__( + self, + view_feature_dim: int = 96, + num_views: int = 3, + view_encoder_mode: str = "shared", + camera_names: Sequence[str] = ("r_vis", "top", "front"), + checkpoint_path: str | Path | None = None, + ) -> None: + super().__init__() + if view_encoder_mode not in {"shared", "separate"}: + raise ValueError( + f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}" + ) + + self.view_feature_dim = int(view_feature_dim) + self.num_views = int(num_views) + self.view_encoder_mode = view_encoder_mode + self.camera_names = tuple(camera_names) + if len(self.camera_names) != self.num_views: + raise ValueError( + f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})" + ) + self.output_dim = self.view_feature_dim * self.num_views + self.joint_output_dim = self.output_dim + self.tokens_per_step = 1 + + if view_encoder_mode == "shared": + self.single_view_encoder = ResNet18SpatialEncoder( + view_feature_dim=view_feature_dim + ) + self.view_encoders = None + else: + self.single_view_encoder = None + self.view_encoders = nn.ModuleList( + [ + ResNet18SpatialEncoder(view_feature_dim=view_feature_dim) + for _ in range(num_views) + ] + ) + + if checkpoint_path is not None: + self.load_lewm_checkpoint(checkpoint_path) + + @staticmethod + def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: + state_dict = payload.get("state_dict", payload) + if not isinstance(state_dict, Mapping): + raise TypeError("checkpoint payload must contain a mapping state_dict") + return state_dict + + @staticmethod + def _extract_prefixed_state_dict( + state_dict: Mapping[str, torch.Tensor], + prefix: str, + ) -> Dict[str, torch.Tensor]: + extracted = { + key[len(prefix):]: value + for key, value in state_dict.items() + if key.startswith(prefix) + } + if not extracted: + raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}") + return extracted + + def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None: + if isinstance(checkpoint_or_path, (str, Path)): + payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False) + else: + payload = checkpoint_or_path + state_dict = self._unwrap_state_dict(payload) + encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.") + self.load_state_dict(encoder_state_dict, strict=True) + + def forward(self, images): + missing = [camera_name for camera_name in self.camera_names if camera_name not in images] + if missing: + raise ValueError( + f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}" + ) + + first_image = images[self.camera_names[0]] + batch_size, steps = first_image.shape[:2] + view_embeddings = [] + if self.view_encoder_mode == "shared": + for camera_name in self.camera_names: + view_embeddings.append(self.single_view_encoder(images[camera_name])) + else: + for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names): + view_embeddings.append(single_view_encoder(images[camera_name])) + + embeddings = torch.cat(view_embeddings, dim=-1) + return embeddings.reshape(batch_size, steps, self.output_dim) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + self.heads = heads + self.dropout = dropout + self.norm = nn.LayerNorm(dim) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, causal=True): + x = self.norm(x) + drop = self.dropout if self.training else 0.0 + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal) + out = rearrange(out, "b h t d -> b t (h d)") + return self.to_out(out) + + +class Block(nn.Module): + def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout=0.0, + block_class=Block, + ): + super().__init__() + self.norm = nn.LayerNorm(hidden_dim) + self.layers = nn.ModuleList([]) + + self.input_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + self.cond_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + self.output_proj = ( + nn.Linear(hidden_dim, output_dim) + if hidden_dim != output_dim + else nn.Identity() + ) + + for _ in range(depth): + self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)) + + def forward(self, x, c=None): + x = self.input_proj(x) + if c is not None: + c = self.cond_proj(c) + for block in self.layers: + x = block(x) + x = self.norm(x) + return self.output_proj(x) + + +class QueryTokenPredictor(nn.Module): + """History-only transformer predictor that decodes learned query tokens.""" + + def __init__( + self, + *, + num_frames, + query_offsets, + depth, + heads, + mlp_dim, + input_dim, + hidden_dim, + output_dim=None, + dim_head=64, + dropout=0.0, + emb_dropout=0.0, + ): + super().__init__() + if num_frames <= 0: + raise ValueError(f"num_frames must be positive, got {num_frames}") + + query_offsets = tuple(query_offsets) + if not query_offsets: + raise ValueError("query_offsets must contain at least one offset") + if any(offset <= 0 for offset in query_offsets): + raise ValueError(f"query_offsets must be positive, got {query_offsets}") + + self.num_frames = int(num_frames) + self.query_offsets = query_offsets + self.num_query_tokens = len(query_offsets) + self.pos_embedding = nn.Parameter( + torch.randn(1, self.num_frames + self.num_query_tokens, input_dim) + ) + self.query_tokens = nn.Parameter( + torch.randn(1, self.num_query_tokens, input_dim) + ) + self.dropout = nn.Dropout(emb_dropout) + self.transformer = Transformer( + input_dim, + hidden_dim, + output_dim or input_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + block_class=Block, + ) + + def forward(self, x): + if x.ndim != 3: + raise ValueError( + f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}" + ) + + T = x.size(1) + if T > self.num_frames: + raise ValueError( + f"input sequence length {T} exceeds configured num_frames {self.num_frames}" + ) + + query_tokens = self.query_tokens.expand(x.size(0), -1, -1) + tokens = torch.cat([x, query_tokens], dim=1) + tokens = tokens + self.pos_embedding[:, : tokens.size(1)] + tokens = self.dropout(tokens) + tokens = self.transformer(tokens) + return tokens[:, -self.num_query_tokens :] + + +class LeWMProjectorMLP(nn.Module): + def __init__( + self, + input_dim: int = 288, + hidden_dim: int = 2048, + output_dim: int = 288, + ) -> None: + super().__init__() + self.output_dim = int(output_dim) + self.net = nn.Sequential( + nn.Linear(int(input_dim), int(hidden_dim)), + nn.BatchNorm1d(int(hidden_dim)), + nn.GELU(), + nn.Linear(int(hidden_dim), self.output_dim), + ) + + def forward(self, x): + return self.net(x) + + +class SIGReg(nn.Module): + """Sketch Isotropic Gaussian Regularizer, matching the original LeWM design.""" + + def __init__(self, knots: int = 17, num_proj: int = 1024) -> None: + super().__init__() + self.num_proj = int(num_proj) + t = torch.linspace(0, 3, int(knots), dtype=torch.float32) + dt = 3 / (int(knots) - 1) + weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32) + weights[[0, -1]] = dt + window = torch.exp(-t.square() / 2.0) + self.register_buffer("t", t) + self.register_buffer("phi", window) + self.register_buffer("weights", weights * window) + + def forward(self, proj: torch.Tensor) -> torch.Tensor: + """ + proj: (T, B, D) + """ + A = torch.randn(proj.size(-1), self.num_proj, device=proj.device) + A = A.div_(A.norm(p=2, dim=0)) + x_t = (proj @ A).unsqueeze(-1) * self.t + err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square() + statistic = (err @ self.weights) * proj.size(-2) + return statistic.mean() diff --git a/roboimi/vla/modules/encoders.py b/roboimi/vla/modules/encoders.py index 0fa0970..5ecb5b6 100644 --- a/roboimi/vla/modules/encoders.py +++ b/roboimi/vla/modules/encoders.py @@ -15,4 +15,24 @@ class IdentityActionEncoder(nn.Module): super().__init__() def forward(self, action): - return action \ No newline at end of file + return action + + +class LeWMStateEncoder(nn.Module): + def __init__( + self, + input_dim: int = 16, + hidden_dim: int = 256, + output_dim: int = 64, + ): + super().__init__() + self.output_dim = int(output_dim) + self.net = nn.Sequential( + nn.Linear(int(input_dim), int(hidden_dim)), + nn.LayerNorm(int(hidden_dim)), + nn.GELU(), + nn.Linear(int(hidden_dim), self.output_dim), + ) + + def forward(self, state): + return self.net(state) diff --git a/tests/test_imf_vla_agent.py b/tests/test_imf_vla_agent.py index dfccdff..d611009 100644 --- a/tests/test_imf_vla_agent.py +++ b/tests/test_imf_vla_agent.py @@ -376,6 +376,29 @@ class _ForbiddenScheduler: raise AssertionError('IMF inference should not use DDIM scheduler step') +class _StubFutureTokenPredictor(nn.Module): + def __init__(self, num_future_tokens=1): + super().__init__() + self.num_future_tokens = int(num_future_tokens) + self.calls = [] + + def forward(self, history_tokens): + self.calls.append(history_tokens.detach().clone()) + summary = history_tokens.mean(dim=1, keepdim=True) + return summary.repeat(1, self.num_future_tokens, 1) + + +class _RecordingSigReg(nn.Module): + def __init__(self, value=0.5): + super().__init__() + self.value = float(value) + self.calls = [] + + def forward(self, embeddings): + self.calls.append(embeddings.detach().clone()) + return embeddings.new_tensor(self.value) + + def _make_images(batch_size, obs_horizon, per_camera_fill): return { name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32) @@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase): self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2))) self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) + def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self): + agent_cls, agent_module = _load_imf_agent_class() + head = _RecordingLinearIMFHead() + future_predictor = _StubFutureTokenPredictor(num_future_tokens=1) + agent = agent_cls( + vision_backbone=_StubVisionBackbone(), + state_encoder=nn.Identity(), + action_encoder=nn.Identity(), + head=head, + action_dim=2, + obs_dim=1, + pred_horizon=3, + obs_horizon=2, + diffusion_steps=10, + inference_steps=1, + num_cams=len(_CAMERA_NAMES), + camera_names=_CAMERA_NAMES, + num_action_steps=2, + head_type='transformer', + extra_condition_tokens=1, + lewm_history_horizon=3, + lewm_query_offsets=[8], + lewm_predictor=future_predictor, + lewm_pred_projector=nn.Identity(), + lewm_loss_weight=0.5, + ) + agent.infer_scheduler = _ForbiddenScheduler() + + images = _make_images( + batch_size=1, + obs_horizon=2, + per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, + ) + qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32) + lewm_images = _make_images( + batch_size=1, + obs_horizon=3, + per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, + ) + lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32) + initial_noise = torch.tensor( + [[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], + dtype=torch.float32, + ) + + with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise): + _ = agent.predict_action( + images, + qpos, + lewm_images=lewm_images, + lewm_proprioception=lewm_qpos, + ) + + expected_history = torch.tensor( + [[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]], + dtype=torch.float32, + ) + expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32) + expected_cond = torch.cat([expected_history, expected_future], dim=1) + + self.assertEqual(agent.condition_sequence_length, 3) + self.assertEqual(agent.per_step_cond_dim, 4) + self.assertEqual(len(head.calls), 1) + self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) + self.assertEqual(len(future_predictor.calls), 1) + + def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self): + agent_cls, agent_module = _load_imf_agent_class() + head = _RecordingLinearIMFHead() + future_predictor = _StubFutureTokenPredictor(num_future_tokens=1) + sigreg = _RecordingSigReg(value=0.75) + agent = agent_cls( + vision_backbone=_StubVisionBackbone(), + state_encoder=nn.Identity(), + action_encoder=nn.Identity(), + head=head, + action_dim=2, + obs_dim=1, + pred_horizon=3, + obs_horizon=2, + diffusion_steps=10, + inference_steps=1, + num_cams=len(_CAMERA_NAMES), + camera_names=_CAMERA_NAMES, + num_action_steps=2, + head_type='transformer', + extra_condition_tokens=1, + lewm_history_horizon=3, + lewm_query_offsets=[8], + lewm_predictor=future_predictor, + lewm_pred_projector=nn.Identity(), + lewm_sigreg=sigreg, + lewm_sigreg_weight=0.09, + lewm_loss_weight=0.25, + ) + + images = _make_images( + batch_size=1, + obs_horizon=2, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32) + actions = torch.tensor( + [[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]], + dtype=torch.float32, + ) + lewm_images = _make_images( + batch_size=1, + obs_horizon=3, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32) + lewm_future_images = _make_images( + batch_size=1, + obs_horizon=1, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32) + noise = torch.tensor( + [[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]], + dtype=torch.float32, + ) + t_sample = torch.tensor([0.8], dtype=torch.float32) + r_sample = torch.tensor([0.25], dtype=torch.float32) + + with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \ + mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]): + loss = agent.compute_loss( + { + 'images': images, + 'qpos': qpos, + 'action': actions, + 'lewm_images': lewm_images, + 'lewm_qpos': lewm_qpos, + 'lewm_future_images': lewm_future_images, + 'lewm_future_qpos': lewm_future_qpos, + } + ) + + metrics = agent.get_last_loss_breakdown() + self.assertAlmostEqual(loss.item(), metrics['loss'], places=6) + self.assertIn('action_loss', metrics) + self.assertIn('lewm_pred_loss', metrics) + self.assertIn('lewm_sigreg_loss', metrics) + self.assertIn('lewm_loss', metrics) + self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6) + self.assertAlmostEqual( + metrics['lewm_loss'], + metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'], + places=5, + ) + self.assertAlmostEqual( + metrics['loss'], + metrics['action_loss'] + 0.25 * metrics['lewm_loss'], + places=5, + ) + self.assertEqual(len(sigreg.calls), 1) + expected_lewm_history = torch.tensor( + [[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]], + dtype=torch.float32, + ) + torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1)) + def test_select_action_only_regenerates_when_action_queue_is_empty(self): agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2) observation = { @@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase): self.assertEqual(agent.vision_encoder.output_dim, 96) self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256)) + def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self): + cfg = _compose_cfg( + overrides=[ + 'agent=lewm_resnet_query_imf_attnres', + 'agent.head.n_layer=1', + 'agent.head.n_emb=16', + 'agent.lewm_query_offsets=[8]', + ] + ) + + self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') + self.assertEqual( + cfg.agent.vision_backbone._target_, + 'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone', + ) + self.assertEqual( + cfg.agent.state_encoder._target_, + 'roboimi.vla.modules.encoders.LeWMStateEncoder', + ) + self.assertEqual(cfg.agent.head.cond_dim, 288) + self.assertEqual(cfg.agent.cond_projector.output_dim, 288) + self.assertEqual(cfg.agent.extra_condition_tokens, 1) + self.assertEqual( + cfg.agent.lewm_sigreg._target_, + 'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg', + ) + self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09) + + with _stub_optional_modules(include_imf_head=True): + agent = instantiate(cfg.agent) + + self.assertEqual(agent.per_step_cond_dim, 288) + self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1) + self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288) + self.assertEqual( + agent.noise_pred_net.constructor_kwargs['n_obs_steps'], + agent.condition_sequence_length, + ) + self.assertIsNotNone(agent.lewm_sigreg) + def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self): cfg = _compose_cfg( diff --git a/tests/test_simple_robot_dataset_image_loading.py b/tests/test_simple_robot_dataset_image_loading.py index b305275..2dd5321 100644 --- a/tests/test_simple_robot_dataset_image_loading.py +++ b/tests/test_simple_robot_dataset_image_loading.py @@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): fake_cv2.resize.assert_not_called() self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8)) + + def test_getitem_can_emit_lewm_history_and_future_observations(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) + self._write_episode(dataset_dir) + dataset = SimpleRobotDataset( + dataset_dir, + obs_horizon=2, + pred_horizon=3, + camera_names=["front"], + image_resize_shape=None, + lewm_history_horizon=3, + lewm_query_offsets=[1, 2], + ) + + sample = dataset[1] + + self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4)) + self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8)) + self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4)) + self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8)) diff --git a/tests/test_train_vla_swanlab_logging.py b/tests/test_train_vla_swanlab_logging.py index 3918e9b..a191e5e 100644 --- a/tests/test_train_vla_swanlab_logging.py +++ b/tests/test_train_vla_swanlab_logging.py @@ -114,6 +114,22 @@ class FakeAgent(nn.Module): return {} +class RecordingAgent(FakeAgent): + def __init__(self): + super().__init__() + self.seen_inputs = [] + + def compute_loss(self, agent_input): + self.seen_inputs.append(agent_input) + return super().compute_loss(agent_input) + + +class ShapeMixedFakeAgent(FakeAgent): + def __init__(self): + super().__init__() + self.bias = nn.Parameter(torch.zeros(2)) + + class FakeSwanLab: def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None): self.init_error = init_error @@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), } + def _make_lewm_batch(self): + batch = self._make_batch() + batch.update( + { + 'lewm.observation.front': torch.ones(1, 3, 2, 2), + 'lewm.observation.state': torch.ones(1, 4), + 'lewm.future.front': torch.full((1, 3, 2, 2), 2.0), + 'lewm.future.state': torch.full((1, 4), 2.0), + } + ) + return batch + def _loader_factory(self): train_batch = self._make_batch() val_batch = self._make_batch() @@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): return factory + def _lewm_loader_factory(self): + train_batch = self._make_lewm_batch() + val_batch = self._make_lewm_batch() + + def factory(_dataset, *, shuffle, **_kwargs): + return FakeLoader(train_batch if shuffle else val_batch) + + return factory + def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) @@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt')) self.assertEqual(fake_swanlab.finish_calls, 1) + def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg(use_swanlab=False) + cfg.train.max_steps = 1 + cfg.train.save_freq = 100 + agent = RecordingAgent() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertGreaterEqual(len(agent.seen_inputs), 1) + first_input = agent.seen_inputs[0] + self.assertIn('lewm_images', first_input) + self.assertIn('lewm_qpos', first_input) + self.assertIn('lewm_future_images', first_input) + self.assertIn('lewm_future_qpos', first_input) + self.assertIn('front', first_input['lewm_images']) + self.assertIn('front', first_input['lewm_future_images']) + def test_run_training_skips_swanlab_when_disabled(self): module = self._load_train_vla_module() run_training = self._get_run_training(module) @@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase): self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt')) self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths)) + def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg(use_swanlab=False) + cfg.train.max_steps = 0 + cfg.train.save_freq = 100 + cfg.train.pretrained_ckpt = 'pretrained.pt' + agent = ShapeMixedFakeAgent() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_torch_load(path, map_location=None): + del map_location + if Path(path).name != 'pretrained.pt': + raise AssertionError(f'unexpected load path: {path}') + return { + 'model_state_dict': { + 'weight': torch.tensor(3.0), + 'bias': torch.tensor([1.0, 2.0, 3.0]), + }, + 'step': 123, + 'loss': 0.5, + } + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + Path('pretrained.pt').write_bytes(b'pretend') + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.torch, 'load', side_effect=fake_torch_load): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertAlmostEqual(agent.weight.item(), 3.0, places=6) + def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self): module = self._load_train_vla_module() run_training = self._get_run_training(module)