feat: add lewm-conditioned imf training and sigreg loss
This commit is contained in:
@@ -237,6 +237,32 @@ def build_training_optimizer(agent, lr, weight_decay):
|
|||||||
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
return AdamW(optim_groups, lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_ignoring_shape_mismatches(module, incoming_state_dict):
|
||||||
|
"""Load only checkpoint tensors whose keys exist locally and whose shapes match."""
|
||||||
|
current_state_dict = module.state_dict()
|
||||||
|
compatible_state_dict = {}
|
||||||
|
mismatched_keys = []
|
||||||
|
missing_keys = []
|
||||||
|
|
||||||
|
for key, value in incoming_state_dict.items():
|
||||||
|
if key not in current_state_dict:
|
||||||
|
missing_keys.append(key)
|
||||||
|
continue
|
||||||
|
if current_state_dict[key].shape != value.shape:
|
||||||
|
mismatched_keys.append(key)
|
||||||
|
continue
|
||||||
|
compatible_state_dict[key] = value
|
||||||
|
|
||||||
|
merged_state_dict = dict(current_state_dict)
|
||||||
|
merged_state_dict.update(compatible_state_dict)
|
||||||
|
module.load_state_dict(merged_state_dict, strict=True)
|
||||||
|
return {
|
||||||
|
'loaded_keys': sorted(compatible_state_dict.keys()),
|
||||||
|
'missing_keys': sorted(missing_keys),
|
||||||
|
'mismatched_keys': sorted(mismatched_keys),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _init_swanlab(cfg):
|
def _init_swanlab(cfg):
|
||||||
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
"""按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。"""
|
||||||
if not bool(cfg.train.get('use_swanlab', False)):
|
if not bool(cfg.train.get('use_swanlab', False)):
|
||||||
@@ -509,18 +535,23 @@ def _run_training(cfg: DictConfig):
|
|||||||
try:
|
try:
|
||||||
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
checkpoint = torch.load(ckpt_path, map_location=cfg.train.device)
|
||||||
|
|
||||||
# 只加载模型权重(不加载 optimizer、scheduler)
|
load_info = load_state_dict_ignoring_shape_mismatches(
|
||||||
missing_keys, unexpected_keys = agent.load_state_dict(
|
agent,
|
||||||
checkpoint['model_state_dict'],
|
checkpoint['model_state_dict'],
|
||||||
strict=False # 允许部分加载(结构不完全匹配时)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"✅ [Finetune] 模型权重加载成功")
|
log.info(f"✅ [Finetune] 模型权重加载成功")
|
||||||
|
|
||||||
if missing_keys:
|
if load_info['missing_keys']:
|
||||||
log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...")
|
log.warning(
|
||||||
if unexpected_keys:
|
f"⚠️ [Finetune] checkpoint 中存在本地模型没有的键 ({len(load_info['missing_keys'])} 个): "
|
||||||
log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...")
|
f"{load_info['missing_keys'][:5]}..."
|
||||||
|
)
|
||||||
|
if load_info['mismatched_keys']:
|
||||||
|
log.warning(
|
||||||
|
f"⚠️ [Finetune] 因形状不匹配而跳过的键 ({len(load_info['mismatched_keys'])} 个): "
|
||||||
|
f"{load_info['mismatched_keys'][:5]}..."
|
||||||
|
)
|
||||||
|
|
||||||
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}")
|
||||||
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})")
|
||||||
@@ -652,13 +683,35 @@ def _run_training(cfg: DictConfig):
|
|||||||
if key in batch_data:
|
if key in batch_data:
|
||||||
images[cam_name] = batch_data[key]
|
images[cam_name] = batch_data[key]
|
||||||
|
|
||||||
return {
|
agent_input = {
|
||||||
'images': images,
|
'images': images,
|
||||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||||
'action': batch_data['action'],
|
'action': batch_data['action'],
|
||||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lewm_images = {}
|
||||||
|
lewm_future_images = {}
|
||||||
|
for cam_name in cfg.data.camera_names:
|
||||||
|
lewm_obs_key = f"lewm.observation.{cam_name}"
|
||||||
|
if lewm_obs_key in batch_data:
|
||||||
|
lewm_images[cam_name] = batch_data[lewm_obs_key]
|
||||||
|
|
||||||
|
lewm_future_key = f"lewm.future.{cam_name}"
|
||||||
|
if lewm_future_key in batch_data:
|
||||||
|
lewm_future_images[cam_name] = batch_data[lewm_future_key]
|
||||||
|
|
||||||
|
if 'lewm.observation.state' in batch_data:
|
||||||
|
agent_input['lewm_qpos'] = batch_data['lewm.observation.state']
|
||||||
|
if lewm_images:
|
||||||
|
agent_input['lewm_images'] = lewm_images
|
||||||
|
if 'lewm.future.state' in batch_data:
|
||||||
|
agent_input['lewm_future_qpos'] = batch_data['lewm.future.state']
|
||||||
|
if lewm_future_images:
|
||||||
|
agent_input['lewm_future_images'] = lewm_future_images
|
||||||
|
|
||||||
|
return agent_input
|
||||||
|
|
||||||
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None):
|
||||||
agent_stats = agent.get_normalization_stats()
|
agent_stats = agent.get_normalization_stats()
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -809,6 +862,15 @@ def _run_training(cfg: DictConfig):
|
|||||||
},
|
},
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
if hasattr(agent, 'get_last_loss_breakdown'):
|
||||||
|
loss_breakdown = agent.get_last_loss_breakdown()
|
||||||
|
extra_train_metrics = {
|
||||||
|
f"train/{key}": value
|
||||||
|
for key, value in loss_breakdown.items()
|
||||||
|
if value is not None and key != 'loss'
|
||||||
|
}
|
||||||
|
if extra_train_metrics:
|
||||||
|
_log_to_swanlab(swanlab_module, extra_train_metrics, step=step)
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# 检查点保存与验证
|
# 检查点保存与验证
|
||||||
|
|||||||
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
267
roboimi/scripts/refresh_experiment_suite_status.py
Executable file
@@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
STEP_PAT = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||||
|
BAR_PAT = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_chunks(text: str):
|
||||||
|
for part in re.split(r"[\r\n]+", text):
|
||||||
|
part = part.strip()
|
||||||
|
if part:
|
||||||
|
yield part
|
||||||
|
|
||||||
|
|
||||||
|
def parse_latest_line(text: str) -> tuple[str, int | None]:
|
||||||
|
latest_line = ""
|
||||||
|
latest_step = None
|
||||||
|
for line in normalize_chunks(text):
|
||||||
|
if "步骤" not in line and "训练中:" not in line:
|
||||||
|
continue
|
||||||
|
latest_line = line
|
||||||
|
match = STEP_PAT.search(line) or BAR_PAT.search(line)
|
||||||
|
if match:
|
||||||
|
latest_step = int(match.group(1))
|
||||||
|
return latest_line, latest_step
|
||||||
|
|
||||||
|
|
||||||
|
def now_iso() -> str:
|
||||||
|
return dt.datetime.now(
|
||||||
|
dt.timezone(dt.timedelta(hours=8)),
|
||||||
|
).isoformat(timespec="seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
|
||||||
|
return subprocess.run(cmd, capture_output=True, text=True, check=check)
|
||||||
|
|
||||||
|
|
||||||
|
def probe_local(run: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
pid = str(run["pid"])
|
||||||
|
ps = run_cmd(["ps", "-p", pid, "-o", "pid=,stat=,etime=,args="], check=False)
|
||||||
|
log_path = pathlib.Path(run["log_path"])
|
||||||
|
latest_line = ""
|
||||||
|
latest_step = None
|
||||||
|
if log_path.exists():
|
||||||
|
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||||
|
return {
|
||||||
|
"alive": bool(ps.stdout.strip()),
|
||||||
|
"ps": ps.stdout.strip(),
|
||||||
|
"log_exists": log_path.exists(),
|
||||||
|
"latest_line": latest_line,
|
||||||
|
"latest_step": latest_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remote_probe(host: str, remote_user: str, runs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
payload = [
|
||||||
|
{
|
||||||
|
"run_id": run["run_id"],
|
||||||
|
"pid": str(run["pid"]),
|
||||||
|
"log_path": run["log_path"],
|
||||||
|
}
|
||||||
|
for run in runs
|
||||||
|
]
|
||||||
|
remote_py = r"""
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
payload = json.loads(sys.argv[1])
|
||||||
|
step_pat = re.compile(r"步骤\s+(\d+)/(\d+)")
|
||||||
|
bar_pat = re.compile(r"\|\s*(\d+)/(\d+)")
|
||||||
|
|
||||||
|
def normalize_chunks(text):
|
||||||
|
for part in re.split(r"[\r\n]+", text):
|
||||||
|
part = part.strip()
|
||||||
|
if part:
|
||||||
|
yield part
|
||||||
|
|
||||||
|
def parse_latest_line(text):
|
||||||
|
latest_line = ""
|
||||||
|
latest_step = None
|
||||||
|
for line in normalize_chunks(text):
|
||||||
|
if "步骤" not in line and "训练中:" not in line:
|
||||||
|
continue
|
||||||
|
latest_line = line
|
||||||
|
match = step_pat.search(line) or bar_pat.search(line)
|
||||||
|
if match:
|
||||||
|
latest_step = int(match.group(1))
|
||||||
|
return latest_line, latest_step
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for item in payload:
|
||||||
|
try:
|
||||||
|
ps = subprocess.run(
|
||||||
|
["ps", "-p", item["pid"], "-o", "pid=,stat=,etime=,args="],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
log_path = pathlib.Path(item["log_path"])
|
||||||
|
latest_line = ""
|
||||||
|
latest_step = None
|
||||||
|
if log_path.exists():
|
||||||
|
latest_line, latest_step = parse_latest_line(log_path.read_text(errors="replace"))
|
||||||
|
out[item["run_id"]] = {
|
||||||
|
"alive": bool(ps.stdout.strip()),
|
||||||
|
"ps": ps.stdout.strip(),
|
||||||
|
"log_exists": log_path.exists(),
|
||||||
|
"latest_line": latest_line,
|
||||||
|
"latest_step": latest_step,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
out[item["run_id"]] = {
|
||||||
|
"alive": False,
|
||||||
|
"ps": "",
|
||||||
|
"log_exists": False,
|
||||||
|
"latest_line": "",
|
||||||
|
"latest_step": None,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
print(json.dumps(out, ensure_ascii=False))
|
||||||
|
"""
|
||||||
|
remote_target = host if "@" in host else f"{remote_user}@{host}"
|
||||||
|
remote_cmd = (
|
||||||
|
f"python3 -c {shlex.quote(remote_py)} "
|
||||||
|
f"{shlex.quote(json.dumps(payload, ensure_ascii=False))}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
res = run_cmd(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
"-F",
|
||||||
|
"/dev/null",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=accept-new",
|
||||||
|
remote_target,
|
||||||
|
remote_cmd,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return json.loads(res.stdout)
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
error = (exc.stderr or exc.stdout or str(exc)).strip()
|
||||||
|
return {
|
||||||
|
run["run_id"]: {
|
||||||
|
"alive": False,
|
||||||
|
"ps": "",
|
||||||
|
"log_exists": False,
|
||||||
|
"latest_line": "",
|
||||||
|
"latest_step": None,
|
||||||
|
"error": f"ssh_failed: {error}",
|
||||||
|
}
|
||||||
|
for run in runs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def append_notes(notes_path: pathlib.Path, snapshot_at: str, runs: list[dict[str, Any]]) -> None:
|
||||||
|
lines = [f"\n## Status snapshot {snapshot_at}"]
|
||||||
|
for run in runs:
|
||||||
|
lines.append(
|
||||||
|
(
|
||||||
|
f"- {run['run_id']}: host={run['host']} gpu={run['gpu']} "
|
||||||
|
f"alive={run.get('alive', False)} step={run.get('latest_step')} "
|
||||||
|
f"pid={run['pid']}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if run.get("latest_line"):
|
||||||
|
lines.append(f" - latest_line: `{run['latest_line']}`")
|
||||||
|
if run.get("error"):
|
||||||
|
lines.append(f" - error: `{run['error']}`")
|
||||||
|
with notes_path.open("a", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(lines) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("suite_dir", type=pathlib.Path)
|
||||||
|
parser.add_argument("--remote-user", default="droid")
|
||||||
|
parser.add_argument("--append-notes", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
suite_dir = args.suite_dir.resolve()
|
||||||
|
status_path = suite_dir / "status.json"
|
||||||
|
notes_path = suite_dir / "notes.md"
|
||||||
|
monitor_dir = suite_dir / "monitor_logs"
|
||||||
|
monitor_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
status = json.loads(status_path.read_text(encoding="utf-8"))
|
||||||
|
runs: list[dict[str, Any]] = status["runs"]
|
||||||
|
snapshot_at = now_iso()
|
||||||
|
|
||||||
|
by_host: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||||
|
for run in runs:
|
||||||
|
by_host[run["host"]].append(run)
|
||||||
|
|
||||||
|
results: dict[str, dict[str, Any]] = {}
|
||||||
|
for host, host_runs in by_host.items():
|
||||||
|
if host == "local":
|
||||||
|
for run in host_runs:
|
||||||
|
results[run["run_id"]] = probe_local(run)
|
||||||
|
else:
|
||||||
|
results.update(remote_probe(host, args.remote_user, host_runs))
|
||||||
|
|
||||||
|
alive_count = 0
|
||||||
|
for run in runs:
|
||||||
|
result = results[run["run_id"]]
|
||||||
|
run["alive"] = result["alive"]
|
||||||
|
run["ps"] = result["ps"]
|
||||||
|
run["log_exists"] = result["log_exists"]
|
||||||
|
run["latest_line"] = result["latest_line"]
|
||||||
|
run["latest_step"] = result["latest_step"]
|
||||||
|
run["last_verified_at"] = snapshot_at
|
||||||
|
if "error" in result:
|
||||||
|
run["error"] = result["error"]
|
||||||
|
else:
|
||||||
|
run.pop("error", None)
|
||||||
|
run["status"] = "running" if result["alive"] else "stopped"
|
||||||
|
alive_count += int(result["alive"])
|
||||||
|
|
||||||
|
status["last_verified_at"] = snapshot_at
|
||||||
|
status["alive_count"] = alive_count
|
||||||
|
status["total_runs"] = len(runs)
|
||||||
|
|
||||||
|
status_path.write_text(json.dumps(status, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
snapshot_payload = {
|
||||||
|
"suite_name": status.get("suite_name"),
|
||||||
|
"snapshot_at": snapshot_at,
|
||||||
|
"alive_count": alive_count,
|
||||||
|
"total_runs": len(runs),
|
||||||
|
"runs": {run["run_id"]: results[run["run_id"]] for run in runs},
|
||||||
|
}
|
||||||
|
timestamp_slug = snapshot_at.replace(":", "").replace("+", "_").replace("-", "")
|
||||||
|
snapshot_path = monitor_dir / f"status-{timestamp_slug}.json"
|
||||||
|
snapshot_path.write_text(
|
||||||
|
json.dumps(snapshot_payload, ensure_ascii=False, indent=2) + "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.append_notes:
|
||||||
|
append_notes(notes_path, snapshot_at, runs)
|
||||||
|
|
||||||
|
print(json.dumps(snapshot_payload, ensure_ascii=False, indent=2))
|
||||||
|
print(f"\nstatus_json={status_path}")
|
||||||
|
print(f"snapshot_json={snapshot_path}")
|
||||||
|
if args.append_notes:
|
||||||
|
print(f"notes_md={notes_path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -28,6 +28,7 @@ class VLAAgent(nn.Module):
|
|||||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||||
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
cond_projector=None, # 可选:将视觉+状态条件投影到head期望维度
|
||||||
|
extra_condition_tokens: int = 0, # 可选:额外条件token数量(例如未来预测embedding)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 保存参数
|
# 保存参数
|
||||||
@@ -39,6 +40,9 @@ class VLAAgent(nn.Module):
|
|||||||
self.num_action_steps = num_action_steps
|
self.num_action_steps = num_action_steps
|
||||||
self.inference_steps = inference_steps
|
self.inference_steps = inference_steps
|
||||||
self.head_type = head_type # 'unet' 或 'transformer'
|
self.head_type = head_type # 'unet' 或 'transformer'
|
||||||
|
self.extra_condition_tokens = int(extra_condition_tokens)
|
||||||
|
if self.extra_condition_tokens < 0:
|
||||||
|
raise ValueError(f"extra_condition_tokens must be >= 0, got {self.extra_condition_tokens}")
|
||||||
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||||
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||||
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||||
@@ -71,11 +75,14 @@ class VLAAgent(nn.Module):
|
|||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
normalization_type=normalization_type
|
normalization_type=normalization_type
|
||||||
)
|
)
|
||||||
|
self.dataset_stats = dataset_stats
|
||||||
|
|
||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
|
self.state_encoder = state_encoder
|
||||||
if self.camera_names is not None:
|
if self.camera_names is not None:
|
||||||
self.vision_encoder.camera_names = self.camera_names
|
self.vision_encoder.camera_names = self.camera_names
|
||||||
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
self.condition_tokens_per_step = int(getattr(self.vision_encoder, 'tokens_per_step', 1))
|
||||||
|
self.state_feature_dim = int(getattr(self.state_encoder, 'output_dim', obs_dim))
|
||||||
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
joint_vision_dim = getattr(self.vision_encoder, 'joint_output_dim', None)
|
||||||
if joint_vision_dim is not None:
|
if joint_vision_dim is not None:
|
||||||
per_token_vision_dim = int(joint_vision_dim)
|
per_token_vision_dim = int(joint_vision_dim)
|
||||||
@@ -87,8 +94,11 @@ class VLAAgent(nn.Module):
|
|||||||
else:
|
else:
|
||||||
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
per_token_vision_dim = int(single_cam_feat_dim) * int(num_cams)
|
||||||
|
|
||||||
self.condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
self.history_condition_sequence_length = self.obs_horizon * self.condition_tokens_per_step
|
||||||
self.raw_per_step_cond_dim = per_token_vision_dim + obs_dim
|
self.condition_sequence_length = (
|
||||||
|
self.history_condition_sequence_length + self.extra_condition_tokens
|
||||||
|
)
|
||||||
|
self.raw_per_step_cond_dim = per_token_vision_dim + self.state_feature_dim
|
||||||
if cond_projector is None:
|
if cond_projector is None:
|
||||||
self.cond_projector = None
|
self.cond_projector = None
|
||||||
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
self.per_step_cond_dim = self.raw_per_step_cond_dim
|
||||||
@@ -139,7 +149,6 @@ class VLAAgent(nn.Module):
|
|||||||
global_cond_dim=self.global_cond_dim
|
global_cond_dim=self.global_cond_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state_encoder = state_encoder
|
|
||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
|
|
||||||
# 初始化队列(用于在线推理)
|
# 初始化队列(用于在线推理)
|
||||||
@@ -220,7 +229,7 @@ class VLAAgent(nn.Module):
|
|||||||
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
)
|
)
|
||||||
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
cond = cond.reshape(batch_size, obs_steps * token_count, self.per_step_cond_dim)
|
||||||
expected_length = self.condition_sequence_length
|
expected_length = self.history_condition_sequence_length
|
||||||
if cond.shape[1] != expected_length:
|
if cond.shape[1] != expected_length:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
f"条件序列长度不匹配: got {cond.shape[1]}, expected {expected_length}"
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Dict, Optional
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from roboimi.vla.agent import VLAAgent
|
from roboimi.vla.agent import VLAAgent
|
||||||
@@ -15,14 +18,59 @@ except ImportError: # pragma: no cover
|
|||||||
|
|
||||||
|
|
||||||
class IMFVLAAgent(VLAAgent):
|
class IMFVLAAgent(VLAAgent):
|
||||||
def __init__(self, *args, inference_steps: int = 1, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
inference_steps: int = 1,
|
||||||
|
lewm_history_horizon: Optional[int] = None,
|
||||||
|
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||||
|
lewm_predictor: Optional[nn.Module] = None,
|
||||||
|
lewm_pred_projector: Optional[nn.Module] = None,
|
||||||
|
lewm_sigreg: Optional[nn.Module] = None,
|
||||||
|
lewm_sigreg_weight: float = 0.09,
|
||||||
|
lewm_loss_weight: float = 0.0,
|
||||||
|
lewm_pretrained_ckpt: Optional[str | Path | Mapping[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if inference_steps != 1:
|
if inference_steps != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'IMFVLAAgent only supports one-step inference; '
|
'IMFVLAAgent only supports one-step inference; '
|
||||||
f'inference_steps must be 1, got {inference_steps}.'
|
f'inference_steps must be 1, got {inference_steps}.'
|
||||||
)
|
)
|
||||||
|
lewm_query_offsets = tuple(int(offset) for offset in (lewm_query_offsets or ()))
|
||||||
|
inferred_extra_condition_tokens = len(lewm_query_offsets) if lewm_query_offsets else 0
|
||||||
|
kwargs.setdefault('extra_condition_tokens', inferred_extra_condition_tokens)
|
||||||
|
self.__dict__['lewm_history_horizon'] = int(lewm_history_horizon or kwargs.get('obs_horizon', 1))
|
||||||
|
self.__dict__['lewm_query_offsets'] = lewm_query_offsets
|
||||||
|
self.__dict__['lewm_predictor'] = lewm_predictor
|
||||||
|
self.__dict__['lewm_pred_projector'] = lewm_pred_projector or nn.Identity()
|
||||||
|
self.__dict__['lewm_sigreg'] = lewm_sigreg
|
||||||
|
self.__dict__['lewm_sigreg_weight'] = float(lewm_sigreg_weight)
|
||||||
|
self.__dict__['lewm_loss_weight'] = float(lewm_loss_weight)
|
||||||
|
self.__dict__['_last_loss_breakdown'] = {
|
||||||
|
'action_loss': 0.0,
|
||||||
|
'lewm_pred_loss': 0.0,
|
||||||
|
'lewm_sigreg_loss': 0.0,
|
||||||
|
'lewm_loss': 0.0,
|
||||||
|
'loss': 0.0,
|
||||||
|
}
|
||||||
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||||
self.inference_steps = 1
|
self.inference_steps = 1
|
||||||
|
self.lewm_history_horizon = int(lewm_history_horizon or self.obs_horizon)
|
||||||
|
self.lewm_predictor = lewm_predictor
|
||||||
|
self.lewm_pred_projector = lewm_pred_projector or nn.Identity()
|
||||||
|
self.lewm_sigreg = lewm_sigreg
|
||||||
|
self.lewm_sigreg_weight = float(lewm_sigreg_weight)
|
||||||
|
if self.lewm_predictor is None and self.extra_condition_tokens > 0:
|
||||||
|
raise ValueError(
|
||||||
|
'extra_condition_tokens > 0 requires lewm_predictor to be provided'
|
||||||
|
)
|
||||||
|
if self.lewm_predictor is not None and self.extra_condition_tokens != inferred_extra_condition_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
'extra_condition_tokens must equal len(lewm_query_offsets) when lewm_predictor is enabled'
|
||||||
|
)
|
||||||
|
if lewm_pretrained_ckpt is not None:
|
||||||
|
self.load_lewm_pretrained_components(lewm_pretrained_ckpt)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -119,14 +167,181 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
delta = self._broadcast_batch_time(t - r, z_t)
|
delta = self._broadcast_batch_time(t - r, z_t)
|
||||||
return z_t - delta * u
|
return z_t - delta * u
|
||||||
|
|
||||||
|
def _normalize_qpos_for_lewm(self, qpos: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.normalization.enabled:
|
||||||
|
return qpos
|
||||||
|
|
||||||
|
qpos_mean = getattr(self.normalization, 'qpos_mean', None)
|
||||||
|
qpos_std = getattr(self.normalization, 'qpos_std', None)
|
||||||
|
if qpos_mean is not None and qpos_std is not None:
|
||||||
|
return (qpos - qpos_mean) / qpos_std
|
||||||
|
if isinstance(self.dataset_stats, dict):
|
||||||
|
mean = self.dataset_stats.get('qpos_mean', None)
|
||||||
|
std = self.dataset_stats.get('qpos_std', None)
|
||||||
|
if mean is not None and std is not None:
|
||||||
|
mean = torch.as_tensor(mean, dtype=qpos.dtype, device=qpos.device)
|
||||||
|
std = torch.as_tensor(std, dtype=qpos.dtype, device=qpos.device)
|
||||||
|
return (qpos - mean) / std
|
||||||
|
return self.normalization.normalize_qpos(qpos)
|
||||||
|
|
||||||
|
def _project_lewm_future_tokens(self, predicted_tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
if predicted_tokens.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"expected predicted future tokens to be 3D, got rank {predicted_tokens.ndim}"
|
||||||
|
)
|
||||||
|
batch_size, token_count, token_dim = predicted_tokens.shape
|
||||||
|
flattened = predicted_tokens.reshape(batch_size * token_count, token_dim)
|
||||||
|
projected = self.lewm_pred_projector(flattened)
|
||||||
|
if projected.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"expected lewm_pred_projector to return rank-2 tensors, got rank {projected.ndim}"
|
||||||
|
)
|
||||||
|
return projected.reshape(batch_size, token_count, projected.shape[-1])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_checkpoint_payload(
|
||||||
|
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||||
|
) -> Mapping[str, torch.Tensor]:
|
||||||
|
if isinstance(checkpoint_or_path, (str, Path)):
|
||||||
|
payload = torch.load(Path(checkpoint_or_path), map_location='cpu', weights_only=False)
|
||||||
|
else:
|
||||||
|
payload = checkpoint_or_path
|
||||||
|
state_dict = payload.get('state_dict', payload)
|
||||||
|
if not isinstance(state_dict, Mapping):
|
||||||
|
raise TypeError('checkpoint payload must contain a mapping state_dict')
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_prefixed_state_dict(
|
||||||
|
state_dict: Mapping[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
extracted = {
|
||||||
|
key[len(prefix):]: value
|
||||||
|
for key, value in state_dict.items()
|
||||||
|
if key.startswith(prefix)
|
||||||
|
}
|
||||||
|
if not extracted:
|
||||||
|
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _adapt_and_load_state_dict(
|
||||||
|
module: nn.Module,
|
||||||
|
incoming_state_dict: Mapping[str, torch.Tensor],
|
||||||
|
*,
|
||||||
|
query_key: str = 'query_tokens',
|
||||||
|
pos_key: str = 'pos_embedding',
|
||||||
|
) -> None:
|
||||||
|
current_state_dict = module.state_dict()
|
||||||
|
adapted_state_dict = dict(current_state_dict)
|
||||||
|
for key, current_tensor in current_state_dict.items():
|
||||||
|
if key not in incoming_state_dict:
|
||||||
|
continue
|
||||||
|
source_tensor = incoming_state_dict[key]
|
||||||
|
if source_tensor.shape == current_tensor.shape:
|
||||||
|
adapted_state_dict[key] = source_tensor
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||||
|
patched = current_tensor.clone()
|
||||||
|
if key == query_key:
|
||||||
|
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||||
|
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||||
|
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||||
|
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||||
|
else:
|
||||||
|
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||||
|
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||||
|
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||||
|
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||||
|
adapted_state_dict[key] = patched
|
||||||
|
|
||||||
|
module.load_state_dict(adapted_state_dict, strict=True)
|
||||||
|
|
||||||
|
def load_lewm_pretrained_components(
|
||||||
|
self,
|
||||||
|
checkpoint_or_path: str | Path | Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
state_dict = self._load_checkpoint_payload(checkpoint_or_path)
|
||||||
|
|
||||||
|
if hasattr(self.vision_encoder, 'load_lewm_checkpoint'):
|
||||||
|
self.vision_encoder.load_lewm_checkpoint({'state_dict': state_dict})
|
||||||
|
else:
|
||||||
|
vision_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.encoder.')
|
||||||
|
self.vision_encoder.load_state_dict(vision_state_dict, strict=True)
|
||||||
|
|
||||||
|
state_encoder_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.state_encoder.')
|
||||||
|
self.state_encoder.load_state_dict(state_encoder_state_dict, strict=True)
|
||||||
|
|
||||||
|
projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.projector.proj.')
|
||||||
|
mapped_projector_state_dict = {
|
||||||
|
f'linear.{key}': value
|
||||||
|
for key, value in projector_state_dict.items()
|
||||||
|
}
|
||||||
|
self.cond_projector.load_state_dict(mapped_projector_state_dict, strict=True)
|
||||||
|
|
||||||
|
if self.lewm_predictor is not None:
|
||||||
|
predictor_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.predictor.')
|
||||||
|
self._adapt_and_load_state_dict(self.lewm_predictor, predictor_state_dict)
|
||||||
|
|
||||||
|
if self.lewm_pred_projector is not None:
|
||||||
|
pred_projector_state_dict = self._extract_prefixed_state_dict(state_dict, 'model.pred_proj.')
|
||||||
|
self.lewm_pred_projector.load_state_dict(pred_projector_state_dict, strict=True)
|
||||||
|
|
||||||
|
def _build_full_condition(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
proprioception,
|
||||||
|
*,
|
||||||
|
lewm_images=None,
|
||||||
|
lewm_proprioception=None,
|
||||||
|
):
|
||||||
|
normalized_proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
history_cond = self._build_cond(images, normalized_proprioception)
|
||||||
|
predicted_future_tokens = None
|
||||||
|
lewm_history_cond = None
|
||||||
|
|
||||||
|
if self.lewm_predictor is None:
|
||||||
|
return history_cond, predicted_future_tokens, lewm_history_cond
|
||||||
|
|
||||||
|
lewm_images = lewm_images if lewm_images is not None else images
|
||||||
|
lewm_proprioception = (
|
||||||
|
lewm_proprioception if lewm_proprioception is not None else proprioception
|
||||||
|
)
|
||||||
|
lewm_history_cond = self._build_cond(
|
||||||
|
lewm_images,
|
||||||
|
self._normalize_qpos_for_lewm(lewm_proprioception),
|
||||||
|
)
|
||||||
|
predicted_future_tokens = self.lewm_predictor(lewm_history_cond)
|
||||||
|
predicted_future_tokens = self._project_lewm_future_tokens(predicted_future_tokens)
|
||||||
|
cond = torch.cat([history_cond, predicted_future_tokens], dim=1)
|
||||||
|
if cond.shape[1] != self.condition_sequence_length:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"完整条件序列长度不匹配: got {cond.shape[1]}, expected {self.condition_sequence_length}"
|
||||||
|
)
|
||||||
|
if cond.shape[-1] != self.per_step_cond_dim:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"完整条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
return cond, predicted_future_tokens, lewm_history_cond
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _masked_mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(pred, target)
|
||||||
|
|
||||||
def compute_loss(self, batch):
|
def compute_loss(self, batch):
|
||||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
action_is_pad = batch.get('action_is_pad', None)
|
action_is_pad = batch.get('action_is_pad', None)
|
||||||
batch_size = actions.shape[0]
|
batch_size = actions.shape[0]
|
||||||
|
|
||||||
states = self.normalization.normalize_qpos(states)
|
|
||||||
actions = self.normalization.normalize_action(actions)
|
actions = self.normalization.normalize_action(actions)
|
||||||
cond = self._build_cond(images, states)
|
cond, predicted_future_tokens, lewm_history_cond = self._build_full_condition(
|
||||||
|
images,
|
||||||
|
states,
|
||||||
|
lewm_images=batch.get('lewm_images', None),
|
||||||
|
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||||
|
)
|
||||||
|
|
||||||
x = actions
|
x = actions
|
||||||
e = torch.randn_like(x)
|
e = torch.randn_like(x)
|
||||||
@@ -146,16 +361,103 @@ class IMFVLAAgent(VLAAgent):
|
|||||||
if action_is_pad is not None:
|
if action_is_pad is not None:
|
||||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||||
valid_count = mask.sum() * loss.shape[-1]
|
valid_count = mask.sum() * loss.shape[-1]
|
||||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
action_loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||||
else:
|
else:
|
||||||
loss = loss.mean()
|
action_loss = loss.mean()
|
||||||
return loss
|
|
||||||
|
lewm_pred_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||||
|
lewm_sigreg_loss = torch.zeros((), device=action_loss.device, dtype=action_loss.dtype)
|
||||||
|
if predicted_future_tokens is not None:
|
||||||
|
lewm_future_images = batch.get('lewm_future_images', None)
|
||||||
|
lewm_future_qpos = batch.get('lewm_future_qpos', None)
|
||||||
|
if lewm_future_images is not None and lewm_future_qpos is not None:
|
||||||
|
future_target = self._build_cond(
|
||||||
|
lewm_future_images,
|
||||||
|
self._normalize_qpos_for_lewm(lewm_future_qpos),
|
||||||
|
)
|
||||||
|
lewm_pred_loss = self._masked_mse_loss(predicted_future_tokens, future_target)
|
||||||
|
if self.lewm_sigreg is not None and lewm_history_cond is not None:
|
||||||
|
lewm_sigreg_loss = self.lewm_sigreg(lewm_history_cond.transpose(0, 1))
|
||||||
|
|
||||||
|
lewm_loss = lewm_pred_loss + self.lewm_sigreg_weight * lewm_sigreg_loss
|
||||||
|
total_loss = action_loss + self.lewm_loss_weight * lewm_loss
|
||||||
|
self._last_loss_breakdown = {
|
||||||
|
'action_loss': float(action_loss.detach().item()),
|
||||||
|
'lewm_pred_loss': float(lewm_pred_loss.detach().item()),
|
||||||
|
'lewm_sigreg_loss': float(lewm_sigreg_loss.detach().item()),
|
||||||
|
'lewm_loss': float(lewm_loss.detach().item()),
|
||||||
|
'loss': float(total_loss.detach().item()),
|
||||||
|
}
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
def get_last_loss_breakdown(self) -> Dict[str, float]:
|
||||||
|
return dict(self._last_loss_breakdown)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
super().reset()
|
||||||
|
if self.lewm_predictor is not None:
|
||||||
|
self._queues['lewm_qpos'] = deque(maxlen=self.lewm_history_horizon)
|
||||||
|
self._queues['lewm_images'] = deque(maxlen=self.lewm_history_horizon)
|
||||||
|
|
||||||
|
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||||
|
super()._populate_queues(observation)
|
||||||
|
if self.lewm_predictor is None:
|
||||||
|
return
|
||||||
|
if 'qpos' in observation:
|
||||||
|
self._queues['lewm_qpos'].append(observation['qpos'].clone())
|
||||||
|
if 'images' in observation:
|
||||||
|
ordered_images = self._order_images(observation['images'])
|
||||||
|
self._queues['lewm_images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||||
|
|
||||||
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||||
|
batch = super()._prepare_observation_batch()
|
||||||
|
if self.lewm_predictor is None:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
qpos_list = list(self._queues['lewm_qpos'])
|
||||||
|
images_list = list(self._queues['lewm_images'])
|
||||||
|
if len(qpos_list) == 0 or len(images_list) == 0:
|
||||||
|
raise ValueError("LeWM 观测队列为空,请先调用 _populate_queues 添加观测")
|
||||||
|
while len(qpos_list) < self.lewm_history_horizon:
|
||||||
|
qpos_list.append(qpos_list[-1])
|
||||||
|
while len(images_list) < self.lewm_history_horizon:
|
||||||
|
images_list.append(images_list[-1])
|
||||||
|
|
||||||
|
batch['lewm_qpos'] = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||||
|
batch['lewm_images'] = {}
|
||||||
|
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||||
|
for cam_name in camera_names:
|
||||||
|
batch['lewm_images'][cam_name] = torch.stack(
|
||||||
|
[img[cam_name] for img in images_list],
|
||||||
|
dim=0,
|
||||||
|
).unsqueeze(0)
|
||||||
|
return batch
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action(self, images, proprioception):
|
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.predict_action(
|
||||||
|
batch['images'],
|
||||||
|
batch['qpos'],
|
||||||
|
lewm_images=batch.get('lewm_images', None),
|
||||||
|
lewm_proprioception=batch.get('lewm_qpos', None),
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
proprioception,
|
||||||
|
*,
|
||||||
|
lewm_images=None,
|
||||||
|
lewm_proprioception=None,
|
||||||
|
):
|
||||||
batch_size = proprioception.shape[0]
|
batch_size = proprioception.shape[0]
|
||||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
cond, _predicted_future_tokens, _lewm_history_cond = self._build_full_condition(
|
||||||
cond = self._build_cond(images, proprioception)
|
images,
|
||||||
|
proprioception,
|
||||||
|
lewm_images=lewm_images,
|
||||||
|
lewm_proprioception=lewm_proprioception,
|
||||||
|
)
|
||||||
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||||
action = self._sample_one_step(z_t, cond=cond)
|
action = self._sample_one_step(z_t, cond=cond)
|
||||||
return self.normalization.denormalize_action(action)
|
return self.normalization.denormalize_action(action)
|
||||||
|
|||||||
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
77
roboimi/vla/conf/agent/lewm_resnet_query_imf_attnres.yaml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: lewm_resnet_query_fusion
|
||||||
|
- /modules@state_encoder: lewm_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /modules@cond_projector: linear_condition_projector
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 8
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
num_views: ${agent.num_cams}
|
||||||
|
|
||||||
|
cond_projector:
|
||||||
|
output_dim: 288
|
||||||
|
|
||||||
|
lewm_history_horizon: 3
|
||||||
|
lewm_query_offsets: [8]
|
||||||
|
extra_condition_tokens: ${len:${agent.lewm_query_offsets}}
|
||||||
|
lewm_loss_weight: 1.0
|
||||||
|
lewm_sigreg_weight: 0.09
|
||||||
|
lewm_pretrained_ckpt: null
|
||||||
|
|
||||||
|
lewm_sigreg:
|
||||||
|
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg
|
||||||
|
knots: 17
|
||||||
|
num_proj: 1024
|
||||||
|
|
||||||
|
lewm_predictor:
|
||||||
|
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.QueryTokenPredictor
|
||||||
|
num_frames: ${agent.lewm_history_horizon}
|
||||||
|
query_offsets: ${agent.lewm_query_offsets}
|
||||||
|
input_dim: ${agent.cond_projector.output_dim}
|
||||||
|
hidden_dim: ${agent.cond_projector.output_dim}
|
||||||
|
output_dim: ${agent.cond_projector.output_dim}
|
||||||
|
depth: 6
|
||||||
|
heads: 16
|
||||||
|
mlp_dim: 2048
|
||||||
|
dim_head: 64
|
||||||
|
dropout: 0.1
|
||||||
|
emb_dropout: 0.0
|
||||||
|
|
||||||
|
lewm_pred_projector:
|
||||||
|
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMProjectorMLP
|
||||||
|
input_dim: ${agent.cond_projector.output_dim}
|
||||||
|
hidden_dim: 2048
|
||||||
|
output_dim: ${agent.cond_projector.output_dim}
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 288
|
||||||
|
n_emb: 384
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
7
roboimi/vla/conf/backbone/lewm_resnet_query_fusion.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
_target_: roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone
|
||||||
|
|
||||||
|
view_feature_dim: 96
|
||||||
|
num_views: ${agent.num_cams}
|
||||||
|
view_encoder_mode: separate
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
checkpoint_path: null
|
||||||
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
5
roboimi/vla/conf/modules/lewm_state_encoder.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
_target_: roboimi.vla.modules.encoders.LeWMStateEncoder
|
||||||
|
|
||||||
|
input_dim: ${agent.obs_dim}
|
||||||
|
hidden_dim: 256
|
||||||
|
output_dim: 64
|
||||||
@@ -24,6 +24,8 @@ class SimpleRobotDataset(Dataset):
|
|||||||
camera_names: List[str] = None,
|
camera_names: List[str] = None,
|
||||||
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
image_resize_shape: Optional[Sequence[int]] = (224, 224),
|
||||||
max_open_files: int = 64,
|
max_open_files: int = 64,
|
||||||
|
lewm_history_horizon: Optional[int] = None,
|
||||||
|
lewm_query_offsets: Optional[Sequence[int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -42,6 +44,13 @@ class SimpleRobotDataset(Dataset):
|
|||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.camera_names = camera_names or []
|
self.camera_names = camera_names or []
|
||||||
|
self.lewm_history_horizon = (
|
||||||
|
int(lewm_history_horizon) if lewm_history_horizon is not None else None
|
||||||
|
)
|
||||||
|
self.lewm_query_offsets = (
|
||||||
|
tuple(int(offset) for offset in lewm_query_offsets)
|
||||||
|
if lewm_query_offsets is not None else ()
|
||||||
|
)
|
||||||
self.image_resize_shape = (
|
self.image_resize_shape = (
|
||||||
tuple(int(v) for v in image_resize_shape)
|
tuple(int(v) for v in image_resize_shape)
|
||||||
if image_resize_shape is not None else None
|
if image_resize_shape is not None else None
|
||||||
@@ -220,6 +229,60 @@ class SimpleRobotDataset(Dataset):
|
|||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||||
|
|
||||||
|
if self.lewm_history_horizon is not None and self.lewm_history_horizon > 0:
|
||||||
|
lewm_observations = {
|
||||||
|
"state": [],
|
||||||
|
}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
lewm_observations[f"observation.{cam_name}"] = []
|
||||||
|
|
||||||
|
for delta in range(-self.lewm_history_horizon + 1, 1):
|
||||||
|
target_idx = idx + delta
|
||||||
|
if ep_start <= target_idx <= ep_end:
|
||||||
|
target_frame = self._load_frame(target_idx)
|
||||||
|
else:
|
||||||
|
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||||
|
target_frame = self._load_frame(boundary_idx)
|
||||||
|
|
||||||
|
lewm_observations["state"].append(target_frame["observation.state"])
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
lewm_observations[f"observation.{cam_name}"].append(
|
||||||
|
target_frame[f"observation.{cam_name}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result["lewm.observation.state"] = torch.stack(lewm_observations["state"])
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
result[f"lewm.observation.{cam_name}"] = torch.stack(
|
||||||
|
lewm_observations[f"observation.{cam_name}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.lewm_query_offsets:
|
||||||
|
lewm_future = {
|
||||||
|
"state": [],
|
||||||
|
}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
lewm_future[f"observation.{cam_name}"] = []
|
||||||
|
|
||||||
|
for offset in self.lewm_query_offsets:
|
||||||
|
target_idx = idx + offset
|
||||||
|
if ep_start <= target_idx <= ep_end:
|
||||||
|
target_frame = self._load_frame(target_idx)
|
||||||
|
else:
|
||||||
|
boundary_idx = ep_start if target_idx < ep_start else ep_end
|
||||||
|
target_frame = self._load_frame(boundary_idx)
|
||||||
|
|
||||||
|
lewm_future["state"].append(target_frame["observation.state"])
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
lewm_future[f"observation.{cam_name}"].append(
|
||||||
|
target_frame[f"observation.{cam_name}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result["lewm.future.state"] = torch.stack(lewm_future["state"])
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
result[f"lewm.future.{cam_name}"] = torch.stack(
|
||||||
|
lewm_future[f"observation.{cam_name}"]
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
# Backbone models
|
# Backbone models
|
||||||
__all__ = ["LEWMViTBackbone", "ResNetBackbone", "ResNetDiffusionBackbone", "SigLIP2DiffusionBackbone"]
|
__all__ = [
|
||||||
|
"LEWMViTBackbone",
|
||||||
|
"LeWMMultiViewResNetBackbone",
|
||||||
|
"QueryTokenPredictor",
|
||||||
|
"LeWMProjectorMLP",
|
||||||
|
"SIGReg",
|
||||||
|
"ResNetBackbone",
|
||||||
|
"ResNetDiffusionBackbone",
|
||||||
|
"SigLIP2DiffusionBackbone",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name):
|
def __getattr__(name):
|
||||||
@@ -9,6 +18,19 @@ def __getattr__(name):
|
|||||||
if name == "SigLIP2DiffusionBackbone":
|
if name == "SigLIP2DiffusionBackbone":
|
||||||
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
from .siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||||
return SigLIP2DiffusionBackbone
|
return SigLIP2DiffusionBackbone
|
||||||
|
if name in {"LeWMMultiViewResNetBackbone", "QueryTokenPredictor", "LeWMProjectorMLP", "SIGReg"}:
|
||||||
|
from .lewm_resnet_query_fusion import (
|
||||||
|
LeWMMultiViewResNetBackbone,
|
||||||
|
QueryTokenPredictor,
|
||||||
|
LeWMProjectorMLP,
|
||||||
|
SIGReg,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"LeWMMultiViewResNetBackbone": LeWMMultiViewResNetBackbone,
|
||||||
|
"QueryTokenPredictor": QueryTokenPredictor,
|
||||||
|
"LeWMProjectorMLP": LeWMProjectorMLP,
|
||||||
|
"SIGReg": SIGReg,
|
||||||
|
}[name]
|
||||||
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
if name in {"ResNetBackbone", "ResNetDiffusionBackbone"}:
|
||||||
from .resnet_diffusion import ResNetDiffusionBackbone
|
from .resnet_diffusion import ResNetDiffusionBackbone
|
||||||
return ResNetDiffusionBackbone
|
return ResNetDiffusionBackbone
|
||||||
|
|||||||
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
409
roboimi/vla/models/backbones/lewm_resnet_query_fusion.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialSoftmax2D(nn.Module):
|
||||||
|
"""Convert a feature map into expected 2D keypoint coordinates per channel."""
|
||||||
|
|
||||||
|
def forward(self, feature_map):
|
||||||
|
if feature_map.ndim != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"SpatialSoftmax2D expects a 4D tensor, got rank {feature_map.ndim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch, channels, height, width = feature_map.shape
|
||||||
|
scores = feature_map.reshape(batch, channels, height * width)
|
||||||
|
attention = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
ys = torch.linspace(-1.0, 1.0, height, device=feature_map.device, dtype=feature_map.dtype)
|
||||||
|
xs = torch.linspace(-1.0, 1.0, width, device=feature_map.device, dtype=feature_map.dtype)
|
||||||
|
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
||||||
|
grid_x = grid_x.reshape(1, 1, height * width)
|
||||||
|
grid_y = grid_y.reshape(1, 1, height * width)
|
||||||
|
|
||||||
|
expected_x = (attention * grid_x).sum(dim=-1)
|
||||||
|
expected_y = (attention * grid_y).sum(dim=-1)
|
||||||
|
return torch.cat([expected_x, expected_y], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet18SpatialEncoder(nn.Module):
|
||||||
|
"""Encode one camera view into a fixed-dimensional spatial-softmax embedding."""
|
||||||
|
|
||||||
|
def __init__(self, view_feature_dim=96):
|
||||||
|
super().__init__()
|
||||||
|
if view_feature_dim % 2 != 0:
|
||||||
|
raise ValueError("view_feature_dim must be even for spatial softmax features")
|
||||||
|
|
||||||
|
backbone = models.resnet18(weights=None)
|
||||||
|
if all(
|
||||||
|
hasattr(backbone, name)
|
||||||
|
for name in ("conv1", "bn1", "relu", "maxpool", "layer1", "layer2", "layer3", "layer4")
|
||||||
|
):
|
||||||
|
self.backbone = nn.Sequential(
|
||||||
|
backbone.conv1,
|
||||||
|
backbone.bn1,
|
||||||
|
backbone.relu,
|
||||||
|
backbone.maxpool,
|
||||||
|
backbone.layer1,
|
||||||
|
backbone.layer2,
|
||||||
|
backbone.layer3,
|
||||||
|
backbone.layer4,
|
||||||
|
)
|
||||||
|
feature_channels = 512
|
||||||
|
else:
|
||||||
|
children = list(backbone.children())
|
||||||
|
if len(children) < 1:
|
||||||
|
raise ValueError("resnet18 backbone must expose child modules")
|
||||||
|
truncated = children[:-2] if len(children) > 2 else children
|
||||||
|
self.backbone = nn.Sequential(*truncated)
|
||||||
|
with torch.no_grad():
|
||||||
|
dummy = torch.zeros(1, 3, 16, 16)
|
||||||
|
feature_channels = int(self.backbone(dummy).shape[1])
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(feature_channels, view_feature_dim // 2, kernel_size=1)
|
||||||
|
self.spatial_softmax = SpatialSoftmax2D()
|
||||||
|
self.output_dim = int(view_feature_dim)
|
||||||
|
|
||||||
|
def forward(self, pixels):
|
||||||
|
if pixels.ndim not in (4, 5):
|
||||||
|
raise ValueError(
|
||||||
|
f"ResNet18SpatialEncoder expects a 4D or 5D tensor, got rank {pixels.ndim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
needs_unflatten = pixels.ndim == 5
|
||||||
|
if needs_unflatten:
|
||||||
|
batch, steps, channels, height, width = pixels.shape
|
||||||
|
pixels = rearrange(pixels, "b t c h w -> (b t) c h w")
|
||||||
|
|
||||||
|
features = self.backbone(pixels.float())
|
||||||
|
features = self.proj(features)
|
||||||
|
embeddings = self.spatial_softmax(features)
|
||||||
|
|
||||||
|
if needs_unflatten:
|
||||||
|
embeddings = rearrange(embeddings, "(b t) d -> b t d", b=batch, t=steps)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class LeWMMultiViewResNetBackbone(VLABackbone):
|
||||||
|
"""RoboIMI-side LeWM multiview ResNet spatial-softmax encoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
view_feature_dim: int = 96,
|
||||||
|
num_views: int = 3,
|
||||||
|
view_encoder_mode: str = "shared",
|
||||||
|
camera_names: Sequence[str] = ("r_vis", "top", "front"),
|
||||||
|
checkpoint_path: str | Path | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if view_encoder_mode not in {"shared", "separate"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"view_encoder_mode must be 'shared' or 'separate', got {view_encoder_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.view_feature_dim = int(view_feature_dim)
|
||||||
|
self.num_views = int(num_views)
|
||||||
|
self.view_encoder_mode = view_encoder_mode
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
if len(self.camera_names) != self.num_views:
|
||||||
|
raise ValueError(
|
||||||
|
f"camera_names length({len(self.camera_names)}) must equal num_views({self.num_views})"
|
||||||
|
)
|
||||||
|
self.output_dim = self.view_feature_dim * self.num_views
|
||||||
|
self.joint_output_dim = self.output_dim
|
||||||
|
self.tokens_per_step = 1
|
||||||
|
|
||||||
|
if view_encoder_mode == "shared":
|
||||||
|
self.single_view_encoder = ResNet18SpatialEncoder(
|
||||||
|
view_feature_dim=view_feature_dim
|
||||||
|
)
|
||||||
|
self.view_encoders = None
|
||||||
|
else:
|
||||||
|
self.single_view_encoder = None
|
||||||
|
self.view_encoders = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResNet18SpatialEncoder(view_feature_dim=view_feature_dim)
|
||||||
|
for _ in range(num_views)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
self.load_lewm_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_state_dict(payload: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
|
||||||
|
state_dict = payload.get("state_dict", payload)
|
||||||
|
if not isinstance(state_dict, Mapping):
|
||||||
|
raise TypeError("checkpoint payload must contain a mapping state_dict")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_prefixed_state_dict(
|
||||||
|
state_dict: Mapping[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
extracted = {
|
||||||
|
key[len(prefix):]: value
|
||||||
|
for key, value in state_dict.items()
|
||||||
|
if key.startswith(prefix)
|
||||||
|
}
|
||||||
|
if not extracted:
|
||||||
|
raise KeyError(f"checkpoint missing parameters with prefix {prefix!r}")
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
def load_lewm_checkpoint(self, checkpoint_or_path: str | Path | Mapping[str, Any]) -> None:
|
||||||
|
if isinstance(checkpoint_or_path, (str, Path)):
|
||||||
|
payload = torch.load(Path(checkpoint_or_path), map_location="cpu", weights_only=False)
|
||||||
|
else:
|
||||||
|
payload = checkpoint_or_path
|
||||||
|
state_dict = self._unwrap_state_dict(payload)
|
||||||
|
encoder_state_dict = self._extract_prefixed_state_dict(state_dict, "model.encoder.")
|
||||||
|
self.load_state_dict(encoder_state_dict, strict=True)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
missing = [camera_name for camera_name in self.camera_names if camera_name not in images]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"image input missing required cameras. missing={missing}, expected={list(self.camera_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
first_image = images[self.camera_names[0]]
|
||||||
|
batch_size, steps = first_image.shape[:2]
|
||||||
|
view_embeddings = []
|
||||||
|
if self.view_encoder_mode == "shared":
|
||||||
|
for camera_name in self.camera_names:
|
||||||
|
view_embeddings.append(self.single_view_encoder(images[camera_name]))
|
||||||
|
else:
|
||||||
|
for single_view_encoder, camera_name in zip(self.view_encoders, self.camera_names):
|
||||||
|
view_embeddings.append(single_view_encoder(images[camera_name]))
|
||||||
|
|
||||||
|
embeddings = torch.cat(view_embeddings, dim=-1)
|
||||||
|
return embeddings.reshape(batch_size, steps, self.output_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
project_out = not (heads == 1 and dim_head == dim)
|
||||||
|
self.heads = heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||||
|
self.to_out = (
|
||||||
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
||||||
|
if project_out
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, causal=True):
|
||||||
|
x = self.norm(x)
|
||||||
|
drop = self.dropout if self.training else 0.0
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
||||||
|
out = rearrange(out, "b h t d -> b t (h d)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
||||||
|
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
output_dim,
|
||||||
|
depth,
|
||||||
|
heads,
|
||||||
|
dim_head,
|
||||||
|
mlp_dim,
|
||||||
|
dropout=0.0,
|
||||||
|
block_class=Block,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(hidden_dim)
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.input_proj = (
|
||||||
|
nn.Linear(input_dim, hidden_dim)
|
||||||
|
if input_dim != hidden_dim
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.cond_proj = (
|
||||||
|
nn.Linear(input_dim, hidden_dim)
|
||||||
|
if input_dim != hidden_dim
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
self.output_proj = (
|
||||||
|
nn.Linear(hidden_dim, output_dim)
|
||||||
|
if hidden_dim != output_dim
|
||||||
|
else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(block_class(hidden_dim, heads, dim_head, mlp_dim, dropout))
|
||||||
|
|
||||||
|
def forward(self, x, c=None):
|
||||||
|
x = self.input_proj(x)
|
||||||
|
if c is not None:
|
||||||
|
c = self.cond_proj(c)
|
||||||
|
for block in self.layers:
|
||||||
|
x = block(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
return self.output_proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryTokenPredictor(nn.Module):
|
||||||
|
"""History-only transformer predictor that decodes learned query tokens."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_frames,
|
||||||
|
query_offsets,
|
||||||
|
depth,
|
||||||
|
heads,
|
||||||
|
mlp_dim,
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
output_dim=None,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
emb_dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if num_frames <= 0:
|
||||||
|
raise ValueError(f"num_frames must be positive, got {num_frames}")
|
||||||
|
|
||||||
|
query_offsets = tuple(query_offsets)
|
||||||
|
if not query_offsets:
|
||||||
|
raise ValueError("query_offsets must contain at least one offset")
|
||||||
|
if any(offset <= 0 for offset in query_offsets):
|
||||||
|
raise ValueError(f"query_offsets must be positive, got {query_offsets}")
|
||||||
|
|
||||||
|
self.num_frames = int(num_frames)
|
||||||
|
self.query_offsets = query_offsets
|
||||||
|
self.num_query_tokens = len(query_offsets)
|
||||||
|
self.pos_embedding = nn.Parameter(
|
||||||
|
torch.randn(1, self.num_frames + self.num_query_tokens, input_dim)
|
||||||
|
)
|
||||||
|
self.query_tokens = nn.Parameter(
|
||||||
|
torch.randn(1, self.num_query_tokens, input_dim)
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
|
self.transformer = Transformer(
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
output_dim or input_dim,
|
||||||
|
depth,
|
||||||
|
heads,
|
||||||
|
dim_head,
|
||||||
|
mlp_dim,
|
||||||
|
dropout,
|
||||||
|
block_class=Block,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"QueryTokenPredictor expects a 3D tensor, got rank {x.ndim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
T = x.size(1)
|
||||||
|
if T > self.num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"input sequence length {T} exceeds configured num_frames {self.num_frames}"
|
||||||
|
)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(x.size(0), -1, -1)
|
||||||
|
tokens = torch.cat([x, query_tokens], dim=1)
|
||||||
|
tokens = tokens + self.pos_embedding[:, : tokens.size(1)]
|
||||||
|
tokens = self.dropout(tokens)
|
||||||
|
tokens = self.transformer(tokens)
|
||||||
|
return tokens[:, -self.num_query_tokens :]
|
||||||
|
|
||||||
|
|
||||||
|
class LeWMProjectorMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int = 288,
|
||||||
|
hidden_dim: int = 2048,
|
||||||
|
output_dim: int = 288,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = int(output_dim)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||||
|
nn.BatchNorm1d(int(hidden_dim)),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(int(hidden_dim), self.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SIGReg(nn.Module):
|
||||||
|
"""Sketch Isotropic Gaussian Regularizer, matching the original LeWM design."""
|
||||||
|
|
||||||
|
def __init__(self, knots: int = 17, num_proj: int = 1024) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_proj = int(num_proj)
|
||||||
|
t = torch.linspace(0, 3, int(knots), dtype=torch.float32)
|
||||||
|
dt = 3 / (int(knots) - 1)
|
||||||
|
weights = torch.full((int(knots),), 2 * dt, dtype=torch.float32)
|
||||||
|
weights[[0, -1]] = dt
|
||||||
|
window = torch.exp(-t.square() / 2.0)
|
||||||
|
self.register_buffer("t", t)
|
||||||
|
self.register_buffer("phi", window)
|
||||||
|
self.register_buffer("weights", weights * window)
|
||||||
|
|
||||||
|
def forward(self, proj: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
proj: (T, B, D)
|
||||||
|
"""
|
||||||
|
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||||
|
A = A.div_(A.norm(p=2, dim=0))
|
||||||
|
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||||
|
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
||||||
|
statistic = (err @ self.weights) * proj.size(-2)
|
||||||
|
return statistic.mean()
|
||||||
@@ -15,4 +15,24 @@ class IdentityActionEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, action):
|
def forward(self, action):
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
class LeWMStateEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int = 16,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
output_dim: int = 64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.output_dim = int(output_dim)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(int(input_dim), int(hidden_dim)),
|
||||||
|
nn.LayerNorm(int(hidden_dim)),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(int(hidden_dim), self.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
return self.net(state)
|
||||||
|
|||||||
@@ -376,6 +376,29 @@ class _ForbiddenScheduler:
|
|||||||
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||||
|
|
||||||
|
|
||||||
|
class _StubFutureTokenPredictor(nn.Module):
|
||||||
|
def __init__(self, num_future_tokens=1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_future_tokens = int(num_future_tokens)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def forward(self, history_tokens):
|
||||||
|
self.calls.append(history_tokens.detach().clone())
|
||||||
|
summary = history_tokens.mean(dim=1, keepdim=True)
|
||||||
|
return summary.repeat(1, self.num_future_tokens, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingSigReg(nn.Module):
|
||||||
|
def __init__(self, value=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.value = float(value)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def forward(self, embeddings):
|
||||||
|
self.calls.append(embeddings.detach().clone())
|
||||||
|
return embeddings.new_tensor(self.value)
|
||||||
|
|
||||||
|
|
||||||
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||||
return {
|
return {
|
||||||
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||||
@@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
|
||||||
|
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
extra_condition_tokens=1,
|
||||||
|
lewm_history_horizon=3,
|
||||||
|
lewm_query_offsets=[8],
|
||||||
|
lewm_predictor=future_predictor,
|
||||||
|
lewm_pred_projector=nn.Identity(),
|
||||||
|
lewm_loss_weight=0.5,
|
||||||
|
)
|
||||||
|
agent.infer_scheduler = _ForbiddenScheduler()
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||||
|
lewm_images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=3,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
|
||||||
|
initial_noise = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||||
|
_ = agent.predict_action(
|
||||||
|
images,
|
||||||
|
qpos,
|
||||||
|
lewm_images=lewm_images,
|
||||||
|
lewm_proprioception=lewm_qpos,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_history = torch.tensor(
|
||||||
|
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
|
||||||
|
expected_cond = torch.cat([expected_history, expected_future], dim=1)
|
||||||
|
|
||||||
|
self.assertEqual(agent.condition_sequence_length, 3)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
self.assertEqual(len(future_predictor.calls), 1)
|
||||||
|
|
||||||
|
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||||
|
sigreg = _RecordingSigReg(value=0.75)
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=3,
|
||||||
|
obs_horizon=2,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=2,
|
||||||
|
head_type='transformer',
|
||||||
|
extra_condition_tokens=1,
|
||||||
|
lewm_history_horizon=3,
|
||||||
|
lewm_query_offsets=[8],
|
||||||
|
lewm_predictor=future_predictor,
|
||||||
|
lewm_pred_projector=nn.Identity(),
|
||||||
|
lewm_sigreg=sigreg,
|
||||||
|
lewm_sigreg_weight=0.09,
|
||||||
|
lewm_loss_weight=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||||
|
actions = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
lewm_images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=3,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||||
|
lewm_future_images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=1,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||||
|
noise = torch.tensor(
|
||||||
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||||
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||||
|
loss = agent.compute_loss(
|
||||||
|
{
|
||||||
|
'images': images,
|
||||||
|
'qpos': qpos,
|
||||||
|
'action': actions,
|
||||||
|
'lewm_images': lewm_images,
|
||||||
|
'lewm_qpos': lewm_qpos,
|
||||||
|
'lewm_future_images': lewm_future_images,
|
||||||
|
'lewm_future_qpos': lewm_future_qpos,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = agent.get_last_loss_breakdown()
|
||||||
|
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||||
|
self.assertIn('action_loss', metrics)
|
||||||
|
self.assertIn('lewm_pred_loss', metrics)
|
||||||
|
self.assertIn('lewm_sigreg_loss', metrics)
|
||||||
|
self.assertIn('lewm_loss', metrics)
|
||||||
|
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
metrics['lewm_loss'],
|
||||||
|
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||||
|
places=5,
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
metrics['loss'],
|
||||||
|
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
|
||||||
|
places=5,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(sigreg.calls), 1)
|
||||||
|
expected_lewm_history = torch.tensor(
|
||||||
|
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
||||||
|
|
||||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||||
observation = {
|
observation = {
|
||||||
@@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase):
|
|||||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=lewm_resnet_query_imf_attnres',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
'agent.lewm_query_offsets=[8]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(
|
||||||
|
cfg.agent.vision_backbone._target_,
|
||||||
|
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
cfg.agent.state_encoder._target_,
|
||||||
|
'roboimi.vla.modules.encoders.LeWMStateEncoder',
|
||||||
|
)
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||||
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
|
||||||
|
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
cfg.agent.lewm_sigreg._target_,
|
||||||
|
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||||
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
|
||||||
|
self.assertEqual(
|
||||||
|
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
|
||||||
|
agent.condition_sequence_length,
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(agent.lewm_sigreg)
|
||||||
|
|
||||||
|
|
||||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||||
cfg = _compose_cfg(
|
cfg = _compose_cfg(
|
||||||
|
|||||||
@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
|||||||
|
|
||||||
fake_cv2.resize.assert_not_called()
|
fake_cv2.resize.assert_not_called()
|
||||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||||
|
|
||||||
|
def test_getitem_can_emit_lewm_history_and_future_observations(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
self._write_episode(dataset_dir)
|
||||||
|
dataset = SimpleRobotDataset(
|
||||||
|
dataset_dir,
|
||||||
|
obs_horizon=2,
|
||||||
|
pred_horizon=3,
|
||||||
|
camera_names=["front"],
|
||||||
|
image_resize_shape=None,
|
||||||
|
lewm_history_horizon=3,
|
||||||
|
lewm_query_offsets=[1, 2],
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = dataset[1]
|
||||||
|
|
||||||
|
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
|
||||||
|
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
||||||
|
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
||||||
|
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
||||||
|
|||||||
@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingAgent(FakeAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.seen_inputs = []
|
||||||
|
|
||||||
|
def compute_loss(self, agent_input):
|
||||||
|
self.seen_inputs.append(agent_input)
|
||||||
|
return super().compute_loss(agent_input)
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeMixedFakeAgent(FakeAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = nn.Parameter(torch.zeros(2))
|
||||||
|
|
||||||
|
|
||||||
class FakeSwanLab:
|
class FakeSwanLab:
|
||||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||||
self.init_error = init_error
|
self.init_error = init_error
|
||||||
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _make_lewm_batch(self):
|
||||||
|
batch = self._make_batch()
|
||||||
|
batch.update(
|
||||||
|
{
|
||||||
|
'lewm.observation.front': torch.ones(1, 3, 2, 2),
|
||||||
|
'lewm.observation.state': torch.ones(1, 4),
|
||||||
|
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
|
||||||
|
'lewm.future.state': torch.full((1, 4), 2.0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return batch
|
||||||
|
|
||||||
def _loader_factory(self):
|
def _loader_factory(self):
|
||||||
train_batch = self._make_batch()
|
train_batch = self._make_batch()
|
||||||
val_batch = self._make_batch()
|
val_batch = self._make_batch()
|
||||||
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
def _lewm_loader_factory(self):
|
||||||
|
train_batch = self._make_lewm_batch()
|
||||||
|
val_batch = self._make_lewm_batch()
|
||||||
|
|
||||||
|
def factory(_dataset, *, shuffle, **_kwargs):
|
||||||
|
return FakeLoader(train_batch if shuffle else val_batch)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
run_training = self._get_run_training(module)
|
run_training = self._get_run_training(module)
|
||||||
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||||
|
|
||||||
|
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg(use_swanlab=False)
|
||||||
|
cfg.train.max_steps = 1
|
||||||
|
cfg.train.save_freq = 100
|
||||||
|
agent = RecordingAgent()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(agent.seen_inputs), 1)
|
||||||
|
first_input = agent.seen_inputs[0]
|
||||||
|
self.assertIn('lewm_images', first_input)
|
||||||
|
self.assertIn('lewm_qpos', first_input)
|
||||||
|
self.assertIn('lewm_future_images', first_input)
|
||||||
|
self.assertIn('lewm_future_qpos', first_input)
|
||||||
|
self.assertIn('front', first_input['lewm_images'])
|
||||||
|
self.assertIn('front', first_input['lewm_future_images'])
|
||||||
|
|
||||||
def test_run_training_skips_swanlab_when_disabled(self):
|
def test_run_training_skips_swanlab_when_disabled(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
run_training = self._get_run_training(module)
|
run_training = self._get_run_training(module)
|
||||||
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
|||||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||||
|
|
||||||
|
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg(use_swanlab=False)
|
||||||
|
cfg.train.max_steps = 0
|
||||||
|
cfg.train.save_freq = 100
|
||||||
|
cfg.train.pretrained_ckpt = 'pretrained.pt'
|
||||||
|
agent = ShapeMixedFakeAgent()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_torch_load(path, map_location=None):
|
||||||
|
del map_location
|
||||||
|
if Path(path).name != 'pretrained.pt':
|
||||||
|
raise AssertionError(f'unexpected load path: {path}')
|
||||||
|
return {
|
||||||
|
'model_state_dict': {
|
||||||
|
'weight': torch.tensor(3.0),
|
||||||
|
'bias': torch.tensor([1.0, 2.0, 3.0]),
|
||||||
|
},
|
||||||
|
'step': 123,
|
||||||
|
'loss': 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
Path('pretrained.pt').write_bytes(b'pretend')
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
|
||||||
|
|
||||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
run_training = self._get_run_training(module)
|
run_training = self._get_run_training(module)
|
||||||
|
|||||||
Reference in New Issue
Block a user