From d84bc6876eb43932fe9a616ee7ce5137e3f15df0 Mon Sep 17 00:00:00 2001 From: Logic Date: Tue, 31 Mar 2026 15:39:20 +0800 Subject: [PATCH] feat(vla): align transformer training stack and rollout validation --- environment.yml | 30 +- roboimi/assets/robots/arm_base.py | 42 +- roboimi/demos/vla_scripts/train_vla.py | 1068 +++++++++++------ roboimi/envs/double_base.py | 13 +- roboimi/envs/double_pos_ctrl_env.py | 6 +- roboimi/vla/agent.py | 93 +- .../vla/conf/agent/resnet_transformer.yaml | 8 + roboimi/vla/conf/config.yaml | 14 +- roboimi/vla/conf/head/transformer1d.yaml | 9 +- roboimi/vla/data/simpe_robot_dataset.py | 29 +- roboimi/vla/eval_utils.py | 3 + .../vla/models/backbones/resnet_diffusion.py | 30 +- roboimi/vla/models/heads/transformer1d.py | 423 +++---- roboimi/vla/scripts/calculate_stats.py | 47 +- tests/__init__.py | 1 + tests/test_calculate_stats_cli.py | 88 ++ tests/test_eval_vla_execution.py | 28 + tests/test_eval_vla_headless.py | 259 ++++ tests/test_resnet_transformer_agent_wiring.py | 387 ++++++ tests/test_robot_asset_paths.py | 63 + ...test_simple_robot_dataset_image_loading.py | 58 + tests/test_train_vla_rollout_validation.py | 779 ++++++++++++ tests/test_train_vla_swanlab_logging.py | 699 +++++++++++ tests/test_train_vla_transformer_optimizer.py | 310 +++++ .../test_transformer1d_external_alignment.py | 262 ++++ 25 files changed, 4043 insertions(+), 706 deletions(-) create mode 100644 roboimi/vla/eval_utils.py create mode 100644 tests/__init__.py create mode 100644 tests/test_calculate_stats_cli.py create mode 100644 tests/test_eval_vla_execution.py create mode 100644 tests/test_eval_vla_headless.py create mode 100644 tests/test_resnet_transformer_agent_wiring.py create mode 100644 tests/test_robot_asset_paths.py create mode 100644 tests/test_simple_robot_dataset_image_loading.py create mode 100644 tests/test_train_vla_rollout_validation.py create mode 100644 tests/test_train_vla_swanlab_logging.py create mode 100644 tests/test_train_vla_transformer_optimizer.py create mode 100644 tests/test_transformer1d_external_alignment.py diff --git a/environment.yml b/environment.yml index 944a238..7f1c879 100644 --- a/environment.yml +++ b/environment.yml @@ -229,6 +229,11 @@ dependencies: - python-xxhash=3.6.0 - python_abi=3.10 - pytorch=2.4.0 + - hydra-core=1.3.2 + - omegaconf=2.3.0 + - einops=0.8.2 + - diffusers=0.36.0 + - torchvision=0.19.0 - pytz=2024.1 - pyyaml=6.0.3 - qhull=2020.2 @@ -321,12 +326,10 @@ dependencies: - datasets==4.5.0 - decorator==5.2.1 - deepdiff==8.6.1 - - diffusers==0.30.0 - dill==0.4.0 - docstring_parser==0.17.0 - draccus==0.10.0 - eigenpy==3.10.3 - - einops==0.8.1 - etils==1.7.0 - evdev==1.9.2 - exceptiongroup==1.3.1 @@ -350,7 +353,6 @@ dependencies: - httpcore==1.0.9 - httpx==0.28.1 - huggingface_hub==1.3.2 - - hydra-core==1.3.2 - imageio==2.35.1 - imageio-ffmpeg==0.6.0 - importlib_metadata==8.7.1 @@ -380,22 +382,6 @@ dependencies: - networkx==3.4.2 - numcodecs==0.13.1 - numpy==2.2.6 - - nvidia-cublas-cu12==12.4.5.8 - - nvidia-cuda-cupti-cu12==12.4.127 - - nvidia-cuda-nvrtc-cu12==12.4.127 - - nvidia-cuda-runtime-cu12==12.4.127 - - nvidia-cudnn-cu12==9.1.0.70 - - nvidia-cufft-cu12==11.2.1.3 - - nvidia-cufile-cu12==1.11.1.6 - - nvidia-curand-cu12==10.3.5.147 - - nvidia-cusolver-cu12==11.6.1.9 - - nvidia-cusparse-cu12==12.3.1.170 - - nvidia-cusparselt-cu12==0.6.3 - - nvidia-nccl-cu12==2.21.5 - - nvidia-nvjitlink-cu12==12.4.127 - - nvidia-nvshmem-cu12==3.3.20 - - nvidia-nvtx-cu12==12.4.127 - - omegaconf==2.3.0 - opencv-contrib-python==4.10.0.84 - opencv-python==4.13.0.90 - orderly-set==5.5.0 @@ -431,7 +417,7 @@ dependencies: - regex==2026.1.15 - requests==2.32.5 - rerun-sdk==0.26.2 - - rich==14.2.0 + - rich==13.9.4 - ruckig==0.9.2 - safehttpx==0.1.7 - safetensors==0.7.0 @@ -443,18 +429,16 @@ dependencies: - stack-data==0.6.3 - starlette==0.50.0 - sympy==1.13.1 + - swanlab==0.7.13 - termcolor==3.3.0 - timm==1.0.24 - toml==0.10.2 - tomli==2.4.0 - tomlkit==0.13.3 - - torch==2.5.0 - torchcodec==0.5 - torchmetrics==1.8.2 - - torchvision==0.20.0 - tqdm==4.67.1 - traitlets==5.14.3 - - triton==3.1.0 - typer==0.21.1 - typer-slim==0.21.1 - typeshed_client==2.8.2 diff --git a/roboimi/assets/robots/arm_base.py b/roboimi/assets/robots/arm_base.py index 5cf94bd..0e80f7b 100644 --- a/roboimi/assets/robots/arm_base.py +++ b/roboimi/assets/robots/arm_base.py @@ -1,8 +1,46 @@ import mujoco import numpy as np +from pathlib import Path from roboimi.utils.KDL_utils import KDL_utils +def resolve_robot_asset_path(asset_path): + if asset_path is None: + return None + + raw_path = Path(asset_path).expanduser() + if raw_path.is_absolute(): + return str(raw_path.resolve()) + + current_dir = Path(__file__).resolve().parent + package_root = current_dir.parents[1] + repo_root = current_dir.parents[2] + + candidates = [] + if raw_path.parts and raw_path.parts[0] == 'roboimi': + candidates.append(repo_root / raw_path) + + candidates.extend([ + current_dir / raw_path, + package_root / raw_path, + repo_root / raw_path, + ]) + + normalized_candidates = [] + seen = set() + for candidate in candidates: + resolved = candidate.resolve() + if resolved not in seen: + normalized_candidates.append(resolved) + seen.add(resolved) + + for candidate in normalized_candidates: + if candidate.exists(): + return str(candidate) + + return str(normalized_candidates[0]) + + class ArmBase(object): def __init__(self, name=None, @@ -11,8 +49,8 @@ class ArmBase(object): gripper=None ): self.name = name - self.urdf_path = urdf_path - self.xml_path = xml_path + self.urdf_path = resolve_robot_asset_path(urdf_path) + self.xml_path = resolve_robot_asset_path(xml_path) self.gripper = gripper self.robot_model = mujoco.MjModel.from_xml_path(filename=self.xml_path, assets=None) self.robot_data = mujoco.MjData(self.robot_model) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 058776e..8b3e787 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -3,6 +3,7 @@ import os import logging import json import pickle +import importlib import hydra import torch import re @@ -111,8 +112,134 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty return LambdaLR(optimizer, lr_lambda) -@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") -def main(cfg: DictConfig): +def build_training_optimizer(agent, lr, weight_decay): + """为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。""" + trainable_params = [param for param in agent.parameters() if param.requires_grad] + noise_pred_net = getattr(agent, 'noise_pred_net', None) + get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None) + use_head_groups = ( + getattr(agent, 'head_type', None) == 'transformer' + and callable(get_optim_groups) + ) + + if not use_head_groups: + return AdamW(trainable_params, lr=lr, weight_decay=weight_decay) + + head_groups = [] + grouped_param_ids = set() + for group in get_optim_groups(weight_decay=weight_decay): + params = [param for param in group['params'] if param.requires_grad] + if not params: + continue + normalized_group = dict(group) + normalized_group['params'] = params + head_groups.append(normalized_group) + + for param in params: + param_id = id(param) + if param_id in grouped_param_ids: + raise ValueError('Transformer optimizer groups contain duplicate parameters') + grouped_param_ids.add(param_id) + + head_trainable_param_ids = { + id(param) for param in noise_pred_net.parameters() if param.requires_grad + } + missing_head_param_ids = head_trainable_param_ids - grouped_param_ids + if missing_head_param_ids: + raise ValueError('Transformer optimizer groups missed trainable head parameters') + + remaining_params = [ + param for param in trainable_params + if id(param) not in grouped_param_ids + ] + + optim_groups = head_groups + if remaining_params: + optim_groups = optim_groups + [{ + 'params': remaining_params, + 'weight_decay': weight_decay, + }] + grouped_param_ids.update(id(param) for param in remaining_params) + + all_trainable_param_ids = {id(param) for param in trainable_params} + if grouped_param_ids != all_trainable_param_ids: + raise ValueError('Optimizer parameter groups must include each trainable parameter exactly once') + + return AdamW(optim_groups, lr=lr, weight_decay=weight_decay) + + +def _init_swanlab(cfg): + """按需初始化 SwanLab,并在缺少依赖或认证失败时快速失败。""" + if not bool(cfg.train.get('use_swanlab', False)): + return None + + try: + swanlab = importlib.import_module("swanlab") + except ImportError as exc: + raise RuntimeError( + "SwanLab logging is enabled, but the 'swanlab' package could not be imported." + ) from exc + + def _to_plain_config(value): + if isinstance(value, dict): + return {key: _to_plain_config(val) for key, val in value.items()} + if isinstance(value, list): + return [_to_plain_config(item) for item in value] + if isinstance(value, tuple): + return tuple(_to_plain_config(item) for item in value) + + items_method = getattr(value, 'items', None) + if callable(items_method): + try: + return {key: _to_plain_config(val) for key, val in items_method()} + except Exception: + pass + + return value + + swanlab_config = { + key: _to_plain_config(cfg[key]) + for key in ('train', 'data', 'agent') + if key in cfg + } + + init_kwargs = { + 'project': cfg.train.get('swanlab_project', 'roboimi-vla'), + 'config': swanlab_config, + } + run_name = cfg.train.get('swanlab_run_name', None) + if run_name: + init_kwargs['experiment_name'] = run_name + + try: + swanlab.init(**init_kwargs) + except Exception as exc: + raise RuntimeError( + f"SwanLab logging is enabled, but SwanLab init/login failed: {exc}" + ) from exc + + return swanlab + + +def _log_to_swanlab(swanlab_module, payload, step=None): + if swanlab_module is None: + return + try: + swanlab_module.log(payload, step=step) + except Exception as exc: + log.warning(f"SwanLab log failed at step {step}: {exc}") + + +def _finish_swanlab(swanlab_module): + if swanlab_module is None: + return + try: + swanlab_module.finish() + except Exception as exc: + log.warning(f"SwanLab finish failed: {exc}") + + +def _run_training(cfg: DictConfig): """ VLA 训练脚本(ResNet 骨干网络 + Diffusion 策略) @@ -131,401 +258,598 @@ def main(cfg: DictConfig): print("=" * 80) log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})") - - # 创建检查点目录 - checkpoint_dir = Path("checkpoints") - checkpoint_dir.mkdir(exist_ok=True) - - # ========================================================================= - # 1. 实例化数据集与 DataLoader - # ========================================================================= - log.info("📦 加载数据集...") + swanlab_module = _init_swanlab(cfg) try: - dataset = instantiate(cfg.data) - log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}") - except Exception as e: - log.error(f"❌ 数据集加载失败: {e}") - raise + # 创建检查点目录 + checkpoint_dir = Path("checkpoints") + checkpoint_dir.mkdir(exist_ok=True) + default_best_model_path = checkpoint_dir / "vla_model_best.pt" - # 训练/验证集划分 - val_split = float(cfg.train.get('val_split', 0.1)) - seed = int(cfg.train.get('seed', 42)) - val_size = int(len(dataset) * val_split) - train_size = len(dataset) - val_size - if val_size > 0: - train_dataset, val_dataset = random_split( - dataset, - [train_size, val_size], - generator=torch.Generator().manual_seed(seed) - ) - log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})") - else: - train_dataset, val_dataset = dataset, None - log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)") - - train_loader = DataLoader( - train_dataset, - batch_size=cfg.train.batch_size, - shuffle=True, - num_workers=cfg.train.num_workers, - pin_memory=(cfg.train.device != "cpu"), - persistent_workers=(cfg.train.num_workers > 0), - drop_last=True # 丢弃不完整批次以稳定训练 - ) - - val_loader = None - if val_dataset is not None: - val_loader = DataLoader( - val_dataset, - batch_size=cfg.train.batch_size, - shuffle=False, - num_workers=cfg.train.num_workers, - pin_memory=(cfg.train.device != "cpu"), - persistent_workers=(cfg.train.num_workers > 0), - drop_last=False - ) - - log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}") - if val_loader is not None: - log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}") - - # ========================================================================= - # 2. 加载数据集统计信息(将传递给 agent) - # ========================================================================= - log.info("💾 加载数据集统计信息...") - dataset_stats = None - try: - dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer') - stats_path = Path(dataset_dir) / 'dataset_stats.pkl' - - if stats_path.exists(): - with open(stats_path, 'rb') as f: - stats = pickle.load(f) - - # 扁平化stats字典(嵌套结构→扁平结构)以匹配NormalizationModule的期望格式 - dataset_stats = { - 'action_mean': stats['action_mean'].tolist(), - 'action_std': stats['action_std'].tolist(), - 'action_min': stats['action_min'].tolist(), - 'action_max': stats['action_max'].tolist(), - 'qpos_mean': stats['qpos_mean'].tolist(), - 'qpos_std': stats['qpos_std'].tolist(), - 'qpos_min': stats['qpos_min'].tolist(), - 'qpos_max': stats['qpos_max'].tolist(), - } - log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})") - else: - log.warning(f"⚠️ 统计文件未找到: {stats_path}") - log.warning("⚠️ 推理时动作将无法反归一化!") - - except Exception as e: - log.warning(f"⚠️ 统计信息加载失败: {e}") - log.warning("⚠️ 训练将继续,但推理可能无法正常工作") - - # ========================================================================= - # 3. 实例化 VLA Agent - # ========================================================================= - log.info("🤖 初始化 VLA Agent...") - try: - # 将 dataset_stats 和 normalization_type 传递给 agent - agent = instantiate(cfg.agent, dataset_stats=dataset_stats) - agent.to(cfg.train.device) - agent.train() - log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}") - - # 统计参数量 - total_params = sum(p.numel() for p in agent.parameters()) - trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) - log.info(f"📊 总参数量: {total_params:,}") - log.info(f"📊 可训练参数量: {trainable_params:,}") - - except Exception as e: - log.error(f"❌ Agent 初始化失败: {e}") - raise - - # ========================================================================= - # 3.1 从预训练 checkpoint 加载权重(微调) - # ========================================================================= - pretrained_ckpt = cfg.train.get('pretrained_ckpt', None) - if pretrained_ckpt is not None: - ckpt_path = Path(pretrained_ckpt) - if ckpt_path.exists(): - log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}") - try: - checkpoint = torch.load(ckpt_path, map_location=cfg.train.device) - - # 只加载模型权重(不加载 optimizer、scheduler) - missing_keys, unexpected_keys = agent.load_state_dict( - checkpoint['model_state_dict'], - strict=False # 允许部分加载(结构不完全匹配时) - ) - - log.info(f"✅ [Finetune] 模型权重加载成功") - - if missing_keys: - log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...") - if unexpected_keys: - log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...") - - log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}") - log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})") - - except Exception as e: - log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}") - log.warning("⚠️ 将从头开始训练") - else: - log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}") - log.warning("⚠️ 将从头开始训练") - - # ========================================================================= - # 4. 设置优化器与学习率调度器 - # ========================================================================= - weight_decay = float(cfg.train.get('weight_decay', 1e-5)) - grad_clip = float(cfg.train.get('grad_clip', 1.0)) - - optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay) - log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})") - - # 设置带预热的学習率调度器 - warmup_steps = int(cfg.train.get('warmup_steps', 500)) - scheduler_type = cfg.train.get('scheduler_type', 'cosine') - min_lr = float(cfg.train.get('min_lr', 1e-6)) - - scheduler = get_lr_schedule_with_warmup( - optimizer, - warmup_steps=warmup_steps, - max_steps=cfg.train.max_steps, - scheduler_type=scheduler_type, - min_lr=min_lr - ) - log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})") - - # ========================================================================= - # 4.1 断点续训(恢复模型、优化器、调度器、步数) - # ========================================================================= - start_step = 0 - resume_loss = None - resume_best_loss = float('inf') - - resume_ckpt = cfg.train.get('resume_ckpt', None) - resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir) - if resume_ckpt is not None: - if pretrained_ckpt is not None: - log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训") - if resume_path is None: - log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练") - elif not resume_path.exists(): - log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}") - log.warning("⚠️ 将从头开始训练") - else: - log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}") - try: - checkpoint = torch.load(resume_path, map_location=cfg.train.device) - - agent.load_state_dict(checkpoint['model_state_dict'], strict=True) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - - resume_step = int(checkpoint['step']) - start_step = resume_step + 1 - - loaded_loss = checkpoint.get('loss', None) - loaded_val_loss = checkpoint.get('val_loss', None) - resume_loss = float(loaded_loss) if loaded_loss is not None else None - if loaded_val_loss is not None: - resume_best_loss = float(loaded_val_loss) - elif loaded_loss is not None: - resume_best_loss = float(loaded_loss) - - log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始") - log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}") - except Exception as e: - log.error(f"❌ [Resume] 恢复失败: {e}") - log.warning("⚠️ 将从头开始训练") - start_step = 0 - resume_loss = None - resume_best_loss = float('inf') - - # ========================================================================= - # 5. 训练循环 - # ========================================================================= - log.info("🏋️ 开始训练循环...") - - def build_agent_input(batch_data): - """构建 agent 输入格式""" - images = {} - # SimpleRobotDataset 返回 observation.{cam_name} 格式 - for cam_name in cfg.data.camera_names: - key = f"observation.{cam_name}" - if key in batch_data: - images[cam_name] = batch_data[key] - - return { - 'images': images, - 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state - 'action': batch_data['action'], - 'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask - } - - def run_validation(): - """运行验证""" - if val_loader is None: - return None - agent.eval() - - # 设置确定性种子以获得可重现的损失 - # 这确保验证损失在不同步骤之间可比较 - torch.manual_seed(42) - if torch.cuda.is_available(): - torch.cuda.manual_seed(42) - - total_loss = 0.0 - num_batches = 0 - with torch.no_grad(): - for val_batch in val_loader: - val_batch = recursive_to_device(val_batch, cfg.train.device) - val_input = build_agent_input(val_batch) - val_loss = agent.compute_loss(val_input) - total_loss += val_loss.item() - num_batches += 1 - agent.train() - return total_loss / max(num_batches, 1) - - data_iter = iter(train_loader) - pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100) - - best_loss = resume_best_loss - last_loss = resume_loss - - if start_step >= cfg.train.max_steps: - log.warning( - f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环" - ) - - for step in pbar: + # ========================================================================= + # 1. 实例化数据集与 DataLoader + # ========================================================================= + log.info("📦 加载数据集...") try: - batch = next(data_iter) - except StopIteration: - # 轮次结束时重启迭代器 - data_iter = iter(train_loader) - batch = next(data_iter) - - # ===================================================================== - # 将批次移至设备 - # ===================================================================== - batch = recursive_to_device(batch, cfg.train.device) - - # ===================================================================== - # 准备 agent 输入 - # ===================================================================== - # 数据集返回: {action, qpos, image_, ...} - # Agent 期望: {images: dict, qpos: tensor, action: tensor} - - # 准备 agent 输入 - agent_input = build_agent_input(batch) - - # ===================================================================== - # 前向传播与损失计算 - # ===================================================================== - try: - loss = agent.compute_loss(agent_input) + dataset = instantiate(cfg.data) + log.info(f"✅ 数据集加载成功。总样本数: {len(dataset)}") except Exception as e: - log.error(f"❌ 步骤 {step} 前向传播失败: {e}") + log.error(f"❌ 数据集加载失败: {e}") raise - last_loss = loss.item() + # 训练/验证集划分 + val_split = float(cfg.train.get('val_split', 0.1)) + seed = int(cfg.train.get('seed', 42)) + val_size = int(len(dataset) * val_split) + train_size = len(dataset) - val_size + if val_size > 0: + train_dataset, val_dataset = random_split( + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(seed) + ) + log.info(f"✅ 数据集划分: 训练集={train_size}, 验证集={val_size} (验证比例={val_split})") + else: + train_dataset, val_dataset = dataset, None + log.info("✅ 数据集划分: 全部用于训练, 验证集=0 (验证比例=0)") - # ===================================================================== - # 反向传播与优化 - # ===================================================================== - optimizer.zero_grad() - loss.backward() + train_batch_size = int(cfg.train.batch_size) + train_drop_last = len(train_dataset) >= train_batch_size + if not train_drop_last: + log.warning( + "⚠️ 训练集样本数 (%s) 小于 batch_size (%s),将保留最后一个不完整批次以避免空训练加载器", + len(train_dataset), + train_batch_size, + ) - # 梯度裁剪以稳定训练 - torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip) + train_loader = DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=cfg.train.num_workers, + pin_memory=(cfg.train.device != "cpu"), + persistent_workers=False, + drop_last=train_drop_last + ) - optimizer.step() - scheduler.step() + val_loader = None + if val_dataset is not None: + val_loader = DataLoader( + val_dataset, + batch_size=train_batch_size, + shuffle=False, + num_workers=cfg.train.num_workers, + pin_memory=(cfg.train.device != "cpu"), + persistent_workers=False, + drop_last=False + ) - # ===================================================================== - # 日志记录 - # ===================================================================== - if step % cfg.train.log_freq == 0: - current_lr = optimizer.param_groups[0]['lr'] - pbar.set_postfix({ - "loss": f"{loss.item():.4f}", - "lr": f"{current_lr:.2e}", - "best_loss": f"{best_loss:.4f}" - }) - log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}") + log.info(f"✅ 训练加载器每轮批次数: {len(train_loader)}") + if val_loader is not None: + log.info(f"✅ 验证加载器每轮批次数: {len(val_loader)}") - # ===================================================================== - # 检查点保存与验证 - # ===================================================================== - if step > 0 and step % cfg.train.save_freq == 0: - # 运行验证 - val_loss = run_validation() - if val_loss is not None: - log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}") + # ========================================================================= + # 2. 加载数据集统计信息(将传递给 agent) + # ========================================================================= + log.info("💾 加载数据集统计信息...") + dataset_stats = None + try: + dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer') + stats_path = Path(dataset_dir) / 'dataset_stats.pkl' - checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" - # 使用agent的归一化统计信息(包含normalization_type) + if stats_path.exists(): + with open(stats_path, 'rb') as f: + stats = pickle.load(f) + + # 扁平化stats字典(嵌套结构→扁平结构)以匹配NormalizationModule的期望格式 + dataset_stats = { + 'action_mean': stats['action_mean'].tolist(), + 'action_std': stats['action_std'].tolist(), + 'action_min': stats['action_min'].tolist(), + 'action_max': stats['action_max'].tolist(), + 'qpos_mean': stats['qpos_mean'].tolist(), + 'qpos_std': stats['qpos_std'].tolist(), + 'qpos_min': stats['qpos_min'].tolist(), + 'qpos_max': stats['qpos_max'].tolist(), + } + log.info(f"✅ 数据集统计信息加载完成 (归一化: {cfg.agent.normalization_type})") + else: + log.warning(f"⚠️ 统计文件未找到: {stats_path}") + log.warning("⚠️ 推理时动作将无法反归一化!") + + except Exception as e: + log.warning(f"⚠️ 统计信息加载失败: {e}") + log.warning("⚠️ 训练将继续,但推理可能无法正常工作") + + # ========================================================================= + # 3. 实例化 VLA Agent + # ========================================================================= + log.info("🤖 初始化 VLA Agent...") + try: + # 将 dataset_stats 和 normalization_type 传递给 agent + agent = instantiate(cfg.agent, dataset_stats=dataset_stats) + agent.to(cfg.train.device) + agent.train() + log.info(f"✅ Agent 初始化完成并已移至 {cfg.train.device}") + + # 统计参数量 + total_params = sum(p.numel() for p in agent.parameters()) + trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) + log.info(f"📊 总参数量: {total_params:,}") + log.info(f"📊 可训练参数量: {trainable_params:,}") + + except Exception as e: + log.error(f"❌ Agent 初始化失败: {e}") + raise + + # ========================================================================= + # 3.1 从预训练 checkpoint 加载权重(微调) + # ========================================================================= + pretrained_ckpt = cfg.train.get('pretrained_ckpt', None) + if pretrained_ckpt is not None: + ckpt_path = Path(pretrained_ckpt) + if ckpt_path.exists(): + log.info(f"🔄 [Finetune] 从预训练 checkpoint 加载权重: {ckpt_path}") + try: + checkpoint = torch.load(ckpt_path, map_location=cfg.train.device) + + # 只加载模型权重(不加载 optimizer、scheduler) + missing_keys, unexpected_keys = agent.load_state_dict( + checkpoint['model_state_dict'], + strict=False # 允许部分加载(结构不完全匹配时) + ) + + log.info(f"✅ [Finetune] 模型权重加载成功") + + if missing_keys: + log.warning(f"⚠️ [Finetune] 缺少的键 ({len(missing_keys)} 个): {missing_keys[:5]}...") + if unexpected_keys: + log.warning(f"⚠️ [Finetune] 多余的键 ({len(unexpected_keys)} 个): {unexpected_keys[:5]}...") + + log.info(f"📊 [Finetune] 预训练信息: 步骤={checkpoint.get('step', 'N/A')}, 损失={checkpoint.get('loss', 'N/A')}") + log.info(f"📈 [Finetune] 使用新的训练配置(lr={cfg.train.lr}, max_steps={cfg.train.max_steps})") + + except Exception as e: + log.error(f"❌ [Finetune] 加载 checkpoint 失败: {e}") + log.warning("⚠️ 将从头开始训练") + else: + log.error(f"❌ [Finetune] Checkpoint 文件不存在: {ckpt_path}") + log.warning("⚠️ 将从头开始训练") + + # ========================================================================= + # 4. 设置优化器与学习率调度器 + # ========================================================================= + weight_decay = float(cfg.train.get('weight_decay', 1e-5)) + grad_clip = float(cfg.train.get('grad_clip', 1.0)) + + optimizer = build_training_optimizer(agent, lr=cfg.train.lr, weight_decay=weight_decay) + log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})") + + # 设置带预热的学習率调度器 + warmup_steps = int(cfg.train.get('warmup_steps', 500)) + scheduler_type = cfg.train.get('scheduler_type', 'cosine') + min_lr = float(cfg.train.get('min_lr', 1e-6)) + + scheduler = get_lr_schedule_with_warmup( + optimizer, + warmup_steps=warmup_steps, + max_steps=cfg.train.max_steps, + scheduler_type=scheduler_type, + min_lr=min_lr + ) + log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})") + + # ========================================================================= + # 4.1 断点续训(恢复模型、优化器、调度器、步数) + # ========================================================================= + def extract_checkpoint_metric_baseline(checkpoint): + checkpoint_loss = checkpoint.get('loss', None) + checkpoint_val_loss = checkpoint.get('val_loss', None) + checkpoint_rollout_reward = checkpoint.get('rollout_avg_reward', None) + + baseline_loss = float('inf') + baseline_rollout_reward = float('-inf') + if checkpoint_rollout_reward is not None: + baseline_rollout_reward = float(checkpoint_rollout_reward) + if checkpoint_val_loss is not None: + baseline_loss = float(checkpoint_val_loss) + elif checkpoint_loss is not None: + baseline_loss = float(checkpoint_loss) + return baseline_loss, baseline_rollout_reward + + start_step = 0 + resume_loss = None + resume_best_loss = float('inf') + resume_best_rollout_reward = float('-inf') + best_model_path = None + + resume_ckpt = cfg.train.get('resume_ckpt', None) + resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir) + if resume_ckpt is not None: + if pretrained_ckpt is not None: + log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训") + if resume_path is None: + log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练") + elif not resume_path.exists(): + log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}") + log.warning("⚠️ 将从头开始训练") + else: + log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}") + try: + checkpoint = torch.load(resume_path, map_location=cfg.train.device) + + agent.load_state_dict(checkpoint['model_state_dict'], strict=True) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + resume_step = int(checkpoint['step']) + start_step = resume_step + 1 + + loaded_loss = checkpoint.get('loss', None) + resume_loss = float(loaded_loss) if loaded_loss is not None else None + resume_best_loss, resume_best_rollout_reward = extract_checkpoint_metric_baseline(checkpoint) + if ( + resume_best_rollout_reward != float('-inf') + or resume_best_loss != float('inf') + ): + best_model_path = resume_path + + if default_best_model_path.exists(): + try: + best_checkpoint = torch.load(default_best_model_path, map_location=cfg.train.device) + _, best_checkpoint_rollout_reward = ( + extract_checkpoint_metric_baseline(best_checkpoint) + ) + if best_checkpoint_rollout_reward != float('-inf'): + resume_best_rollout_reward = best_checkpoint_rollout_reward + best_model_path = default_best_model_path + log.info( + "📈 [Resume] 从最佳 checkpoint 恢复最佳 rollout 基线: %s", + default_best_model_path, + ) + except Exception as e: + log.warning( + f"⚠️ [Resume] 读取最佳 checkpoint 失败,将回退到恢复 checkpoint 的验证基线: {e}" + ) + + log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始") + log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}") + except Exception as e: + log.error(f"❌ [Resume] 恢复失败: {e}") + log.warning("⚠️ 将从头开始训练") + start_step = 0 + resume_loss = None + resume_best_loss = float('inf') + resume_best_rollout_reward = float('-inf') + + # ========================================================================= + # 5. 训练循环 + # ========================================================================= + log.info("🏋️ 开始训练循环...") + + def build_agent_input(batch_data): + """构建 agent 输入格式""" + images = {} + # SimpleRobotDataset 返回 observation.{cam_name} 格式 + for cam_name in cfg.data.camera_names: + key = f"observation.{cam_name}" + if key in batch_data: + images[cam_name] = batch_data[key] + + return { + 'images': images, + 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state + 'action': batch_data['action'], + 'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask + } + + def save_checkpoint(checkpoint_path: Path, step: int, loss_value, val_loss=None, rollout_avg_reward=None): agent_stats = agent.get_normalization_stats() torch.save({ 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), - 'loss': loss.item(), + 'loss': loss_value, 'val_loss': val_loss, + 'rollout_avg_reward': rollout_avg_reward, 'dataset_stats': agent_stats, # 保存agent的统计信息 'current_lr': optimizer.param_groups[0]['lr'], }, checkpoint_path) - log.info(f"💾 检查点已保存: {checkpoint_path}") + return checkpoint_path - # 根据验证损失保存最佳模型 - eval_loss = val_loss if val_loss is not None else loss.item() - if eval_loss < best_loss: - best_loss = eval_loss - best_model_path = checkpoint_dir / "vla_model_best.pt" - agent_stats = agent.get_normalization_stats() - torch.save({ - 'step': step, - 'model_state_dict': agent.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'loss': loss.item(), - 'val_loss': val_loss, - 'dataset_stats': agent_stats, # 保存agent的统计信息 - 'current_lr': optimizer.param_groups[0]['lr'], - }, best_model_path) - log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})") + def run_validation(): + """运行验证""" + if val_loader is None: + return None + agent.eval() - # ========================================================================= - # 6. 保存最终模型 - # ========================================================================= - final_model_path = checkpoint_dir / "vla_model_final.pt" - agent_stats = agent.get_normalization_stats() - torch.save({ - 'step': cfg.train.max_steps, - 'model_state_dict': agent.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'loss': last_loss, - 'dataset_stats': agent_stats, # 保存agent的统计信息 - 'current_lr': optimizer.param_groups[0]['lr'], - }, final_model_path) - log.info(f"💾 最终模型已保存: {final_model_path}") + # 设置确定性种子以获得可重现的损失 + # 这确保验证损失在不同步骤之间可比较 + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) - log.info("✅ 训练成功完成!") - if last_loss is not None: - log.info(f"📊 最终损失: {last_loss:.4f}") - else: - log.info("📊 最终损失: N/A(未执行训练步)") - if best_loss != float('inf'): - log.info(f"📊 最佳损失: {best_loss:.4f}") - else: - log.info("📊 最佳损失: N/A(无有效验证/训练损失)") + total_loss = 0.0 + num_batches = 0 + with torch.no_grad(): + for val_batch in val_loader: + val_batch = recursive_to_device(val_batch, cfg.train.device) + val_input = build_agent_input(val_batch) + val_loss = agent.compute_loss(val_input) + total_loss += val_loss.item() + num_batches += 1 + agent.train() + return total_loss / max(num_batches, 1) + + def run_rollout_validation(checkpoint_path: Path): + from roboimi.demos.vla_scripts import eval_vla + + rollout_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False)) + rollout_cfg.eval.ckpt_path = str(checkpoint_path) + rollout_cfg.eval.num_episodes = int(cfg.train.get('rollout_num_episodes', 1)) + rollout_cfg.eval.headless = True + rollout_cfg.eval.device = 'cpu' + rollout_cfg.eval.verbose_action = False + + log.info( + "🎯 开始 checkpoint rollout 验证: %s (episodes=%s, headless=True)", + checkpoint_path, + rollout_cfg.eval.num_episodes, + ) + return eval_vla._run_eval(rollout_cfg) + + def run_checkpoint_rollout_validation(checkpoint_path: Path): + if not bool(cfg.train.get('rollout_validate_on_checkpoint', False)): + return None + return run_rollout_validation(checkpoint_path) + + data_iter = iter(train_loader) + pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100) + + steps_per_epoch = len(train_loader) + rollout_val_freq_epochs = int(cfg.train.get('rollout_val_freq_epochs', 0) or 0) + rollout_validation_enabled = rollout_val_freq_epochs > 0 + best_loss = resume_best_loss + best_rollout_reward = resume_best_rollout_reward + last_loss = resume_loss + + if start_step >= cfg.train.max_steps: + log.warning( + f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环" + ) + + for step in pbar: + try: + batch = next(data_iter) + except StopIteration: + # 轮次结束时重启迭代器 + data_iter = iter(train_loader) + batch = next(data_iter) + + # ===================================================================== + # 将批次移至设备 + # ===================================================================== + batch = recursive_to_device(batch, cfg.train.device) + + # ===================================================================== + # 准备 agent 输入 + # ===================================================================== + # 数据集返回: {action, qpos, image_, ...} + # Agent 期望: {images: dict, qpos: tensor, action: tensor} + + # 准备 agent 输入 + agent_input = build_agent_input(batch) + + # ===================================================================== + # 前向传播与损失计算 + # ===================================================================== + try: + loss = agent.compute_loss(agent_input) + except Exception as e: + log.error(f"❌ 步骤 {step} 前向传播失败: {e}") + raise + + last_loss = loss.item() + + # ===================================================================== + # 反向传播与优化 + # ===================================================================== + optimizer.zero_grad() + loss.backward() + + # 梯度裁剪以稳定训练 + torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip) + + optimizer.step() + scheduler.step() + + # ===================================================================== + # 日志记录 + # ===================================================================== + if step % cfg.train.log_freq == 0: + current_lr = optimizer.param_groups[0]['lr'] + best_loss_to_log = best_loss if best_loss != float('inf') else loss.item() + pbar.set_postfix({ + "loss": f"{loss.item():.4f}", + "lr": f"{current_lr:.2e}", + "best_loss": f"{best_loss_to_log:.4f}" + }) + log.info(f"步骤 {step}/{cfg.train.max_steps} | 损失: {loss.item():.4f} | 学习率: {current_lr:.2e}") + _log_to_swanlab( + swanlab_module, + { + 'train/loss': loss.item(), + 'train/lr': current_lr, + 'train/best_loss': best_loss_to_log, + 'train/step': step, + }, + step=step, + ) + + # ===================================================================== + # 检查点保存与验证 + # ===================================================================== + checkpoint_path = None + val_loss = None + if step > 0 and step % cfg.train.save_freq == 0: + # 运行验证 + val_loss = run_validation() + if val_loss is not None: + log.info(f"步骤 {step}/{cfg.train.max_steps} | 验证损失: {val_loss:.4f}") + _log_to_swanlab( + swanlab_module, + {'val/loss': val_loss}, + step=step, + ) + + checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" + save_checkpoint( + checkpoint_path, + step, + loss.item(), + val_loss=val_loss, + ) + log.info(f"💾 检查点已保存: {checkpoint_path}") + + # 在首次拿到 rollout 平均奖励之前,使用损失作为最佳模型回退指标 + if best_rollout_reward == float('-inf'): + eval_loss = val_loss if val_loss is not None else loss.item() + if eval_loss < best_loss: + best_loss = eval_loss + best_model_path = default_best_model_path + save_checkpoint( + best_model_path, + step, + loss.item(), + val_loss=val_loss, + ) + log.info(f"🌟 最佳模型已更新: {best_model_path} (验证损失: {best_loss:.4f})") + + checkpoint_rollout_stats = run_checkpoint_rollout_validation(checkpoint_path) + checkpoint_rollout_avg_reward = ( + checkpoint_rollout_stats.get('avg_reward') + if checkpoint_rollout_stats is not None else None + ) + if checkpoint_rollout_avg_reward is not None: + log.info( + f"步骤 {step}/{cfg.train.max_steps} | checkpoint rollout 平均奖励: " + f"{checkpoint_rollout_avg_reward:.4f}" + ) + _log_to_swanlab( + swanlab_module, + {'rollout/avg_reward': checkpoint_rollout_avg_reward}, + step=step, + ) + if checkpoint_rollout_avg_reward > best_rollout_reward: + best_rollout_reward = checkpoint_rollout_avg_reward + best_model_path = default_best_model_path + save_checkpoint( + best_model_path, + step, + loss.item(), + val_loss=val_loss, + rollout_avg_reward=checkpoint_rollout_avg_reward, + ) + log.info( + f"🌟 最佳模型已更新: {best_model_path} " + f"(checkpoint rollout 平均奖励: {best_rollout_reward:.4f})" + ) + + completed_steps = step + 1 + completed_epoch = ( + completed_steps // steps_per_epoch + if steps_per_epoch > 0 else 0 + ) + should_run_epoch_rollout = ( + rollout_validation_enabled + and steps_per_epoch > 0 + and completed_steps % steps_per_epoch == 0 + and completed_epoch > 0 + and completed_epoch % rollout_val_freq_epochs == 0 + ) + if should_run_epoch_rollout: + if checkpoint_path is None: + checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" + save_checkpoint( + checkpoint_path, + step, + loss.item(), + val_loss=val_loss, + ) + log.info(f"💾 Epoch rollout 验证前检查点已保存: {checkpoint_path}") + + rollout_stats = run_rollout_validation(checkpoint_path) + rollout_avg_reward = ( + rollout_stats.get('avg_reward') + if rollout_stats is not None else None + ) + if rollout_avg_reward is not None: + log.info( + f"步骤 {step}/{cfg.train.max_steps} | Epoch {completed_epoch} " + f"rollout 平均奖励: {rollout_avg_reward:.4f}" + ) + _log_to_swanlab( + swanlab_module, + { + 'rollout/avg_reward': rollout_avg_reward, + 'rollout/epoch': completed_epoch, + }, + step=step, + ) + if rollout_avg_reward > best_rollout_reward: + best_rollout_reward = rollout_avg_reward + best_model_path = default_best_model_path + save_checkpoint( + best_model_path, + step, + loss.item(), + val_loss=val_loss, + rollout_avg_reward=rollout_avg_reward, + ) + log.info( + f"🌟 最佳模型已更新: {best_model_path} " + f"(Epoch {completed_epoch} rollout 平均奖励: {best_rollout_reward:.4f})" + ) + + # ========================================================================= + # 6. 保存最终模型 + # ========================================================================= + final_model_path = checkpoint_dir / "vla_model_final.pt" + save_checkpoint( + final_model_path, + cfg.train.max_steps, + last_loss, + ) + log.info(f"💾 最终模型已保存: {final_model_path}") + _log_to_swanlab( + swanlab_module, + { + 'final/checkpoint_path': str(final_model_path), + 'final/best_checkpoint_path': ( + str(best_model_path) if best_model_path is not None else '' + ), + }, + step=cfg.train.max_steps, + ) + + log.info("✅ 训练成功完成!") + if last_loss is not None: + log.info(f"📊 最终损失: {last_loss:.4f}") + else: + log.info("📊 最终损失: N/A(未执行训练步)") + if best_rollout_reward != float('-inf'): + log.info(f"📊 最佳 rollout 平均奖励: {best_rollout_reward:.4f}") + elif best_loss != float('inf'): + log.info(f"📊 最佳损失: {best_loss:.4f}") + else: + log.info("📊 最佳验证指标: N/A(无有效 rollout/验证损失)") + finally: + _finish_swanlab(swanlab_module) + + +@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") +def main(cfg: DictConfig): + _run_training(cfg) if __name__ == "__main__": diff --git a/roboimi/envs/double_base.py b/roboimi/envs/double_base.py index d84de3d..1089d3a 100644 --- a/roboimi/envs/double_base.py +++ b/roboimi/envs/double_base.py @@ -213,7 +213,9 @@ class DualDianaMed(MujocoEnv): def camera_viewer(self): img_renderer = mj.Renderer(self.mj_model,height=480,width=640) - cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL) + show_gui = self.is_render + if show_gui: + cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL) while not self.exit_flag: img_renderer.update_scene(self.mj_data,camera="rs_cam_right") self.r_vis = img_renderer.render() @@ -230,9 +232,10 @@ class DualDianaMed(MujocoEnv): img_renderer.update_scene(self.mj_data,camera="front") self.front = img_renderer.render() self.front = self.front[:, :, ::-1] - if self.cam_view is not None: - cv2.imshow('Cam view', self.cam_view) - cv2.waitKey(1) + if show_gui: + if self.cam_view is not None: + cv2.imshow('Cam view', self.cam_view) + cv2.waitKey(1) def cam_start(self): @@ -300,4 +303,4 @@ if __name__ == "__main__": # print("quat_right =",quat_right,"\n") if env.is_render: env.render() - \ No newline at end of file + diff --git a/roboimi/envs/double_pos_ctrl_env.py b/roboimi/envs/double_pos_ctrl_env.py index 2189b44..78cb1a6 100644 --- a/roboimi/envs/double_pos_ctrl_env.py +++ b/roboimi/envs/double_pos_ctrl_env.py @@ -133,12 +133,12 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): return reward -def make_sim_env(task_name): +def make_sim_env(task_name, headless=False): if 'sim_transfer' in task_name: from roboimi.assets.robots.diana_med import BiDianaMed env = DualDianaMed_Pos_Ctrl( robot=BiDianaMed(), - is_render=True, + is_render=not headless, control_freq=30, is_interpolate=True, cam_view='angle' @@ -167,4 +167,4 @@ if __name__ == "__main__": env.step(action) if env.is_render: env.render() - \ No newline at end of file + diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index b35d568..12f8a26 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -3,10 +3,8 @@ import torch.nn as nn import numpy as np from collections import deque from typing import Dict, Optional, Any, Tuple -from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D from roboimi.vla.models.normalization import NormalizationModule class VLAAgent(nn.Module): @@ -24,6 +22,7 @@ class VLAAgent(nn.Module): diffusion_steps=100, # DDPM 加噪步数 inference_steps=10, # DDIM 推理步数 num_cams=3, # 视觉输入的摄像头数量 + camera_names: Optional[Tuple[str, ...]] = None, # 条件相机顺序 dataset_stats=None, # 数据集统计信息,用于归一化 normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max' num_action_steps=8, # 每次推理实际执行多少步动作 @@ -39,6 +38,31 @@ class VLAAgent(nn.Module): self.num_action_steps = num_action_steps self.inference_steps = inference_steps self.head_type = head_type # 'unet' 或 'transformer' + agent_camera_names = tuple(camera_names) if camera_names is not None else None + backbone_camera_names = getattr(vision_backbone, 'camera_names', None) + backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None + backbone_num_cameras = getattr(vision_backbone, 'num_cameras', None) + if backbone_num_cameras is not None and backbone_num_cameras != self.num_cams: + raise ValueError( + f"agent.num_cams({self.num_cams}) 与 " + f"vision_backbone.num_cameras({backbone_num_cameras}) 不一致" + ) + if ( + agent_camera_names is not None + and backbone_camera_names is not None + and agent_camera_names != backbone_camera_names + ): + raise ValueError( + f"agent.camera_names({list(agent_camera_names)}) 与 " + f"vision_backbone.camera_names({list(backbone_camera_names)}) 不一致" + ) + self.camera_names = ( + agent_camera_names if agent_camera_names is not None else backbone_camera_names + ) + if self.camera_names is not None and len(self.camera_names) != self.num_cams: + raise ValueError( + f"camera_names 长度({len(self.camera_names)})与 num_cams({self.num_cams})不一致" + ) # 归一化模块 - 统一训练和推理的归一化逻辑 @@ -48,6 +72,8 @@ class VLAAgent(nn.Module): ) self.vision_encoder = vision_backbone + if self.camera_names is not None: + self.vision_encoder.camera_names = self.camera_names single_cam_feat_dim = self.vision_encoder.output_dim # global_cond_dim: 展平后的总维度(用于UNet) total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon @@ -117,6 +143,34 @@ class VLAAgent(nn.Module): return tuple(self._move_to_device(v, device) for v in data) return data + def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """按显式配置的相机顺序返回图像字典。""" + if self.camera_names is None: + camera_names = tuple(sorted(images.keys())) + if len(camera_names) != self.num_cams: + raise ValueError( + f"图像条件相机数量({len(camera_names)})与 num_cams({self.num_cams})不一致" + ) + return {cam_name: images[cam_name] for cam_name in camera_names} + + missing = [cam_name for cam_name in self.camera_names if cam_name not in images] + if missing: + raise ValueError( + f"图像条件缺少必需相机。missing={missing}, expected={list(self.camera_names)}" + ) + return {cam_name: images[cam_name] for cam_name in self.camera_names} + + def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor: + """构造每步条件,确保图像条件顺序稳定。""" + ordered_images = self._order_images(images) + visual_features = self.vision_encoder(ordered_images) + state_features = self.state_encoder(states) + cond = torch.cat([visual_features, state_features], dim=-1) + if cond.shape[-1] != self.per_step_cond_dim: + raise RuntimeError( + f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}" + ) + return cond # ========================== # 训练阶段 (Training) @@ -136,10 +190,8 @@ class VLAAgent(nn.Module): states = self.normalization.normalize_qpos(states) actions = self.normalization.normalize_action(actions) - state_features = self.state_encoder(states) - # 1. 提取视觉特征 - visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim) + per_step_cond = self._build_cond(images, states) action_features = self.action_encoder(actions) # 2. 采样噪声 @@ -157,21 +209,16 @@ class VLAAgent(nn.Module): ) # 拼接全局条件并展平 - # visual_features: (B, obs_horizon, vision_dim) - # state_features: (B, obs_horizon, obs_dim) - # 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim)) - global_cond = torch.cat([visual_features, state_features], dim=-1) - global_cond = global_cond.flatten(start_dim=1) + # per_step_cond: (B, obs_horizon, vision_dim * num_cams + obs_dim) + # 展平后用于 UNet,全序列形式用于 Transformer + global_cond = per_step_cond.flatten(start_dim=1) # 5. 网络预测噪声(根据head类型选择接口) if self.head_type == 'transformer': - # Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step) - # 将展平的global_cond reshape回序列格式 - cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) pred_noise = self.noise_pred_net( sample=noisy_actions, timestep=timesteps, - cond=cond + cond=per_step_cond ) else: # 'unet' pred_noise = self.noise_pred_net( @@ -218,7 +265,8 @@ class VLAAgent(nn.Module): # 添加图像 if 'images' in observation: - self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()}) + ordered_images = self._order_images(observation['images']) + self._queues['images'].append({k: v.clone() for k, v in ordered_images.items()}) def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]: """ @@ -246,7 +294,8 @@ class VLAAgent(nn.Module): images_list.append(images_list[-1]) batch_images = {} - for cam_name in images_list[0].keys(): + camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys())) + for cam_name in camera_names: batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0) return {'qpos': batch_qpos, 'images': batch_images} @@ -346,22 +395,18 @@ class VLAAgent(nn.Module): proprioception = self.normalization.normalize_qpos(proprioception) # 1. 提取当前观测特征(只提取一次) - visual_features = self.vision_encoder(images) - state_features = self.state_encoder(proprioception) + per_step_cond = self._build_cond(images, proprioception) # 拼接条件(只计算一次) - # visual_features: (B, obs_horizon, vision_dim) - # state_features: (B, obs_horizon, obs_dim) - global_cond = torch.cat([visual_features, state_features], dim=-1) - global_cond_flat = global_cond.flatten(start_dim=1) + global_cond_flat = per_step_cond.flatten(start_dim=1) if self.head_type == 'transformer': - cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) + cond = per_step_cond else: cond = None # 2. 初始化纯高斯噪声动作 # 形状: (B, pred_horizon, action_dim) - device = visual_features.device + device = per_step_cond.device current_actions = torch.randn( (B, self.pred_horizon, self.action_dim), device=device ) diff --git a/roboimi/vla/conf/agent/resnet_transformer.yaml b/roboimi/vla/conf/agent/resnet_transformer.yaml index fd306a1..5b129fc 100644 --- a/roboimi/vla/conf/agent/resnet_transformer.yaml +++ b/roboimi/vla/conf/agent/resnet_transformer.yaml @@ -29,8 +29,13 @@ num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= p # ==================== # 相机配置 # ==================== +camera_names: ${data.camera_names} # 条件相机顺序固定为 r_vis, top, front num_cams: 3 # 摄像头数量 (r_vis, top, front) +vision_backbone: + num_cameras: ${agent.num_cams} + camera_names: ${agent.camera_names} + # ==================== # 扩散过程配置 # ==================== @@ -52,3 +57,6 @@ head: # ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机 # 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208 cond_dim: 208 + causal_attn: false + time_as_cond: true + obs_as_cond: true diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 00b0b5f..6eef43f 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -9,19 +9,25 @@ defaults: # ==================== train: # 基础训练参数 - batch_size: 8 # 批次大小 - lr: 5e-5 # 学习率(Transformer建议更小) + batch_size: 16 # 批次大小 + lr: 1e-4 # 学习率 max_steps: 100000 # 最大训练步数 device: "cuda" # 设备: "cuda" 或 "cpu" # 数据加载 - num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8) - val_split: 0.1 # 验证集比例 + num_workers: 12 # DataLoader 工作进程数(调试时设为 0) + val_split: 0.0 # 验证集比例;默认使用全量数据训练 seed: 42 # 随机种子(用于数据划分) # 日志和检查点 log_freq: 100 # 日志记录频率(步数) save_freq: 2000 # 保存检查点频率(步数) + use_swanlab: false # 是否启用 SwanLab 标量日志 + swanlab_project: "roboimi-vla" # SwanLab project 名称 + swanlab_run_name: null # 可选的 SwanLab 运行名 + rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证 + rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证 + rollout_num_episodes: 3 # rollout 验证的回合数 # 学习率调度器(带预热) warmup_steps: 2000 # 预热步数(Transformer建议更长) diff --git a/roboimi/vla/conf/head/transformer1d.yaml b/roboimi/vla/conf/head/transformer1d.yaml index 73b4527..4c9cc78 100644 --- a/roboimi/vla/conf/head/transformer1d.yaml +++ b/roboimi/vla/conf/head/transformer1d.yaml @@ -5,7 +5,7 @@ _partial_: true # ==================== # Transformer 架构配置 # ==================== -n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性) +n_layer: 4 # Transformer层数(保持当前小模型配置) n_head: 4 # 注意力头数 n_emb: 128 # 嵌入维度 p_drop_emb: 0.05 # Embedding dropout @@ -14,9 +14,10 @@ p_drop_attn: 0.05 # Attention dropout # ==================== # 条件配置 # ==================== -causal_attn: false # 是否使用因果注意力(自回归生成) -obs_as_cond: true # 观测作为条件(由cond_dim > 0决定) -n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合) +causal_attn: false # 对齐 external TransformerForDiffusion 的 full-attention / nocausal 变体 +time_as_cond: true # 与 external 实现一致:时间步作为条件 token +obs_as_cond: true # API 对齐;实际是否启用由 cond_dim > 0 决定 +n_cond_layers: 1 # 条件编码器层数(保留当前配置) # ==================== # 注意事项 diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py index 83c995f..b55ab85 100644 --- a/roboimi/vla/data/simpe_robot_dataset.py +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -105,7 +105,7 @@ class SimpleRobotDataset(Dataset): self._file_cache[key] = f return f - def _load_frame(self, idx: int) -> Dict: + def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict: """从 HDF5 文件懒加载单帧数据""" meta = self.frame_meta[idx] f = self._get_h5_file(meta["hdf5_path"]) @@ -118,21 +118,22 @@ class SimpleRobotDataset(Dataset): } # 加载图像数据: observations/images/{cam_name} -> observation.{cam_name} - for cam_name in self.camera_names: - h5_path = f'observations/images/{cam_name}' - if h5_path in f: - img = f[h5_path][meta["frame_idx"]] - # Resize图像到224x224(减少内存和I/O负担) - import cv2 - img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) - # 转换为float并归一化到 [0, 1] - img = torch.from_numpy(img).float() / 255.0 - frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW + if load_images: + for cam_name in self.camera_names: + h5_path = f'observations/images/{cam_name}' + if h5_path in f: + img = f[h5_path][meta["frame_idx"]] + # Resize图像到224x224(减少内存和I/O负担) + import cv2 + img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) + # 转换为float并归一化到 [0, 1] + img = torch.from_numpy(img).float() / 255.0 + frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW return frame def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - frame = self._load_frame(idx) + frame = self._load_frame(idx, load_images=False) ep_idx = frame["episode_index"] # 获取当前 episode 的帧索引范围 @@ -186,10 +187,10 @@ class SimpleRobotDataset(Dataset): target_idx = idx + delta if target_idx <= ep_end: - actions.append(self._load_frame(target_idx)["action"]) + actions.append(self._load_frame(target_idx, load_images=False)["action"]) action_is_pad.append(False) else: - actions.append(self._load_frame(ep_end)["action"]) + actions.append(self._load_frame(ep_end, load_images=False)["action"]) action_is_pad.append(True) # ============================================ diff --git a/roboimi/vla/eval_utils.py b/roboimi/vla/eval_utils.py new file mode 100644 index 0000000..73cb05d --- /dev/null +++ b/roboimi/vla/eval_utils.py @@ -0,0 +1,3 @@ +def execute_policy_action(env, action): + """Execute policy outputs using EE-action semantics.""" + env.step(action) diff --git a/roboimi/vla/models/backbones/resnet_diffusion.py b/roboimi/vla/models/backbones/resnet_diffusion.py index b5c898f..726c504 100644 --- a/roboimi/vla/models/backbones/resnet_diffusion.py +++ b/roboimi/vla/models/backbones/resnet_diffusion.py @@ -178,12 +178,18 @@ class ResNetDiffusionBackbone(VLABackbone): spatial_softmax_num_keypoints: int = 32, use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器 num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用) + camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序 freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True) ): super().__init__() self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera self.num_cameras = num_cameras + self.camera_names = tuple(camera_names) if camera_names is not None else None + if self.camera_names is not None and len(self.camera_names) != self.num_cameras: + raise ValueError( + f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致" + ) if use_separate_rgb_encoder_per_camera: # 独立编码器模式:为每个摄像头创建独立的编码器 @@ -217,6 +223,22 @@ class ResNetDiffusionBackbone(VLABackbone): ) self.feature_dim = self.rgb_encoder.feature_dim + def _ordered_camera_names(self, images) -> Tuple[str, ...]: + if self.camera_names is None: + camera_names = tuple(sorted(images.keys())) + if len(camera_names) != self.num_cameras: + raise ValueError( + f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致" + ) + return camera_names + + missing = [cam_name for cam_name in self.camera_names if cam_name not in images] + if missing: + raise ValueError( + f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}" + ) + return self.camera_names + def forward(self, images): """ Args: @@ -228,7 +250,7 @@ class ResNetDiffusionBackbone(VLABackbone): """ any_tensor = next(iter(images.values())) B, T = any_tensor.shape[:2] - cam_names = sorted(images.keys()) + cam_names = self._ordered_camera_names(images) if self.use_separate_rgb_encoder_per_camera: # 独立编码器模式:每个摄像头使用对应的编码器 @@ -236,7 +258,7 @@ class ResNetDiffusionBackbone(VLABackbone): for cam_idx, cam_name in enumerate(cam_names): img = images[cam_name] encoder = self.rgb_encoder[cam_idx] - features = encoder.forward_single_image(img.view(B * T, *img.shape[2:])) + features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:])) features_all.append(features) return torch.cat(features_all, dim=1).view(B, T, -1) else: @@ -244,7 +266,7 @@ class ResNetDiffusionBackbone(VLABackbone): features_all = [] for cam_name in cam_names: img = images[cam_name] - features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:])) + features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:])) features_all.append(features) return torch.cat(features_all, dim=1).view(B, T, -1) @@ -369,4 +391,4 @@ if __name__ == "__main__": print("\n" + "=" * 60) print("🎉 All tests completed successfully!") - print("=" * 60) \ No newline at end of file + print("=" * 60) diff --git a/roboimi/vla/models/heads/transformer1d.py b/roboimi/vla/models/heads/transformer1d.py index 8d517d8..2b0752a 100644 --- a/roboimi/vla/models/heads/transformer1d.py +++ b/roboimi/vla/models/heads/transformer1d.py @@ -1,19 +1,35 @@ -""" -Transformer-based Diffusion Policy Head +"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion.""" -使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。 -支持通过Cross-Attention注入全局条件(观测特征)。 -""" +from __future__ import annotations +import logging import math +from typing import Optional, Tuple, Union + import torch import torch.nn as nn -from typing import Optional + +logger = logging.getLogger(__name__) + + +class ModuleAttrMixin(nn.Module): + """Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity.""" + + def __init__(self) -> None: + super().__init__() + self._dummy_variable = nn.Parameter() + + @property + def device(self): + return next(iter(self.parameters())).device + + @property + def dtype(self): + return next(iter(self.parameters())).dtype class SinusoidalPosEmb(nn.Module): - """正弦位置编码(用于时间步嵌入)""" - def __init__(self, dim: int): + def __init__(self, dim: int) -> None: super().__init__() self.dim = dim @@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module): return emb -class Transformer1D(nn.Module): - """ - Transformer-based 1D Diffusion Model - - 使用Encoder-Decoder架构: - - Encoder: 处理条件(观测 + 时间步) - - Decoder: 通过Cross-Attention预测噪声 - - Args: - input_dim: 输入动作维度 - output_dim: 输出动作维度 - horizon: 预测horizon长度 - n_obs_steps: 观测步数 - cond_dim: 条件维度 - n_layer: Transformer层数 - n_head: 注意力头数 - n_emb: 嵌入维度 - p_drop_emb: Embedding dropout - p_drop_attn: Attention dropout - causal_attn: 是否使用因果注意力(自回归) - n_cond_layers: Encoder层数(0表示使用MLP) - """ - +class Transformer1D(ModuleAttrMixin): def __init__( self, input_dim: int, output_dim: int, horizon: int, - n_obs_steps: int = None, + n_obs_steps: Optional[int] = None, cond_dim: int = 0, n_layer: int = 8, n_head: int = 8, @@ -63,57 +57,42 @@ class Transformer1D(nn.Module): p_drop_emb: float = 0.1, p_drop_attn: float = 0.1, causal_attn: bool = False, + time_as_cond: bool = True, obs_as_cond: bool = False, - n_cond_layers: int = 0 - ): + n_cond_layers: int = 0, + ) -> None: super().__init__() - # 计算序列长度 if n_obs_steps is None: n_obs_steps = horizon T = horizon - T_cond = 1 # 时间步token数量 - - # 确定是否使用观测作为条件 + T_cond = 1 + if not time_as_cond: + T += 1 + T_cond -= 1 obs_as_cond = cond_dim > 0 if obs_as_cond: + assert time_as_cond T_cond += n_obs_steps - # 保存配置 - self.T = T - self.T_cond = T_cond - self.horizon = horizon - self.obs_as_cond = obs_as_cond - self.input_dim = input_dim - self.output_dim = output_dim - - # ==================== 输入嵌入 ==================== self.input_emb = nn.Linear(input_dim, n_emb) self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) self.drop = nn.Dropout(p_drop_emb) - # ==================== 条件编码 ==================== - # 时间步嵌入 self.time_emb = SinusoidalPosEmb(n_emb) - - # 观测条件嵌入(可选) self.cond_obs_emb = None if obs_as_cond: self.cond_obs_emb = nn.Linear(cond_dim, n_emb) - # 条件位置编码 self.cond_pos_emb = None + self.encoder = None + self.decoder = None + encoder_only = False + if T_cond > 0: self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) - - # ==================== Encoder ==================== - self.encoder = None - self.encoder_only = False - - if T_cond > 0: if n_cond_layers > 0: - # 使用Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=n_emb, nhead=n_head, @@ -121,61 +100,19 @@ class Transformer1D(nn.Module): dropout=p_drop_attn, activation='gelu', batch_first=True, - norm_first=True # Pre-LN更稳定 + norm_first=True, ) self.encoder = nn.TransformerEncoder( encoder_layer=encoder_layer, - num_layers=n_cond_layers + num_layers=n_cond_layers, ) else: - # 使用简单的MLP self.encoder = nn.Sequential( nn.Linear(n_emb, 4 * n_emb), nn.Mish(), - nn.Linear(4 * n_emb, n_emb) + nn.Linear(4 * n_emb, n_emb), ) - else: - # Encoder-only模式(BERT风格) - self.encoder_only = True - encoder_layer = nn.TransformerEncoderLayer( - d_model=n_emb, - nhead=n_head, - dim_feedforward=4 * n_emb, - dropout=p_drop_attn, - activation='gelu', - batch_first=True, - norm_first=True - ) - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=n_layer - ) - # ==================== Attention Mask ==================== - self.mask = None - self.memory_mask = None - - if causal_attn: - # 因果mask:确保只关注左侧 - sz = T - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) - self.register_buffer("mask", mask) - - if obs_as_cond: - # 交叉注意力mask - S = T_cond - t, s = torch.meshgrid( - torch.arange(T), - torch.arange(S), - indexing='ij' - ) - mask = t >= (s - 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) - self.register_buffer('memory_mask', mask) - - # ==================== Decoder ==================== - if not self.encoder_only: decoder_layer = nn.TransformerDecoderLayer( d_model=n_emb, nhead=n_head, @@ -183,136 +120,199 @@ class Transformer1D(nn.Module): dropout=p_drop_attn, activation='gelu', batch_first=True, - norm_first=True + norm_first=True, ) self.decoder = nn.TransformerDecoder( decoder_layer=decoder_layer, - num_layers=n_layer + num_layers=n_layer, + ) + else: + encoder_only = True + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layer, ) - # ==================== 输出头 ==================== + if causal_attn: + sz = T + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('mask', mask) + + if time_as_cond and obs_as_cond: + S = T_cond + t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij') + mask = t >= (s - 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('memory_mask', mask) + else: + self.memory_mask = None + else: + self.mask = None + self.memory_mask = None + self.ln_f = nn.LayerNorm(n_emb) self.head = nn.Linear(n_emb, output_dim) - # ==================== 初始化 ==================== - self.apply(self._init_weights) + self.T = T + self.T_cond = T_cond + self.horizon = horizon + self.time_as_cond = time_as_cond + self.obs_as_cond = obs_as_cond + self.encoder_only = encoder_only - # 打印参数量 - total_params = sum(p.numel() for p in self.parameters()) - print(f"Transformer1D parameters: {total_params:,}") + self.apply(self._init_weights) + logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters())) def _init_weights(self, module): - """初始化权重""" + ignore_types = ( + nn.Dropout, + SinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + ) if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): - # MultiheadAttention的权重初始化 - for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']: - weight = getattr(module, name, None) + for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'): + weight = getattr(module, name) if weight is not None: torch.nn.init.normal_(weight, mean=0.0, std=0.02) - for name in ['in_proj_bias', 'bias_k', 'bias_v']: - bias = getattr(module, name, None) + for name in ('in_proj_bias', 'bias_k', 'bias_v'): + bias = getattr(module, name) if bias is not None: torch.nn.init.zeros_(bias) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) elif isinstance(module, Transformer1D): - # 位置编码初始化 - torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02) - if self.cond_pos_emb is not None: - torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02) + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_obs_emb is not None: + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + pass + else: + raise RuntimeError(f'Unaccounted module {module}') + + def get_optim_groups(self, weight_decay: float = 1e-3): + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + + for module_name, module in self.named_modules(): + for param_name, _ in module.named_parameters(): + full_param_name = f'{module_name}.{param_name}' if module_name else param_name + + if param_name.endswith('bias'): + no_decay.add(full_param_name) + elif param_name.startswith('bias'): + no_decay.add(full_param_name) + elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules): + decay.add(full_param_name) + elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules): + no_decay.add(full_param_name) + + no_decay.add('pos_emb') + no_decay.add('_dummy_variable') + if self.cond_pos_emb is not None: + no_decay.add('cond_pos_emb') + + param_dict = {name: param for name, param in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!' + assert len(param_dict.keys() - union_params) == 0, ( + f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!' + ) + + return [ + { + 'params': [param_dict[name] for name in sorted(decay)], + 'weight_decay': weight_decay, + }, + { + 'params': [param_dict[name] for name in sorted(no_decay)], + 'weight_decay': 0.0, + }, + ] + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) def forward( self, sample: torch.Tensor, - timestep: torch.Tensor, + timestep: Union[torch.Tensor, float, int], cond: Optional[torch.Tensor] = None, - **kwargs + **kwargs, ): - """ - 前向传播 - - Args: - sample: (B, T, input_dim) 输入序列(加噪动作) - timestep: (B,) 时间步 - cond: (B, T', cond_dim) 条件序列(观测特征) - - Returns: - (B, T, output_dim) 预测的噪声 - """ - # ==================== 处理时间步 ==================== timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) - - # 扩展到batch维度 timesteps = timesteps.expand(sample.shape[0]) - time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb) + time_emb = self.time_emb(timesteps).unsqueeze(1) - # ==================== 处理输入 ==================== - input_emb = self.input_emb(sample) # (B, T, n_emb) + input_emb = self.input_emb(sample) - # ==================== Encoder-Decoder模式 ==================== - if not self.encoder_only: - # --- Encoder: 处理条件 --- + if self.encoder_only: + token_embeddings = torch.cat([time_emb, input_emb], dim=1) + t = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :t, :] + x = self.drop(token_embeddings + position_embeddings) + x = self.encoder(src=x, mask=self.mask) + x = x[:, 1:, :] + else: cond_embeddings = time_emb - - if self.obs_as_cond and cond is not None: - # 添加观测条件 - cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb) + if self.obs_as_cond: + cond_obs_emb = self.cond_obs_emb(cond) cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) - - # 添加位置编码 tc = cond_embeddings.shape[1] - pos_emb = self.cond_pos_emb[:, :tc, :] - x = self.drop(cond_embeddings + pos_emb) + position_embeddings = self.cond_pos_emb[:, :tc, :] + x = self.drop(cond_embeddings + position_embeddings) + memory = self.encoder(x) - # 通过encoder - memory = self.encoder(x) # (B, T_cond, n_emb) - - # --- Decoder: 预测噪声 --- - # 添加位置编码到输入 token_embeddings = input_emb t = token_embeddings.shape[1] - pos_emb = self.pos_emb[:, :t, :] - x = self.drop(token_embeddings + pos_emb) - - # Cross-Attention: Query来自输入,Key/Value来自memory + position_embeddings = self.pos_emb[:, :t, :] + x = self.drop(token_embeddings + position_embeddings) x = self.decoder( tgt=x, memory=memory, tgt_mask=self.mask, - memory_mask=self.memory_mask + memory_mask=self.memory_mask, ) - # ==================== Encoder-Only模式 ==================== - else: - # BERT风格:时间步作为特殊token - token_embeddings = torch.cat([time_emb, input_emb], dim=1) - t = token_embeddings.shape[1] - pos_emb = self.pos_emb[:, :t, :] - x = self.drop(token_embeddings + pos_emb) - - x = self.encoder(src=x, mask=self.mask) - x = x[:, 1:, :] # 移除时间步token - - # ==================== 输出头 ==================== x = self.ln_f(x) - x = self.head(x) # (B, T, output_dim) - + x = self.head(x) return x -# ============================================================================ -# 便捷函数:创建Transformer1D模型 -# ============================================================================ def create_transformer1d( input_dim: int, output_dim: int, @@ -322,26 +322,9 @@ def create_transformer1d( n_layer: int = 8, n_head: int = 8, n_emb: int = 256, - **kwargs + **kwargs, ) -> Transformer1D: - """ - 创建Transformer1D模型的便捷函数 - - Args: - input_dim: 输入动作维度 - output_dim: 输出动作维度 - horizon: 预测horizon - n_obs_steps: 观测步数 - cond_dim: 条件维度 - n_layer: Transformer层数 - n_head: 注意力头数 - n_emb: 嵌入维度 - **kwargs: 其他参数 - - Returns: - Transformer1D模型 - """ - model = Transformer1D( + return Transformer1D( input_dim=input_dim, output_dim=output_dim, horizon=horizon, @@ -350,47 +333,5 @@ def create_transformer1d( n_layer=n_layer, n_head=n_head, n_emb=n_emb, - **kwargs + **kwargs, ) - return model - - -if __name__ == "__main__": - print("=" * 80) - print("Testing Transformer1D") - print("=" * 80) - - # 配置 - B = 4 - T = 16 - action_dim = 16 - obs_horizon = 2 - cond_dim = 416 # vision + state特征维度 - - # 创建模型 - model = Transformer1D( - input_dim=action_dim, - output_dim=action_dim, - horizon=T, - n_obs_steps=obs_horizon, - cond_dim=cond_dim, - n_layer=4, - n_head=8, - n_emb=256, - causal_attn=False - ) - - # 测试前向传播 - sample = torch.randn(B, T, action_dim) - timestep = torch.randint(0, 100, (B,)) - cond = torch.randn(B, obs_horizon, cond_dim) - - output = model(sample, timestep, cond) - - print(f"\n输入:") - print(f" sample: {sample.shape}") - print(f" timestep: {timestep.shape}") - print(f" cond: {cond.shape}") - print(f"\n输出:") - print(f" output: {output.shape}") - print(f"\n✅ 测试通过!") diff --git a/roboimi/vla/scripts/calculate_stats.py b/roboimi/vla/scripts/calculate_stats.py index 5fece0e..072f4bf 100644 --- a/roboimi/vla/scripts/calculate_stats.py +++ b/roboimi/vla/scripts/calculate_stats.py @@ -1,8 +1,16 @@ +import argparse +import glob +import os +import pickle +from pathlib import Path + import h5py import numpy as np -import os -import glob -import pickle + + +DEFAULT_DATASET_DIR = str( + Path(__file__).resolve().parents[2] / "demos" / "dataset" / "sim_transfer" +) def get_data_stats(dataset_dir): """ @@ -23,6 +31,11 @@ def get_data_stats(dataset_dir): files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5'))) print(f"Found {len(files)} episodes in {dataset_dir}") + if not files: + raise ValueError( + f"No episode_*.hdf5 files found in dataset_dir: {dataset_dir}" + ) + all_actions = [] all_qpos = [] @@ -70,18 +83,32 @@ def get_data_stats(dataset_dir): } return stats_flat -if __name__ == "__main__": - DATASET_DIR = 'roboimi/demos/dataset/sim_transfer' - OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl" - stats_flat = get_data_stats(DATASET_DIR) +def write_dataset_stats(dataset_dir): + output_path = os.path.join(dataset_dir, "dataset_stats.pkl") + stats_flat = get_data_stats(dataset_dir) # 打印检查 print("\n--- Stats Computed ---") print(f"Action Mean shape: {stats_flat['action_mean'].shape}") print(f"Action Std shape: {stats_flat['action_std'].shape}") - # 保存 - with open(OUTPUT_PATH, 'wb') as f: + with open(output_path, 'wb') as f: pickle.dump(stats_flat, f) - print(f"\nStats saved to {OUTPUT_PATH}") \ No newline at end of file + print(f"\nStats saved to {output_path}") + + return output_path + + +def main(argv=None): + parser = argparse.ArgumentParser(description="Calculate dataset statistics.") + parser.add_argument( + "--dataset_dir", + default=DEFAULT_DATASET_DIR, + help="Directory containing episode_*.hdf5 files.", + ) + args = parser.parse_args(argv) + write_dataset_stats(args.dataset_dir) + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_calculate_stats_cli.py b/tests/test_calculate_stats_cli.py new file mode 100644 index 0000000..a298422 --- /dev/null +++ b/tests/test_calculate_stats_cli.py @@ -0,0 +1,88 @@ +import pickle +import tempfile +import unittest +from pathlib import Path + +import h5py +import numpy as np + +from roboimi.vla.scripts import calculate_stats + + +class CalculateStatsCliTest(unittest.TestCase): + def test_default_dataset_dir_is_absolute_and_package_relative(self): + expected = ( + Path(calculate_stats.__file__).resolve().parents[2] + / "demos" + / "dataset" + / "sim_transfer" + ) + + self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected) + self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute()) + + def test_main_writes_dataset_stats_pkl_to_dataset_dir(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) + episode_path = dataset_dir / "episode_0.hdf5" + + with h5py.File(episode_path, "w") as root: + root.create_dataset( + "action", + data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + ) + observations = root.create_group("observations") + observations.create_dataset( + "qpos", + data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32), + ) + + calculate_stats.main(["--dataset_dir", str(dataset_dir)]) + + stats_path = dataset_dir / "dataset_stats.pkl" + self.assertTrue(stats_path.exists()) + + with stats_path.open("rb") as f: + stats = pickle.load(f) + + self.assertEqual( + set(stats), + { + "action_mean", + "action_std", + "action_min", + "action_max", + "qpos_mean", + "qpos_std", + "qpos_min", + "qpos_max", + }, + ) + np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0])) + np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0])) + + def test_main_raises_clear_error_for_empty_dataset_dir(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) + + with self.assertRaisesRegex( + ValueError, r"No episode_\*\.hdf5 files found" + ) as ctx: + calculate_stats.main(["--dataset_dir", str(dataset_dir)]) + + self.assertIn(str(dataset_dir), str(ctx.exception)) + + def test_main_raises_clear_error_for_missing_dataset_dir(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) / "missing" + + with self.assertRaisesRegex( + ValueError, r"No episode_\*\.hdf5 files found" + ) as ctx: + calculate_stats.main(["--dataset_dir", str(dataset_dir)]) + + self.assertIn(str(dataset_dir), str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_eval_vla_execution.py b/tests/test_eval_vla_execution.py new file mode 100644 index 0000000..6a468ac --- /dev/null +++ b/tests/test_eval_vla_execution.py @@ -0,0 +1,28 @@ +import unittest + +from roboimi.vla.eval_utils import execute_policy_action + + +class _FakeEnv: + def __init__(self): + self.calls = [] + + def step(self, action): + self.calls.append(("step", action)) + + def step_jnt(self, action): + self.calls.append(("step_jnt", action)) + + +class EvalVLAExecutionTest(unittest.TestCase): + def test_execute_policy_action_uses_ee_step(self): + env = _FakeEnv() + action = [1, 2, 3] + + execute_policy_action(env, action) + + self.assertEqual(env.calls, [("step", action)]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_eval_vla_headless.py b/tests/test_eval_vla_headless.py new file mode 100644 index 0000000..e6f4abb --- /dev/null +++ b/tests/test_eval_vla_headless.py @@ -0,0 +1,259 @@ +import unittest +from pathlib import Path +from unittest import mock + +import numpy as np +import torch +from omegaconf import OmegaConf + +from roboimi.demos.vla_scripts import eval_vla +from roboimi.envs.double_base import DualDianaMed +from roboimi.envs.double_pos_ctrl_env import make_sim_env + + +class _FakeAgent: + def __init__(self): + self.reset_calls = 0 + self.last_observation = None + + def eval(self): + return self + + def to(self, _device): + return self + + def reset(self): + self.reset_calls += 1 + + def select_action(self, observation): + self.last_observation = observation + return torch.zeros(16) + + +class _FakeEnv: + def __init__(self): + self.image_obs_calls = 0 + self.render_calls = 0 + self.reset_calls = [] + + def reset(self, box_pos): + self.reset_calls.append(np.array(box_pos)) + + def _get_image_obs(self): + self.image_obs_calls += 1 + return { + "images": { + "front": np.zeros((8, 8, 3), dtype=np.uint8), + } + } + + def _get_qpos_obs(self): + return {"qpos": np.zeros(16, dtype=np.float32)} + + def render(self): + self.render_calls += 1 + raise AssertionError("env.render() should be skipped when eval.headless=true") + + +class _RewardTrackingEnv(_FakeEnv): + def __init__(self, reward_sequences): + super().__init__() + self.reward_sequences = reward_sequences + self.episode_index = -1 + self.step_index = 0 + self.rew = 0.0 + + def reset(self, box_pos): + super().reset(box_pos) + self.episode_index += 1 + self.step_index = 0 + + +class _FakeRenderer: + def __init__(self, env): + self._env = env + self._frames = [ + np.full((4, 4, 3), fill_value=index, dtype=np.uint8) + for index in range(5) + ] + self._index = 0 + + def update_scene(self, _mj_data, camera=None): + self._camera = camera + + def render(self): + frame = self._frames[self._index] + self._index += 1 + if self._index >= len(self._frames): + self._env.exit_flag = True + return frame + + +class EvalVLAHeadlessTest(unittest.TestCase): + def test_eval_config_exposes_headless_default(self): + eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) + + self.assertIn("headless", eval_cfg) + self.assertFalse(eval_cfg.headless) + + def test_make_sim_env_accepts_headless_and_disables_render(self): + fake_env = object() + + with mock.patch( + "roboimi.assets.robots.diana_med.BiDianaMed", + return_value="robot", + ), mock.patch( + "roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl", + return_value=fake_env, + ) as env_cls: + env = make_sim_env("sim_transfer", headless=True) + + self.assertIs(env, fake_env) + env_cls.assert_called_once_with( + robot="robot", + is_render=False, + control_freq=30, + is_interpolate=True, + cam_view="angle", + ) + + def test_camera_viewer_headless_updates_images_without_gui_calls(self): + env = DualDianaMed.__new__(DualDianaMed) + env.mj_model = object() + env.mj_data = object() + env.exit_flag = False + env.is_render = False + env.cam = "angle" + env.r_vis = None + env.l_vis = None + env.top = None + env.angle = None + env.front = None + + with mock.patch( + "roboimi.envs.double_base.mj.Renderer", + side_effect=lambda *args, **kwargs: _FakeRenderer(env), + ), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch( + "roboimi.envs.double_base.cv2.imshow" + ) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key: + env.camera_viewer() + + named_window.assert_not_called() + imshow.assert_not_called() + wait_key.assert_not_called() + self.assertIsNotNone(env.r_vis) + self.assertIsNotNone(env.l_vis) + self.assertIsNotNone(env.top) + self.assertIsNotNone(env.angle) + self.assertIsNotNone(env.front) + + def test_eval_main_headless_skips_render_and_still_executes_policy(self): + fake_env = _FakeEnv() + fake_agent = _FakeAgent() + cfg = OmegaConf.create( + { + "agent": {}, + "eval": { + "ckpt_path": "checkpoints/vla_model_best.pt", + "num_episodes": 1, + "max_timesteps": 1, + "device": "cpu", + "task_name": "sim_transfer", + "camera_names": ["front"], + "use_smoothing": False, + "smooth_alpha": 0.3, + "verbose_action": False, + "headless": True, + }, + } + ) + + with mock.patch.object( + eval_vla, + "load_checkpoint", + return_value=(fake_agent, None), + ), mock.patch.object( + eval_vla, + "make_sim_env", + return_value=fake_env, + ) as make_env, mock.patch.object( + eval_vla, + "sample_transfer_pose", + return_value=np.array([0.1, 0.2, 0.3]), + ), mock.patch.object( + eval_vla, + "execute_policy_action", + ) as execute_policy_action, mock.patch.object( + eval_vla, + "tqdm", + side_effect=lambda iterable, **kwargs: iterable, + ): + eval_vla.main.__wrapped__(cfg) + + make_env.assert_called_once_with("sim_transfer", headless=True) + execute_policy_action.assert_called_once() + self.assertEqual(fake_env.image_obs_calls, 1) + self.assertEqual(fake_env.render_calls, 0) + self.assertIsNotNone(fake_agent.last_observation) + self.assertIn("front", fake_agent.last_observation["images"]) + + def test_run_eval_returns_average_reward_summary(self): + reward_sequences = [ + [1.0, 2.0], + [0.5, 4.0], + ] + fake_env = _RewardTrackingEnv(reward_sequences) + fake_agent = _FakeAgent() + cfg = OmegaConf.create( + { + "agent": {}, + "eval": { + "ckpt_path": "checkpoints/vla_model_best.pt", + "num_episodes": 2, + "max_timesteps": 2, + "device": "cpu", + "task_name": "sim_transfer", + "camera_names": ["front"], + "use_smoothing": False, + "smooth_alpha": 0.3, + "verbose_action": False, + "headless": True, + }, + } + ) + + def fake_execute_policy_action(env, action): + del action + env.rew = env.reward_sequences[env.episode_index][env.step_index] + env.step_index += 1 + + with mock.patch.object( + eval_vla, + "load_checkpoint", + return_value=(fake_agent, None), + ), mock.patch.object( + eval_vla, + "make_sim_env", + return_value=fake_env, + ), mock.patch.object( + eval_vla, + "sample_transfer_pose", + return_value=np.array([0.1, 0.2, 0.3]), + ), mock.patch.object( + eval_vla, + "execute_policy_action", + side_effect=fake_execute_policy_action, + ), mock.patch.object( + eval_vla, + "tqdm", + side_effect=lambda iterable, **kwargs: iterable, + ): + summary = eval_vla._run_eval(cfg) + + self.assertEqual(summary["episode_rewards"], [3.0, 4.5]) + self.assertAlmostEqual(summary["avg_reward"], 3.75) + self.assertEqual(summary["num_episodes"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resnet_transformer_agent_wiring.py b/tests/test_resnet_transformer_agent_wiring.py new file mode 100644 index 0000000..cdd862e --- /dev/null +++ b/tests/test_resnet_transformer_agent_wiring.py @@ -0,0 +1,387 @@ +import contextlib +import sys +import types +import unittest +from pathlib import Path + +import torch +from hydra import compose, initialize_config_dir +from hydra.errors import InstantiationException +from hydra.core.global_hydra import GlobalHydra +from hydra.utils import instantiate +from omegaconf import OmegaConf + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve()) +_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front'] +_MISSING = object() + + +class _FakeScheduler: + def __init__(self, num_train_timesteps=100, **kwargs): + self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps) + self.timesteps = [] + + def add_noise(self, sample, noise, timestep): + return sample + noise + + def set_timesteps(self, num_inference_steps): + self.timesteps = list(range(num_inference_steps - 1, -1, -1)) + + def step(self, noise_pred, timestep, sample): + return types.SimpleNamespace(prev_sample=sample) + + +class _IdentityCrop: + def __init__(self, size): + self.size = size + + def __call__(self, x): + return x + + +class _FakeResNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2) + self.relu2 = torch.nn.ReLU() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(16, 16) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.relu2(self.conv2(x)) + x = self.avgpool(x) + x = torch.flatten(x, start_dim=1) + return self.fc(x) + + +class _FakeRearrange(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return x + + +class _CondCapturingHead(torch.nn.Module): + def __init__(self): + super().__init__() + self.last_cond = None + + def forward(self, sample, timestep, cond): + self.last_cond = cond.detach().clone() + return torch.zeros_like(sample) + + +@contextlib.contextmanager +def _stub_optional_modules(): + previous_modules = {} + + def inject(name, module): + if name not in previous_modules: + previous_modules[name] = sys.modules.get(name, _MISSING) + sys.modules[name] = module + + diffusers_module = types.ModuleType('diffusers') + schedulers_module = types.ModuleType('diffusers.schedulers') + ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm') + ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim') + ddpm_module.DDPMScheduler = _FakeScheduler + ddim_module.DDIMScheduler = _FakeScheduler + diffusers_module.DDPMScheduler = _FakeScheduler + diffusers_module.DDIMScheduler = _FakeScheduler + diffusers_module.schedulers = schedulers_module + schedulers_module.scheduling_ddpm = ddpm_module + schedulers_module.scheduling_ddim = ddim_module + + torchvision_module = types.ModuleType('torchvision') + models_module = types.ModuleType('torchvision.models') + transforms_module = types.ModuleType('torchvision.transforms') + models_module.resnet18 = lambda weights=None: _FakeResNet() + transforms_module.CenterCrop = _IdentityCrop + transforms_module.RandomCrop = _IdentityCrop + torchvision_module.models = models_module + torchvision_module.transforms = transforms_module + + einops_module = types.ModuleType('einops') + einops_module.rearrange = lambda x, *args, **kwargs: x + einops_layers_module = types.ModuleType('einops.layers') + einops_layers_torch_module = types.ModuleType('einops.layers.torch') + einops_layers_torch_module.Rearrange = _FakeRearrange + einops_module.layers = einops_layers_module + einops_layers_module.torch = einops_layers_torch_module + + try: + inject('diffusers', diffusers_module) + inject('diffusers.schedulers', schedulers_module) + inject('diffusers.schedulers.scheduling_ddpm', ddpm_module) + inject('diffusers.schedulers.scheduling_ddim', ddim_module) + inject('torchvision', torchvision_module) + inject('torchvision.models', models_module) + inject('torchvision.transforms', transforms_module) + inject('einops', einops_module) + inject('einops.layers', einops_layers_module) + inject('einops.layers.torch', einops_layers_torch_module) + yield + finally: + for name, previous in reversed(list(previous_modules.items())): + if previous is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous + + +def _compose_cfg(overrides=None): + if not OmegaConf.has_resolver('len'): + OmegaConf.register_new_resolver('len', lambda x: len(x)) + + GlobalHydra.instance().clear() + with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR): + return compose(config_name='config', overrides=list(overrides or [])) + + +def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None): + channels, height, width = image_shape + per_camera_fill = per_camera_fill or { + 'front': 30.0, + 'top': 20.0, + 'r_vis': 10.0, + } + return { + name: torch.full( + (batch_size, obs_horizon, channels, height, width), + fill_value=fill_value, + dtype=torch.float32, + ) + for name, fill_value in per_camera_fill.items() + } + + +def _patch_backbone_for_order_tracking(backbone): + feature_dim = backbone.output_dim + + def encode_mean(image_batch): + mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1) + return mean_feature.repeat(1, feature_dim) + + if backbone.use_separate_rgb_encoder_per_camera: + for encoder in backbone.rgb_encoder: + encoder.forward_single_image = encode_mean + else: + backbone.rgb_encoder.forward_single_image = encode_mean + + +def _extract_camera_markers(cond, feature_dim, num_cams): + camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim) + return camera_block[:, 0] + + +class ResNetTransformerAgentWiringTest(unittest.TestCase): + def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + 'agent.inference_steps=1', + 'agent.head.n_layer=1', + 'agent.head.n_cond_layers=0', + 'agent.head.n_emb=32', + 'agent.head.n_head=4', + ] + ) + + self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES) + self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES) + self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES) + self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES) + self.assertEqual(cfg.agent.head_type, 'transformer') + self.assertEqual(cfg.agent.num_cams, 3) + self.assertTrue(cfg.agent.head.obs_as_cond) + self.assertFalse(cfg.agent.head.causal_attn) + + with _stub_optional_modules(): + agent = instantiate(cfg.agent) + expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim + self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim) + self.assertEqual(agent.per_step_cond_dim, expected_cond_dim) + self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim) + + batch_size = 2 + image_shape = tuple(cfg.agent.vision_backbone.input_shape) + images = _make_images( + batch_size, + cfg.agent.obs_horizon, + image_shape, + per_camera_fill={ + 'front': 30.0, + 'top': 20.0, + 'r_vis': 10.0, + 'left_wrist': 99.0, + }, + ) + proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim) + _patch_backbone_for_order_tracking(agent.vision_encoder) + capturing_head = _CondCapturingHead() + agent.noise_pred_net = capturing_head + predicted_actions = agent.predict_action(images, proprioception) + self.assertEqual( + predicted_actions.shape, + (batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim), + ) + self.assertIsNotNone(capturing_head.last_cond) + self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim) + camera_markers = _extract_camera_markers( + capturing_head.last_cond, + agent.vision_encoder.output_dim, + agent.num_cams, + ) + self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0]))) + + missing_images = dict(images) + missing_images.pop('top') + with self.assertRaisesRegex(ValueError, 'missing=.*top'): + agent.predict_action(missing_images, proprioception) + + def test_agent_rejects_conflicting_explicit_backbone_camera_names(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + ] + ) + cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis'] + + with _stub_optional_modules(): + with self.assertRaisesRegex(InstantiationException, 'camera_names'): + instantiate(cfg.agent) + + def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + ] + ) + cfg.agent.vision_backbone.camera_names = None + + with _stub_optional_modules(): + backbone = instantiate(cfg.agent.vision_backbone) + _patch_backbone_for_order_tracking(backbone) + images = _make_images( + batch_size=1, + obs_horizon=cfg.agent.obs_horizon, + image_shape=tuple(cfg.agent.vision_backbone.input_shape), + per_camera_fill={ + 'top': 20.0, + 'front': 30.0, + 'r_vis': 10.0, + }, + ) + ordered_features = backbone(images) + camera_markers = _extract_camera_markers( + ordered_features, + backbone.output_dim, + len(images), + ) + self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0]))) + + def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + ] + ) + cfg.agent.camera_names = None + cfg.agent.vision_backbone.camera_names = None + + with _stub_optional_modules(): + agent = instantiate(cfg.agent) + observation = { + 'qpos': torch.randn(cfg.agent.obs_dim), + 'images': { + 'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0), + 'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0), + 'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0), + }, + } + agent._populate_queues(observation) + batch = agent._prepare_observation_batch() + self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top']) + + def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + ] + ) + cfg.agent.vision_backbone.camera_names = None + + with _stub_optional_modules(): + backbone = instantiate(cfg.agent.vision_backbone) + images = _make_images( + batch_size=1, + obs_horizon=cfg.agent.obs_horizon, + image_shape=tuple(cfg.agent.vision_backbone.input_shape), + per_camera_fill={ + 'front': 30.0, + 'r_vis': 10.0, + }, + ) + with self.assertRaisesRegex(ValueError, 'num_cameras'): + backbone(images) + + def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + 'agent.inference_steps=1', + 'agent.head.n_layer=1', + 'agent.head.n_cond_layers=0', + 'agent.head.n_emb=32', + 'agent.head.n_head=4', + ] + ) + cfg.agent.camera_names = None + cfg.agent.vision_backbone.camera_names = None + + with _stub_optional_modules(): + agent = instantiate(cfg.agent) + images = _make_images( + batch_size=1, + obs_horizon=cfg.agent.obs_horizon, + image_shape=tuple(cfg.agent.vision_backbone.input_shape), + per_camera_fill={ + 'front': 30.0, + 'r_vis': 10.0, + }, + ) + proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim) + with self.assertRaisesRegex(ValueError, 'num_cams'): + agent.predict_action(images, proprioception) + + def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self): + cfg = _compose_cfg( + overrides=[ + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + ] + ) + cfg.agent.camera_names = None + cfg.agent.vision_backbone.camera_names = None + cfg.agent.num_cams = 2 + cfg.agent.vision_backbone.num_cameras = 3 + + with _stub_optional_modules(): + with self.assertRaisesRegex(InstantiationException, 'num_cams'): + instantiate(cfg.agent) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_robot_asset_paths.py b/tests/test_robot_asset_paths.py new file mode 100644 index 0000000..8412192 --- /dev/null +++ b/tests/test_robot_asset_paths.py @@ -0,0 +1,63 @@ +import os +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from roboimi.assets.robots.diana_med import BiDianaMed + + +class _FakeKDL: + init_calls = [] + reset_calls = [] + + def __init__(self, urdf_path): + self.__class__.init_calls.append(urdf_path) + + def resetChain(self, base, end): + self.__class__.reset_calls.append((base, end)) + + +class RobotAssetPathResolutionTest(unittest.TestCase): + def setUp(self): + _FakeKDL.init_calls = [] + _FakeKDL.reset_calls = [] + + def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self): + repo_root = Path(__file__).resolve().parents[1] + expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml' + expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf' + xml_calls = [] + + def fake_from_xml_path(*, filename, assets=None): + xml_calls.append((filename, assets)) + return object() + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch( + 'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path', + side_effect=fake_from_xml_path, + ), mock.patch( + 'roboimi.assets.robots.arm_base.mujoco.MjData', + return_value=object(), + ), mock.patch( + 'roboimi.assets.robots.arm_base.KDL_utils', + _FakeKDL, + ): + BiDianaMed() + finally: + os.chdir(previous_cwd) + + self.assertEqual(len(xml_calls), 1) + self.assertEqual(Path(xml_calls[0][0]), expected_xml) + self.assertTrue(Path(xml_calls[0][0]).is_absolute()) + self.assertGreaterEqual(len(_FakeKDL.init_calls), 2) + self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf}) + self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_simple_robot_dataset_image_loading.py b/tests/test_simple_robot_dataset_image_loading.py new file mode 100644 index 0000000..04c2f3e --- /dev/null +++ b/tests/test_simple_robot_dataset_image_loading.py @@ -0,0 +1,58 @@ +import sys +import tempfile +import types +import unittest +from pathlib import Path +from unittest import mock + +import h5py +import numpy as np + +from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset + + +class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): + def _write_episode(self, dataset_dir: Path) -> None: + episode_path = dataset_dir / "episode_0.hdf5" + with h5py.File(episode_path, "w") as root: + root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2)) + root.create_dataset( + "observations/qpos", + data=np.arange(16, dtype=np.float32).reshape(4, 4), + ) + root.create_dataset("task", data=np.array([b"sim_transfer"])) + root.create_dataset( + "observations/images/front", + data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3), + ) + + def test_getitem_only_resizes_observation_horizon_images(self): + with tempfile.TemporaryDirectory() as tmpdir: + dataset_dir = Path(tmpdir) + self._write_episode(dataset_dir) + dataset = SimpleRobotDataset( + dataset_dir, + obs_horizon=2, + pred_horizon=3, + camera_names=["front"], + ) + + resize_calls = [] + + def fake_resize(image, size, interpolation=None): + resize_calls.append( + { + "shape": tuple(image.shape), + "size": size, + "interpolation": interpolation, + } + ) + return image + + fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize) + + with mock.patch.dict(sys.modules, {"cv2": fake_cv2}): + sample = dataset[1] + + self.assertEqual(len(resize_calls), 2) + self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8)) diff --git a/tests/test_train_vla_rollout_validation.py b/tests/test_train_vla_rollout_validation.py new file mode 100644 index 0000000..4fdc06b --- /dev/null +++ b/tests/test_train_vla_rollout_validation.py @@ -0,0 +1,779 @@ +import os +import tempfile +import unittest +from copy import deepcopy +from pathlib import Path +from unittest import mock + +import numpy as np +import torch +from omegaconf import OmegaConf +from torch import nn + +from roboimi.demos.vla_scripts import eval_vla, train_vla + + +class _FakeDataset: + def __len__(self): + return 4 + + +class _FakeLoader: + def __init__(self, batch, length=1): + self._batches = [batch] * length + + def __len__(self): + return len(self._batches) + + def __iter__(self): + return iter(self._batches) + + +class _FakeOptimizer: + def __init__(self, lr=1e-3): + self.param_groups = [{'lr': lr}] + + def zero_grad(self): + return None + + def step(self): + return None + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + del state_dict + return None + + +class _FakeScheduler: + def __init__(self): + self.step_calls = 0 + + def step(self): + self.step_calls += 1 + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + del state_dict + return None + + +class _FakeProgressBar: + def __init__(self, iterable): + self._items = list(iterable) + self.postfix_calls = [] + + def __iter__(self): + return iter(self._items) + + def set_postfix(self, values): + self.postfix_calls.append(values) + + +class _FakeAgent(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor(0.0)) + + def to(self, device): + del device + return self + + def compute_loss(self, agent_input): + del agent_input + return (self.weight - torch.tensor(0.5)).pow(2) + + def get_normalization_stats(self): + return {} + + +class _SequentialLossAgent(nn.Module): + def __init__(self, losses): + super().__init__() + self.weight = nn.Parameter(torch.tensor(0.0)) + self._losses = list(losses) + self._index = 0 + + def to(self, device): + del device + return self + + def compute_loss(self, agent_input): + del agent_input + loss_value = self._losses[self._index] + self._index += 1 + return (self.weight * 0) + torch.tensor(float(loss_value)) + + def get_normalization_stats(self): + return {} + + +class _FakeEvalAgent: + def __init__(self): + self.reset_calls = 0 + + def eval(self): + return self + + def to(self, device): + del device + return self + + def reset(self): + self.reset_calls += 1 + + def select_action(self, observation): + del observation + return torch.zeros(2) + + +class _FakeEvalEnv: + def reset(self, box_pos): + self.box_pos = box_pos + + def _get_image_obs(self): + return { + 'images': { + 'front': np.zeros((8, 8, 3), dtype=np.uint8), + } + } + + def _get_qpos_obs(self): + return {'qpos': np.zeros(4, dtype=np.float32)} + + def render(self): + raise AssertionError('render should not be called in this helper delegation test') + + +class TrainVLARolloutValidationTest(unittest.TestCase): + def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self): + cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml')) + + self.assertEqual(cfg.train.val_split, 0.0) + self.assertGreater(cfg.train.batch_size, 8) + self.assertGreater(float(cfg.train.lr), 5e-5) + self.assertGreater(cfg.train.num_workers, 8) + self.assertEqual(cfg.train.rollout_val_freq_epochs, 50) + + def test_eval_main_delegates_to_plain_run_eval_helper(self): + cfg = OmegaConf.create( + { + 'agent': {}, + 'eval': { + 'ckpt_path': 'checkpoints/vla_model_step_1.pt', + 'num_episodes': 1, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': True, + }, + } + ) + run_eval_mock = mock.Mock() + + with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \ + mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \ + mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \ + mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \ + mock.patch.object(eval_vla, 'execute_policy_action'), \ + mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): + eval_vla.main.__wrapped__(cfg) + + run_eval_mock.assert_called_once_with(cfg) + + def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 1, + 'num_workers': 0, + 'val_split': 0.0, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 100, + 'log_freq': 1, + 'save_freq': 1000, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 50, + 'rollout_num_episodes': 3, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + 'eval': { + 'ckpt_path': 'unused.pt', + 'num_episodes': 99, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': False, + }, + } + ) + agent = _FakeAgent() + rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}]) + swanlab_log_mock = mock.Mock() + saved_checkpoints = [] + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(_dataset, *, shuffle, **_kwargs): + del shuffle, _kwargs + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=1, + ) + + def fake_torch_save(payload, path): + saved_checkpoints.append((str(path), deepcopy(payload))) + return None + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \ + mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ + mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \ + mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual(rollout_mock.call_count, 2) + first_rollout_cfg = rollout_mock.call_args_list[0].args[0] + second_rollout_cfg = rollout_mock.call_args_list[1].args[0] + self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt') + self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt') + self.assertEqual(first_rollout_cfg.eval.num_episodes, 3) + self.assertTrue(first_rollout_cfg.eval.headless) + self.assertEqual(first_rollout_cfg.eval.device, 'cpu') + self.assertFalse(first_rollout_cfg.eval.verbose_action) + self.assertEqual(cfg.eval.ckpt_path, 'unused.pt') + self.assertEqual(cfg.eval.num_episodes, 99) + self.assertFalse(cfg.eval.headless) + self.assertEqual(cfg.eval.device, 'cpu') + self.assertFalse(cfg.eval.verbose_action) + + rollout_reward_logs = [ + call.args[1]['rollout/avg_reward'] + for call in swanlab_log_mock.call_args_list + if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1] + ] + self.assertEqual(rollout_reward_logs, [2.0, 1.0]) + + best_model_saves = [ + payload for path, payload in saved_checkpoints + if path.endswith('checkpoints/vla_model_best.pt') + ] + self.assertEqual(len(best_model_saves), 1) + self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0) + + def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 1, + 'num_workers': 0, + 'val_split': 0.0, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 5, + 'log_freq': 1, + 'save_freq': 2, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 50, + 'rollout_num_episodes': 3, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + 'eval': { + 'ckpt_path': 'unused.pt', + 'num_episodes': 99, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': False, + }, + } + ) + saved_checkpoints = [] + rollout_mock = mock.Mock() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _FakeDataset() + if config_node is cfg.agent: + return _FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(_dataset, *, shuffle, **_kwargs): + del shuffle, _kwargs + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=5, + ) + + def fake_torch_save(payload, path): + saved_checkpoints.append((str(path), deepcopy(payload))) + return None + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ + mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual(rollout_mock.call_count, 0) + best_model_saves = [ + payload for path, payload in saved_checkpoints + if path.endswith('checkpoints/vla_model_best.pt') + ] + self.assertEqual(len(best_model_saves), 1) + self.assertIsNone(best_model_saves[0]['rollout_avg_reward']) + + def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 8, + 'num_workers': 0, + 'val_split': 0.0, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 1, + 'log_freq': 1, + 'save_freq': 10, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 50, + 'rollout_num_episodes': 3, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + 'eval': { + 'ckpt_path': 'unused.pt', + 'num_episodes': 99, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': False, + }, + } + ) + dataloader_calls = [] + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _FakeDataset() + if config_node is cfg.agent: + return _FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs): + dataloader_calls.append({ + 'shuffle': shuffle, + 'drop_last': drop_last, + 'dataset_len': len(dataset), + }) + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=1, + ) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla.torch, 'save', return_value=None): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + train_loader_calls = [call for call in dataloader_calls if call['shuffle']] + self.assertEqual(len(train_loader_calls), 1) + self.assertFalse(train_loader_calls[0]['drop_last']) + + def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 2, + 'num_workers': 2, + 'val_split': 0.25, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 1, + 'log_freq': 1, + 'save_freq': 10, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 50, + 'rollout_num_episodes': 3, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + 'eval': { + 'ckpt_path': 'unused.pt', + 'num_episodes': 99, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': False, + }, + } + ) + dataloader_calls = [] + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _FakeDataset() + if config_node is cfg.agent: + return _FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs): + dataloader_calls.append({ + 'shuffle': shuffle, + 'num_workers': num_workers, + 'persistent_workers': persistent_workers, + }) + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=1, + ) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla.torch, 'save', return_value=None): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual(len(dataloader_calls), 2) + self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False]) + self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls)) + self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls)) + + def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 1, + 'num_workers': 0, + 'val_split': 0.0, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 6, + 'log_freq': 1, + 'save_freq': 1, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 2, + 'rollout_num_episodes': 1, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + 'eval': { + 'ckpt_path': 'unused.pt', + 'num_episodes': 99, + 'max_timesteps': 1, + 'device': 'cpu', + 'task_name': 'sim_transfer', + 'camera_names': ['front'], + 'use_smoothing': False, + 'smooth_alpha': 0.3, + 'verbose_action': False, + 'headless': False, + }, + } + ) + agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5]) + rollout_mock = mock.Mock(return_value={'avg_reward': 1.0}) + saved_checkpoints = [] + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(_dataset, *, shuffle, **_kwargs): + del _kwargs + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=2 if shuffle else 1, + ) + + def fake_torch_save(payload, path): + saved_checkpoints.append((str(path), deepcopy(payload))) + return None + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \ + mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + best_model_saves = [ + (payload['step'], payload['rollout_avg_reward']) + for path, payload in saved_checkpoints + if path.endswith('checkpoints/vla_model_best.pt') + ] + self.assertEqual( + best_model_saves, + [ + (1, None), + (2, None), + (3, None), + (3, 1.0), + ], + ) + self.assertEqual(rollout_mock.call_count, 1) + + def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self): + cfg = OmegaConf.create( + { + 'train': { + 'device': 'cpu', + 'batch_size': 8, + 'num_workers': 0, + 'val_split': 0.0, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 1, + 'log_freq': 1, + 'save_freq': 1000, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': False, + 'rollout_val_freq_epochs': 0, + }, + 'data': { + 'camera_names': ['front'], + }, + 'agent': { + '_target_': 'fake.agent', + }, + } + ) + agent = _FakeAgent() + dataloader_calls = [] + saved_checkpoints = [] + + class _TinyDataset: + def __len__(self): + return 1 + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return _TinyDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs): + del _kwargs + dataloader_calls.append( + { + 'shuffle': shuffle, + 'drop_last': drop_last, + 'dataset_len': len(dataset), + } + ) + loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1 + return _FakeLoader( + { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + }, + length=loader_length, + ) + + def fake_torch_save(payload, path): + saved_checkpoints.append((str(path), deepcopy(payload))) + return None + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \ + mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \ + mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \ + mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \ + mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save): + train_vla._run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual( + dataloader_calls[0], + { + 'shuffle': True, + 'drop_last': False, + 'dataset_len': 1, + }, + ) + self.assertEqual( + [path for path, _payload in saved_checkpoints], + ['checkpoints/vla_model_final.pt'], + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_train_vla_swanlab_logging.py b/tests/test_train_vla_swanlab_logging.py new file mode 100644 index 0000000..2e6e1da --- /dev/null +++ b/tests/test_train_vla_swanlab_logging.py @@ -0,0 +1,699 @@ +import importlib +import importlib.util +import os +import sys +import tempfile +import types +import unittest +from pathlib import Path +from unittest import mock + +import torch +from torch import nn + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py' +_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml' + + +class AttrDict(dict): + def __getattr__(self, name): + try: + return self[name] + except KeyError as exc: + raise AttributeError(name) from exc + + def __setattr__(self, name, value): + self[name] = value + + +def _to_attrdict(value): + if isinstance(value, dict): + return AttrDict({key: _to_attrdict(item) for key, item in value.items()}) + if isinstance(value, list): + return [_to_attrdict(item) for item in value] + return value + + +class FakeDataset: + def __len__(self): + return 4 + + +class FakeLoader: + def __init__(self, batch): + self.batch = batch + + def __len__(self): + return 1 + + def __iter__(self): + return iter((self.batch,)) + + +class FakeScheduler: + def __init__(self): + self.step_calls = 0 + + def step(self): + self.step_calls += 1 + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + return None + + +class FakeOptimizer: + def __init__(self, lr=1e-3): + self.param_groups = [{'lr': lr}] + self.loaded_state_dict = None + + def zero_grad(self): + return None + + def step(self): + return None + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + self.loaded_state_dict = state_dict + return None + + +class FakeProgressBar: + def __init__(self, iterable): + self._items = list(iterable) + self.postfix_calls = [] + + def __iter__(self): + return iter(self._items) + + def set_postfix(self, values): + self.postfix_calls.append(values) + + +class FakeAgent(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor(0.0)) + + def to(self, device): + return self + + def compute_loss(self, agent_input): + del agent_input + target = torch.tensor(0.25 if self.training else 0.1) + return (self.weight - target).pow(2) + + def get_normalization_stats(self): + return {} + + +class FakeSwanLab: + def __init__(self, init_error=None, log_errors=None, finish_error=None): + self.init_error = init_error + self.log_errors = list(log_errors or []) + self.finish_error = finish_error + self.init_calls = [] + self.log_calls = [] + self.finish_calls = 0 + + def init(self, project, experiment_name=None, config=None): + self.init_calls.append({ + 'project': project, + 'experiment_name': experiment_name, + 'config': config, + }) + if self.init_error is not None: + raise self.init_error + return object() + + def log(self, payload, step=None): + self.log_calls.append((dict(payload), step)) + if self.log_errors: + raise self.log_errors.pop(0) + + def finish(self): + self.finish_calls += 1 + if self.finish_error is not None: + raise self.finish_error + + +class TrainVLASwanLabLoggingTest(unittest.TestCase): + def test_default_config_keeps_swanlab_opt_in(self): + config_text = _CONFIG_PATH.read_text(encoding='utf-8') + self.assertIn('use_swanlab: false', config_text) + + def _load_train_vla_module(self): + hydra_module = types.ModuleType('hydra') + hydra_utils_module = types.ModuleType('hydra.utils') + hydra_utils_module.instantiate = lambda *args, **kwargs: None + + def hydra_main(**_kwargs): + def decorator(func): + return func + return decorator + + hydra_module.main = hydra_main + hydra_module.utils = hydra_utils_module + + class OmegaConfStub: + _resolvers = {} + + @classmethod + def has_resolver(cls, name): + return name in cls._resolvers + + @classmethod + def register_new_resolver(cls, name, resolver): + cls._resolvers[name] = resolver + + @staticmethod + def to_yaml(_cfg): + return 'stub-config' + + @staticmethod + def to_container(cfg, resolve=False): + del resolve + return dict(cfg) + + @staticmethod + def create(cfg): + return _to_attrdict(cfg) + + omegaconf_module = types.ModuleType('omegaconf') + omegaconf_module.DictConfig = dict + omegaconf_module.OmegaConf = OmegaConfStub + + module_name = 'train_vla_swanlab_test_module' + spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH) + module = importlib.util.module_from_spec(spec) + with mock.patch.dict( + sys.modules, + { + 'hydra': hydra_module, + 'hydra.utils': hydra_utils_module, + 'omegaconf': omegaconf_module, + }, + ): + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'): + return AttrDict( + train=AttrDict( + device='cpu', + batch_size=2, + num_workers=0, + val_split=0.25, + seed=0, + lr=1e-3, + max_steps=2, + log_freq=1, + save_freq=1, + warmup_steps=1, + scheduler_type='constant', + min_lr=0.0, + grad_clip=1.0, + weight_decay=0.0, + pretrained_ckpt=None, + resume_ckpt=None, + use_swanlab=use_swanlab, + swanlab_project='roboimi-vla-tests', + swanlab_run_name=swanlab_run_name, + ), + data=AttrDict( + camera_names=('front',), + ), + agent=AttrDict( + _target_='fake.agent', + ), + eval=AttrDict( + ckpt_path='unused.pt', + num_episodes=1, + max_timesteps=1, + device='cpu', + task_name='sim_transfer', + camera_names=('front',), + use_smoothing=False, + smooth_alpha=0.3, + verbose_action=False, + headless=False, + ), + ) + + def _get_run_training(self, module): + run_training = getattr(module, '_run_training', None) + self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper') + return run_training + + def _make_batch(self): + return { + 'observation.front': torch.zeros(1, 3, 2, 2), + 'observation.state': torch.zeros(1, 4), + 'action': torch.zeros(1, 2), + 'action_is_pad': torch.zeros(1, 1, dtype=torch.bool), + } + + def _loader_factory(self): + train_batch = self._make_batch() + val_batch = self._make_batch() + + def factory(_dataset, *, shuffle, **_kwargs): + return FakeLoader(train_batch if shuffle else val_batch) + + return factory + + def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + agent = FakeAgent() + fake_swanlab = FakeSwanLab() + real_import_module = importlib.import_module + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual( + fake_swanlab.init_calls, + [{ + 'project': 'roboimi-vla-tests', + 'experiment_name': 'smoke-run', + 'config': { + 'train': { + 'device': 'cpu', + 'batch_size': 2, + 'num_workers': 0, + 'val_split': 0.25, + 'seed': 0, + 'lr': 1e-3, + 'max_steps': 2, + 'log_freq': 1, + 'save_freq': 1, + 'warmup_steps': 1, + 'scheduler_type': 'constant', + 'min_lr': 0.0, + 'grad_clip': 1.0, + 'weight_decay': 0.0, + 'pretrained_ckpt': None, + 'resume_ckpt': None, + 'use_swanlab': True, + 'swanlab_project': 'roboimi-vla-tests', + 'swanlab_run_name': 'smoke-run', + }, + 'data': { + 'camera_names': ('front',), + }, + 'agent': { + '_target_': 'fake.agent', + }, + }, + }], + ) + + logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls)) + self.assertTrue( + { + 'train/loss', + 'train/lr', + 'train/best_loss', + 'train/step', + 'val/loss', + 'final/checkpoint_path', + 'final/best_checkpoint_path', + }.issubset(logged_keys) + ) + + final_payload, final_step = fake_swanlab.log_calls[-1] + self.assertEqual(final_step, cfg.train.max_steps) + self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt') + self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt') + self.assertEqual(fake_swanlab.finish_calls, 1) + + def test_run_training_skips_swanlab_when_disabled(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg(use_swanlab=False) + agent = FakeAgent() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + def test_run_training_finishes_swanlab_when_exception_happens_after_init(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + fake_swanlab = FakeSwanLab() + real_import_module = importlib.import_module + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + with self.assertRaisesRegex(RuntimeError, 'dataset boom'): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual(fake_swanlab.finish_calls, 1) + + def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + agent = FakeAgent() + fake_swanlab = FakeSwanLab( + log_errors=[RuntimeError('log backend hiccup')], + finish_error=RuntimeError('finish backend hiccup'), + ) + real_import_module = importlib.import_module + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \ + mock.patch.object(module.log, 'warning') as warning_mock: + run_training(cfg) + finally: + os.chdir(previous_cwd) + + warning_messages = [call.args[0] for call in warning_mock.call_args_list] + self.assertTrue(any('SwanLab log failed' in message for message in warning_messages)) + self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages)) + self.assertEqual(fake_swanlab.finish_calls, 1) + + def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + cfg.train.max_steps = 2 + cfg.train.save_freq = 1 + cfg.train.rollout_validate_on_checkpoint = True + fake_swanlab = FakeSwanLab() + fake_optimizer = FakeOptimizer(lr=cfg.train.lr) + fake_scheduler = FakeScheduler() + real_import_module = importlib.import_module + saved_paths = [] + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + checkpoint_dir = Path('checkpoints') + checkpoint_dir.mkdir() + resume_path = checkpoint_dir / 'vla_model_step_0.pt' + resume_path.write_bytes(b'resume') + best_path = checkpoint_dir / 'vla_model_best.pt' + best_path.write_bytes(b'best') + cfg.train.resume_ckpt = str(resume_path) + + resume_checkpoint_state = { + 'step': 0, + 'model_state_dict': FakeAgent().state_dict(), + 'optimizer_state_dict': {}, + 'scheduler_state_dict': {}, + 'loss': 0.5, + 'val_loss': 0.25, + } + best_checkpoint_state = { + 'step': 0, + 'model_state_dict': FakeAgent().state_dict(), + 'optimizer_state_dict': {}, + 'scheduler_state_dict': {}, + 'loss': 0.5, + 'val_loss': 0.25, + 'rollout_avg_reward': 5.0, + } + + def fake_torch_load(path, map_location=None): + del map_location + path = Path(path) + if path == resume_path: + return resume_checkpoint_state + if path == best_path: + return best_checkpoint_state + raise AssertionError(f'unexpected load path: {path}') + + def fake_torch_save(payload, path): + saved_paths.append(str(path)) + return None + + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \ + mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \ + mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + final_payload, final_step = fake_swanlab.log_calls[-1] + self.assertEqual(final_step, cfg.train.max_steps) + self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt') + self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths) + + def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + cfg.train.max_steps = 1 + fake_swanlab = FakeSwanLab() + fake_optimizer = FakeOptimizer(lr=cfg.train.lr) + fake_scheduler = FakeScheduler() + real_import_module = importlib.import_module + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + checkpoint_dir = Path('checkpoints') + checkpoint_dir.mkdir() + resume_path = checkpoint_dir / 'vla_model_step_0.pt' + resume_path.write_bytes(b'resume') + best_path = checkpoint_dir / 'vla_model_best.pt' + best_path.write_bytes(b'stale') + cfg.train.resume_ckpt = str(resume_path) + + resume_checkpoint_state = { + 'step': 0, + 'model_state_dict': FakeAgent().state_dict(), + 'optimizer_state_dict': {}, + 'scheduler_state_dict': {}, + 'loss': 0.5, + 'val_loss': 0.25, + } + stale_best_checkpoint_state = { + 'step': 0, + 'model_state_dict': FakeAgent().state_dict(), + 'optimizer_state_dict': {}, + 'scheduler_state_dict': {}, + 'loss': 0.4, + 'val_loss': 0.2, + } + + def fake_torch_load(path, map_location=None): + del map_location + path = Path(path) + if path == resume_path: + return resume_checkpoint_state + if path == best_path: + return stale_best_checkpoint_state + raise AssertionError(f'unexpected load path: {path}') + + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + final_payload, final_step = fake_swanlab.log_calls[-1] + self.assertEqual(final_step, cfg.train.max_steps) + self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt') + + def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + cfg.train.max_steps = 1 + fake_swanlab = FakeSwanLab() + real_import_module = importlib.import_module + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return FakeAgent() + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + checkpoint_dir = Path('checkpoints') + checkpoint_dir.mkdir() + (checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best') + + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + run_training(cfg) + finally: + os.chdir(previous_cwd) + + final_payload, final_step = fake_swanlab.log_calls[-1] + self.assertEqual(final_step, cfg.train.max_steps) + self.assertEqual(final_payload['final/best_checkpoint_path'], '') + + def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + real_import_module = importlib.import_module + + def fake_import_module(name, package=None): + if name == 'swanlab': + raise ImportError('missing swanlab') + return real_import_module(name, package) + + with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + with self.assertRaisesRegex(RuntimeError, 'SwanLab'): + run_training(cfg) + + def test_run_training_fails_fast_when_swanlab_init_fails(self): + module = self._load_train_vla_module() + run_training = self._get_run_training(module) + cfg = self._make_cfg() + fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in')) + real_import_module = importlib.import_module + + def fake_import_module(name, package=None): + if name == 'swanlab': + return fake_swanlab + return real_import_module(name, package) + + with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \ + mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module): + with self.assertRaisesRegex(RuntimeError, 'not logged in'): + run_training(cfg) + + self.assertEqual(fake_swanlab.finish_calls, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_train_vla_transformer_optimizer.py b/tests/test_train_vla_transformer_optimizer.py new file mode 100644 index 0000000..204014d --- /dev/null +++ b/tests/test_train_vla_transformer_optimizer.py @@ -0,0 +1,310 @@ +import importlib.util +import os +import sys +import tempfile +import types +import unittest +from pathlib import Path +from unittest import mock + +import torch +from torch import nn + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py' + + +class AttrDict(dict): + def __getattr__(self, name): + try: + return self[name] + except KeyError as exc: + raise AttributeError(name) from exc + + def __setattr__(self, name, value): + self[name] = value + + +class FakeDataset: + def __len__(self): + return 4 + + +class FakeLoader: + def __len__(self): + return 1 + + def __iter__(self): + return iter(()) + + +class FakeScheduler: + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + return None + + +class RecordingAdamW: + created = [] + + def __init__(self, params, lr, weight_decay): + self.lr = lr + self.weight_decay = weight_decay + self.param_groups = self._normalize_param_groups(params, lr, weight_decay) + RecordingAdamW.created.append(self) + + @staticmethod + def _normalize_param_groups(params, lr, weight_decay): + if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict): + groups = [] + for group in params: + normalized = dict(group) + normalized['params'] = list(group['params']) + normalized.setdefault('lr', lr) + groups.append(normalized) + return groups + + return [{ + 'params': list(params), + 'lr': lr, + 'weight_decay': weight_decay, + }] + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + return None + + +class RecordingTransformerHead(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(4, 4) + self.norm = nn.LayerNorm(4) + self.optim_group_calls = [] + + def get_optim_groups(self, weight_decay): + self.optim_group_calls.append(weight_decay) + return [ + { + 'params': [self.proj.weight], + 'weight_decay': weight_decay, + }, + { + 'params': [self.proj.bias, self.norm.weight, self.norm.bias], + 'weight_decay': 0.0, + }, + ] + + +class FakeTransformerAgent(nn.Module): + def __init__(self): + super().__init__() + self.head_type = 'transformer' + self.noise_pred_net = RecordingTransformerHead() + self.backbone = nn.Linear(4, 3) + self.adapter = nn.Linear(3, 2, bias=False) + self.frozen = nn.Linear(2, 2) + for param in self.frozen.parameters(): + param.requires_grad = False + + def to(self, device): + return self + + def get_normalization_stats(self): + return {} + + +class TrainVLATransformerOptimizerTest(unittest.TestCase): + def setUp(self): + RecordingAdamW.created = [] + + def _load_train_vla_module(self): + hydra_module = types.ModuleType('hydra') + hydra_utils_module = types.ModuleType('hydra.utils') + hydra_utils_module.instantiate = lambda *args, **kwargs: None + + def hydra_main(**_kwargs): + def decorator(func): + return func + return decorator + + hydra_module.main = hydra_main + hydra_module.utils = hydra_utils_module + + class OmegaConfStub: + _resolvers = {} + + @classmethod + def has_resolver(cls, name): + return name in cls._resolvers + + @classmethod + def register_new_resolver(cls, name, resolver): + cls._resolvers[name] = resolver + + @staticmethod + def to_yaml(_cfg): + return 'stub-config' + + omegaconf_module = types.ModuleType('omegaconf') + omegaconf_module.DictConfig = dict + omegaconf_module.OmegaConf = OmegaConfStub + + module_name = 'train_vla_optimizer_test_module' + spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH) + module = importlib.util.module_from_spec(spec) + with mock.patch.dict( + sys.modules, + { + 'hydra': hydra_module, + 'hydra.utils': hydra_utils_module, + 'omegaconf': omegaconf_module, + }, + ): + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + def _make_cfg(self): + return AttrDict( + train=AttrDict( + device='cpu', + batch_size=2, + num_workers=0, + val_split=0, + seed=0, + lr=1e-4, + max_steps=0, + log_freq=1, + save_freq=100, + warmup_steps=1, + scheduler_type='constant', + min_lr=0.0, + grad_clip=1.0, + weight_decay=0.123, + pretrained_ckpt=None, + resume_ckpt=None, + ), + data=AttrDict( + camera_names=('front',), + ), + agent=AttrDict( + _target_='fake.agent', + ), + ) + + def _group_names(self, agent, optimizer): + names_by_param_id = {id(param): name for name, param in agent.named_parameters()} + return [ + {names_by_param_id[id(param)] for param in group['params']} + for group in optimizer.param_groups + ] + + def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self): + module = self._load_train_vla_module() + agent = FakeTransformerAgent() + cfg = self._make_cfg() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'AdamW', RecordingAdamW), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): + module.main(cfg) + finally: + os.chdir(previous_cwd) + + self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay]) + + optimizer = RecordingAdamW.created[-1] + trainable_names = { + name for name, param in agent.named_parameters() if param.requires_grad + } + grouped_names = self._group_names(agent, optimizer) + optimizer_names = set().union(*grouped_names) + expected_head_names = { + 'noise_pred_net.proj.weight', + 'noise_pred_net.proj.bias', + 'noise_pred_net.norm.weight', + 'noise_pred_net.norm.bias', + } + expected_non_head_names = { + 'backbone.weight', + 'backbone.bias', + 'adapter.weight', + } + + self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'}) + self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'}) + self.assertEqual(grouped_names[2], expected_non_head_names) + self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay) + self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0) + self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay) + self.assertEqual(optimizer_names, trainable_names) + + flattened_param_ids = [ + id(param) + for group in optimizer.param_groups + for param in group['params'] + ] + self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids))) + self.assertNotIn('frozen.weight', optimizer_names) + self.assertNotIn('frozen.bias', optimizer_names) + + def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self): + module = self._load_train_vla_module() + agent = FakeTransformerAgent() + agent.noise_pred_net.norm.bias.requires_grad = False + cfg = self._make_cfg() + + def fake_instantiate(config_node, **_kwargs): + if config_node is cfg.data: + return FakeDataset() + if config_node is cfg.agent: + return agent + raise AssertionError(f'unexpected instantiate config: {config_node!r}') + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \ + mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \ + mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \ + mock.patch.object(module, 'AdamW', RecordingAdamW), \ + mock.patch.object(module.torch, 'save', return_value=None), \ + mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable): + module.main(cfg) + finally: + os.chdir(previous_cwd) + + optimizer = RecordingAdamW.created[-1] + optimizer_names = set().union(*self._group_names(agent, optimizer)) + trainable_names = { + name for name, param in agent.named_parameters() if param.requires_grad + } + + self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay]) + self.assertEqual(optimizer_names, trainable_names) + self.assertNotIn('noise_pred_net.norm.bias', optimizer_names) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_transformer1d_external_alignment.py b/tests/test_transformer1d_external_alignment.py new file mode 100644 index 0000000..f3b199c --- /dev/null +++ b/tests/test_transformer1d_external_alignment.py @@ -0,0 +1,262 @@ +import contextlib +import importlib.util +import inspect +import sys +import types +import unittest +import warnings +from pathlib import Path + +import torch + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py' +_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy' +_TRANSFORMER_WARNING_MESSAGE = ( + r'enable_nested_tensor is True, but self.use_nested_tensor is False ' + r'because encoder_layer\.norm_first was True' +) +_MISSING = object() + + +def _load_module_from_path(name: str, path: Path, *, register: bool = False): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + if register: + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +def _resolve_external_module_paths(external_checkout_root: Path): + diffusion_policy_root = external_checkout_root / 'diffusion_policy' + paths = { + 'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py', + 'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py', + 'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py', + } + if not all(path.exists() for path in paths.values()): + return None + return paths + + +@contextlib.contextmanager +def _temporary_registered_modules(): + previous_modules = {} + + def remember(name: str) -> None: + if name not in previous_modules: + previous_modules[name] = sys.modules.get(name, _MISSING) + + def ensure_package(name: str) -> None: + if not name or name in sys.modules: + return + remember(name) + package = types.ModuleType(name) + package.__path__ = [] + sys.modules[name] = package + + def load(name: str, path: Path): + package_parts = name.split('.')[:-1] + for idx in range(1, len(package_parts) + 1): + ensure_package('.'.join(package_parts[:idx])) + + remember(name) + return _load_module_from_path(name, path, register=True) + + try: + yield load + finally: + for name, previous in reversed(list(previous_modules.items())): + if previous is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous + + +@contextlib.contextmanager +def _suppress_nested_tensor_warning(): + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + message=_TRANSFORMER_WARNING_MESSAGE, + category=UserWarning, + module=r'torch\.nn\.modules\.transformer', + ) + yield + + +def _load_local_module(): + return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH) + + +class Transformer1DExternalAlignmentTest(unittest.TestCase): + def _load_transformer_classes_or_skip(self): + external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT) + if external_paths is None: + self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}') + + local_module = _load_local_module() + with _temporary_registered_modules() as load_external: + load_external( + 'diffusion_policy.model.diffusion.positional_embedding', + external_paths['positional_embedding'], + ) + load_external( + 'diffusion_policy.model.common.module_attr_mixin', + external_paths['module_attr_mixin'], + ) + external_module = load_external( + 'diffusion_policy.model.diffusion.transformer_for_diffusion', + external_paths['transformer_for_diffusion'], + ) + + return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion + + def _optim_group_names(self, model, groups): + names_by_param = {id(param): name for name, param in model.named_parameters()} + return [ + {names_by_param[id(param)] for param in group['params']} + for group in groups + ] + + def test_missing_external_checkout_resolution_returns_none(self): + self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__')) + + def test_external_loader_restores_injected_sys_modules(self): + external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT) + if external_paths is None: + self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}') + + watched_names = [ + 'diffusion_policy', + 'diffusion_policy.model', + 'diffusion_policy.model.common', + 'diffusion_policy.model.common.module_attr_mixin', + 'diffusion_policy.model.diffusion', + 'diffusion_policy.model.diffusion.positional_embedding', + 'diffusion_policy.model.diffusion.transformer_for_diffusion', + ] + before = {name: sys.modules.get(name, _MISSING) for name in watched_names} + + with _temporary_registered_modules() as load_external: + load_external( + 'diffusion_policy.model.diffusion.positional_embedding', + external_paths['positional_embedding'], + ) + load_external( + 'diffusion_policy.model.common.module_attr_mixin', + external_paths['module_attr_mixin'], + ) + load_external( + 'diffusion_policy.model.diffusion.transformer_for_diffusion', + external_paths['transformer_for_diffusion'], + ) + + after = {name: sys.modules.get(name, _MISSING) for name in watched_names} + self.assertEqual(after, before) + + def test_transformer1d_preserves_local_direct_call_defaults(self): + local_module = _load_local_module() + ctor = inspect.signature(local_module.Transformer1D.__init__).parameters + helper = inspect.signature(local_module.create_transformer1d).parameters + + self.assertEqual(ctor['n_layer'].default, 8) + self.assertEqual(ctor['n_head'].default, 8) + self.assertEqual(ctor['n_emb'].default, 256) + self.assertEqual(helper['n_layer'].default, 8) + self.assertEqual(helper['n_head'].default, 8) + self.assertEqual(helper['n_emb'].default, 256) + + def test_time_as_cond_false_token_accounting_matches_external(self): + Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip() + self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters) + + config = dict( + input_dim=4, + output_dim=4, + horizon=6, + n_obs_steps=3, + cond_dim=0, + n_layer=2, + n_head=2, + n_emb=8, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=False, + time_as_cond=False, + obs_as_cond=False, + n_cond_layers=0, + ) + + torch.manual_seed(5) + with _suppress_nested_tensor_warning(): + external_model = TransformerForDiffusion(**config) + local_model = Transformer1D(**config) + external_model.eval() + local_model.eval() + + self.assertEqual(local_model.T, external_model.T) + self.assertEqual(local_model.T_cond, external_model.T_cond) + self.assertEqual(local_model.time_as_cond, external_model.time_as_cond) + self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond) + self.assertEqual(local_model.encoder_only, external_model.encoder_only) + + def test_nocausal_state_dict_forward_and_optim_groups_match_external(self): + Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip() + config = dict( + input_dim=4, + output_dim=4, + horizon=6, + n_obs_steps=3, + cond_dim=5, + n_layer=2, + n_head=2, + n_emb=8, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=False, + obs_as_cond=True, + n_cond_layers=1, + ) + + torch.manual_seed(7) + with _suppress_nested_tensor_warning(): + external_model = TransformerForDiffusion(**config) + local_model = Transformer1D(**config) + external_model.eval() + local_model.eval() + + external_state_dict = external_model.state_dict() + self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys())) + local_model.load_state_dict(external_state_dict, strict=True) + + batch_size = 2 + sample = torch.randn(batch_size, config['horizon'], config['input_dim']) + cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim']) + timestep = torch.tensor([11, 17], dtype=torch.long) + + with torch.no_grad(): + external_out = external_model(sample=sample, timestep=timestep, cond=cond) + local_out = local_model(sample=sample, timestep=timestep, cond=cond) + + self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim'])) + self.assertEqual(local_out.shape, external_out.shape) + self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5)) + + weight_decay = 0.123 + external_groups = external_model.get_optim_groups(weight_decay=weight_decay) + local_groups = local_model.get_optim_groups(weight_decay=weight_decay) + + self.assertEqual(len(local_groups), len(external_groups)) + self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0]) + self.assertEqual( + self._optim_group_names(local_model, local_groups), + self._optim_group_names(external_model, external_groups), + ) + + +if __name__ == '__main__': + unittest.main()