feat: add lewm-conditioned imf training and sigreg loss
This commit is contained in:
@@ -237,6 +237,32 @@ def build_training_optimizer(agent, lr, weight_decay):
|
||||
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
|
||||
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
|
||||
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
|
||||
current_state_dict = module.state_dict()
|
||||
compatible_state_dict = {}
|
||||
mismatched_keys = []
|
||||
missing_keys = []
|
||||
|
||||
for key, value in incoming_state_dict.items():
|
||||
if key not in current_state_dict:
|
||||
missing_keys.append(key)
|
||||
continue
|
||||
if current_state_dict[key].shape != value.shape:
|
||||
mismatched_keys.append(key)
|
||||
continue
|
||||
compatible_state_dict[key] = value
|
||||
|
||||
merged_state_dict = dict(current_state_dict)
|
||||
merged_state_dict.update(compatible_state_dict)
|
||||
module.load_state_dict(merged_state_dict, strict=True)
|
||||
return {
|
||||
'loaded_keys': sorted(compatible_state_dict.keys()),
|
||||
'missing_keys': sorted(missing_keys),
|
||||
'mismatched_keys': sorted(mismatched_keys),
|
||||
}
|
||||
|
||||
|
||||
def _init_swanlab(cfg):
|
||||
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
||||
if not bool(cfg.train.get('use_swanlab', False)):
|
||||
@@ -509,18 +535,23 @@ def _run_training(cfg: DictConfig):
|
||||
try:
|
||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||
|
||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
||||
load_info = load_state_dict_ignoring_shape_mismatches(
|
||||
agent,
|
||||
checkpoint['model_state_dict'],
|
||||
strict=False # 允许部分加载(结构不完全匹配时)
|
||||
)
|
||||
|
||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||
|
||||
if missing_keys:
|
||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
||||
if unexpected_keys:
|
||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
||||
if load_info['missing_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
|
||||
f"{load_info['missing_keys'][:5]}..."
|
||||
)
|
||||
if load_info['mismatched_keys']:
|
||||
log.warning(
|
||||
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
|
||||
f"{load_info['mismatched_keys'][:5]}..."
|
||||
)
|
||||
|
||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||
@@ -652,13 +683,35 @@ def _run_training(cfg: DictConfig):
|
||||
if key in batch_data:
|
||||
images[cam_name] = batch_data[key]
|
||||
|
||||
return {
|
||||
agent_input = {
|
||||
'images': images,
|
||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||
'action': batch_data['action'],
|
||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||
}
|
||||
|
||||
lewm_images = {}
|
||||
lewm_future_images = {}
|
||||
for cam_name in cfg.data.camera_names:
|
||||
lewm_obs_key = f"lewm.observation.{cam_name}"
|
||||
if lewm_obs_key in batch_data:
|
||||
lewm_images[cam_name] = batch_data[lewm_obs_key]
|
||||
|
||||
lewm_future_key = f"lewm.future.{cam_name}"
|
||||
if lewm_future_key in batch_data:
|
||||
lewm_future_images[cam_name] = batch_data[lewm_future_key]
|
||||
|
||||
if 'lewm.observation.state' in batch_data:
|
||||
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
||||
if lewm_images:
|
||||
agent_input['lewm_images'] = lewm_images
|
||||
if 'lewm.future.state' in batch_data:
|
||||
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
||||
if lewm_future_images:
|
||||
agent_input['lewm_future_images'] = lewm_future_images
|
||||
|
||||
return agent_input
|
||||
|
||||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||||
agent_stats = agent.get_normalization_stats()
|
||||
torch.save({
|
||||
@@ -809,6 +862,15 @@ def _run_training(cfg: DictConfig):
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
if hasattr(agent, 'get_last_loss_breakdown'):
|
||||
loss_breakdown = agent.get_last_loss_breakdown()
|
||||
extra_train_metrics = {
|
||||
f"train/{key}": value
|
||||
for key, value in loss_breakdown.items()
|
||||
if value is not None and key != 'loss'
|
||||
}
|
||||
if extra_train_metrics:
|
||||
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
|
||||
|
||||
# =====================================================================
|
||||
# 检查点保存与验证
|
||||
|
||||
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
|
||||
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
|
||||
def normalize_chunks(text: str):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
|
||||
def parse_latest_line(text: str) -> tuple[str, int | None]:
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = STEP_PAT.search(line) or BAR_PAT.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
return dt.datetime.now(
|
||||
dt.timezone(dt.timedelta(hours=8)),
|
||||
).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(cmd, capture_output=True, text=True, check=check)
|
||||
|
||||
|
||||
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
|
||||
pid = str(run["pid"])
|
||||
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
|
||||
log_path = pathlib.Path(run["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
return {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
|
||||
|
||||
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
payload = [
|
||||
{
|
||||
"run_id": run["run_id"],
|
||||
"pid": str(run["pid"]),
|
||||
"log_path": run["log_path"],
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
remote_py = r"""
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
payload = json.loads(sys.argv[1])
|
||||
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||
|
||||
def normalize_chunks(text):
|
||||
for part in re.split(r"[\r\n]+", text):
|
||||
part = part.strip()
|
||||
if part:
|
||||
yield part
|
||||
|
||||
def parse_latest_line(text):
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
for line in normalize_chunks(text):
|
||||
if "步骤" not in line and "训练中:" not in line:
|
||||
continue
|
||||
latest_line = line
|
||||
match = step_pat.search(line) or bar_pat.search(line)
|
||||
if match:
|
||||
latest_step = int(match.group(1))
|
||||
return latest_line, latest_step
|
||||
|
||||
out = {}
|
||||
for item in payload:
|
||||
try:
|
||||
ps = subprocess.run(
|
||||
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
log_path = pathlib.Path(item["log_path"])
|
||||
latest_line = ""
|
||||
latest_step = None
|
||||
if log_path.exists():
|
||||
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||
out[item["run_id"]] = {
|
||||
"alive": bool(ps.stdout.strip()),
|
||||
"ps": ps.stdout.strip(),
|
||||
"log_exists": log_path.exists(),
|
||||
"latest_line": latest_line,
|
||||
"latest_step": latest_step,
|
||||
}
|
||||
except Exception as exc:
|
||||
out[item["run_id"]] = {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": str(exc),
|
||||
}
|
||||
print(json.dumps(out, ensure_ascii=False))
|
||||
"""
|
||||
remote_target = host if "@" in host else f"{remote_user}@{host}"
|
||||
remote_cmd = (
|
||||
f"python3 -c {shlex.quote(remote_py)} "
|
||||
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
|
||||
)
|
||||
try:
|
||||
res = run_cmd(
|
||||
[
|
||||
"ssh",
|
||||
"-F",
|
||||
"/dev/null",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=accept-new",
|
||||
remote_target,
|
||||
remote_cmd,
|
||||
]
|
||||
)
|
||||
return json.loads(res.stdout)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
error = (exc.stderr or exc.stdout or str(exc)).strip()
|
||||
return {
|
||||
run["run_id"]: {
|
||||
"alive": False,
|
||||
"ps": "",
|
||||
"log_exists": False,
|
||||
"latest_line": "",
|
||||
"latest_step": None,
|
||||
"error": f"ssh_failed: {error}",
|
||||
}
|
||||
for run in runs
|
||||
}
|
||||
|
||||
|
||||
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
|
||||
lines = [f"\n## Status snapshot {snapshot_at}"]
|
||||
for run in runs:
|
||||
lines.append(
|
||||
(
|
||||
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
|
||||
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
|
||||
f"pid={run['pid']}"
|
||||
)
|
||||
)
|
||||
if run.get("latest_line"):
|
||||
lines.append(f" - latest_line: `{run['latest_line']}`")
|
||||
if run.get("error"):
|
||||
lines.append(f" - error: `{run['error']}`")
|
||||
with notes_path.open("a", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("suite_dir", type=pathlib.Path)
|
||||
parser.add_argument("--remote-user", default="droid")
|
||||
parser.add_argument("--append-notes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
suite_dir = args.suite_dir.resolve()
|
||||
status_path = suite_dir / "status.json"
|
||||
notes_path = suite_dir / "notes.md"
|
||||
monitor_dir = suite_dir / "monitor_logs"
|
||||
monitor_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
status = json.loads(status_path.read_text(encoding="utf-8"))
|
||||
runs: list[dict[str, Any]] = status["runs"]
|
||||
snapshot_at = now_iso()
|
||||
|
||||
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for run in runs:
|
||||
by_host[run["host"]].append(run)
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for host, host_runs in by_host.items():
|
||||
if host == "local":
|
||||
for run in host_runs:
|
||||
results[run["run_id"]] = probe_local(run)
|
||||
else:
|
||||
results.update(remote_probe(host, args.remote_user, host_runs))
|
||||
|
||||
alive_count = 0
|
||||
for run in runs:
|
||||
result = results[run["run_id"]]
|
||||
run["alive"] = result["alive"]
|
||||
run["ps"] = result["ps"]
|
||||
run["log_exists"] = result["log_exists"]
|
||||
run["latest_line"] = result["latest_line"]
|
||||
run["latest_step"] = result["latest_step"]
|
||||
run["last_verified_at"] = snapshot_at
|
||||
if "error" in result:
|
||||
run["error"] = result["error"]
|
||||
else:
|
||||
run.pop("error", None)
|
||||
run["status"] = "running" if result["alive"] else "stopped"
|
||||
alive_count += int(result["alive"])
|
||||
|
||||
status["last_verified_at"] = snapshot_at
|
||||
status["alive_count"] = alive_count
|
||||
status["total_runs"] = len(runs)
|
||||
|
||||
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
snapshot_payload = {
|
||||
"suite_name": status.get("suite_name"),
|
||||
"snapshot_at": snapshot_at,
|
||||
"alive_count": alive_count,
|
||||
"total_runs": len(runs),
|
||||
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
|
||||
}
|
||||
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
|
||||
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
|
||||
snapshot_path.write_text(
|
||||
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
if args.append_notes:
|
||||
append_notes(notes_path, snapshot_at, runs)
|
||||
|
||||
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
|
||||
print(f"\nstatus_json={status_path}")
|
||||
print(f"snapshot_json={snapshot_path}")
|
||||
if args.append_notes:
|
||||
print(f"notes_md={notes_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -28,6 +28,7 @@ class VLAAgent(nn.Module):
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||
extra_condition_tokens: int = 0, # 可选:额外条件token数量(例如未来预测embedding)
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
@@ -39,6 +40,9 @@ class VLAAgent(nn.Module):
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
self.head_type = head_type # 'unet' 或 'transformer'
|
||||
self.extra_condition_tokens = int(extra_condition_tokens)
|
||||
if self.extra_condition_tokens < 0:
|
||||
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
|
||||
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||
@@ -71,11 +75,14 @@ class VLAAgent(nn.Module):
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type
|
||||
)
|
||||
self.dataset_stats = dataset_stats
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
self.state_encoder = state_encoder
|
||||
if self.camera_names is not None:
|
||||
self.vision_encoder.camera_names = self.camera_names
|
||||
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
|
||||
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||
if joint_vision_dim is not None:
|
||||
per_token_vision_dim = int(joint_vision_dim)
|
||||
@@ -87,8 +94,11 @@ class VLAAgent(nn.Module):
|
||||
else:
|
||||
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||
|
||||
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
|
||||
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||
self.condition_sequence_length = (
|
||||
self.history_condition_sequence_length + self.extra_condition_tokens
|
||||
)
|
||||
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
|
||||
if cond_projector is None:
|
||||
self.cond_projector = None
|
||||
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||
@@ -139,7 +149,6 @@ class VLAAgent(nn.Module):
|
||||
global_cond_dim=self.global_cond_dim
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
|
||||
# 初始化队列(用于在线推理)
|
||||
@@ -220,7 +229,7 @@ class VLAAgent(nn.Module):
|
||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||
expected_length = self.condition_sequence_length
|
||||
expected_length = self.history_condition_sequence_length
|
||||
if cond.shape[1] != expected_length:
|
||||
raise RuntimeError(
|
||||
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from roboimi.vla.agent import VLAAgent
|
||||
@@ -15,14 +18,59 @@ except ImportError: # pragma: no cover
|
||||
|
||||
|
||||
class IMFVLAAgent(VLAAgent):
|
||||
def __init__(self, *args, inference_steps: int = 1, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
inference_steps: int = 1,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
lewm_predictor: Optional[nn.Module] = None,
|
||||
lewm_pred_projector: Optional[nn.Module] = None,
|
||||
lewm_sigreg: Optional[nn.Module] = None,
|
||||
lewm_sigreg_weight: float = 0.09,
|
||||
lewm_loss_weight: float = 0.0,
|
||||
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if inference_steps != 1:
|
||||
raise ValueError(
|
||||
'IMFVLAAgent only supports one-step inference; '
|
||||
f'inference_steps must be 1, got {inference_steps}.'
|
||||
)
|
||||
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
||||
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
||||
kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens)
|
||||
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
||||
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
||||
self.__dict__['lewm_predictor'] = lewm_predictor
|
||||
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
||||
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
||||
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
||||
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
||||
self.__dict__['_last_loss_breakdown'] = {
|
||||
'action_loss': 0.0,
|
||||
'lewm_pred_loss': 0.0,
|
||||
'lewm_sigreg_loss': 0.0,
|
||||
'lewm_loss': 0.0,
|
||||
'loss': 0.0,
|
||||
}
|
||||
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||
self.inference_steps = 1
|
||||
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
||||
self.lewm_predictor = lewm_predictor
|
||||
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
||||
self.lewm_sigreg = lewm_sigreg
|
||||
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
||||
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
||||
)
|
||||
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
|
||||
raise ValueError(
|
||||
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
||||
)
|
||||
if lewm_pretrained_ckpt is not None:
|
||||
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
@@ -119,14 +167,181 @@ class IMFVLAAgent(VLAAgent):
|
||||
delta = self._broadcast_batch_time(t - r, z_t)
|
||||
return z_t - delta * u
|
||||
|
||||
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||
if not self.normalization.enabled:
|
||||
return qpos
|
||||
|
||||
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
|
||||
qpos_std = getattr(self.normalization, 'qpos_std', None)
|
||||
if qpos_mean is not None and qpos_std is not None:
|
||||
return (qpos - qpos_mean) / qpos_std
|
||||
if isinstance(self.dataset_stats, dict):
|
||||
mean = self.dataset_stats.get('qpos_mean', None)
|
||||
std = self.dataset_stats.get('qpos_std', None)
|
||||
if mean is not None and std is not None:
|
||||
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
|
||||
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
|
||||
return (qpos - mean) / std
|
||||
return self.normalization.normalize_qpos(qpos)
|
||||
|
||||
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
|
||||
if predicted_tokens.ndim != 3:
|
||||
raise ValueError(
|
||||
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
|
||||
)
|
||||
batch_size, token_count, token_dim = predicted_tokens.shape
|
||||
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
|
||||
projected = self.lewm_pred_projector(flattened)
|
||||
if projected.ndim != 2:
|
||||
raise ValueError(
|
||||
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
|
||||
)
|
||||
return projected.reshape(batch_size, token_count, projected.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
def _load_checkpoint_payload(
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> Mapping[str, torch.Tensor]:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = payload.get('state_dict', payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError('checkpoint payload must contain a mapping state_dict')
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
@staticmethod
|
||||
def _adapt_and_load_state_dict(
|
||||
module: nn.Module,
|
||||
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||
*,
|
||||
query_key: str = 'query_tokens',
|
||||
pos_key: str = 'pos_embedding',
|
||||
) -> None:
|
||||
current_state_dict = module.state_dict()
|
||||
adapted_state_dict = dict(current_state_dict)
|
||||
for key, current_tensor in current_state_dict.items():
|
||||
if key not in incoming_state_dict:
|
||||
continue
|
||||
source_tensor = incoming_state_dict[key]
|
||||
if source_tensor.shape == current_tensor.shape:
|
||||
adapted_state_dict[key] = source_tensor
|
||||
continue
|
||||
|
||||
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||
patched = current_tensor.clone()
|
||||
if key == query_key:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
else:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
adapted_state_dict[key] = patched
|
||||
|
||||
module.load_state_dict(adapted_state_dict, strict=True)
|
||||
|
||||
def load_lewm_pretrained_components(
|
||||
self,
|
||||
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||
) -> None:
|
||||
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
||||
|
||||
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
||||
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
||||
else:
|
||||
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||
self.vision_encoder.load_state_dict(vision_state_dict, strict=True)
|
||||
|
||||
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
||||
self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True)
|
||||
|
||||
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
||||
mapped_projector_state_dict = {
|
||||
f'linear.{key}': value
|
||||
for key, value in projector_state_dict.items()
|
||||
}
|
||||
self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True)
|
||||
|
||||
if self.lewm_predictor is not None:
|
||||
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
||||
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
|
||||
|
||||
if self.lewm_pred_projector is not None:
|
||||
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
||||
self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True)
|
||||
|
||||
def _build_full_condition(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
history_cond = self._build_cond(images, normalized_proprioception)
|
||||
predicted_future_tokens = None
|
||||
lewm_history_cond = None
|
||||
|
||||
if self.lewm_predictor is None:
|
||||
return history_cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
lewm_images = lewm_images if lewm_images is not None else images
|
||||
lewm_proprioception = (
|
||||
lewm_proprioception if lewm_proprioception is not None else proprioception
|
||||
)
|
||||
lewm_history_cond = self._build_cond(
|
||||
lewm_images,
|
||||
self._normalize_qpos_for_lewm(lewm_proprioception),
|
||||
)
|
||||
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
||||
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
||||
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
||||
if cond.shape[1] != self.condition_sequence_length:
|
||||
raise RuntimeError(
|
||||
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
||||
)
|
||||
if cond.shape[-1] != self.per_step_cond_dim:
|
||||
raise RuntimeError(
|
||||
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||
)
|
||||
return cond, predicted_future_tokens, lewm_history_cond
|
||||
|
||||
@staticmethod
|
||||
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(pred, target)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
action_is_pad = batch.get('action_is_pad', None)
|
||||
batch_size = actions.shape[0]
|
||||
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
cond = self._build_cond(images, states)
|
||||
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
states,
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
x = actions
|
||||
e = torch.randn_like(x)
|
||||
@@ -146,16 +361,103 @@ class IMFVLAAgent(VLAAgent):
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
action_loss = loss.mean()
|
||||
|
||||
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||
if predicted_future_tokens is not None:
|
||||
lewm_future_images = batch.get('lewm_future_images', None)
|
||||
lewm_future_qpos = batch.get('lewm_future_qpos', None)
|
||||
if lewm_future_images is not None and lewm_future_qpos is not None:
|
||||
future_target = self._build_cond(
|
||||
lewm_future_images,
|
||||
self._normalize_qpos_for_lewm(lewm_future_qpos),
|
||||
)
|
||||
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
|
||||
if self.lewm_sigreg is not None and lewm_history_cond is not None:
|
||||
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
|
||||
|
||||
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
|
||||
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
|
||||
self._last_loss_breakdown = {
|
||||
'action_loss': float(action_loss.detach().item()),
|
||||
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
|
||||
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
|
||||
'lewm_loss': float(lewm_loss.detach().item()),
|
||||
'loss': float(total_loss.detach().item()),
|
||||
}
|
||||
return total_loss
|
||||
|
||||
def get_last_loss_breakdown(self) -> Dict[str, float]:
|
||||
return dict(self._last_loss_breakdown)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self.lewm_predictor is not None:
|
||||
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
|
||||
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
super()._populate_queues(observation)
|
||||
if self.lewm_predictor is None:
|
||||
return
|
||||
if 'qpos' in observation:
|
||||
self._queues['lewm_qpos'].append(observation['qpos'].clone())
|
||||
if 'images' in observation:
|
||||
ordered_images = self._order_images(observation['images'])
|
||||
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
batch = super()._prepare_observation_batch()
|
||||
if self.lewm_predictor is None:
|
||||
return batch
|
||||
|
||||
qpos_list = list(self._queues['lewm_qpos'])
|
||||
images_list = list(self._queues['lewm_images'])
|
||||
if len(qpos_list) == 0 or len(images_list) == 0:
|
||||
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
|
||||
while len(qpos_list) < self.lewm_history_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
while len(images_list) < self.lewm_history_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||
batch['lewm_images'] = {}
|
||||
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||
for cam_name in camera_names:
|
||||
batch['lewm_images'][cam_name] = torch.stack(
|
||||
[img[cam_name] for img in images_list],
|
||||
dim=0,
|
||||
).unsqueeze(0)
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return self.predict_action(
|
||||
batch['images'],
|
||||
batch['qpos'],
|
||||
lewm_images=batch.get('lewm_images', None),
|
||||
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
images,
|
||||
proprioception,
|
||||
*,
|
||||
lewm_images=None,
|
||||
lewm_proprioception=None,
|
||||
):
|
||||
batch_size = proprioception.shape[0]
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
cond = self._build_cond(images, proprioception)
|
||||
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
||||
images,
|
||||
proprioception,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_proprioception,
|
||||
)
|
||||
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||
action = self._sample_one_step(z_t, cond=cond)
|
||||
return self.normalization.denormalize_action(action)
|
||||
|
||||
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||
- /modules@state_encoder: lewm_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /modules@cond_projector: linear_condition_projector
|
||||
- /head: imf_transformer1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
normalization_type: "min_max"
|
||||
pred_horizon: 8
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
camera_names: ${data.camera_names}
|
||||
num_cams: 3
|
||||
|
||||
vision_backbone:
|
||||
camera_names: ${agent.camera_names}
|
||||
num_views: ${agent.num_cams}
|
||||
|
||||
cond_projector:
|
||||
output_dim: 288
|
||||
|
||||
lewm_history_horizon: 3
|
||||
lewm_query_offsets: [8]
|
||||
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
|
||||
lewm_loss_weight: 1.0
|
||||
lewm_sigreg_weight: 0.09
|
||||
lewm_pretrained_ckpt: null
|
||||
|
||||
lewm_sigreg:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||
knots: 17
|
||||
num_proj: 1024
|
||||
|
||||
lewm_predictor:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
|
||||
num_frames: ${agent.lewm_history_horizon}
|
||||
query_offsets: ${agent.lewm_query_offsets}
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
depth: 6
|
||||
heads: 16
|
||||
mlp_dim: 2048
|
||||
dim_head: 64
|
||||
dropout: 0.1
|
||||
emb_dropout: 0.0
|
||||
|
||||
lewm_pred_projector:
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
hidden_dim: 2048
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
|
||||
diffusion_steps: 100
|
||||
inference_steps: 1
|
||||
head_type: "transformer"
|
||||
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 288
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
time_as_cond: true
|
||||
obs_as_cond: true
|
||||
n_cond_layers: 0
|
||||
backbone_type: attnres_full
|
||||
n_head: 1
|
||||
n_kv_head: 1
|
||||
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
|
||||
|
||||
view_feature_dim: 96
|
||||
num_views: ${agent.num_cams}
|
||||
view_encoder_mode: separate
|
||||
camera_names: ${agent.camera_names}
|
||||
checkpoint_path: null
|
||||
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
|
||||
|
||||
input_dim: ${agent.obs_dim}
|
||||
hidden_dim: 256
|
||||
output_dim: 64
|
||||
@@ -24,6 +24,8 @@ class SimpleRobotDataset(Dataset):
|
||||
camera_names: List[str] = None,
|
||||
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||
max_open_files: int = 64,
|
||||
lewm_history_horizon: Optional[int] = None,
|
||||
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -42,6 +44,13 @@ class SimpleRobotDataset(Dataset):
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.camera_names = camera_names or []
|
||||
self.lewm_history_horizon = (
|
||||
int(lewm_history_horizon) if lewm_history_horizon is not None else None
|
||||
)
|
||||
self.lewm_query_offsets = (
|
||||
tuple(int(offset) for offset in lewm_query_offsets)
|
||||
if lewm_query_offsets is not None else ()
|
||||
)
|
||||
self.image_resize_shape = (
|
||||
tuple(int(v) for v in image_resize_shape)
|
||||
if image_resize_shape is not None else None
|
||||
@@ -220,6 +229,60 @@ class SimpleRobotDataset(Dataset):
|
||||
for cam_name in self.camera_names:
|
||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||
|
||||
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
|
||||
lewm_observations = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"] = []
|
||||
|
||||
for delta in range(-self.lewm_history_horizon + 1, 1):
|
||||
target_idx = idx + delta
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_observations["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_observations[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.observation.{cam_name}"] = torch.stack(
|
||||
lewm_observations[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
if self.lewm_query_offsets:
|
||||
lewm_future = {
|
||||
"state": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"] = []
|
||||
|
||||
for offset in self.lewm_query_offsets:
|
||||
target_idx = idx + offset
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self._load_frame(target_idx)
|
||||
else:
|
||||
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||
target_frame = self._load_frame(boundary_idx)
|
||||
|
||||
lewm_future["state"].append(target_frame["observation.state"])
|
||||
for cam_name in self.camera_names:
|
||||
lewm_future[f"observation.{cam_name}"].append(
|
||||
target_frame[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
result["lewm.future.state"] = torch.stack(lewm_future["state"])
|
||||
for cam_name in self.camera_names:
|
||||
result[f"lewm.future.{cam_name}"] = torch.stack(
|
||||
lewm_future[f"observation.{cam_name}"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# Backbone models
|
||||
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
|
||||
__all__ = [
|
||||
"LEWMViTBackbone",
|
||||
"LeWMMultiViewResNetBackbone",
|
||||
"QueryTokenPredictor",
|
||||
"LeWMProjectorMLP",
|
||||
"SIGReg",
|
||||
"ResNetBackbone",
|
||||
"ResNetDiffusionBackbone",
|
||||
"SigLIP2DiffusionBackbone",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
@@ -9,6 +18,19 @@ def __getattr__(name):
|
||||
if name == "SigLIP2DiffusionBackbone":
|
||||
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
return SigLIP2DiffusionBackbone
|
||||
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
|
||||
from .lewm_resnet_query_fusion import (
|
||||
LeWMMultiViewResNetBackbone,
|
||||
QueryTokenPredictor,
|
||||
LeWMProjectorMLP,
|
||||
SIGReg,
|
||||
)
|
||||
return {
|
||||
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
|
||||
"QueryTokenPredictor": QueryTokenPredictor,
|
||||
"LeWMProjectorMLP": LeWMProjectorMLP,
|
||||
"SIGReg": SIGReg,
|
||||
}[name]
|
||||
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||
return ResNetDiffusionBackbone
|
||||
|
||||
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
from roboimi.vla.core.interfaces import VLABackbone
|
||||
|
||||
|
||||
class SpatialSoftmax2D(nn.Module):
|
||||
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
|
||||
|
||||
def forward(self, feature_map):
|
||||
if feature_map.ndim != 4:
|
||||
raise ValueError(
|
||||
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
|
||||
)
|
||||
|
||||
batch, channels, height, width = feature_map.shape
|
||||
scores = feature_map.reshape(batch, channels, height * width)
|
||||
attention = F.softmax(scores, dim=-1)
|
||||
|
||||
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
|
||||
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
|
||||
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
||||
grid_x = grid_x.reshape(1, 1, height * width)
|
||||
grid_y = grid_y.reshape(1, 1, height * width)
|
||||
|
||||
expected_x = (attention * grid_x).sum(dim=-1)
|
||||
expected_y = (attention * grid_y).sum(dim=-1)
|
||||
return torch.cat([expected_x, expected_y], dim=-1)
|
||||
|
||||
|
||||
class ResNet18SpatialEncoder(nn.Module):
|
||||
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
|
||||
|
||||
def __init__(self, view_feature_dim=96):
|
||||
super().__init__()
|
||||
if view_feature_dim % 2 != 0:
|
||||
raise ValueError("view_feature_dim must be even for spatial softmax features")
|
||||
|
||||
backbone = models.resnet18(weights=None)
|
||||
if all(
|
||||
hasattr(backbone, name)
|
||||
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
|
||||
):
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3,
|
||||
backbone.layer4,
|
||||
)
|
||||
feature_channels = 512
|
||||
else:
|
||||
children = list(backbone.children())
|
||||
if len(children) < 1:
|
||||
raise ValueError("resnet18 backbone must expose child modules")
|
||||
truncated = children[:-2] if len(children) > 2 else children
|
||||
self.backbone = nn.Sequential(*truncated)
|
||||
with torch.no_grad():
|
||||
dummy = torch.zeros(1, 3, 16, 16)
|
||||
feature_channels = int(self.backbone(dummy).shape[1])
|
||||
|
||||
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
|
||||
self.spatial_softmax = SpatialSoftmax2D()
|
||||
self.output_dim = int(view_feature_dim)
|
||||
|
||||
def forward(self, pixels):
|
||||
if pixels.ndim not in (4, 5):
|
||||
raise ValueError(
|
||||
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
|
||||
)
|
||||
|
||||
needs_unflatten = pixels.ndim == 5
|
||||
if needs_unflatten:
|
||||
batch, steps, channels, height, width = pixels.shape
|
||||
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
|
||||
|
||||
features = self.backbone(pixels.float())
|
||||
features = self.proj(features)
|
||||
embeddings = self.spatial_softmax(features)
|
||||
|
||||
if needs_unflatten:
|
||||
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
|
||||
return embeddings
|
||||
|
||||
|
||||
class LeWMMultiViewResNetBackbone(VLABackbone):
|
||||
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
view_feature_dim: int = 96,
|
||||
num_views: int = 3,
|
||||
view_encoder_mode: str = "shared",
|
||||
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||
checkpoint_path: str | Path | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if view_encoder_mode not in {"shared", "separate"}:
|
||||
raise ValueError(
|
||||
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
|
||||
)
|
||||
|
||||
self.view_feature_dim = int(view_feature_dim)
|
||||
self.num_views = int(num_views)
|
||||
self.view_encoder_mode = view_encoder_mode
|
||||
self.camera_names = tuple(camera_names)
|
||||
if len(self.camera_names) != self.num_views:
|
||||
raise ValueError(
|
||||
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
|
||||
)
|
||||
self.output_dim = self.view_feature_dim * self.num_views
|
||||
self.joint_output_dim = self.output_dim
|
||||
self.tokens_per_step = 1
|
||||
|
||||
if view_encoder_mode == "shared":
|
||||
self.single_view_encoder = ResNet18SpatialEncoder(
|
||||
view_feature_dim=view_feature_dim
|
||||
)
|
||||
self.view_encoders = None
|
||||
else:
|
||||
self.single_view_encoder = None
|
||||
self.view_encoders = nn.ModuleList(
|
||||
[
|
||||
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
|
||||
for _ in range(num_views)
|
||||
]
|
||||
)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
self.load_lewm_checkpoint(checkpoint_path)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||
state_dict = payload.get("state_dict", payload)
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def _extract_prefixed_state_dict(
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
prefix: str,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
extracted = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in state_dict.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
if not extracted:
|
||||
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||
return extracted
|
||||
|
||||
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||
if isinstance(checkpoint_or_path, (str, Path)):
|
||||
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||
else:
|
||||
payload = checkpoint_or_path
|
||||
state_dict = self._unwrap_state_dict(payload)
|
||||
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||
self.load_state_dict(encoder_state_dict, strict=True)
|
||||
|
||||
def forward(self, images):
|
||||
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||
)
|
||||
|
||||
first_image = images[self.camera_names[0]]
|
||||
batch_size, steps = first_image.shape[:2]
|
||||
view_embeddings = []
|
||||
if self.view_encoder_mode == "shared":
|
||||
for camera_name in self.camera_names:
|
||||
view_embeddings.append(self.single_view_encoder(images[camera_name]))
|
||||
else:
|
||||
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
|
||||
view_embeddings.append(single_view_encoder(images[camera_name]))
|
||||
|
||||
embeddings = torch.cat(view_embeddings, dim=-1)
|
||||
return embeddings.reshape(batch_size, steps, self.output_dim)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
self.to_out = (
|
||||
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
||||
if project_out
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, causal=True):
|
||||
x = self.norm(x)
|
||||
drop = self.dropout if self.training else 0.0
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
||||
out = rearrange(out, "b h t d -> b t (h d)")
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
||||
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout=0.0,
|
||||
block_class=Block,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.input_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.cond_proj = (
|
||||
nn.Linear(input_dim, hidden_dim)
|
||||
if input_dim != hidden_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
self.output_proj = (
|
||||
nn.Linear(hidden_dim, output_dim)
|
||||
if hidden_dim != output_dim
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
|
||||
|
||||
def forward(self, x, c=None):
|
||||
x = self.input_proj(x)
|
||||
if c is not None:
|
||||
c = self.cond_proj(c)
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
x = self.norm(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
|
||||
class QueryTokenPredictor(nn.Module):
|
||||
"""History-only transformer predictor that decodes learned query tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_frames,
|
||||
query_offsets,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim=None,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
emb_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
if num_frames <= 0:
|
||||
raise ValueError(f"num_frames must be positive, got {num_frames}")
|
||||
|
||||
query_offsets = tuple(query_offsets)
|
||||
if not query_offsets:
|
||||
raise ValueError("query_offsets must contain at least one offset")
|
||||
if any(offset <= 0 for offset in query_offsets):
|
||||
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
|
||||
|
||||
self.num_frames = int(num_frames)
|
||||
self.query_offsets = query_offsets
|
||||
self.num_query_tokens = len(query_offsets)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.query_tokens = nn.Parameter(
|
||||
torch.randn(1, self.num_query_tokens, input_dim)
|
||||
)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
output_dim or input_dim,
|
||||
depth,
|
||||
heads,
|
||||
dim_head,
|
||||
mlp_dim,
|
||||
dropout,
|
||||
block_class=Block,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if x.ndim != 3:
|
||||
raise ValueError(
|
||||
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
|
||||
)
|
||||
|
||||
T = x.size(1)
|
||||
if T > self.num_frames:
|
||||
raise ValueError(
|
||||
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
|
||||
)
|
||||
|
||||
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
|
||||
tokens = torch.cat([x, query_tokens], dim=1)
|
||||
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
|
||||
tokens = self.dropout(tokens)
|
||||
tokens = self.transformer(tokens)
|
||||
return tokens[:, -self.num_query_tokens :]
|
||||
|
||||
|
||||
class LeWMProjectorMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 288,
|
||||
hidden_dim: int = 2048,
|
||||
output_dim: int = 288,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.BatchNorm1d(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class SIGReg(nn.Module):
|
||||
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
|
||||
|
||||
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
|
||||
super().__init__()
|
||||
self.num_proj = int(num_proj)
|
||||
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
|
||||
dt = 3 / (int(knots) - 1)
|
||||
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
|
||||
weights[[0, -1]] = dt
|
||||
window = torch.exp(-t.square() / 2.0)
|
||||
self.register_buffer("t", t)
|
||||
self.register_buffer("phi", window)
|
||||
self.register_buffer("weights", weights * window)
|
||||
|
||||
def forward(self, proj: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
proj: (T, B, D)
|
||||
"""
|
||||
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||
A = A.div_(A.norm(p=2, dim=0))
|
||||
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
||||
statistic = (err @ self.weights) * proj.size(-2)
|
||||
return statistic.mean()
|
||||
@@ -16,3 +16,23 @@ class IdentityActionEncoder(nn.Module):
|
||||
|
||||
def forward(self, action):
|
||||
return action
|
||||
|
||||
|
||||
class LeWMStateEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
hidden_dim: int = 256,
|
||||
output_dim: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = int(output_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||
nn.LayerNorm(int(hidden_dim)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(hidden_dim), self.output_dim),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.net(state)
|
||||
|
||||
@@ -376,6 +376,29 @@ class _ForbiddenScheduler:
|
||||
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||
|
||||
|
||||
class _StubFutureTokenPredictor(nn.Module):
|
||||
def __init__(self, num_future_tokens=1):
|
||||
super().__init__()
|
||||
self.num_future_tokens = int(num_future_tokens)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, history_tokens):
|
||||
self.calls.append(history_tokens.detach().clone())
|
||||
summary = history_tokens.mean(dim=1, keepdim=True)
|
||||
return summary.repeat(1, self.num_future_tokens, 1)
|
||||
|
||||
|
||||
class _RecordingSigReg(nn.Module):
|
||||
def __init__(self, value=0.5):
|
||||
super().__init__()
|
||||
self.value = float(value)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, embeddings):
|
||||
self.calls.append(embeddings.detach().clone())
|
||||
return embeddings.new_tensor(self.value)
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||
return {
|
||||
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||
@@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_loss_weight=0.5,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
_ = agent.predict_action(
|
||||
images,
|
||||
qpos,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_qpos,
|
||||
)
|
||||
|
||||
expected_history = torch.tensor(
|
||||
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
|
||||
expected_cond = torch.cat([expected_history, expected_future], dim=1)
|
||||
|
||||
self.assertEqual(agent.condition_sequence_length, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
self.assertEqual(len(future_predictor.calls), 1)
|
||||
|
||||
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
sigreg = _RecordingSigReg(value=0.75)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_sigreg=sigreg,
|
||||
lewm_sigreg_weight=0.09,
|
||||
lewm_loss_weight=0.25,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||
actions = torch.tensor(
|
||||
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||
lewm_future_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=1,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||
noise = torch.tensor(
|
||||
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||
loss = agent.compute_loss(
|
||||
{
|
||||
'images': images,
|
||||
'qpos': qpos,
|
||||
'action': actions,
|
||||
'lewm_images': lewm_images,
|
||||
'lewm_qpos': lewm_qpos,
|
||||
'lewm_future_images': lewm_future_images,
|
||||
'lewm_future_qpos': lewm_future_qpos,
|
||||
}
|
||||
)
|
||||
|
||||
metrics = agent.get_last_loss_breakdown()
|
||||
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||
self.assertIn('action_loss', metrics)
|
||||
self.assertIn('lewm_pred_loss', metrics)
|
||||
self.assertIn('lewm_sigreg_loss', metrics)
|
||||
self.assertIn('lewm_loss', metrics)
|
||||
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||
self.assertAlmostEqual(
|
||||
metrics['lewm_loss'],
|
||||
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
metrics['loss'],
|
||||
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertEqual(len(sigreg.calls), 1)
|
||||
expected_lewm_history = torch.tensor(
|
||||
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
||||
|
||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||
observation = {
|
||||
@@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_resnet_query_imf_attnres',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.lewm_query_offsets=[8]',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
|
||||
)
|
||||
self.assertEqual(
|
||||
cfg.agent.state_encoder._target_,
|
||||
'roboimi.vla.modules.encoders.LeWMStateEncoder',
|
||||
)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
|
||||
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
|
||||
self.assertEqual(
|
||||
cfg.agent.lewm_sigreg._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
|
||||
)
|
||||
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
|
||||
self.assertEqual(
|
||||
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
|
||||
agent.condition_sequence_length,
|
||||
)
|
||||
self.assertIsNotNone(agent.lewm_sigreg)
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
|
||||
@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
|
||||
fake_cv2.resize.assert_not_called()
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
def test_getitem_can_emit_lewm_history_and_future_observations(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
image_resize_shape=None,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[1, 2],
|
||||
)
|
||||
|
||||
sample = dataset[1]
|
||||
|
||||
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
|
||||
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
||||
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
||||
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
||||
|
||||
@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
|
||||
return {}
|
||||
|
||||
|
||||
class RecordingAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seen_inputs = []
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
self.seen_inputs.append(agent_input)
|
||||
return super().compute_loss(agent_input)
|
||||
|
||||
|
||||
class ShapeMixedFakeAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(2))
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
|
||||
def _make_lewm_batch(self):
|
||||
batch = self._make_batch()
|
||||
batch.update(
|
||||
{
|
||||
'lewm.observation.front': torch.ones(1, 3, 2, 2),
|
||||
'lewm.observation.state': torch.ones(1, 4),
|
||||
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
|
||||
'lewm.future.state': torch.full((1, 4), 2.0),
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
def _loader_factory(self):
|
||||
train_batch = self._make_batch()
|
||||
val_batch = self._make_batch()
|
||||
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
return factory
|
||||
|
||||
def _lewm_loader_factory(self):
|
||||
train_batch = self._make_lewm_batch()
|
||||
val_batch = self._make_lewm_batch()
|
||||
|
||||
def factory(_dataset, *, shuffle, **_kwargs):
|
||||
return FakeLoader(train_batch if shuffle else val_batch)
|
||||
|
||||
return factory
|
||||
|
||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 1
|
||||
cfg.train.save_freq = 100
|
||||
agent = RecordingAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertGreaterEqual(len(agent.seen_inputs), 1)
|
||||
first_input = agent.seen_inputs[0]
|
||||
self.assertIn('lewm_images', first_input)
|
||||
self.assertIn('lewm_qpos', first_input)
|
||||
self.assertIn('lewm_future_images', first_input)
|
||||
self.assertIn('lewm_future_qpos', first_input)
|
||||
self.assertIn('front', first_input['lewm_images'])
|
||||
self.assertIn('front', first_input['lewm_future_images'])
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 0
|
||||
cfg.train.save_freq = 100
|
||||
cfg.train.pretrained_ckpt = 'pretrained.pt'
|
||||
agent = ShapeMixedFakeAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
if Path(path).name != 'pretrained.pt':
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
return {
|
||||
'model_state_dict': {
|
||||
'weight': torch.tensor(3.0),
|
||||
'bias': torch.tensor([1.0, 2.0, 3.0]),
|
||||
},
|
||||
'step': 123,
|
||||
'loss': 0.5,
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
Path('pretrained.pt').write_bytes(b'pretend')
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
|
||||
Reference in New Issue
Block a user