feat: add lewm-conditioned imf training and sigreg loss

This commit is contained in:
Logic
2026-04-17 18:46:02 +08:00
parent ff7c9c1f2a
commit 74f4963613
14 changed files with 1634 additions and 24 deletions

View File

@@ -237,6 +237,32 @@ def build_training_optimizer(agent, lr, weight_decay):
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
current_state_dict = module.state_dict()
compatible_state_dict = {}
mismatched_keys = []
missing_keys = []
for key, value in incoming_state_dict.items():
if key not in current_state_dict:
missing_keys.append(key)
continue
if current_state_dict[key].shape != value.shape:
mismatched_keys.append(key)
continue
compatible_state_dict[key] = value
merged_state_dict = dict(current_state_dict)
merged_state_dict.update(compatible_state_dict)
module.load_state_dict(merged_state_dict, strict=True)
return {
'loaded_keys': sorted(compatible_state_dict.keys()),
'missing_keys': sorted(missing_keys),
'mismatched_keys': sorted(mismatched_keys),
}
def _init_swanlab(cfg):
"""按需初始化 SwanLab并在缺少依赖或认证失败时快速失败。"""
if not bool(cfg.train.get('use_swanlab', False)):
@@ -509,18 +535,23 @@ def _run_training(cfg: DictConfig):
try:
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
# 只加载模型权重(不加载 optimizer、scheduler
missing_keys, unexpected_keys = agent.load_state_dict(
load_info = load_state_dict_ignoring_shape_mismatches(
agent,
checkpoint['model_state_dict'],
strict=False # 允许部分加载(结构不完全匹配时)
)
log.info(f"✅ [Finetune] 模型权重加载成功")
if missing_keys:
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
if unexpected_keys:
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
if load_info['missing_keys']:
log.warning(
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
f"{load_info['missing_keys'][:5]}..."
)
if load_info['mismatched_keys']:
log.warning(
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
f"{load_info['mismatched_keys'][:5]}..."
)
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
log.info(f"📈 [Finetune] 使用新的训练配置lr={cfg.train.lr}, max_steps={cfg.train.max_steps}")
@@ -652,13 +683,35 @@ def _run_training(cfg: DictConfig):
if key in batch_data:
images[cam_name] = batch_data[key]
return {
agent_input = {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
lewm_images = {}
lewm_future_images = {}
for cam_name in cfg.data.camera_names:
lewm_obs_key = f"lewm.observation.{cam_name}"
if lewm_obs_key in batch_data:
lewm_images[cam_name] = batch_data[lewm_obs_key]
lewm_future_key = f"lewm.future.{cam_name}"
if lewm_future_key in batch_data:
lewm_future_images[cam_name] = batch_data[lewm_future_key]
if 'lewm.observation.state' in batch_data:
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
if lewm_images:
agent_input['lewm_images'] = lewm_images
if 'lewm.future.state' in batch_data:
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
if lewm_future_images:
agent_input['lewm_future_images'] = lewm_future_images
return agent_input
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
agent_stats = agent.get_normalization_stats()
torch.save({
@@ -809,6 +862,15 @@ def _run_training(cfg: DictConfig):
},
step=step,
)
if hasattr(agent, 'get_last_loss_breakdown'):
loss_breakdown = agent.get_last_loss_breakdown()
extra_train_metrics = {
f"train/{key}": value
for key, value in loss_breakdown.items()
if value is not None and key != 'loss'
}
if extra_train_metrics:
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
# =====================================================================
# 检查点保存与验证

View File

@@ -0,0 +1,267 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import datetime as dt
import json
import pathlib
import re
import shlex
import subprocess
from collections import defaultdict
from typing import Any
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
def normalize_chunks(text: str):
for part in re.split(r"[\r\n]+", text):
part = part.strip()
if part:
yield part
def parse_latest_line(text: str) -> tuple[str, int | None]:
latest_line = ""
latest_step = None
for line in normalize_chunks(text):
if "步骤" not in line and "训练中:" not in line:
continue
latest_line = line
match = STEP_PAT.search(line) or BAR_PAT.search(line)
if match:
latest_step = int(match.group(1))
return latest_line, latest_step
def now_iso() -> str:
return dt.datetime.now(
dt.timezone(dt.timedelta(hours=8)),
).isoformat(timespec="seconds")
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
return subprocess.run(cmd, capture_output=True, text=True, check=check)
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
pid = str(run["pid"])
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
log_path = pathlib.Path(run["log_path"])
latest_line = ""
latest_step = None
if log_path.exists():
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
return {
"alive": bool(ps.stdout.strip()),
"ps": ps.stdout.strip(),
"log_exists": log_path.exists(),
"latest_line": latest_line,
"latest_step": latest_step,
}
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
payload = [
{
"run_id": run["run_id"],
"pid": str(run["pid"]),
"log_path": run["log_path"],
}
for run in runs
]
remote_py = r"""
import json
import pathlib
import re
import subprocess
import sys
payload = json.loads(sys.argv[1])
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
def normalize_chunks(text):
for part in re.split(r"[\r\n]+", text):
part = part.strip()
if part:
yield part
def parse_latest_line(text):
latest_line = ""
latest_step = None
for line in normalize_chunks(text):
if "步骤" not in line and "训练中:" not in line:
continue
latest_line = line
match = step_pat.search(line) or bar_pat.search(line)
if match:
latest_step = int(match.group(1))
return latest_line, latest_step
out = {}
for item in payload:
try:
ps = subprocess.run(
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
capture_output=True,
text=True,
check=False,
)
log_path = pathlib.Path(item["log_path"])
latest_line = ""
latest_step = None
if log_path.exists():
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
out[item["run_id"]] = {
"alive": bool(ps.stdout.strip()),
"ps": ps.stdout.strip(),
"log_exists": log_path.exists(),
"latest_line": latest_line,
"latest_step": latest_step,
}
except Exception as exc:
out[item["run_id"]] = {
"alive": False,
"ps": "",
"log_exists": False,
"latest_line": "",
"latest_step": None,
"error": str(exc),
}
print(json.dumps(out, ensure_ascii=False))
"""
remote_target = host if "@" in host else f"{remote_user}@{host}"
remote_cmd = (
f"python3 -c {shlex.quote(remote_py)} "
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
)
try:
res = run_cmd(
[
"ssh",
"-F",
"/dev/null",
"-o",
"BatchMode=yes",
"-o",
"StrictHostKeyChecking=accept-new",
remote_target,
remote_cmd,
]
)
return json.loads(res.stdout)
except subprocess.CalledProcessError as exc:
error = (exc.stderr or exc.stdout or str(exc)).strip()
return {
run["run_id"]: {
"alive": False,
"ps": "",
"log_exists": False,
"latest_line": "",
"latest_step": None,
"error": f"ssh_failed: {error}",
}
for run in runs
}
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
lines = [f"\n## Status snapshot {snapshot_at}"]
for run in runs:
lines.append(
(
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
f"pid={run['pid']}"
)
)
if run.get("latest_line"):
lines.append(f" - latest_line: `{run['latest_line']}`")
if run.get("error"):
lines.append(f" - error: `{run['error']}`")
with notes_path.open("a", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("suite_dir", type=pathlib.Path)
parser.add_argument("--remote-user", default="droid")
parser.add_argument("--append-notes", action="store_true")
args = parser.parse_args()
suite_dir = args.suite_dir.resolve()
status_path = suite_dir / "status.json"
notes_path = suite_dir / "notes.md"
monitor_dir = suite_dir / "monitor_logs"
monitor_dir.mkdir(parents=True, exist_ok=True)
status = json.loads(status_path.read_text(encoding="utf-8"))
runs: list[dict[str, Any]] = status["runs"]
snapshot_at = now_iso()
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
for run in runs:
by_host[run["host"]].append(run)
results: dict[str, dict[str, Any]] = {}
for host, host_runs in by_host.items():
if host == "local":
for run in host_runs:
results[run["run_id"]] = probe_local(run)
else:
results.update(remote_probe(host, args.remote_user, host_runs))
alive_count = 0
for run in runs:
result = results[run["run_id"]]
run["alive"] = result["alive"]
run["ps"] = result["ps"]
run["log_exists"] = result["log_exists"]
run["latest_line"] = result["latest_line"]
run["latest_step"] = result["latest_step"]
run["last_verified_at"] = snapshot_at
if "error" in result:
run["error"] = result["error"]
else:
run.pop("error", None)
run["status"] = "running" if result["alive"] else "stopped"
alive_count += int(result["alive"])
status["last_verified_at"] = snapshot_at
status["alive_count"] = alive_count
status["total_runs"] = len(runs)
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
snapshot_payload = {
"suite_name": status.get("suite_name"),
"snapshot_at": snapshot_at,
"alive_count": alive_count,
"total_runs": len(runs),
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
}
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
snapshot_path.write_text(
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
if args.append_notes:
append_notes(notes_path, snapshot_at, runs)
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
print(f"\nstatus_json={status_path}")
print(f"snapshot_json={snapshot_path}")
if args.append_notes:
print(f"notes_md={notes_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -28,6 +28,7 @@ class VLAAgent(nn.Module):
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
extra_condition_tokens: int = 0, # 可选额外条件token数量例如未来预测embedding
):
super().__init__()
# 保存参数
@@ -39,6 +40,9 @@ class VLAAgent(nn.Module):
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
self.extra_condition_tokens = int(extra_condition_tokens)
if self.extra_condition_tokens < 0:
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
agent_camera_names = tuple(camera_names) if camera_names is not None else None
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
@@ -71,11 +75,14 @@ class VLAAgent(nn.Module):
stats=dataset_stats,
normalization_type=normalization_type
)
self.dataset_stats = dataset_stats
self.vision_encoder = vision_backbone
self.state_encoder = state_encoder
if self.camera_names is not None:
self.vision_encoder.camera_names = self.camera_names
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
if joint_vision_dim is not None:
per_token_vision_dim = int(joint_vision_dim)
@@ -87,8 +94,11 @@ class VLAAgent(nn.Module):
else:
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
self.condition_sequence_length = (
self.history_condition_sequence_length + self.extra_condition_tokens
)
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
if cond_projector is None:
self.cond_projector = None
self.per_step_cond_dim = self.raw_per_step_cond_dim
@@ -139,7 +149,6 @@ class VLAAgent(nn.Module):
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
@@ -220,7 +229,7 @@ class VLAAgent(nn.Module):
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
expected_length = self.condition_sequence_length
expected_length = self.history_condition_sequence_length
if cond.shape[1] != expected_length:
raise RuntimeError(
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"

View File

@@ -1,9 +1,12 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Dict, Optional
from collections import deque
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from roboimi.vla.agent import VLAAgent
@@ -15,14 +18,59 @@ except ImportError: # pragma: no cover
class IMFVLAAgent(VLAAgent):
def __init__(self, *args, inference_steps: int = 1, **kwargs):
def __init__(
self,
*args,
inference_steps: int = 1,
lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None,
lewm_predictor: Optional[nn.Module] = None,
lewm_pred_projector: Optional[nn.Module] = None,
lewm_sigreg: Optional[nn.Module] = None,
lewm_sigreg_weight: float = 0.09,
lewm_loss_weight: float = 0.0,
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
**kwargs,
):
if inference_steps != 1:
raise ValueError(
'IMFVLAAgent only supports one-step inference; '
f'inference_steps must be 1, got {inference_steps}.'
)
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens)
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
self.__dict__['lewm_predictor'] = lewm_predictor
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
self.__dict__['lewm_sigreg'] = lewm_sigreg
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
self.__dict__['_last_loss_breakdown'] = {
'action_loss': 0.0,
'lewm_pred_loss': 0.0,
'lewm_sigreg_loss': 0.0,
'lewm_loss': 0.0,
'loss': 0.0,
}
super().__init__(*args, inference_steps=inference_steps, **kwargs)
self.inference_steps = 1
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
self.lewm_predictor = lewm_predictor
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
self.lewm_sigreg = lewm_sigreg
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
raise ValueError(
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
)
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
raise ValueError(
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
)
if lewm_pretrained_ckpt is not None:
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
@@ -119,14 +167,181 @@ class IMFVLAAgent(VLAAgent):
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
if not self.normalization.enabled:
return qpos
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
qpos_std = getattr(self.normalization, 'qpos_std', None)
if qpos_mean is not None and qpos_std is not None:
return (qpos - qpos_mean) / qpos_std
if isinstance(self.dataset_stats, dict):
mean = self.dataset_stats.get('qpos_mean', None)
std = self.dataset_stats.get('qpos_std', None)
if mean is not None and std is not None:
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
return (qpos - mean) / std
return self.normalization.normalize_qpos(qpos)
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
if predicted_tokens.ndim != 3:
raise ValueError(
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
)
batch_size, token_count, token_dim = predicted_tokens.shape
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
projected = self.lewm_pred_projector(flattened)
if projected.ndim != 2:
raise ValueError(
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
)
return projected.reshape(batch_size, token_count, projected.shape[-1])
@staticmethod
def _load_checkpoint_payload(
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> Mapping[str, torch.Tensor]:
if isinstance(checkpoint_or_path, (str, Path)):
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
else:
payload = checkpoint_or_path
state_dict = payload.get('state_dict', payload)
if not isinstance(state_dict, Mapping):
raise TypeError('checkpoint payload must contain a mapping state_dict')
return state_dict
@staticmethod
def _extract_prefixed_state_dict(
state_dict: Mapping[str, torch.Tensor],
prefix: str,
) -> Dict[str, torch.Tensor]:
extracted = {
key[len(prefix):]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
if not extracted:
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
return extracted
@staticmethod
def _adapt_and_load_state_dict(
module: nn.Module,
incoming_state_dict: Mapping[str, torch.Tensor],
*,
query_key: str = 'query_tokens',
pos_key: str = 'pos_embedding',
) -> None:
current_state_dict = module.state_dict()
adapted_state_dict = dict(current_state_dict)
for key, current_tensor in current_state_dict.items():
if key not in incoming_state_dict:
continue
source_tensor = incoming_state_dict[key]
if source_tensor.shape == current_tensor.shape:
adapted_state_dict[key] = source_tensor
continue
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
patched = current_tensor.clone()
if key == query_key:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
else:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
adapted_state_dict[key] = patched
module.load_state_dict(adapted_state_dict, strict=True)
def load_lewm_pretrained_components(
self,
checkpoint_or_path: str | Path | Mapping[str, Any],
) -> None:
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
else:
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
self.vision_encoder.load_state_dict(vision_state_dict, strict=True)
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True)
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
mapped_projector_state_dict = {
f'linear.{key}': value
for key, value in projector_state_dict.items()
}
self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True)
if self.lewm_predictor is not None:
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
if self.lewm_pred_projector is not None:
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True)
def _build_full_condition(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
history_cond = self._build_cond(images, normalized_proprioception)
predicted_future_tokens = None
lewm_history_cond = None
if self.lewm_predictor is None:
return history_cond, predicted_future_tokens, lewm_history_cond
lewm_images = lewm_images if lewm_images is not None else images
lewm_proprioception = (
lewm_proprioception if lewm_proprioception is not None else proprioception
)
lewm_history_cond = self._build_cond(
lewm_images,
self._normalize_qpos_for_lewm(lewm_proprioception),
)
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
if cond.shape[1] != self.condition_sequence_length:
raise RuntimeError(
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
)
if cond.shape[-1] != self.per_step_cond_dim:
raise RuntimeError(
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
return cond, predicted_future_tokens, lewm_history_cond
@staticmethod
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(pred, target)
def compute_loss(self, batch):
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None)
batch_size = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
cond = self._build_cond(images, states)
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
images,
states,
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
x = actions
e = torch.randn_like(x)
@@ -146,16 +361,103 @@ class IMFVLAAgent(VLAAgent):
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
action_loss = loss.mean()
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
if predicted_future_tokens is not None:
lewm_future_images = batch.get('lewm_future_images', None)
lewm_future_qpos = batch.get('lewm_future_qpos', None)
if lewm_future_images is not None and lewm_future_qpos is not None:
future_target = self._build_cond(
lewm_future_images,
self._normalize_qpos_for_lewm(lewm_future_qpos),
)
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
if self.lewm_sigreg is not None and lewm_history_cond is not None:
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
self._last_loss_breakdown = {
'action_loss': float(action_loss.detach().item()),
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
'lewm_loss': float(lewm_loss.detach().item()),
'loss': float(total_loss.detach().item()),
}
return total_loss
def get_last_loss_breakdown(self) -> Dict[str, float]:
return dict(self._last_loss_breakdown)
def reset(self):
super().reset()
if self.lewm_predictor is not None:
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
super()._populate_queues(observation)
if self.lewm_predictor is None:
return
if 'qpos' in observation:
self._queues['lewm_qpos'].append(observation['qpos'].clone())
if 'images' in observation:
ordered_images = self._order_images(observation['images'])
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
batch = super()._prepare_observation_batch()
if self.lewm_predictor is None:
return batch
qpos_list = list(self._queues['lewm_qpos'])
images_list = list(self._queues['lewm_images'])
if len(qpos_list) == 0 or len(images_list) == 0:
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
while len(qpos_list) < self.lewm_history_horizon:
qpos_list.append(qpos_list[-1])
while len(images_list) < self.lewm_history_horizon:
images_list.append(images_list[-1])
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
batch['lewm_images'] = {}
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
for cam_name in camera_names:
batch['lewm_images'][cam_name] = torch.stack(
[img[cam_name] for img in images_list],
dim=0,
).unsqueeze(0)
return batch
@torch.no_grad()
def predict_action(self, images, proprioception):
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(
batch['images'],
batch['qpos'],
lewm_images=batch.get('lewm_images', None),
lewm_proprioception=batch.get('lewm_qpos', None),
)
@torch.no_grad()
def predict_action(
self,
images,
proprioception,
*,
lewm_images=None,
lewm_proprioception=None,
):
batch_size = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
images,
proprioception,
lewm_images=lewm_images,
lewm_proprioception=lewm_proprioception,
)
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
action = self._sample_one_step(z_t, cond=cond)
return self.normalization.denormalize_action(action)

View File

@@ -0,0 +1,77 @@
# @package agent
defaults:
- /backbone@vision_backbone: lewm_resnet_query_fusion
- /modules@state_encoder: lewm_state_encoder
- /modules@action_encoder: identity_action_encoder
- /modules@cond_projector: linear_condition_projector
- /head: imf_transformer1d
- _self_
_target_: roboimi.vla.agent_imf.IMFVLAAgent
action_dim: 16
obs_dim: 16
normalization_type: "min_max"
pred_horizon: 8
obs_horizon: 2
num_action_steps: 8
camera_names: ${data.camera_names}
num_cams: 3
vision_backbone:
camera_names: ${agent.camera_names}
num_views: ${agent.num_cams}
cond_projector:
output_dim: 288
lewm_history_horizon: 3
lewm_query_offsets: [8]
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
lewm_loss_weight: 1.0
lewm_sigreg_weight: 0.09
lewm_pretrained_ckpt: null
lewm_sigreg:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
knots: 17
num_proj: 1024
lewm_predictor:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
num_frames: ${agent.lewm_history_horizon}
query_offsets: ${agent.lewm_query_offsets}
input_dim: ${agent.cond_projector.output_dim}
hidden_dim: ${agent.cond_projector.output_dim}
output_dim: ${agent.cond_projector.output_dim}
depth: 6
heads: 16
mlp_dim: 2048
dim_head: 64
dropout: 0.1
emb_dropout: 0.0
lewm_pred_projector:
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
input_dim: ${agent.cond_projector.output_dim}
hidden_dim: 2048
output_dim: ${agent.cond_projector.output_dim}
diffusion_steps: 100
inference_steps: 1
head_type: "transformer"
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: 288
n_emb: 384
causal_attn: false
time_as_cond: true
obs_as_cond: true
n_cond_layers: 0
backbone_type: attnres_full
n_head: 1
n_kv_head: 1

View File

@@ -0,0 +1,7 @@
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
view_feature_dim: 96
num_views: ${agent.num_cams}
view_encoder_mode: separate
camera_names: ${agent.camera_names}
checkpoint_path: null

View File

@@ -0,0 +1,5 @@
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
input_dim: ${agent.obs_dim}
hidden_dim: 256
output_dim: 64

View File

@@ -24,6 +24,8 @@ class SimpleRobotDataset(Dataset):
camera_names: List[str] = None,
image_resize_shape: Optional[Sequence[int]] = (224, 224),
max_open_files: int = 64,
lewm_history_horizon: Optional[int] = None,
lewm_query_offsets: Optional[Sequence[int]] = None,
):
"""
Args:
@@ -42,6 +44,13 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.lewm_history_horizon = (
int(lewm_history_horizon) if lewm_history_horizon is not None else None
)
self.lewm_query_offsets = (
tuple(int(offset) for offset in lewm_query_offsets)
if lewm_query_offsets is not None else ()
)
self.image_resize_shape = (
tuple(int(v) for v in image_resize_shape)
if image_resize_shape is not None else None
@@ -220,6 +229,60 @@ class SimpleRobotDataset(Dataset):
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
lewm_observations = {
"state": [],
}
for cam_name in self.camera_names:
lewm_observations[f"observation.{cam_name}"] = []
for delta in range(-self.lewm_history_horizon + 1, 1):
target_idx = idx + delta
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
else:
boundary_idx = ep_start if target_idx < ep_start else ep_end
target_frame = self._load_frame(boundary_idx)
lewm_observations["state"].append(target_frame["observation.state"])
for cam_name in self.camera_names:
lewm_observations[f"observation.{cam_name}"].append(
target_frame[f"observation.{cam_name}"]
)
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
for cam_name in self.camera_names:
result[f"lewm.observation.{cam_name}"] = torch.stack(
lewm_observations[f"observation.{cam_name}"]
)
if self.lewm_query_offsets:
lewm_future = {
"state": [],
}
for cam_name in self.camera_names:
lewm_future[f"observation.{cam_name}"] = []
for offset in self.lewm_query_offsets:
target_idx = idx + offset
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
else:
boundary_idx = ep_start if target_idx < ep_start else ep_end
target_frame = self._load_frame(boundary_idx)
lewm_future["state"].append(target_frame["observation.state"])
for cam_name in self.camera_names:
lewm_future[f"observation.{cam_name}"].append(
target_frame[f"observation.{cam_name}"]
)
result["lewm.future.state"] = torch.stack(lewm_future["state"])
for cam_name in self.camera_names:
result[f"lewm.future.{cam_name}"] = torch.stack(
lewm_future[f"observation.{cam_name}"]
)
return result
@property

View File

@@ -1,5 +1,14 @@
# Backbone models
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
__all__ = [
"LEWMViTBackbone",
"LeWMMultiViewResNetBackbone",
"QueryTokenPredictor",
"LeWMProjectorMLP",
"SIGReg",
"ResNetBackbone",
"ResNetDiffusionBackbone",
"SigLIP2DiffusionBackbone",
]
def __getattr__(name):
@@ -9,6 +18,19 @@ def __getattr__(name):
if name == "SigLIP2DiffusionBackbone":
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
return SigLIP2DiffusionBackbone
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
from .lewm_resnet_query_fusion import (
LeWMMultiViewResNetBackbone,
QueryTokenPredictor,
LeWMProjectorMLP,
SIGReg,
)
return {
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
"QueryTokenPredictor": QueryTokenPredictor,
"LeWMProjectorMLP": LeWMProjectorMLP,
"SIGReg": SIGReg,
}[name]
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
from .resnet_diffusion import ResNetDiffusionBackbone
return ResNetDiffusionBackbone

View File

@@ -0,0 +1,409 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Sequence
import torch
from einops import rearrange
from torch import nn
import torch.nn.functional as F
from torchvision import models
from roboimi.vla.core.interfaces import VLABackbone
class SpatialSoftmax2D(nn.Module):
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
def forward(self, feature_map):
if feature_map.ndim != 4:
raise ValueError(
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
)
batch, channels, height, width = feature_map.shape
scores = feature_map.reshape(batch, channels, height * width)
attention = F.softmax(scores, dim=-1)
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
grid_x = grid_x.reshape(1, 1, height * width)
grid_y = grid_y.reshape(1, 1, height * width)
expected_x = (attention * grid_x).sum(dim=-1)
expected_y = (attention * grid_y).sum(dim=-1)
return torch.cat([expected_x, expected_y], dim=-1)
class ResNet18SpatialEncoder(nn.Module):
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
def __init__(self, view_feature_dim=96):
super().__init__()
if view_feature_dim % 2 != 0:
raise ValueError("view_feature_dim must be even for spatial softmax features")
backbone = models.resnet18(weights=None)
if all(
hasattr(backbone, name)
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
):
self.backbone = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3,
backbone.layer4,
)
feature_channels = 512
else:
children = list(backbone.children())
if len(children) < 1:
raise ValueError("resnet18 backbone must expose child modules")
truncated = children[:-2] if len(children) > 2 else children
self.backbone = nn.Sequential(*truncated)
with torch.no_grad():
dummy = torch.zeros(1, 3, 16, 16)
feature_channels = int(self.backbone(dummy).shape[1])
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
self.spatial_softmax = SpatialSoftmax2D()
self.output_dim = int(view_feature_dim)
def forward(self, pixels):
if pixels.ndim not in (4, 5):
raise ValueError(
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
)
needs_unflatten = pixels.ndim == 5
if needs_unflatten:
batch, steps, channels, height, width = pixels.shape
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
features = self.backbone(pixels.float())
features = self.proj(features)
embeddings = self.spatial_softmax(features)
if needs_unflatten:
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
return embeddings
class LeWMMultiViewResNetBackbone(VLABackbone):
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
def __init__(
self,
view_feature_dim: int = 96,
num_views: int = 3,
view_encoder_mode: str = "shared",
camera_names: Sequence[str] = ("r_vis", "top", "front"),
checkpoint_path: str | Path | None = None,
) -> None:
super().__init__()
if view_encoder_mode not in {"shared", "separate"}:
raise ValueError(
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
)
self.view_feature_dim = int(view_feature_dim)
self.num_views = int(num_views)
self.view_encoder_mode = view_encoder_mode
self.camera_names = tuple(camera_names)
if len(self.camera_names) != self.num_views:
raise ValueError(
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
)
self.output_dim = self.view_feature_dim * self.num_views
self.joint_output_dim = self.output_dim
self.tokens_per_step = 1
if view_encoder_mode == "shared":
self.single_view_encoder = ResNet18SpatialEncoder(
view_feature_dim=view_feature_dim
)
self.view_encoders = None
else:
self.single_view_encoder = None
self.view_encoders = nn.ModuleList(
[
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
for _ in range(num_views)
]
)
if checkpoint_path is not None:
self.load_lewm_checkpoint(checkpoint_path)
@staticmethod
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
state_dict = payload.get("state_dict", payload)
if not isinstance(state_dict, Mapping):
raise TypeError("checkpoint payload must contain a mapping state_dict")
return state_dict
@staticmethod
def _extract_prefixed_state_dict(
state_dict: Mapping[str, torch.Tensor],
prefix: str,
) -> Dict[str, torch.Tensor]:
extracted = {
key[len(prefix):]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
if not extracted:
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
return extracted
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
if isinstance(checkpoint_or_path, (str, Path)):
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
else:
payload = checkpoint_or_path
state_dict = self._unwrap_state_dict(payload)
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
self.load_state_dict(encoder_state_dict, strict=True)
def forward(self, images):
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
if missing:
raise ValueError(
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
)
first_image = images[self.camera_names[0]]
batch_size, steps = first_image.shape[:2]
view_embeddings = []
if self.view_encoder_mode == "shared":
for camera_name in self.camera_names:
view_embeddings.append(self.single_view_encoder(images[camera_name]))
else:
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
view_embeddings.append(single_view_encoder(images[camera_name]))
embeddings = torch.cat(view_embeddings, dim=-1)
return embeddings.reshape(batch_size, steps, self.output_dim)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.dropout = dropout
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, causal=True):
x = self.norm(x)
drop = self.dropout if self.training else 0.0
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
out = rearrange(out, "b h t d -> b t (h d)")
return self.to_out(out)
class Block(nn.Module):
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout=0.0,
block_class=Block,
):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.layers = nn.ModuleList([])
self.input_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.cond_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(hidden_dim, output_dim)
if hidden_dim != output_dim
else nn.Identity()
)
for _ in range(depth):
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
def forward(self, x, c=None):
x = self.input_proj(x)
if c is not None:
c = self.cond_proj(c)
for block in self.layers:
x = block(x)
x = self.norm(x)
return self.output_proj(x)
class QueryTokenPredictor(nn.Module):
"""History-only transformer predictor that decodes learned query tokens."""
def __init__(
self,
*,
num_frames,
query_offsets,
depth,
heads,
mlp_dim,
input_dim,
hidden_dim,
output_dim=None,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
if num_frames <= 0:
raise ValueError(f"num_frames must be positive, got {num_frames}")
query_offsets = tuple(query_offsets)
if not query_offsets:
raise ValueError("query_offsets must contain at least one offset")
if any(offset <= 0 for offset in query_offsets):
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
self.num_frames = int(num_frames)
self.query_offsets = query_offsets
self.num_query_tokens = len(query_offsets)
self.pos_embedding = nn.Parameter(
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
)
self.query_tokens = nn.Parameter(
torch.randn(1, self.num_query_tokens, input_dim)
)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
input_dim,
hidden_dim,
output_dim or input_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
block_class=Block,
)
def forward(self, x):
if x.ndim != 3:
raise ValueError(
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
)
T = x.size(1)
if T > self.num_frames:
raise ValueError(
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
)
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
tokens = torch.cat([x, query_tokens], dim=1)
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
tokens = self.dropout(tokens)
tokens = self.transformer(tokens)
return tokens[:, -self.num_query_tokens :]
class LeWMProjectorMLP(nn.Module):
def __init__(
self,
input_dim: int = 288,
hidden_dim: int = 2048,
output_dim: int = 288,
) -> None:
super().__init__()
self.output_dim = int(output_dim)
self.net = nn.Sequential(
nn.Linear(int(input_dim), int(hidden_dim)),
nn.BatchNorm1d(int(hidden_dim)),
nn.GELU(),
nn.Linear(int(hidden_dim), self.output_dim),
)
def forward(self, x):
return self.net(x)
class SIGReg(nn.Module):
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
super().__init__()
self.num_proj = int(num_proj)
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
dt = 3 / (int(knots) - 1)
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt
window = torch.exp(-t.square() / 2.0)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
def forward(self, proj: torch.Tensor) -> torch.Tensor:
"""
proj: (T, B, D)
"""
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
A = A.div_(A.norm(p=2, dim=0))
x_t = (proj @ A).unsqueeze(-1) * self.t
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
statistic = (err @ self.weights) * proj.size(-2)
return statistic.mean()

View File

@@ -16,3 +16,23 @@ class IdentityActionEncoder(nn.Module):
def forward(self, action):
return action
class LeWMStateEncoder(nn.Module):
def __init__(
self,
input_dim: int = 16,
hidden_dim: int = 256,
output_dim: int = 64,
):
super().__init__()
self.output_dim = int(output_dim)
self.net = nn.Sequential(
nn.Linear(int(input_dim), int(hidden_dim)),
nn.LayerNorm(int(hidden_dim)),
nn.GELU(),
nn.Linear(int(hidden_dim), self.output_dim),
)
def forward(self, state):
return self.net(state)

View File

@@ -376,6 +376,29 @@ class _ForbiddenScheduler:
raise AssertionError('IMF inference should not use DDIM scheduler step')
class _StubFutureTokenPredictor(nn.Module):
def __init__(self, num_future_tokens=1):
super().__init__()
self.num_future_tokens = int(num_future_tokens)
self.calls = []
def forward(self, history_tokens):
self.calls.append(history_tokens.detach().clone())
summary = history_tokens.mean(dim=1, keepdim=True)
return summary.repeat(1, self.num_future_tokens, 1)
class _RecordingSigReg(nn.Module):
def __init__(self, value=0.5):
super().__init__()
self.value = float(value)
self.calls = []
def forward(self, embeddings):
self.calls.append(embeddings.detach().clone())
return embeddings.new_tensor(self.value)
def _make_images(batch_size, obs_horizon, per_camera_fill):
return {
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
@@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_loss_weight=0.5,
)
agent.infer_scheduler = _ForbiddenScheduler()
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(
images,
qpos,
lewm_images=lewm_images,
lewm_proprioception=lewm_qpos,
)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
expected_cond = torch.cat([expected_history, expected_future], dim=1)
self.assertEqual(agent.condition_sequence_length, 3)
self.assertEqual(agent.per_step_cond_dim, 4)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
self.assertEqual(len(future_predictor.calls), 1)
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=0.25,
)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertIn('action_loss', metrics)
self.assertIn('lewm_pred_loss', metrics)
self.assertIn('lewm_sigreg_loss', metrics)
self.assertIn('lewm_loss', metrics)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
places=5,
)
self.assertEqual(len(sigreg.calls), 1)
expected_lewm_history = torch.tensor(
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
dtype=torch.float32,
)
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = {
@@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.vision_encoder.output_dim, 96)
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_query_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(
cfg.agent.vision_backbone._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
)
self.assertEqual(
cfg.agent.state_encoder._target_,
'roboimi.vla.modules.encoders.LeWMStateEncoder',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
self.assertEqual(
cfg.agent.lewm_sigreg._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
)
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
self.assertEqual(
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
agent.condition_sequence_length,
)
self.assertIsNotNone(agent.lewm_sigreg)
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg(

View File

@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_can_emit_lewm_history_and_future_observations(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
lewm_history_horizon=3,
lewm_query_offsets=[1, 2],
)
sample = dataset[1]
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))

View File

@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
return {}
class RecordingAgent(FakeAgent):
def __init__(self):
super().__init__()
self.seen_inputs = []
def compute_loss(self, agent_input):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.zeros(2))
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _make_lewm_batch(self):
batch = self._make_batch()
batch.update(
{
'lewm.observation.front': torch.ones(1, 3, 2, 2),
'lewm.observation.state': torch.ones(1, 4),
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
'lewm.future.state': torch.full((1, 4), 2.0),
}
)
return batch
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
return factory
def _lewm_loader_factory(self):
train_batch = self._make_lewm_batch()
val_batch = self._make_lewm_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 1
cfg.train.save_freq = 100
agent = RecordingAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertGreaterEqual(len(agent.seen_inputs), 1)
first_input = agent.seen_inputs[0]
self.assertIn('lewm_images', first_input)
self.assertIn('lewm_qpos', first_input)
self.assertIn('lewm_future_images', first_input)
self.assertIn('lewm_future_qpos', first_input)
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 0
cfg.train.save_freq = 100
cfg.train.pretrained_ckpt = 'pretrained.pt'
agent = ShapeMixedFakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_torch_load(path, map_location=None):
del map_location
if Path(path).name != 'pretrained.pt':
raise AssertionError(f'unexpected load path: {path}')
return {
'model_state_dict': {
'weight': torch.tensor(3.0),
'bias': torch.tensor([1.0, 2.0, 3.0]),
},
'step': 123,
'loss': 0.5,
}
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
Path('pretrained.pt').write_bytes(b'pretend')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)