feat(vla): align transformer training stack and rollout validation
This commit is contained in:
@@ -229,6 +229,11 @@ dependencies:
|
|||||||
- python-xxhash=3.6.0
|
- python-xxhash=3.6.0
|
||||||
- python_abi=3.10
|
- python_abi=3.10
|
||||||
- pytorch=2.4.0
|
- pytorch=2.4.0
|
||||||
|
- hydra-core=1.3.2
|
||||||
|
- omegaconf=2.3.0
|
||||||
|
- einops=0.8.2
|
||||||
|
- diffusers=0.36.0
|
||||||
|
- torchvision=0.19.0
|
||||||
- pytz=2024.1
|
- pytz=2024.1
|
||||||
- pyyaml=6.0.3
|
- pyyaml=6.0.3
|
||||||
- qhull=2020.2
|
- qhull=2020.2
|
||||||
@@ -321,12 +326,10 @@ dependencies:
|
|||||||
- datasets==4.5.0
|
- datasets==4.5.0
|
||||||
- decorator==5.2.1
|
- decorator==5.2.1
|
||||||
- deepdiff==8.6.1
|
- deepdiff==8.6.1
|
||||||
- diffusers==0.30.0
|
|
||||||
- dill==0.4.0
|
- dill==0.4.0
|
||||||
- docstring_parser==0.17.0
|
- docstring_parser==0.17.0
|
||||||
- draccus==0.10.0
|
- draccus==0.10.0
|
||||||
- eigenpy==3.10.3
|
- eigenpy==3.10.3
|
||||||
- einops==0.8.1
|
|
||||||
- etils==1.7.0
|
- etils==1.7.0
|
||||||
- evdev==1.9.2
|
- evdev==1.9.2
|
||||||
- exceptiongroup==1.3.1
|
- exceptiongroup==1.3.1
|
||||||
@@ -350,7 +353,6 @@ dependencies:
|
|||||||
- httpcore==1.0.9
|
- httpcore==1.0.9
|
||||||
- httpx==0.28.1
|
- httpx==0.28.1
|
||||||
- huggingface_hub==1.3.2
|
- huggingface_hub==1.3.2
|
||||||
- hydra-core==1.3.2
|
|
||||||
- imageio==2.35.1
|
- imageio==2.35.1
|
||||||
- imageio-ffmpeg==0.6.0
|
- imageio-ffmpeg==0.6.0
|
||||||
- importlib_metadata==8.7.1
|
- importlib_metadata==8.7.1
|
||||||
@@ -380,22 +382,6 @@ dependencies:
|
|||||||
- networkx==3.4.2
|
- networkx==3.4.2
|
||||||
- numcodecs==0.13.1
|
- numcodecs==0.13.1
|
||||||
- numpy==2.2.6
|
- numpy==2.2.6
|
||||||
- nvidia-cublas-cu12==12.4.5.8
|
|
||||||
- nvidia-cuda-cupti-cu12==12.4.127
|
|
||||||
- nvidia-cuda-nvrtc-cu12==12.4.127
|
|
||||||
- nvidia-cuda-runtime-cu12==12.4.127
|
|
||||||
- nvidia-cudnn-cu12==9.1.0.70
|
|
||||||
- nvidia-cufft-cu12==11.2.1.3
|
|
||||||
- nvidia-cufile-cu12==1.11.1.6
|
|
||||||
- nvidia-curand-cu12==10.3.5.147
|
|
||||||
- nvidia-cusolver-cu12==11.6.1.9
|
|
||||||
- nvidia-cusparse-cu12==12.3.1.170
|
|
||||||
- nvidia-cusparselt-cu12==0.6.3
|
|
||||||
- nvidia-nccl-cu12==2.21.5
|
|
||||||
- nvidia-nvjitlink-cu12==12.4.127
|
|
||||||
- nvidia-nvshmem-cu12==3.3.20
|
|
||||||
- nvidia-nvtx-cu12==12.4.127
|
|
||||||
- omegaconf==2.3.0
|
|
||||||
- opencv-contrib-python==4.10.0.84
|
- opencv-contrib-python==4.10.0.84
|
||||||
- opencv-python==4.13.0.90
|
- opencv-python==4.13.0.90
|
||||||
- orderly-set==5.5.0
|
- orderly-set==5.5.0
|
||||||
@@ -431,7 +417,7 @@ dependencies:
|
|||||||
- regex==2026.1.15
|
- regex==2026.1.15
|
||||||
- requests==2.32.5
|
- requests==2.32.5
|
||||||
- rerun-sdk==0.26.2
|
- rerun-sdk==0.26.2
|
||||||
- rich==14.2.0
|
- rich==13.9.4
|
||||||
- ruckig==0.9.2
|
- ruckig==0.9.2
|
||||||
- safehttpx==0.1.7
|
- safehttpx==0.1.7
|
||||||
- safetensors==0.7.0
|
- safetensors==0.7.0
|
||||||
@@ -443,18 +429,16 @@ dependencies:
|
|||||||
- stack-data==0.6.3
|
- stack-data==0.6.3
|
||||||
- starlette==0.50.0
|
- starlette==0.50.0
|
||||||
- sympy==1.13.1
|
- sympy==1.13.1
|
||||||
|
- swanlab==0.7.13
|
||||||
- termcolor==3.3.0
|
- termcolor==3.3.0
|
||||||
- timm==1.0.24
|
- timm==1.0.24
|
||||||
- toml==0.10.2
|
- toml==0.10.2
|
||||||
- tomli==2.4.0
|
- tomli==2.4.0
|
||||||
- tomlkit==0.13.3
|
- tomlkit==0.13.3
|
||||||
- torch==2.5.0
|
|
||||||
- torchcodec==0.5
|
- torchcodec==0.5
|
||||||
- torchmetrics==1.8.2
|
- torchmetrics==1.8.2
|
||||||
- torchvision==0.20.0
|
|
||||||
- tqdm==4.67.1
|
- tqdm==4.67.1
|
||||||
- traitlets==5.14.3
|
- traitlets==5.14.3
|
||||||
- triton==3.1.0
|
|
||||||
- typer==0.21.1
|
- typer==0.21.1
|
||||||
- typer-slim==0.21.1
|
- typer-slim==0.21.1
|
||||||
- typeshed_client==2.8.2
|
- typeshed_client==2.8.2
|
||||||
|
|||||||
@@ -1,8 +1,46 @@
|
|||||||
import mujoco
|
import mujoco
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
from roboimi.utils.KDL_utils import KDL_utils
|
from roboimi.utils.KDL_utils import KDL_utils
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_robot_asset_path(asset_path):
|
||||||
|
if asset_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_path = Path(asset_path).expanduser()
|
||||||
|
if raw_path.is_absolute():
|
||||||
|
return str(raw_path.resolve())
|
||||||
|
|
||||||
|
current_dir = Path(__file__).resolve().parent
|
||||||
|
package_root = current_dir.parents[1]
|
||||||
|
repo_root = current_dir.parents[2]
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
if raw_path.parts and raw_path.parts[0] == 'roboimi':
|
||||||
|
candidates.append(repo_root / raw_path)
|
||||||
|
|
||||||
|
candidates.extend([
|
||||||
|
current_dir / raw_path,
|
||||||
|
package_root / raw_path,
|
||||||
|
repo_root / raw_path,
|
||||||
|
])
|
||||||
|
|
||||||
|
normalized_candidates = []
|
||||||
|
seen = set()
|
||||||
|
for candidate in candidates:
|
||||||
|
resolved = candidate.resolve()
|
||||||
|
if resolved not in seen:
|
||||||
|
normalized_candidates.append(resolved)
|
||||||
|
seen.add(resolved)
|
||||||
|
|
||||||
|
for candidate in normalized_candidates:
|
||||||
|
if candidate.exists():
|
||||||
|
return str(candidate)
|
||||||
|
|
||||||
|
return str(normalized_candidates[0])
|
||||||
|
|
||||||
|
|
||||||
class ArmBase(object):
|
class ArmBase(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
name=None,
|
name=None,
|
||||||
@@ -11,8 +49,8 @@ class ArmBase(object):
|
|||||||
gripper=None
|
gripper=None
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.urdf_path = urdf_path
|
self.urdf_path = resolve_robot_asset_path(urdf_path)
|
||||||
self.xml_path = xml_path
|
self.xml_path = resolve_robot_asset_path(xml_path)
|
||||||
self.gripper = gripper
|
self.gripper = gripper
|
||||||
self.robot_model = mujoco.MjModel.from_xml_path(filename=self.xml_path, assets=None)
|
self.robot_model = mujoco.MjModel.from_xml_path(filename=self.xml_path, assets=None)
|
||||||
self.robot_data = mujoco.MjData(self.robot_model)
|
self.robot_data = mujoco.MjData(self.robot_model)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -213,7 +213,9 @@ class DualDianaMed(MujocoEnv):
|
|||||||
|
|
||||||
def camera_viewer(self):
|
def camera_viewer(self):
|
||||||
img_renderer = mj.Renderer(self.mj_model,height=480,width=640)
|
img_renderer = mj.Renderer(self.mj_model,height=480,width=640)
|
||||||
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
show_gui = self.is_render
|
||||||
|
if show_gui:
|
||||||
|
cv2.namedWindow('Cam view',cv2.WINDOW_NORMAL)
|
||||||
while not self.exit_flag:
|
while not self.exit_flag:
|
||||||
img_renderer.update_scene(self.mj_data,camera="rs_cam_right")
|
img_renderer.update_scene(self.mj_data,camera="rs_cam_right")
|
||||||
self.r_vis = img_renderer.render()
|
self.r_vis = img_renderer.render()
|
||||||
@@ -230,9 +232,10 @@ class DualDianaMed(MujocoEnv):
|
|||||||
img_renderer.update_scene(self.mj_data,camera="front")
|
img_renderer.update_scene(self.mj_data,camera="front")
|
||||||
self.front = img_renderer.render()
|
self.front = img_renderer.render()
|
||||||
self.front = self.front[:, :, ::-1]
|
self.front = self.front[:, :, ::-1]
|
||||||
if self.cam_view is not None:
|
if show_gui:
|
||||||
cv2.imshow('Cam view', self.cam_view)
|
if self.cam_view is not None:
|
||||||
cv2.waitKey(1)
|
cv2.imshow('Cam view', self.cam_view)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
|
||||||
def cam_start(self):
|
def cam_start(self):
|
||||||
|
|||||||
@@ -133,12 +133,12 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
|||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
def make_sim_env(task_name):
|
def make_sim_env(task_name, headless=False):
|
||||||
if 'sim_transfer' in task_name:
|
if 'sim_transfer' in task_name:
|
||||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||||
env = DualDianaMed_Pos_Ctrl(
|
env = DualDianaMed_Pos_Ctrl(
|
||||||
robot=BiDianaMed(),
|
robot=BiDianaMed(),
|
||||||
is_render=True,
|
is_render=not headless,
|
||||||
control_freq=30,
|
control_freq=30,
|
||||||
is_interpolate=True,
|
is_interpolate=True,
|
||||||
cam_view='angle'
|
cam_view='angle'
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ import torch.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict, Optional, Any, Tuple
|
from typing import Dict, Optional, Any, Tuple
|
||||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
|
||||||
from roboimi.vla.models.normalization import NormalizationModule
|
from roboimi.vla.models.normalization import NormalizationModule
|
||||||
|
|
||||||
class VLAAgent(nn.Module):
|
class VLAAgent(nn.Module):
|
||||||
@@ -24,6 +22,7 @@ class VLAAgent(nn.Module):
|
|||||||
diffusion_steps=100, # DDPM 加噪步数
|
diffusion_steps=100, # DDPM 加噪步数
|
||||||
inference_steps=10, # DDIM 推理步数
|
inference_steps=10, # DDIM 推理步数
|
||||||
num_cams=3, # 视觉输入的摄像头数量
|
num_cams=3, # 视觉输入的摄像头数量
|
||||||
|
camera_names: Optional[Tuple[str, ...]] = None, # 条件相机顺序
|
||||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||||
@@ -39,6 +38,31 @@ class VLAAgent(nn.Module):
|
|||||||
self.num_action_steps = num_action_steps
|
self.num_action_steps = num_action_steps
|
||||||
self.inference_steps = inference_steps
|
self.inference_steps = inference_steps
|
||||||
self.head_type = head_type # 'unet' 或 'transformer'
|
self.head_type = head_type # 'unet' 或 'transformer'
|
||||||
|
agent_camera_names = tuple(camera_names) if camera_names is not None else None
|
||||||
|
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
|
||||||
|
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
|
||||||
|
backbone_num_cameras = getattr(vision_backbone, 'num_cameras', None)
|
||||||
|
if backbone_num_cameras is not None and backbone_num_cameras != self.num_cams:
|
||||||
|
raise ValueError(
|
||||||
|
f"agent.num_cams({self.num_cams}) 与 "
|
||||||
|
f"vision_backbone.num_cameras({backbone_num_cameras}) 不一致"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
agent_camera_names is not None
|
||||||
|
and backbone_camera_names is not None
|
||||||
|
and agent_camera_names != backbone_camera_names
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"agent.camera_names({list(agent_camera_names)}) 与 "
|
||||||
|
f"vision_backbone.camera_names({list(backbone_camera_names)}) 不一致"
|
||||||
|
)
|
||||||
|
self.camera_names = (
|
||||||
|
agent_camera_names if agent_camera_names is not None else backbone_camera_names
|
||||||
|
)
|
||||||
|
if self.camera_names is not None and len(self.camera_names) != self.num_cams:
|
||||||
|
raise ValueError(
|
||||||
|
f"camera_names 长度({len(self.camera_names)})与 num_cams({self.num_cams})不一致"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||||
@@ -48,6 +72,8 @@ class VLAAgent(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
|
if self.camera_names is not None:
|
||||||
|
self.vision_encoder.camera_names = self.camera_names
|
||||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
# global_cond_dim: 展平后的总维度(用于UNet)
|
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||||
@@ -117,6 +143,34 @@ class VLAAgent(nn.Module):
|
|||||||
return tuple(self._move_to_device(v, device) for v in data)
|
return tuple(self._move_to_device(v, device) for v in data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""按显式配置的相机顺序返回图像字典。"""
|
||||||
|
if self.camera_names is None:
|
||||||
|
camera_names = tuple(sorted(images.keys()))
|
||||||
|
if len(camera_names) != self.num_cams:
|
||||||
|
raise ValueError(
|
||||||
|
f"图像条件相机数量({len(camera_names)})与 num_cams({self.num_cams})不一致"
|
||||||
|
)
|
||||||
|
return {cam_name: images[cam_name] for cam_name in camera_names}
|
||||||
|
|
||||||
|
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"图像条件缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||||
|
)
|
||||||
|
return {cam_name: images[cam_name] for cam_name in self.camera_names}
|
||||||
|
|
||||||
|
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""构造每步条件,确保图像条件顺序稳定。"""
|
||||||
|
ordered_images = self._order_images(images)
|
||||||
|
visual_features = self.vision_encoder(ordered_images)
|
||||||
|
state_features = self.state_encoder(states)
|
||||||
|
cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
if cond.shape[-1] != self.per_step_cond_dim:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
|
||||||
|
)
|
||||||
|
return cond
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# 训练阶段 (Training)
|
# 训练阶段 (Training)
|
||||||
@@ -136,10 +190,8 @@ class VLAAgent(nn.Module):
|
|||||||
states = self.normalization.normalize_qpos(states)
|
states = self.normalization.normalize_qpos(states)
|
||||||
actions = self.normalization.normalize_action(actions)
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
|
||||||
state_features = self.state_encoder(states)
|
|
||||||
|
|
||||||
# 1. 提取视觉特征
|
# 1. 提取视觉特征
|
||||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
per_step_cond = self._build_cond(images, states)
|
||||||
action_features = self.action_encoder(actions)
|
action_features = self.action_encoder(actions)
|
||||||
|
|
||||||
# 2. 采样噪声
|
# 2. 采样噪声
|
||||||
@@ -157,21 +209,16 @@ class VLAAgent(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 拼接全局条件并展平
|
# 拼接全局条件并展平
|
||||||
# visual_features: (B, obs_horizon, vision_dim)
|
# per_step_cond: (B, obs_horizon, vision_dim * num_cams + obs_dim)
|
||||||
# state_features: (B, obs_horizon, obs_dim)
|
# 展平后用于 UNet,全序列形式用于 Transformer
|
||||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
global_cond = per_step_cond.flatten(start_dim=1)
|
||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
|
||||||
global_cond = global_cond.flatten(start_dim=1)
|
|
||||||
|
|
||||||
# 5. 网络预测噪声(根据head类型选择接口)
|
# 5. 网络预测噪声(根据head类型选择接口)
|
||||||
if self.head_type == 'transformer':
|
if self.head_type == 'transformer':
|
||||||
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
|
||||||
# 将展平的global_cond reshape回序列格式
|
|
||||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
|
||||||
pred_noise = self.noise_pred_net(
|
pred_noise = self.noise_pred_net(
|
||||||
sample=noisy_actions,
|
sample=noisy_actions,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
cond=cond
|
cond=per_step_cond
|
||||||
)
|
)
|
||||||
else: # 'unet'
|
else: # 'unet'
|
||||||
pred_noise = self.noise_pred_net(
|
pred_noise = self.noise_pred_net(
|
||||||
@@ -218,7 +265,8 @@ class VLAAgent(nn.Module):
|
|||||||
|
|
||||||
# 添加图像
|
# 添加图像
|
||||||
if 'images' in observation:
|
if 'images' in observation:
|
||||||
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
ordered_images = self._order_images(observation['images'])
|
||||||
|
self._queues['images'].append({k: v.clone() for k, v in ordered_images.items()})
|
||||||
|
|
||||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -246,7 +294,8 @@ class VLAAgent(nn.Module):
|
|||||||
images_list.append(images_list[-1])
|
images_list.append(images_list[-1])
|
||||||
|
|
||||||
batch_images = {}
|
batch_images = {}
|
||||||
for cam_name in images_list[0].keys():
|
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
|
||||||
|
for cam_name in camera_names:
|
||||||
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||||||
|
|
||||||
return {'qpos': batch_qpos, 'images': batch_images}
|
return {'qpos': batch_qpos, 'images': batch_images}
|
||||||
@@ -346,22 +395,18 @@ class VLAAgent(nn.Module):
|
|||||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
|
||||||
# 1. 提取当前观测特征(只提取一次)
|
# 1. 提取当前观测特征(只提取一次)
|
||||||
visual_features = self.vision_encoder(images)
|
per_step_cond = self._build_cond(images, proprioception)
|
||||||
state_features = self.state_encoder(proprioception)
|
|
||||||
|
|
||||||
# 拼接条件(只计算一次)
|
# 拼接条件(只计算一次)
|
||||||
# visual_features: (B, obs_horizon, vision_dim)
|
global_cond_flat = per_step_cond.flatten(start_dim=1)
|
||||||
# state_features: (B, obs_horizon, obs_dim)
|
|
||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
|
||||||
global_cond_flat = global_cond.flatten(start_dim=1)
|
|
||||||
if self.head_type == 'transformer':
|
if self.head_type == 'transformer':
|
||||||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
cond = per_step_cond
|
||||||
else:
|
else:
|
||||||
cond = None
|
cond = None
|
||||||
|
|
||||||
# 2. 初始化纯高斯噪声动作
|
# 2. 初始化纯高斯噪声动作
|
||||||
# 形状: (B, pred_horizon, action_dim)
|
# 形状: (B, pred_horizon, action_dim)
|
||||||
device = visual_features.device
|
device = per_step_cond.device
|
||||||
current_actions = torch.randn(
|
current_actions = torch.randn(
|
||||||
(B, self.pred_horizon, self.action_dim), device=device
|
(B, self.pred_horizon, self.action_dim), device=device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,8 +29,13 @@ num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= p
|
|||||||
# ====================
|
# ====================
|
||||||
# 相机配置
|
# 相机配置
|
||||||
# ====================
|
# ====================
|
||||||
|
camera_names: ${data.camera_names} # 条件相机顺序固定为 r_vis, top, front
|
||||||
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 扩散过程配置
|
# 扩散过程配置
|
||||||
# ====================
|
# ====================
|
||||||
@@ -52,3 +57,6 @@ head:
|
|||||||
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||||
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||||
cond_dim: 208
|
cond_dim: 208
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
|||||||
@@ -9,19 +9,25 @@ defaults:
|
|||||||
# ====================
|
# ====================
|
||||||
train:
|
train:
|
||||||
# 基础训练参数
|
# 基础训练参数
|
||||||
batch_size: 8 # 批次大小
|
batch_size: 16 # 批次大小
|
||||||
lr: 5e-5 # 学习率(Transformer建议更小)
|
lr: 1e-4 # 学习率
|
||||||
max_steps: 100000 # 最大训练步数
|
max_steps: 100000 # 最大训练步数
|
||||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
|
|
||||||
# 数据加载
|
# 数据加载
|
||||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||||
val_split: 0.1 # 验证集比例
|
val_split: 0.0 # 验证集比例;默认使用全量数据训练
|
||||||
seed: 42 # 随机种子(用于数据划分)
|
seed: 42 # 随机种子(用于数据划分)
|
||||||
|
|
||||||
# 日志和检查点
|
# 日志和检查点
|
||||||
log_freq: 100 # 日志记录频率(步数)
|
log_freq: 100 # 日志记录频率(步数)
|
||||||
save_freq: 2000 # 保存检查点频率(步数)
|
save_freq: 2000 # 保存检查点频率(步数)
|
||||||
|
use_swanlab: false # 是否启用 SwanLab 标量日志
|
||||||
|
swanlab_project: "roboimi-vla" # SwanLab project 名称
|
||||||
|
swanlab_run_name: null # 可选的 SwanLab 运行名
|
||||||
|
rollout_val_freq_epochs: 50 # 每隔多少个 epoch 执行一次 rollout 验证
|
||||||
|
rollout_validate_on_checkpoint: false # 是否在保存 checkpoint 后立即运行 rollout 验证
|
||||||
|
rollout_num_episodes: 3 # rollout 验证的回合数
|
||||||
|
|
||||||
# 学习率调度器(带预热)
|
# 学习率调度器(带预热)
|
||||||
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ _partial_: true
|
|||||||
# ====================
|
# ====================
|
||||||
# Transformer 架构配置
|
# Transformer 架构配置
|
||||||
# ====================
|
# ====================
|
||||||
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
n_layer: 4 # Transformer层数(保持当前小模型配置)
|
||||||
n_head: 4 # 注意力头数
|
n_head: 4 # 注意力头数
|
||||||
n_emb: 128 # 嵌入维度
|
n_emb: 128 # 嵌入维度
|
||||||
p_drop_emb: 0.05 # Embedding dropout
|
p_drop_emb: 0.05 # Embedding dropout
|
||||||
@@ -14,9 +14,10 @@ p_drop_attn: 0.05 # Attention dropout
|
|||||||
# ====================
|
# ====================
|
||||||
# 条件配置
|
# 条件配置
|
||||||
# ====================
|
# ====================
|
||||||
causal_attn: false # 是否使用因果注意力(自回归生成)
|
causal_attn: false # 对齐 external TransformerForDiffusion 的 full-attention / nocausal 变体
|
||||||
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
time_as_cond: true # 与 external 实现一致:时间步作为条件 token
|
||||||
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
obs_as_cond: true # API 对齐;实际是否启用由 cond_dim > 0 决定
|
||||||
|
n_cond_layers: 1 # 条件编码器层数(保留当前配置)
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 注意事项
|
# 注意事项
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
self._file_cache[key] = f
|
self._file_cache[key] = f
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def _load_frame(self, idx: int) -> Dict:
|
def _load_frame(self, idx: int, *, load_images: bool = True) -> Dict:
|
||||||
"""从 HDF5 文件懒加载单帧数据"""
|
"""从 HDF5 文件懒加载单帧数据"""
|
||||||
meta = self.frame_meta[idx]
|
meta = self.frame_meta[idx]
|
||||||
f = self._get_h5_file(meta["hdf5_path"])
|
f = self._get_h5_file(meta["hdf5_path"])
|
||||||
@@ -118,21 +118,22 @@ class SimpleRobotDataset(Dataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||||
for cam_name in self.camera_names:
|
if load_images:
|
||||||
h5_path = f'observations/images/{cam_name}'
|
for cam_name in self.camera_names:
|
||||||
if h5_path in f:
|
h5_path = f'observations/images/{cam_name}'
|
||||||
img = f[h5_path][meta["frame_idx"]]
|
if h5_path in f:
|
||||||
# Resize图像到224x224(减少内存和I/O负担)
|
img = f[h5_path][meta["frame_idx"]]
|
||||||
import cv2
|
# Resize图像到224x224(减少内存和I/O负担)
|
||||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
import cv2
|
||||||
# 转换为float并归一化到 [0, 1]
|
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||||
img = torch.from_numpy(img).float() / 255.0
|
# 转换为float并归一化到 [0, 1]
|
||||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
img = torch.from_numpy(img).float() / 255.0
|
||||||
|
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
frame = self._load_frame(idx)
|
frame = self._load_frame(idx, load_images=False)
|
||||||
ep_idx = frame["episode_index"]
|
ep_idx = frame["episode_index"]
|
||||||
|
|
||||||
# 获取当前 episode 的帧索引范围
|
# 获取当前 episode 的帧索引范围
|
||||||
@@ -186,10 +187,10 @@ class SimpleRobotDataset(Dataset):
|
|||||||
target_idx = idx + delta
|
target_idx = idx + delta
|
||||||
|
|
||||||
if target_idx <= ep_end:
|
if target_idx <= ep_end:
|
||||||
actions.append(self._load_frame(target_idx)["action"])
|
actions.append(self._load_frame(target_idx, load_images=False)["action"])
|
||||||
action_is_pad.append(False)
|
action_is_pad.append(False)
|
||||||
else:
|
else:
|
||||||
actions.append(self._load_frame(ep_end)["action"])
|
actions.append(self._load_frame(ep_end, load_images=False)["action"])
|
||||||
action_is_pad.append(True)
|
action_is_pad.append(True)
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|||||||
3
roboimi/vla/eval_utils.py
Normal file
3
roboimi/vla/eval_utils.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
def execute_policy_action(env, action):
|
||||||
|
"""Execute policy outputs using EE-action semantics."""
|
||||||
|
env.step(action)
|
||||||
@@ -178,12 +178,18 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
spatial_softmax_num_keypoints: int = 32,
|
spatial_softmax_num_keypoints: int = 32,
|
||||||
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
use_separate_rgb_encoder_per_camera: bool = False, # 新增:是否为每个摄像头使用独立编码器
|
||||||
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
num_cameras: int = 1, # 新增:摄像头数量(仅在独立编码器模式下使用)
|
||||||
|
camera_names: Optional[Tuple[str, ...]] = None, # 显式相机顺序
|
||||||
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
freeze_backbone: bool = True, # 新增:是否冻结ResNet backbone(推荐True)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
self.use_separate_rgb_encoder_per_camera = use_separate_rgb_encoder_per_camera
|
||||||
self.num_cameras = num_cameras
|
self.num_cameras = num_cameras
|
||||||
|
self.camera_names = tuple(camera_names) if camera_names is not None else None
|
||||||
|
if self.camera_names is not None and len(self.camera_names) != self.num_cameras:
|
||||||
|
raise ValueError(
|
||||||
|
f"camera_names 长度({len(self.camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||||
|
)
|
||||||
|
|
||||||
if use_separate_rgb_encoder_per_camera:
|
if use_separate_rgb_encoder_per_camera:
|
||||||
# 独立编码器模式:为每个摄像头创建独立的编码器
|
# 独立编码器模式:为每个摄像头创建独立的编码器
|
||||||
@@ -217,6 +223,22 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
)
|
)
|
||||||
self.feature_dim = self.rgb_encoder.feature_dim
|
self.feature_dim = self.rgb_encoder.feature_dim
|
||||||
|
|
||||||
|
def _ordered_camera_names(self, images) -> Tuple[str, ...]:
|
||||||
|
if self.camera_names is None:
|
||||||
|
camera_names = tuple(sorted(images.keys()))
|
||||||
|
if len(camera_names) != self.num_cameras:
|
||||||
|
raise ValueError(
|
||||||
|
f"图像输入相机数量({len(camera_names)})与 num_cameras({self.num_cameras})不一致"
|
||||||
|
)
|
||||||
|
return camera_names
|
||||||
|
|
||||||
|
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"图像输入缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
|
||||||
|
)
|
||||||
|
return self.camera_names
|
||||||
|
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -228,7 +250,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
"""
|
"""
|
||||||
any_tensor = next(iter(images.values()))
|
any_tensor = next(iter(images.values()))
|
||||||
B, T = any_tensor.shape[:2]
|
B, T = any_tensor.shape[:2]
|
||||||
cam_names = sorted(images.keys())
|
cam_names = self._ordered_camera_names(images)
|
||||||
|
|
||||||
if self.use_separate_rgb_encoder_per_camera:
|
if self.use_separate_rgb_encoder_per_camera:
|
||||||
# 独立编码器模式:每个摄像头使用对应的编码器
|
# 独立编码器模式:每个摄像头使用对应的编码器
|
||||||
@@ -236,7 +258,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
for cam_idx, cam_name in enumerate(cam_names):
|
for cam_idx, cam_name in enumerate(cam_names):
|
||||||
img = images[cam_name]
|
img = images[cam_name]
|
||||||
encoder = self.rgb_encoder[cam_idx]
|
encoder = self.rgb_encoder[cam_idx]
|
||||||
features = encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
features = encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||||
features_all.append(features)
|
features_all.append(features)
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
else:
|
else:
|
||||||
@@ -244,7 +266,7 @@ class ResNetDiffusionBackbone(VLABackbone):
|
|||||||
features_all = []
|
features_all = []
|
||||||
for cam_name in cam_names:
|
for cam_name in cam_names:
|
||||||
img = images[cam_name]
|
img = images[cam_name]
|
||||||
features = self.rgb_encoder.forward_single_image(img.view(B * T, *img.shape[2:]))
|
features = self.rgb_encoder.forward_single_image(img.reshape(B * T, *img.shape[2:]))
|
||||||
features_all.append(features)
|
features_all.append(features)
|
||||||
return torch.cat(features_all, dim=1).view(B, T, -1)
|
return torch.cat(features_all, dim=1).view(B, T, -1)
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,35 @@
|
|||||||
"""
|
"""Transformer-based diffusion head aligned with diffusion_policy's TransformerForDiffusion."""
|
||||||
Transformer-based Diffusion Policy Head
|
|
||||||
|
|
||||||
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
from __future__ import annotations
|
||||||
支持通过Cross-Attention注入全局条件(观测特征)。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleAttrMixin(nn.Module):
|
||||||
|
"""Minimal local copy of diffusion_policy's ModuleAttrMixin for state-dict parity."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._dummy_variable = nn.Parameter()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(iter(self.parameters())).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(iter(self.parameters())).dtype
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
"""正弦位置编码(用于时间步嵌入)"""
|
def __init__(self, dim: int) -> None:
|
||||||
def __init__(self, dim: int):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
@@ -27,35 +43,13 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
class Transformer1D(nn.Module):
|
class Transformer1D(ModuleAttrMixin):
|
||||||
"""
|
|
||||||
Transformer-based 1D Diffusion Model
|
|
||||||
|
|
||||||
使用Encoder-Decoder架构:
|
|
||||||
- Encoder: 处理条件(观测 + 时间步)
|
|
||||||
- Decoder: 通过Cross-Attention预测噪声
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dim: 输入动作维度
|
|
||||||
output_dim: 输出动作维度
|
|
||||||
horizon: 预测horizon长度
|
|
||||||
n_obs_steps: 观测步数
|
|
||||||
cond_dim: 条件维度
|
|
||||||
n_layer: Transformer层数
|
|
||||||
n_head: 注意力头数
|
|
||||||
n_emb: 嵌入维度
|
|
||||||
p_drop_emb: Embedding dropout
|
|
||||||
p_drop_attn: Attention dropout
|
|
||||||
causal_attn: 是否使用因果注意力(自回归)
|
|
||||||
n_cond_layers: Encoder层数(0表示使用MLP)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim: int,
|
input_dim: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
horizon: int,
|
horizon: int,
|
||||||
n_obs_steps: int = None,
|
n_obs_steps: Optional[int] = None,
|
||||||
cond_dim: int = 0,
|
cond_dim: int = 0,
|
||||||
n_layer: int = 8,
|
n_layer: int = 8,
|
||||||
n_head: int = 8,
|
n_head: int = 8,
|
||||||
@@ -63,57 +57,42 @@ class Transformer1D(nn.Module):
|
|||||||
p_drop_emb: float = 0.1,
|
p_drop_emb: float = 0.1,
|
||||||
p_drop_attn: float = 0.1,
|
p_drop_attn: float = 0.1,
|
||||||
causal_attn: bool = False,
|
causal_attn: bool = False,
|
||||||
|
time_as_cond: bool = True,
|
||||||
obs_as_cond: bool = False,
|
obs_as_cond: bool = False,
|
||||||
n_cond_layers: int = 0
|
n_cond_layers: int = 0,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 计算序列长度
|
|
||||||
if n_obs_steps is None:
|
if n_obs_steps is None:
|
||||||
n_obs_steps = horizon
|
n_obs_steps = horizon
|
||||||
|
|
||||||
T = horizon
|
T = horizon
|
||||||
T_cond = 1 # 时间步token数量
|
T_cond = 1
|
||||||
|
if not time_as_cond:
|
||||||
# 确定是否使用观测作为条件
|
T += 1
|
||||||
|
T_cond -= 1
|
||||||
obs_as_cond = cond_dim > 0
|
obs_as_cond = cond_dim > 0
|
||||||
if obs_as_cond:
|
if obs_as_cond:
|
||||||
|
assert time_as_cond
|
||||||
T_cond += n_obs_steps
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
self.T = T
|
|
||||||
self.T_cond = T_cond
|
|
||||||
self.horizon = horizon
|
|
||||||
self.obs_as_cond = obs_as_cond
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
|
|
||||||
# ==================== 输入嵌入 ====================
|
|
||||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
self.drop = nn.Dropout(p_drop_emb)
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
# ==================== 条件编码 ====================
|
|
||||||
# 时间步嵌入
|
|
||||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
|
||||||
# 观测条件嵌入(可选)
|
|
||||||
self.cond_obs_emb = None
|
self.cond_obs_emb = None
|
||||||
if obs_as_cond:
|
if obs_as_cond:
|
||||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||||
|
|
||||||
# 条件位置编码
|
|
||||||
self.cond_pos_emb = None
|
self.cond_pos_emb = None
|
||||||
|
self.encoder = None
|
||||||
|
self.decoder = None
|
||||||
|
encoder_only = False
|
||||||
|
|
||||||
if T_cond > 0:
|
if T_cond > 0:
|
||||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
|
||||||
# ==================== Encoder ====================
|
|
||||||
self.encoder = None
|
|
||||||
self.encoder_only = False
|
|
||||||
|
|
||||||
if T_cond > 0:
|
|
||||||
if n_cond_layers > 0:
|
if n_cond_layers > 0:
|
||||||
# 使用Transformer Encoder
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
d_model=n_emb,
|
d_model=n_emb,
|
||||||
nhead=n_head,
|
nhead=n_head,
|
||||||
@@ -121,61 +100,19 @@ class Transformer1D(nn.Module):
|
|||||||
dropout=p_drop_attn,
|
dropout=p_drop_attn,
|
||||||
activation='gelu',
|
activation='gelu',
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
norm_first=True # Pre-LN更稳定
|
norm_first=True,
|
||||||
)
|
)
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = nn.TransformerEncoder(
|
||||||
encoder_layer=encoder_layer,
|
encoder_layer=encoder_layer,
|
||||||
num_layers=n_cond_layers
|
num_layers=n_cond_layers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 使用简单的MLP
|
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Linear(n_emb, 4 * n_emb),
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
nn.Mish(),
|
nn.Mish(),
|
||||||
nn.Linear(4 * n_emb, n_emb)
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Encoder-only模式(BERT风格)
|
|
||||||
self.encoder_only = True
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True
|
|
||||||
)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
encoder_layer=encoder_layer,
|
|
||||||
num_layers=n_layer
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==================== Attention Mask ====================
|
|
||||||
self.mask = None
|
|
||||||
self.memory_mask = None
|
|
||||||
|
|
||||||
if causal_attn:
|
|
||||||
# 因果mask:确保只关注左侧
|
|
||||||
sz = T
|
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
||||||
self.register_buffer("mask", mask)
|
|
||||||
|
|
||||||
if obs_as_cond:
|
|
||||||
# 交叉注意力mask
|
|
||||||
S = T_cond
|
|
||||||
t, s = torch.meshgrid(
|
|
||||||
torch.arange(T),
|
|
||||||
torch.arange(S),
|
|
||||||
indexing='ij'
|
|
||||||
)
|
|
||||||
mask = t >= (s - 1)
|
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
||||||
self.register_buffer('memory_mask', mask)
|
|
||||||
|
|
||||||
# ==================== Decoder ====================
|
|
||||||
if not self.encoder_only:
|
|
||||||
decoder_layer = nn.TransformerDecoderLayer(
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
d_model=n_emb,
|
d_model=n_emb,
|
||||||
nhead=n_head,
|
nhead=n_head,
|
||||||
@@ -183,136 +120,199 @@ class Transformer1D(nn.Module):
|
|||||||
dropout=p_drop_attn,
|
dropout=p_drop_attn,
|
||||||
activation='gelu',
|
activation='gelu',
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
norm_first=True
|
norm_first=True,
|
||||||
)
|
)
|
||||||
self.decoder = nn.TransformerDecoder(
|
self.decoder = nn.TransformerDecoder(
|
||||||
decoder_layer=decoder_layer,
|
decoder_layer=decoder_layer,
|
||||||
num_layers=n_layer
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== 输出头 ====================
|
if causal_attn:
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('mask', mask)
|
||||||
|
|
||||||
|
if time_as_cond and obs_as_cond:
|
||||||
|
S = T_cond
|
||||||
|
t, s = torch.meshgrid(torch.arange(T), torch.arange(S), indexing='ij')
|
||||||
|
mask = t >= (s - 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(n_emb)
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
self.head = nn.Linear(n_emb, output_dim)
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
# ==================== 初始化 ====================
|
self.T = T
|
||||||
self.apply(self._init_weights)
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.time_as_cond = time_as_cond
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.encoder_only = encoder_only
|
||||||
|
|
||||||
# 打印参数量
|
self.apply(self._init_weights)
|
||||||
total_params = sum(p.numel() for p in self.parameters())
|
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||||
print(f"Transformer1D parameters: {total_params:,}")
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""初始化权重"""
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
)
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
torch.nn.init.zeros_(module.bias)
|
torch.nn.init.zeros_(module.bias)
|
||||||
elif isinstance(module, nn.MultiheadAttention):
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
# MultiheadAttention的权重初始化
|
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||||
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
weight = getattr(module, name)
|
||||||
weight = getattr(module, name, None)
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||||
bias = getattr(module, name, None)
|
bias = getattr(module, name)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
torch.nn.init.zeros_(bias)
|
torch.nn.init.zeros_(bias)
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
torch.nn.init.zeros_(module.bias)
|
torch.nn.init.zeros_(module.bias)
|
||||||
torch.nn.init.ones_(module.weight)
|
torch.nn.init.ones_(module.weight)
|
||||||
elif isinstance(module, Transformer1D):
|
elif isinstance(module, Transformer1D):
|
||||||
# 位置编码初始化
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
if module.cond_obs_emb is not None:
|
||||||
if self.cond_pos_emb is not None:
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unaccounted module {module}')
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
|
|
||||||
|
for module_name, module in self.named_modules():
|
||||||
|
for param_name, _ in module.named_parameters():
|
||||||
|
full_param_name = f'{module_name}.{param_name}' if module_name else param_name
|
||||||
|
|
||||||
|
if param_name.endswith('bias'):
|
||||||
|
no_decay.add(full_param_name)
|
||||||
|
elif param_name.startswith('bias'):
|
||||||
|
no_decay.add(full_param_name)
|
||||||
|
elif param_name.endswith('weight') and isinstance(module, whitelist_weight_modules):
|
||||||
|
decay.add(full_param_name)
|
||||||
|
elif param_name.endswith('weight') and isinstance(module, blacklist_weight_modules):
|
||||||
|
no_decay.add(full_param_name)
|
||||||
|
|
||||||
|
no_decay.add('pos_emb')
|
||||||
|
no_decay.add('_dummy_variable')
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
no_decay.add('cond_pos_emb')
|
||||||
|
|
||||||
|
param_dict = {name: param for name, param in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'params': [param_dict[name] for name in sorted(decay)],
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [param_dict[name] for name in sorted(no_decay)],
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: Union[torch.Tensor, float, int],
|
||||||
cond: Optional[torch.Tensor] = None,
|
cond: Optional[torch.Tensor] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
前向传播
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample: (B, T, input_dim) 输入序列(加噪动作)
|
|
||||||
timestep: (B,) 时间步
|
|
||||||
cond: (B, T', cond_dim) 条件序列(观测特征)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(B, T, output_dim) 预测的噪声
|
|
||||||
"""
|
|
||||||
# ==================== 处理时间步 ====================
|
|
||||||
timesteps = timestep
|
timesteps = timestep
|
||||||
if not torch.is_tensor(timesteps):
|
if not torch.is_tensor(timesteps):
|
||||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
timesteps = timesteps[None].to(sample.device)
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
# 扩展到batch维度
|
|
||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
time_emb = self.time_emb(timesteps).unsqueeze(1)
|
||||||
|
|
||||||
# ==================== 处理输入 ====================
|
input_emb = self.input_emb(sample)
|
||||||
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
|
||||||
|
|
||||||
# ==================== Encoder-Decoder模式 ====================
|
if self.encoder_only:
|
||||||
if not self.encoder_only:
|
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||||
# --- Encoder: 处理条件 ---
|
t = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :t, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 1:, :]
|
||||||
|
else:
|
||||||
cond_embeddings = time_emb
|
cond_embeddings = time_emb
|
||||||
|
if self.obs_as_cond:
|
||||||
if self.obs_as_cond and cond is not None:
|
cond_obs_emb = self.cond_obs_emb(cond)
|
||||||
# 添加观测条件
|
|
||||||
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
|
||||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||||
|
|
||||||
# 添加位置编码
|
|
||||||
tc = cond_embeddings.shape[1]
|
tc = cond_embeddings.shape[1]
|
||||||
pos_emb = self.cond_pos_emb[:, :tc, :]
|
position_embeddings = self.cond_pos_emb[:, :tc, :]
|
||||||
x = self.drop(cond_embeddings + pos_emb)
|
x = self.drop(cond_embeddings + position_embeddings)
|
||||||
|
memory = self.encoder(x)
|
||||||
|
|
||||||
# 通过encoder
|
|
||||||
memory = self.encoder(x) # (B, T_cond, n_emb)
|
|
||||||
|
|
||||||
# --- Decoder: 预测噪声 ---
|
|
||||||
# 添加位置编码到输入
|
|
||||||
token_embeddings = input_emb
|
token_embeddings = input_emb
|
||||||
t = token_embeddings.shape[1]
|
t = token_embeddings.shape[1]
|
||||||
pos_emb = self.pos_emb[:, :t, :]
|
position_embeddings = self.pos_emb[:, :t, :]
|
||||||
x = self.drop(token_embeddings + pos_emb)
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
|
||||||
# Cross-Attention: Query来自输入,Key/Value来自memory
|
|
||||||
x = self.decoder(
|
x = self.decoder(
|
||||||
tgt=x,
|
tgt=x,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tgt_mask=self.mask,
|
tgt_mask=self.mask,
|
||||||
memory_mask=self.memory_mask
|
memory_mask=self.memory_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== Encoder-Only模式 ====================
|
|
||||||
else:
|
|
||||||
# BERT风格:时间步作为特殊token
|
|
||||||
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
|
||||||
t = token_embeddings.shape[1]
|
|
||||||
pos_emb = self.pos_emb[:, :t, :]
|
|
||||||
x = self.drop(token_embeddings + pos_emb)
|
|
||||||
|
|
||||||
x = self.encoder(src=x, mask=self.mask)
|
|
||||||
x = x[:, 1:, :] # 移除时间步token
|
|
||||||
|
|
||||||
# ==================== 输出头 ====================
|
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
x = self.head(x) # (B, T, output_dim)
|
x = self.head(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 便捷函数:创建Transformer1D模型
|
|
||||||
# ============================================================================
|
|
||||||
def create_transformer1d(
|
def create_transformer1d(
|
||||||
input_dim: int,
|
input_dim: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
@@ -322,26 +322,9 @@ def create_transformer1d(
|
|||||||
n_layer: int = 8,
|
n_layer: int = 8,
|
||||||
n_head: int = 8,
|
n_head: int = 8,
|
||||||
n_emb: int = 256,
|
n_emb: int = 256,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> Transformer1D:
|
) -> Transformer1D:
|
||||||
"""
|
return Transformer1D(
|
||||||
创建Transformer1D模型的便捷函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dim: 输入动作维度
|
|
||||||
output_dim: 输出动作维度
|
|
||||||
horizon: 预测horizon
|
|
||||||
n_obs_steps: 观测步数
|
|
||||||
cond_dim: 条件维度
|
|
||||||
n_layer: Transformer层数
|
|
||||||
n_head: 注意力头数
|
|
||||||
n_emb: 嵌入维度
|
|
||||||
**kwargs: 其他参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Transformer1D模型
|
|
||||||
"""
|
|
||||||
model = Transformer1D(
|
|
||||||
input_dim=input_dim,
|
input_dim=input_dim,
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
horizon=horizon,
|
horizon=horizon,
|
||||||
@@ -350,47 +333,5 @@ def create_transformer1d(
|
|||||||
n_layer=n_layer,
|
n_layer=n_layer,
|
||||||
n_head=n_head,
|
n_head=n_head,
|
||||||
n_emb=n_emb,
|
n_emb=n_emb,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 80)
|
|
||||||
print("Testing Transformer1D")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
# 配置
|
|
||||||
B = 4
|
|
||||||
T = 16
|
|
||||||
action_dim = 16
|
|
||||||
obs_horizon = 2
|
|
||||||
cond_dim = 416 # vision + state特征维度
|
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
model = Transformer1D(
|
|
||||||
input_dim=action_dim,
|
|
||||||
output_dim=action_dim,
|
|
||||||
horizon=T,
|
|
||||||
n_obs_steps=obs_horizon,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
n_layer=4,
|
|
||||||
n_head=8,
|
|
||||||
n_emb=256,
|
|
||||||
causal_attn=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试前向传播
|
|
||||||
sample = torch.randn(B, T, action_dim)
|
|
||||||
timestep = torch.randint(0, 100, (B,))
|
|
||||||
cond = torch.randn(B, obs_horizon, cond_dim)
|
|
||||||
|
|
||||||
output = model(sample, timestep, cond)
|
|
||||||
|
|
||||||
print(f"\n输入:")
|
|
||||||
print(f" sample: {sample.shape}")
|
|
||||||
print(f" timestep: {timestep.shape}")
|
|
||||||
print(f" cond: {cond.shape}")
|
|
||||||
print(f"\n输出:")
|
|
||||||
print(f" output: {output.shape}")
|
|
||||||
print(f"\n✅ 测试通过!")
|
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import pickle
|
DEFAULT_DATASET_DIR = str(
|
||||||
|
Path(__file__).resolve().parents[2] / "demos" / "dataset" / "sim_transfer"
|
||||||
|
)
|
||||||
|
|
||||||
def get_data_stats(dataset_dir):
|
def get_data_stats(dataset_dir):
|
||||||
"""
|
"""
|
||||||
@@ -23,6 +31,11 @@ def get_data_stats(dataset_dir):
|
|||||||
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||||
print(f"Found {len(files)} episodes in {dataset_dir}")
|
print(f"Found {len(files)} episodes in {dataset_dir}")
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
raise ValueError(
|
||||||
|
f"No episode_*.hdf5 files found in dataset_dir: {dataset_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
all_actions = []
|
all_actions = []
|
||||||
all_qpos = []
|
all_qpos = []
|
||||||
|
|
||||||
@@ -70,18 +83,32 @@ def get_data_stats(dataset_dir):
|
|||||||
}
|
}
|
||||||
return stats_flat
|
return stats_flat
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
DATASET_DIR = 'roboimi/demos/dataset/sim_transfer'
|
|
||||||
OUTPUT_PATH = DATASET_DIR + "/dataset_stats.pkl"
|
|
||||||
|
|
||||||
stats_flat = get_data_stats(DATASET_DIR)
|
def write_dataset_stats(dataset_dir):
|
||||||
|
output_path = os.path.join(dataset_dir, "dataset_stats.pkl")
|
||||||
|
stats_flat = get_data_stats(dataset_dir)
|
||||||
|
|
||||||
# 打印检查
|
# 打印检查
|
||||||
print("\n--- Stats Computed ---")
|
print("\n--- Stats Computed ---")
|
||||||
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
|
print(f"Action Mean shape: {stats_flat['action_mean'].shape}")
|
||||||
print(f"Action Std shape: {stats_flat['action_std'].shape}")
|
print(f"Action Std shape: {stats_flat['action_std'].shape}")
|
||||||
|
|
||||||
# 保存
|
with open(output_path, 'wb') as f:
|
||||||
with open(OUTPUT_PATH, 'wb') as f:
|
|
||||||
pickle.dump(stats_flat, f)
|
pickle.dump(stats_flat, f)
|
||||||
print(f"\nStats saved to {OUTPUT_PATH}")
|
print(f"\nStats saved to {output_path}")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv=None):
|
||||||
|
parser = argparse.ArgumentParser(description="Calculate dataset statistics.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_dir",
|
||||||
|
default=DEFAULT_DATASET_DIR,
|
||||||
|
help="Directory containing episode_*.hdf5 files.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
write_dataset_stats(args.dataset_dir)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
88
tests/test_calculate_stats_cli.py
Normal file
88
tests/test_calculate_stats_cli.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from roboimi.vla.scripts import calculate_stats
|
||||||
|
|
||||||
|
|
||||||
|
class CalculateStatsCliTest(unittest.TestCase):
|
||||||
|
def test_default_dataset_dir_is_absolute_and_package_relative(self):
|
||||||
|
expected = (
|
||||||
|
Path(calculate_stats.__file__).resolve().parents[2]
|
||||||
|
/ "demos"
|
||||||
|
/ "dataset"
|
||||||
|
/ "sim_transfer"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
|
||||||
|
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
|
||||||
|
|
||||||
|
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
episode_path = dataset_dir / "episode_0.hdf5"
|
||||||
|
|
||||||
|
with h5py.File(episode_path, "w") as root:
|
||||||
|
root.create_dataset(
|
||||||
|
"action",
|
||||||
|
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
|
||||||
|
)
|
||||||
|
observations = root.create_group("observations")
|
||||||
|
observations.create_dataset(
|
||||||
|
"qpos",
|
||||||
|
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||||
|
|
||||||
|
stats_path = dataset_dir / "dataset_stats.pkl"
|
||||||
|
self.assertTrue(stats_path.exists())
|
||||||
|
|
||||||
|
with stats_path.open("rb") as f:
|
||||||
|
stats = pickle.load(f)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
set(stats),
|
||||||
|
{
|
||||||
|
"action_mean",
|
||||||
|
"action_std",
|
||||||
|
"action_min",
|
||||||
|
"action_max",
|
||||||
|
"qpos_mean",
|
||||||
|
"qpos_std",
|
||||||
|
"qpos_min",
|
||||||
|
"qpos_max",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
|
||||||
|
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
|
||||||
|
|
||||||
|
def test_main_raises_clear_error_for_empty_dataset_dir(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r"No episode_\*\.hdf5 files found"
|
||||||
|
) as ctx:
|
||||||
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||||
|
|
||||||
|
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||||
|
|
||||||
|
def test_main_raises_clear_error_for_missing_dataset_dir(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir) / "missing"
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r"No episode_\*\.hdf5 files found"
|
||||||
|
) as ctx:
|
||||||
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||||
|
|
||||||
|
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
28
tests/test_eval_vla_execution.py
Normal file
28
tests/test_eval_vla_execution.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from roboimi.vla.eval_utils import execute_policy_action
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEnv:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self.calls.append(("step", action))
|
||||||
|
|
||||||
|
def step_jnt(self, action):
|
||||||
|
self.calls.append(("step_jnt", action))
|
||||||
|
|
||||||
|
|
||||||
|
class EvalVLAExecutionTest(unittest.TestCase):
|
||||||
|
def test_execute_policy_action_uses_ee_step(self):
|
||||||
|
env = _FakeEnv()
|
||||||
|
action = [1, 2, 3]
|
||||||
|
|
||||||
|
execute_policy_action(env, action)
|
||||||
|
|
||||||
|
self.assertEqual(env.calls, [("step", action)])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
259
tests/test_eval_vla_headless.py
Normal file
259
tests/test_eval_vla_headless.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from roboimi.demos.vla_scripts import eval_vla
|
||||||
|
from roboimi.envs.double_base import DualDianaMed
|
||||||
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent:
|
||||||
|
def __init__(self):
|
||||||
|
self.reset_calls = 0
|
||||||
|
self.last_observation = None
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, _device):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.reset_calls += 1
|
||||||
|
|
||||||
|
def select_action(self, observation):
|
||||||
|
self.last_observation = observation
|
||||||
|
return torch.zeros(16)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEnv:
|
||||||
|
def __init__(self):
|
||||||
|
self.image_obs_calls = 0
|
||||||
|
self.render_calls = 0
|
||||||
|
self.reset_calls = []
|
||||||
|
|
||||||
|
def reset(self, box_pos):
|
||||||
|
self.reset_calls.append(np.array(box_pos))
|
||||||
|
|
||||||
|
def _get_image_obs(self):
|
||||||
|
self.image_obs_calls += 1
|
||||||
|
return {
|
||||||
|
"images": {
|
||||||
|
"front": np.zeros((8, 8, 3), dtype=np.uint8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_qpos_obs(self):
|
||||||
|
return {"qpos": np.zeros(16, dtype=np.float32)}
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_calls += 1
|
||||||
|
raise AssertionError("env.render() should be skipped when eval.headless=true")
|
||||||
|
|
||||||
|
|
||||||
|
class _RewardTrackingEnv(_FakeEnv):
|
||||||
|
def __init__(self, reward_sequences):
|
||||||
|
super().__init__()
|
||||||
|
self.reward_sequences = reward_sequences
|
||||||
|
self.episode_index = -1
|
||||||
|
self.step_index = 0
|
||||||
|
self.rew = 0.0
|
||||||
|
|
||||||
|
def reset(self, box_pos):
|
||||||
|
super().reset(box_pos)
|
||||||
|
self.episode_index += 1
|
||||||
|
self.step_index = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRenderer:
|
||||||
|
def __init__(self, env):
|
||||||
|
self._env = env
|
||||||
|
self._frames = [
|
||||||
|
np.full((4, 4, 3), fill_value=index, dtype=np.uint8)
|
||||||
|
for index in range(5)
|
||||||
|
]
|
||||||
|
self._index = 0
|
||||||
|
|
||||||
|
def update_scene(self, _mj_data, camera=None):
|
||||||
|
self._camera = camera
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
frame = self._frames[self._index]
|
||||||
|
self._index += 1
|
||||||
|
if self._index >= len(self._frames):
|
||||||
|
self._env.exit_flag = True
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||||
|
def test_eval_config_exposes_headless_default(self):
|
||||||
|
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||||
|
|
||||||
|
self.assertIn("headless", eval_cfg)
|
||||||
|
self.assertFalse(eval_cfg.headless)
|
||||||
|
|
||||||
|
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||||
|
fake_env = object()
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"roboimi.assets.robots.diana_med.BiDianaMed",
|
||||||
|
return_value="robot",
|
||||||
|
), mock.patch(
|
||||||
|
"roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl",
|
||||||
|
return_value=fake_env,
|
||||||
|
) as env_cls:
|
||||||
|
env = make_sim_env("sim_transfer", headless=True)
|
||||||
|
|
||||||
|
self.assertIs(env, fake_env)
|
||||||
|
env_cls.assert_called_once_with(
|
||||||
|
robot="robot",
|
||||||
|
is_render=False,
|
||||||
|
control_freq=30,
|
||||||
|
is_interpolate=True,
|
||||||
|
cam_view="angle",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||||
|
env = DualDianaMed.__new__(DualDianaMed)
|
||||||
|
env.mj_model = object()
|
||||||
|
env.mj_data = object()
|
||||||
|
env.exit_flag = False
|
||||||
|
env.is_render = False
|
||||||
|
env.cam = "angle"
|
||||||
|
env.r_vis = None
|
||||||
|
env.l_vis = None
|
||||||
|
env.top = None
|
||||||
|
env.angle = None
|
||||||
|
env.front = None
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"roboimi.envs.double_base.mj.Renderer",
|
||||||
|
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
|
||||||
|
), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch(
|
||||||
|
"roboimi.envs.double_base.cv2.imshow"
|
||||||
|
) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key:
|
||||||
|
env.camera_viewer()
|
||||||
|
|
||||||
|
named_window.assert_not_called()
|
||||||
|
imshow.assert_not_called()
|
||||||
|
wait_key.assert_not_called()
|
||||||
|
self.assertIsNotNone(env.r_vis)
|
||||||
|
self.assertIsNotNone(env.l_vis)
|
||||||
|
self.assertIsNotNone(env.top)
|
||||||
|
self.assertIsNotNone(env.angle)
|
||||||
|
self.assertIsNotNone(env.front)
|
||||||
|
|
||||||
|
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
|
||||||
|
fake_env = _FakeEnv()
|
||||||
|
fake_agent = _FakeAgent()
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
"agent": {},
|
||||||
|
"eval": {
|
||||||
|
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||||
|
"num_episodes": 1,
|
||||||
|
"max_timesteps": 1,
|
||||||
|
"device": "cpu",
|
||||||
|
"task_name": "sim_transfer",
|
||||||
|
"camera_names": ["front"],
|
||||||
|
"use_smoothing": False,
|
||||||
|
"smooth_alpha": 0.3,
|
||||||
|
"verbose_action": False,
|
||||||
|
"headless": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"load_checkpoint",
|
||||||
|
return_value=(fake_agent, None),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"make_sim_env",
|
||||||
|
return_value=fake_env,
|
||||||
|
) as make_env, mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"sample_transfer_pose",
|
||||||
|
return_value=np.array([0.1, 0.2, 0.3]),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"execute_policy_action",
|
||||||
|
) as execute_policy_action, mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"tqdm",
|
||||||
|
side_effect=lambda iterable, **kwargs: iterable,
|
||||||
|
):
|
||||||
|
eval_vla.main.__wrapped__(cfg)
|
||||||
|
|
||||||
|
make_env.assert_called_once_with("sim_transfer", headless=True)
|
||||||
|
execute_policy_action.assert_called_once()
|
||||||
|
self.assertEqual(fake_env.image_obs_calls, 1)
|
||||||
|
self.assertEqual(fake_env.render_calls, 0)
|
||||||
|
self.assertIsNotNone(fake_agent.last_observation)
|
||||||
|
self.assertIn("front", fake_agent.last_observation["images"])
|
||||||
|
|
||||||
|
def test_run_eval_returns_average_reward_summary(self):
|
||||||
|
reward_sequences = [
|
||||||
|
[1.0, 2.0],
|
||||||
|
[0.5, 4.0],
|
||||||
|
]
|
||||||
|
fake_env = _RewardTrackingEnv(reward_sequences)
|
||||||
|
fake_agent = _FakeAgent()
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
"agent": {},
|
||||||
|
"eval": {
|
||||||
|
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||||
|
"num_episodes": 2,
|
||||||
|
"max_timesteps": 2,
|
||||||
|
"device": "cpu",
|
||||||
|
"task_name": "sim_transfer",
|
||||||
|
"camera_names": ["front"],
|
||||||
|
"use_smoothing": False,
|
||||||
|
"smooth_alpha": 0.3,
|
||||||
|
"verbose_action": False,
|
||||||
|
"headless": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_execute_policy_action(env, action):
|
||||||
|
del action
|
||||||
|
env.rew = env.reward_sequences[env.episode_index][env.step_index]
|
||||||
|
env.step_index += 1
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"load_checkpoint",
|
||||||
|
return_value=(fake_agent, None),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"make_sim_env",
|
||||||
|
return_value=fake_env,
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"sample_transfer_pose",
|
||||||
|
return_value=np.array([0.1, 0.2, 0.3]),
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"execute_policy_action",
|
||||||
|
side_effect=fake_execute_policy_action,
|
||||||
|
), mock.patch.object(
|
||||||
|
eval_vla,
|
||||||
|
"tqdm",
|
||||||
|
side_effect=lambda iterable, **kwargs: iterable,
|
||||||
|
):
|
||||||
|
summary = eval_vla._run_eval(cfg)
|
||||||
|
|
||||||
|
self.assertEqual(summary["episode_rewards"], [3.0, 4.5])
|
||||||
|
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||||
|
self.assertEqual(summary["num_episodes"], 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
import contextlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hydra import compose, initialize_config_dir
|
||||||
|
from hydra.errors import InstantiationException
|
||||||
|
from hydra.core.global_hydra import GlobalHydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||||
|
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScheduler:
|
||||||
|
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||||
|
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||||
|
self.timesteps = []
|
||||||
|
|
||||||
|
def add_noise(self, sample, noise, timestep):
|
||||||
|
return sample + noise
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps):
|
||||||
|
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||||
|
|
||||||
|
def step(self, noise_pred, timestep, sample):
|
||||||
|
return types.SimpleNamespace(prev_sample=sample)
|
||||||
|
|
||||||
|
|
||||||
|
class _IdentityCrop:
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResNet(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||||
|
self.relu1 = torch.nn.ReLU()
|
||||||
|
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||||
|
self.relu2 = torch.nn.ReLU()
|
||||||
|
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = torch.nn.Linear(16, 16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu1(self.conv1(x))
|
||||||
|
x = self.relu2(self.conv2(x))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, start_dim=1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRearrange(torch.nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _CondCapturingHead(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.last_cond = None
|
||||||
|
|
||||||
|
def forward(self, sample, timestep, cond):
|
||||||
|
self.last_cond = cond.detach().clone()
|
||||||
|
return torch.zeros_like(sample)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _stub_optional_modules():
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def inject(name, module):
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
sys.modules[name] = module
|
||||||
|
|
||||||
|
diffusers_module = types.ModuleType('diffusers')
|
||||||
|
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||||
|
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||||
|
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||||
|
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||||
|
ddim_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.schedulers = schedulers_module
|
||||||
|
schedulers_module.scheduling_ddpm = ddpm_module
|
||||||
|
schedulers_module.scheduling_ddim = ddim_module
|
||||||
|
|
||||||
|
torchvision_module = types.ModuleType('torchvision')
|
||||||
|
models_module = types.ModuleType('torchvision.models')
|
||||||
|
transforms_module = types.ModuleType('torchvision.transforms')
|
||||||
|
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||||
|
transforms_module.CenterCrop = _IdentityCrop
|
||||||
|
transforms_module.RandomCrop = _IdentityCrop
|
||||||
|
torchvision_module.models = models_module
|
||||||
|
torchvision_module.transforms = transforms_module
|
||||||
|
|
||||||
|
einops_module = types.ModuleType('einops')
|
||||||
|
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||||
|
einops_layers_module = types.ModuleType('einops.layers')
|
||||||
|
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||||
|
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||||
|
einops_module.layers = einops_layers_module
|
||||||
|
einops_layers_module.torch = einops_layers_torch_module
|
||||||
|
|
||||||
|
try:
|
||||||
|
inject('diffusers', diffusers_module)
|
||||||
|
inject('diffusers.schedulers', schedulers_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||||
|
inject('torchvision', torchvision_module)
|
||||||
|
inject('torchvision.models', models_module)
|
||||||
|
inject('torchvision.transforms', transforms_module)
|
||||||
|
inject('einops', einops_module)
|
||||||
|
inject('einops.layers', einops_layers_module)
|
||||||
|
inject('einops.layers.torch', einops_layers_torch_module)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_cfg(overrides=None):
|
||||||
|
if not OmegaConf.has_resolver('len'):
|
||||||
|
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||||
|
|
||||||
|
GlobalHydra.instance().clear()
|
||||||
|
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||||
|
return compose(config_name='config', overrides=list(overrides or []))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
|
||||||
|
channels, height, width = image_shape
|
||||||
|
per_camera_fill = per_camera_fill or {
|
||||||
|
'front': 30.0,
|
||||||
|
'top': 20.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
name: torch.full(
|
||||||
|
(batch_size, obs_horizon, channels, height, width),
|
||||||
|
fill_value=fill_value,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
for name, fill_value in per_camera_fill.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_backbone_for_order_tracking(backbone):
|
||||||
|
feature_dim = backbone.output_dim
|
||||||
|
|
||||||
|
def encode_mean(image_batch):
|
||||||
|
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
|
||||||
|
return mean_feature.repeat(1, feature_dim)
|
||||||
|
|
||||||
|
if backbone.use_separate_rgb_encoder_per_camera:
|
||||||
|
for encoder in backbone.rgb_encoder:
|
||||||
|
encoder.forward_single_image = encode_mean
|
||||||
|
else:
|
||||||
|
backbone.rgb_encoder.forward_single_image = encode_mean
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_camera_markers(cond, feature_dim, num_cams):
|
||||||
|
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
|
||||||
|
return camera_block[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||||
|
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.inference_steps=1',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_cond_layers=0',
|
||||||
|
'agent.head.n_emb=32',
|
||||||
|
'agent.head.n_head=4',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||||
|
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||||
|
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||||
|
self.assertEqual(cfg.agent.head_type, 'transformer')
|
||||||
|
self.assertEqual(cfg.agent.num_cams, 3)
|
||||||
|
self.assertTrue(cfg.agent.head.obs_as_cond)
|
||||||
|
self.assertFalse(cfg.agent.head.causal_attn)
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
|
||||||
|
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
|
||||||
|
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size,
|
||||||
|
cfg.agent.obs_horizon,
|
||||||
|
image_shape,
|
||||||
|
per_camera_fill={
|
||||||
|
'front': 30.0,
|
||||||
|
'top': 20.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
'left_wrist': 99.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||||
|
_patch_backbone_for_order_tracking(agent.vision_encoder)
|
||||||
|
capturing_head = _CondCapturingHead()
|
||||||
|
agent.noise_pred_net = capturing_head
|
||||||
|
predicted_actions = agent.predict_action(images, proprioception)
|
||||||
|
self.assertEqual(
|
||||||
|
predicted_actions.shape,
|
||||||
|
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(capturing_head.last_cond)
|
||||||
|
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
|
||||||
|
camera_markers = _extract_camera_markers(
|
||||||
|
capturing_head.last_cond,
|
||||||
|
agent.vision_encoder.output_dim,
|
||||||
|
agent.num_cams,
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||||
|
|
||||||
|
missing_images = dict(images)
|
||||||
|
missing_images.pop('top')
|
||||||
|
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||||
|
agent.predict_action(missing_images, proprioception)
|
||||||
|
|
||||||
|
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
|
||||||
|
instantiate(cfg.agent)
|
||||||
|
|
||||||
|
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.vision_backbone.camera_names = None
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
backbone = instantiate(cfg.agent.vision_backbone)
|
||||||
|
_patch_backbone_for_order_tracking(backbone)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=cfg.agent.obs_horizon,
|
||||||
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||||
|
per_camera_fill={
|
||||||
|
'top': 20.0,
|
||||||
|
'front': 30.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ordered_features = backbone(images)
|
||||||
|
camera_markers = _extract_camera_markers(
|
||||||
|
ordered_features,
|
||||||
|
backbone.output_dim,
|
||||||
|
len(images),
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
|
||||||
|
|
||||||
|
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.camera_names = None
|
||||||
|
cfg.agent.vision_backbone.camera_names = None
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
observation = {
|
||||||
|
'qpos': torch.randn(cfg.agent.obs_dim),
|
||||||
|
'images': {
|
||||||
|
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
|
||||||
|
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
|
||||||
|
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
agent._populate_queues(observation)
|
||||||
|
batch = agent._prepare_observation_batch()
|
||||||
|
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
|
||||||
|
|
||||||
|
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.vision_backbone.camera_names = None
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
backbone = instantiate(cfg.agent.vision_backbone)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=cfg.agent.obs_horizon,
|
||||||
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||||
|
per_camera_fill={
|
||||||
|
'front': 30.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'num_cameras'):
|
||||||
|
backbone(images)
|
||||||
|
|
||||||
|
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.inference_steps=1',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_cond_layers=0',
|
||||||
|
'agent.head.n_emb=32',
|
||||||
|
'agent.head.n_head=4',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.camera_names = None
|
||||||
|
cfg.agent.vision_backbone.camera_names = None
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=cfg.agent.obs_horizon,
|
||||||
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||||
|
per_camera_fill={
|
||||||
|
'front': 30.0,
|
||||||
|
'r_vis': 10.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||||
|
with self.assertRaisesRegex(ValueError, 'num_cams'):
|
||||||
|
agent.predict_action(images, proprioception)
|
||||||
|
|
||||||
|
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cfg.agent.camera_names = None
|
||||||
|
cfg.agent.vision_backbone.camera_names = None
|
||||||
|
cfg.agent.num_cams = 2
|
||||||
|
cfg.agent.vision_backbone.num_cameras = 3
|
||||||
|
|
||||||
|
with _stub_optional_modules():
|
||||||
|
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||||
|
instantiate(cfg.agent)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
63
tests/test_robot_asset_paths.py
Normal file
63
tests/test_robot_asset_paths.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeKDL:
|
||||||
|
init_calls = []
|
||||||
|
reset_calls = []
|
||||||
|
|
||||||
|
def __init__(self, urdf_path):
|
||||||
|
self.__class__.init_calls.append(urdf_path)
|
||||||
|
|
||||||
|
def resetChain(self, base, end):
|
||||||
|
self.__class__.reset_calls.append((base, end))
|
||||||
|
|
||||||
|
|
||||||
|
class RobotAssetPathResolutionTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
_FakeKDL.init_calls = []
|
||||||
|
_FakeKDL.reset_calls = []
|
||||||
|
|
||||||
|
def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self):
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml'
|
||||||
|
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
|
||||||
|
xml_calls = []
|
||||||
|
|
||||||
|
def fake_from_xml_path(*, filename, assets=None):
|
||||||
|
xml_calls.append((filename, assets))
|
||||||
|
return object()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch(
|
||||||
|
'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path',
|
||||||
|
side_effect=fake_from_xml_path,
|
||||||
|
), mock.patch(
|
||||||
|
'roboimi.assets.robots.arm_base.mujoco.MjData',
|
||||||
|
return_value=object(),
|
||||||
|
), mock.patch(
|
||||||
|
'roboimi.assets.robots.arm_base.KDL_utils',
|
||||||
|
_FakeKDL,
|
||||||
|
):
|
||||||
|
BiDianaMed()
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(len(xml_calls), 1)
|
||||||
|
self.assertEqual(Path(xml_calls[0][0]), expected_xml)
|
||||||
|
self.assertTrue(Path(xml_calls[0][0]).is_absolute())
|
||||||
|
self.assertGreaterEqual(len(_FakeKDL.init_calls), 2)
|
||||||
|
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
|
||||||
|
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||||
|
def _write_episode(self, dataset_dir: Path) -> None:
|
||||||
|
episode_path = dataset_dir / "episode_0.hdf5"
|
||||||
|
with h5py.File(episode_path, "w") as root:
|
||||||
|
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
|
||||||
|
root.create_dataset(
|
||||||
|
"observations/qpos",
|
||||||
|
data=np.arange(16, dtype=np.float32).reshape(4, 4),
|
||||||
|
)
|
||||||
|
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
||||||
|
root.create_dataset(
|
||||||
|
"observations/images/front",
|
||||||
|
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_getitem_only_resizes_observation_horizon_images(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dataset_dir = Path(tmpdir)
|
||||||
|
self._write_episode(dataset_dir)
|
||||||
|
dataset = SimpleRobotDataset(
|
||||||
|
dataset_dir,
|
||||||
|
obs_horizon=2,
|
||||||
|
pred_horizon=3,
|
||||||
|
camera_names=["front"],
|
||||||
|
)
|
||||||
|
|
||||||
|
resize_calls = []
|
||||||
|
|
||||||
|
def fake_resize(image, size, interpolation=None):
|
||||||
|
resize_calls.append(
|
||||||
|
{
|
||||||
|
"shape": tuple(image.shape),
|
||||||
|
"size": size,
|
||||||
|
"interpolation": interpolation,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
||||||
|
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
|
||||||
|
|
||||||
|
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||||
|
sample = dataset[1]
|
||||||
|
|
||||||
|
self.assertEqual(len(resize_calls), 2)
|
||||||
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||||
779
tests/test_train_vla_rollout_validation.py
Normal file
779
tests/test_train_vla_rollout_validation.py
Normal file
@@ -0,0 +1,779 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from roboimi.demos.vla_scripts import eval_vla, train_vla
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDataset:
|
||||||
|
def __len__(self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLoader:
|
||||||
|
def __init__(self, batch, length=1):
|
||||||
|
self._batches = [batch] * length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._batches)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._batches)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeOptimizer:
|
||||||
|
def __init__(self, lr=1e-3):
|
||||||
|
self.param_groups = [{'lr': lr}]
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
del state_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScheduler:
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
del state_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeProgressBar:
|
||||||
|
def __init__(self, iterable):
|
||||||
|
self._items = list(iterable)
|
||||||
|
self.postfix_calls = []
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._items)
|
||||||
|
|
||||||
|
def set_postfix(self, values):
|
||||||
|
self.postfix_calls.append(values)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
del device
|
||||||
|
return self
|
||||||
|
|
||||||
|
def compute_loss(self, agent_input):
|
||||||
|
del agent_input
|
||||||
|
return (self.weight - torch.tensor(0.5)).pow(2)
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class _SequentialLossAgent(nn.Module):
|
||||||
|
def __init__(self, losses):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||||
|
self._losses = list(losses)
|
||||||
|
self._index = 0
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
del device
|
||||||
|
return self
|
||||||
|
|
||||||
|
def compute_loss(self, agent_input):
|
||||||
|
del agent_input
|
||||||
|
loss_value = self._losses[self._index]
|
||||||
|
self._index += 1
|
||||||
|
return (self.weight * 0) + torch.tensor(float(loss_value))
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEvalAgent:
|
||||||
|
def __init__(self):
|
||||||
|
self.reset_calls = 0
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
del device
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.reset_calls += 1
|
||||||
|
|
||||||
|
def select_action(self, observation):
|
||||||
|
del observation
|
||||||
|
return torch.zeros(2)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEvalEnv:
|
||||||
|
def reset(self, box_pos):
|
||||||
|
self.box_pos = box_pos
|
||||||
|
|
||||||
|
def _get_image_obs(self):
|
||||||
|
return {
|
||||||
|
'images': {
|
||||||
|
'front': np.zeros((8, 8, 3), dtype=np.uint8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_qpos_obs(self):
|
||||||
|
return {'qpos': np.zeros(4, dtype=np.float32)}
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
raise AssertionError('render should not be called in this helper delegation test')
|
||||||
|
|
||||||
|
|
||||||
|
class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||||
|
def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self):
|
||||||
|
cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml'))
|
||||||
|
|
||||||
|
self.assertEqual(cfg.train.val_split, 0.0)
|
||||||
|
self.assertGreater(cfg.train.batch_size, 8)
|
||||||
|
self.assertGreater(float(cfg.train.lr), 5e-5)
|
||||||
|
self.assertGreater(cfg.train.num_workers, 8)
|
||||||
|
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||||
|
|
||||||
|
def test_eval_main_delegates_to_plain_run_eval_helper(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'agent': {},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'checkpoints/vla_model_step_1.pt',
|
||||||
|
'num_episodes': 1,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
run_eval_mock = mock.Mock()
|
||||||
|
|
||||||
|
with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \
|
||||||
|
mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \
|
||||||
|
mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \
|
||||||
|
mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \
|
||||||
|
mock.patch.object(eval_vla, 'execute_policy_action'), \
|
||||||
|
mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||||
|
eval_vla.main.__wrapped__(cfg)
|
||||||
|
|
||||||
|
run_eval_mock.assert_called_once_with(cfg)
|
||||||
|
|
||||||
|
def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 1,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 100,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 1000,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 50,
|
||||||
|
'rollout_num_episodes': 3,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 99,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
agent = _FakeAgent()
|
||||||
|
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
|
||||||
|
swanlab_log_mock = mock.Mock()
|
||||||
|
saved_checkpoints = []
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||||
|
del shuffle, _kwargs
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_torch_save(payload, path):
|
||||||
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||||
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \
|
||||||
|
mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(rollout_mock.call_count, 2)
|
||||||
|
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
|
||||||
|
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
|
||||||
|
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
|
||||||
|
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
|
||||||
|
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
|
||||||
|
self.assertTrue(first_rollout_cfg.eval.headless)
|
||||||
|
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
|
||||||
|
self.assertFalse(first_rollout_cfg.eval.verbose_action)
|
||||||
|
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
|
||||||
|
self.assertEqual(cfg.eval.num_episodes, 99)
|
||||||
|
self.assertFalse(cfg.eval.headless)
|
||||||
|
self.assertEqual(cfg.eval.device, 'cpu')
|
||||||
|
self.assertFalse(cfg.eval.verbose_action)
|
||||||
|
|
||||||
|
rollout_reward_logs = [
|
||||||
|
call.args[1]['rollout/avg_reward']
|
||||||
|
for call in swanlab_log_mock.call_args_list
|
||||||
|
if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1]
|
||||||
|
]
|
||||||
|
self.assertEqual(rollout_reward_logs, [2.0, 1.0])
|
||||||
|
|
||||||
|
best_model_saves = [
|
||||||
|
payload for path, payload in saved_checkpoints
|
||||||
|
if path.endswith('checkpoints/vla_model_best.pt')
|
||||||
|
]
|
||||||
|
self.assertEqual(len(best_model_saves), 1)
|
||||||
|
self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0)
|
||||||
|
|
||||||
|
def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 1,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 5,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 2,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 50,
|
||||||
|
'rollout_num_episodes': 3,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 99,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
saved_checkpoints = []
|
||||||
|
rollout_mock = mock.Mock()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return _FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||||
|
del shuffle, _kwargs
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_torch_save(payload, path):
|
||||||
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||||
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(rollout_mock.call_count, 0)
|
||||||
|
best_model_saves = [
|
||||||
|
payload for path, payload in saved_checkpoints
|
||||||
|
if path.endswith('checkpoints/vla_model_best.pt')
|
||||||
|
]
|
||||||
|
self.assertEqual(len(best_model_saves), 1)
|
||||||
|
self.assertIsNone(best_model_saves[0]['rollout_avg_reward'])
|
||||||
|
|
||||||
|
def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 8,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 1,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 10,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 50,
|
||||||
|
'rollout_num_episodes': 3,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 99,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataloader_calls = []
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return _FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs):
|
||||||
|
dataloader_calls.append({
|
||||||
|
'shuffle': shuffle,
|
||||||
|
'drop_last': drop_last,
|
||||||
|
'dataset_len': len(dataset),
|
||||||
|
})
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
train_loader_calls = [call for call in dataloader_calls if call['shuffle']]
|
||||||
|
self.assertEqual(len(train_loader_calls), 1)
|
||||||
|
self.assertFalse(train_loader_calls[0]['drop_last'])
|
||||||
|
|
||||||
|
def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 2,
|
||||||
|
'num_workers': 2,
|
||||||
|
'val_split': 0.25,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 1,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 10,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 50,
|
||||||
|
'rollout_num_episodes': 3,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 99,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataloader_calls = []
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return _FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs):
|
||||||
|
dataloader_calls.append({
|
||||||
|
'shuffle': shuffle,
|
||||||
|
'num_workers': num_workers,
|
||||||
|
'persistent_workers': persistent_workers,
|
||||||
|
})
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', return_value=None):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataloader_calls), 2)
|
||||||
|
self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False])
|
||||||
|
self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls))
|
||||||
|
self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls))
|
||||||
|
|
||||||
|
def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 1,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 6,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 1,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 2,
|
||||||
|
'rollout_num_episodes': 1,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
'eval': {
|
||||||
|
'ckpt_path': 'unused.pt',
|
||||||
|
'num_episodes': 99,
|
||||||
|
'max_timesteps': 1,
|
||||||
|
'device': 'cpu',
|
||||||
|
'task_name': 'sim_transfer',
|
||||||
|
'camera_names': ['front'],
|
||||||
|
'use_smoothing': False,
|
||||||
|
'smooth_alpha': 0.3,
|
||||||
|
'verbose_action': False,
|
||||||
|
'headless': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5])
|
||||||
|
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
||||||
|
saved_checkpoints = []
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||||
|
del _kwargs
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=2 if shuffle else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_torch_save(payload, path):
|
||||||
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
|
||||||
|
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
best_model_saves = [
|
||||||
|
(payload['step'], payload['rollout_avg_reward'])
|
||||||
|
for path, payload in saved_checkpoints
|
||||||
|
if path.endswith('checkpoints/vla_model_best.pt')
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
best_model_saves,
|
||||||
|
[
|
||||||
|
(1, None),
|
||||||
|
(2, None),
|
||||||
|
(3, None),
|
||||||
|
(3, 1.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertEqual(rollout_mock.call_count, 1)
|
||||||
|
|
||||||
|
def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self):
|
||||||
|
cfg = OmegaConf.create(
|
||||||
|
{
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 8,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.0,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 1,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 1000,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': False,
|
||||||
|
'rollout_val_freq_epochs': 0,
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ['front'],
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
agent = _FakeAgent()
|
||||||
|
dataloader_calls = []
|
||||||
|
saved_checkpoints = []
|
||||||
|
|
||||||
|
class _TinyDataset:
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return _TinyDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs):
|
||||||
|
del _kwargs
|
||||||
|
dataloader_calls.append(
|
||||||
|
{
|
||||||
|
'shuffle': shuffle,
|
||||||
|
'drop_last': drop_last,
|
||||||
|
'dataset_len': len(dataset),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1
|
||||||
|
return _FakeLoader(
|
||||||
|
{
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
},
|
||||||
|
length=loader_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_torch_save(payload, path):
|
||||||
|
saved_checkpoints.append((str(path), deepcopy(payload)))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||||
|
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||||
|
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||||
|
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save):
|
||||||
|
train_vla._run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
dataloader_calls[0],
|
||||||
|
{
|
||||||
|
'shuffle': True,
|
||||||
|
'drop_last': False,
|
||||||
|
'dataset_len': 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
[path for path, _payload in saved_checkpoints],
|
||||||
|
['checkpoints/vla_model_final.pt'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
699
tests/test_train_vla_swanlab_logging.py
Normal file
699
tests/test_train_vla_swanlab_logging.py
Normal file
@@ -0,0 +1,699 @@
|
|||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||||
|
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
return self[name]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise AttributeError(name) from exc
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
self[name] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _to_attrdict(value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_to_attrdict(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDataset:
|
||||||
|
def __len__(self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLoader:
|
||||||
|
def __init__(self, batch):
|
||||||
|
self.batch = batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter((self.batch,))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeScheduler:
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeOptimizer:
|
||||||
|
def __init__(self, lr=1e-3):
|
||||||
|
self.param_groups = [{'lr': lr}]
|
||||||
|
self.loaded_state_dict = None
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.loaded_state_dict = state_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeProgressBar:
|
||||||
|
def __init__(self, iterable):
|
||||||
|
self._items = list(iterable)
|
||||||
|
self.postfix_calls = []
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._items)
|
||||||
|
|
||||||
|
def set_postfix(self, values):
|
||||||
|
self.postfix_calls.append(values)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAgent(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.tensor(0.0))
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def compute_loss(self, agent_input):
|
||||||
|
del agent_input
|
||||||
|
target = torch.tensor(0.25 if self.training else 0.1)
|
||||||
|
return (self.weight - target).pow(2)
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSwanLab:
|
||||||
|
def __init__(self, init_error=None, log_errors=None, finish_error=None):
|
||||||
|
self.init_error = init_error
|
||||||
|
self.log_errors = list(log_errors or [])
|
||||||
|
self.finish_error = finish_error
|
||||||
|
self.init_calls = []
|
||||||
|
self.log_calls = []
|
||||||
|
self.finish_calls = 0
|
||||||
|
|
||||||
|
def init(self, project, experiment_name=None, config=None):
|
||||||
|
self.init_calls.append({
|
||||||
|
'project': project,
|
||||||
|
'experiment_name': experiment_name,
|
||||||
|
'config': config,
|
||||||
|
})
|
||||||
|
if self.init_error is not None:
|
||||||
|
raise self.init_error
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
self.log_calls.append((dict(payload), step))
|
||||||
|
if self.log_errors:
|
||||||
|
raise self.log_errors.pop(0)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.finish_calls += 1
|
||||||
|
if self.finish_error is not None:
|
||||||
|
raise self.finish_error
|
||||||
|
|
||||||
|
|
||||||
|
class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||||
|
def test_default_config_keeps_swanlab_opt_in(self):
|
||||||
|
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
|
||||||
|
self.assertIn('use_swanlab: false', config_text)
|
||||||
|
|
||||||
|
def _load_train_vla_module(self):
|
||||||
|
hydra_module = types.ModuleType('hydra')
|
||||||
|
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||||
|
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
def hydra_main(**_kwargs):
|
||||||
|
def decorator(func):
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
hydra_module.main = hydra_main
|
||||||
|
hydra_module.utils = hydra_utils_module
|
||||||
|
|
||||||
|
class OmegaConfStub:
|
||||||
|
_resolvers = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_resolver(cls, name):
|
||||||
|
return name in cls._resolvers
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_new_resolver(cls, name, resolver):
|
||||||
|
cls._resolvers[name] = resolver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_yaml(_cfg):
|
||||||
|
return 'stub-config'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_container(cfg, resolve=False):
|
||||||
|
del resolve
|
||||||
|
return dict(cfg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(cfg):
|
||||||
|
return _to_attrdict(cfg)
|
||||||
|
|
||||||
|
omegaconf_module = types.ModuleType('omegaconf')
|
||||||
|
omegaconf_module.DictConfig = dict
|
||||||
|
omegaconf_module.OmegaConf = OmegaConfStub
|
||||||
|
|
||||||
|
module_name = 'train_vla_swanlab_test_module'
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
with mock.patch.dict(
|
||||||
|
sys.modules,
|
||||||
|
{
|
||||||
|
'hydra': hydra_module,
|
||||||
|
'hydra.utils': hydra_utils_module,
|
||||||
|
'omegaconf': omegaconf_module,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
|
||||||
|
return AttrDict(
|
||||||
|
train=AttrDict(
|
||||||
|
device='cpu',
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=0,
|
||||||
|
val_split=0.25,
|
||||||
|
seed=0,
|
||||||
|
lr=1e-3,
|
||||||
|
max_steps=2,
|
||||||
|
log_freq=1,
|
||||||
|
save_freq=1,
|
||||||
|
warmup_steps=1,
|
||||||
|
scheduler_type='constant',
|
||||||
|
min_lr=0.0,
|
||||||
|
grad_clip=1.0,
|
||||||
|
weight_decay=0.0,
|
||||||
|
pretrained_ckpt=None,
|
||||||
|
resume_ckpt=None,
|
||||||
|
use_swanlab=use_swanlab,
|
||||||
|
swanlab_project='roboimi-vla-tests',
|
||||||
|
swanlab_run_name=swanlab_run_name,
|
||||||
|
),
|
||||||
|
data=AttrDict(
|
||||||
|
camera_names=('front',),
|
||||||
|
),
|
||||||
|
agent=AttrDict(
|
||||||
|
_target_='fake.agent',
|
||||||
|
),
|
||||||
|
eval=AttrDict(
|
||||||
|
ckpt_path='unused.pt',
|
||||||
|
num_episodes=1,
|
||||||
|
max_timesteps=1,
|
||||||
|
device='cpu',
|
||||||
|
task_name='sim_transfer',
|
||||||
|
camera_names=('front',),
|
||||||
|
use_smoothing=False,
|
||||||
|
smooth_alpha=0.3,
|
||||||
|
verbose_action=False,
|
||||||
|
headless=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_run_training(self, module):
|
||||||
|
run_training = getattr(module, '_run_training', None)
|
||||||
|
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
|
||||||
|
return run_training
|
||||||
|
|
||||||
|
def _make_batch(self):
|
||||||
|
return {
|
||||||
|
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||||
|
'observation.state': torch.zeros(1, 4),
|
||||||
|
'action': torch.zeros(1, 2),
|
||||||
|
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _loader_factory(self):
|
||||||
|
train_batch = self._make_batch()
|
||||||
|
val_batch = self._make_batch()
|
||||||
|
|
||||||
|
def factory(_dataset, *, shuffle, **_kwargs):
|
||||||
|
return FakeLoader(train_batch if shuffle else val_batch)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
agent = FakeAgent()
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
fake_swanlab.init_calls,
|
||||||
|
[{
|
||||||
|
'project': 'roboimi-vla-tests',
|
||||||
|
'experiment_name': 'smoke-run',
|
||||||
|
'config': {
|
||||||
|
'train': {
|
||||||
|
'device': 'cpu',
|
||||||
|
'batch_size': 2,
|
||||||
|
'num_workers': 0,
|
||||||
|
'val_split': 0.25,
|
||||||
|
'seed': 0,
|
||||||
|
'lr': 1e-3,
|
||||||
|
'max_steps': 2,
|
||||||
|
'log_freq': 1,
|
||||||
|
'save_freq': 1,
|
||||||
|
'warmup_steps': 1,
|
||||||
|
'scheduler_type': 'constant',
|
||||||
|
'min_lr': 0.0,
|
||||||
|
'grad_clip': 1.0,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'pretrained_ckpt': None,
|
||||||
|
'resume_ckpt': None,
|
||||||
|
'use_swanlab': True,
|
||||||
|
'swanlab_project': 'roboimi-vla-tests',
|
||||||
|
'swanlab_run_name': 'smoke-run',
|
||||||
|
},
|
||||||
|
'data': {
|
||||||
|
'camera_names': ('front',),
|
||||||
|
},
|
||||||
|
'agent': {
|
||||||
|
'_target_': 'fake.agent',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
|
||||||
|
self.assertTrue(
|
||||||
|
{
|
||||||
|
'train/loss',
|
||||||
|
'train/lr',
|
||||||
|
'train/best_loss',
|
||||||
|
'train/step',
|
||||||
|
'val/loss',
|
||||||
|
'final/checkpoint_path',
|
||||||
|
'final/best_checkpoint_path',
|
||||||
|
}.issubset(logged_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
|
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
|
||||||
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||||
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||||
|
|
||||||
|
def test_run_training_skips_swanlab_when_disabled(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg(use_swanlab=False)
|
||||||
|
agent = FakeAgent()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||||
|
|
||||||
|
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
agent = FakeAgent()
|
||||||
|
fake_swanlab = FakeSwanLab(
|
||||||
|
log_errors=[RuntimeError('log backend hiccup')],
|
||||||
|
finish_error=RuntimeError('finish backend hiccup'),
|
||||||
|
)
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||||
|
mock.patch.object(module.log, 'warning') as warning_mock:
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
|
||||||
|
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
|
||||||
|
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
|
||||||
|
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||||
|
|
||||||
|
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
cfg.train.max_steps = 2
|
||||||
|
cfg.train.save_freq = 1
|
||||||
|
cfg.train.rollout_validate_on_checkpoint = True
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||||
|
fake_scheduler = FakeScheduler()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
saved_paths = []
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
checkpoint_dir = Path('checkpoints')
|
||||||
|
checkpoint_dir.mkdir()
|
||||||
|
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||||
|
resume_path.write_bytes(b'resume')
|
||||||
|
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||||
|
best_path.write_bytes(b'best')
|
||||||
|
cfg.train.resume_ckpt = str(resume_path)
|
||||||
|
|
||||||
|
resume_checkpoint_state = {
|
||||||
|
'step': 0,
|
||||||
|
'model_state_dict': FakeAgent().state_dict(),
|
||||||
|
'optimizer_state_dict': {},
|
||||||
|
'scheduler_state_dict': {},
|
||||||
|
'loss': 0.5,
|
||||||
|
'val_loss': 0.25,
|
||||||
|
}
|
||||||
|
best_checkpoint_state = {
|
||||||
|
'step': 0,
|
||||||
|
'model_state_dict': FakeAgent().state_dict(),
|
||||||
|
'optimizer_state_dict': {},
|
||||||
|
'scheduler_state_dict': {},
|
||||||
|
'loss': 0.5,
|
||||||
|
'val_loss': 0.25,
|
||||||
|
'rollout_avg_reward': 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_torch_load(path, map_location=None):
|
||||||
|
del map_location
|
||||||
|
path = Path(path)
|
||||||
|
if path == resume_path:
|
||||||
|
return resume_checkpoint_state
|
||||||
|
if path == best_path:
|
||||||
|
return best_checkpoint_state
|
||||||
|
raise AssertionError(f'unexpected load path: {path}')
|
||||||
|
|
||||||
|
def fake_torch_save(payload, path):
|
||||||
|
saved_paths.append(str(path))
|
||||||
|
return None
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
|
||||||
|
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
|
||||||
|
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
|
||||||
|
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
|
||||||
|
|
||||||
|
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
cfg.train.max_steps = 1
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
|
||||||
|
fake_scheduler = FakeScheduler()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
checkpoint_dir = Path('checkpoints')
|
||||||
|
checkpoint_dir.mkdir()
|
||||||
|
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
|
||||||
|
resume_path.write_bytes(b'resume')
|
||||||
|
best_path = checkpoint_dir / 'vla_model_best.pt'
|
||||||
|
best_path.write_bytes(b'stale')
|
||||||
|
cfg.train.resume_ckpt = str(resume_path)
|
||||||
|
|
||||||
|
resume_checkpoint_state = {
|
||||||
|
'step': 0,
|
||||||
|
'model_state_dict': FakeAgent().state_dict(),
|
||||||
|
'optimizer_state_dict': {},
|
||||||
|
'scheduler_state_dict': {},
|
||||||
|
'loss': 0.5,
|
||||||
|
'val_loss': 0.25,
|
||||||
|
}
|
||||||
|
stale_best_checkpoint_state = {
|
||||||
|
'step': 0,
|
||||||
|
'model_state_dict': FakeAgent().state_dict(),
|
||||||
|
'optimizer_state_dict': {},
|
||||||
|
'scheduler_state_dict': {},
|
||||||
|
'loss': 0.4,
|
||||||
|
'val_loss': 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_torch_load(path, map_location=None):
|
||||||
|
del map_location
|
||||||
|
path = Path(path)
|
||||||
|
if path == resume_path:
|
||||||
|
return resume_checkpoint_state
|
||||||
|
if path == best_path:
|
||||||
|
return stale_best_checkpoint_state
|
||||||
|
raise AssertionError(f'unexpected load path: {path}')
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
|
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
|
||||||
|
|
||||||
|
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
cfg.train.max_steps = 1
|
||||||
|
fake_swanlab = FakeSwanLab()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return FakeAgent()
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
checkpoint_dir = Path('checkpoints')
|
||||||
|
checkpoint_dir.mkdir()
|
||||||
|
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
run_training(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
final_payload, final_step = fake_swanlab.log_calls[-1]
|
||||||
|
self.assertEqual(final_step, cfg.train.max_steps)
|
||||||
|
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
|
||||||
|
|
||||||
|
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
raise ImportError('missing swanlab')
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
|
||||||
|
run_training(cfg)
|
||||||
|
|
||||||
|
def test_run_training_fails_fast_when_swanlab_init_fails(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
run_training = self._get_run_training(module)
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
|
||||||
|
real_import_module = importlib.import_module
|
||||||
|
|
||||||
|
def fake_import_module(name, package=None):
|
||||||
|
if name == 'swanlab':
|
||||||
|
return fake_swanlab
|
||||||
|
return real_import_module(name, package)
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
|
||||||
|
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
|
||||||
|
run_training(cfg)
|
||||||
|
|
||||||
|
self.assertEqual(fake_swanlab.finish_calls, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
310
tests/test_train_vla_transformer_optimizer.py
Normal file
310
tests/test_train_vla_transformer_optimizer.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
return self[name]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise AttributeError(name) from exc
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
self[name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDataset:
|
||||||
|
def __len__(self):
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLoader:
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(())
|
||||||
|
|
||||||
|
|
||||||
|
class FakeScheduler:
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingAdamW:
|
||||||
|
created = []
|
||||||
|
|
||||||
|
def __init__(self, params, lr, weight_decay):
|
||||||
|
self.lr = lr
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
|
||||||
|
RecordingAdamW.created.append(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_param_groups(params, lr, weight_decay):
|
||||||
|
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
|
||||||
|
groups = []
|
||||||
|
for group in params:
|
||||||
|
normalized = dict(group)
|
||||||
|
normalized['params'] = list(group['params'])
|
||||||
|
normalized.setdefault('lr', lr)
|
||||||
|
groups.append(normalized)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
return [{
|
||||||
|
'params': list(params),
|
||||||
|
'lr': lr,
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
}]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingTransformerHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(4, 4)
|
||||||
|
self.norm = nn.LayerNorm(4)
|
||||||
|
self.optim_group_calls = []
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay):
|
||||||
|
self.optim_group_calls.append(weight_decay)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'params': [self.proj.weight],
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTransformerAgent(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.head_type = 'transformer'
|
||||||
|
self.noise_pred_net = RecordingTransformerHead()
|
||||||
|
self.backbone = nn.Linear(4, 3)
|
||||||
|
self.adapter = nn.Linear(3, 2, bias=False)
|
||||||
|
self.frozen = nn.Linear(2, 2)
|
||||||
|
for param in self.frozen.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
RecordingAdamW.created = []
|
||||||
|
|
||||||
|
def _load_train_vla_module(self):
|
||||||
|
hydra_module = types.ModuleType('hydra')
|
||||||
|
hydra_utils_module = types.ModuleType('hydra.utils')
|
||||||
|
hydra_utils_module.instantiate = lambda *args, **kwargs: None
|
||||||
|
|
||||||
|
def hydra_main(**_kwargs):
|
||||||
|
def decorator(func):
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
hydra_module.main = hydra_main
|
||||||
|
hydra_module.utils = hydra_utils_module
|
||||||
|
|
||||||
|
class OmegaConfStub:
|
||||||
|
_resolvers = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_resolver(cls, name):
|
||||||
|
return name in cls._resolvers
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_new_resolver(cls, name, resolver):
|
||||||
|
cls._resolvers[name] = resolver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_yaml(_cfg):
|
||||||
|
return 'stub-config'
|
||||||
|
|
||||||
|
omegaconf_module = types.ModuleType('omegaconf')
|
||||||
|
omegaconf_module.DictConfig = dict
|
||||||
|
omegaconf_module.OmegaConf = OmegaConfStub
|
||||||
|
|
||||||
|
module_name = 'train_vla_optimizer_test_module'
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
with mock.patch.dict(
|
||||||
|
sys.modules,
|
||||||
|
{
|
||||||
|
'hydra': hydra_module,
|
||||||
|
'hydra.utils': hydra_utils_module,
|
||||||
|
'omegaconf': omegaconf_module,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def _make_cfg(self):
|
||||||
|
return AttrDict(
|
||||||
|
train=AttrDict(
|
||||||
|
device='cpu',
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=0,
|
||||||
|
val_split=0,
|
||||||
|
seed=0,
|
||||||
|
lr=1e-4,
|
||||||
|
max_steps=0,
|
||||||
|
log_freq=1,
|
||||||
|
save_freq=100,
|
||||||
|
warmup_steps=1,
|
||||||
|
scheduler_type='constant',
|
||||||
|
min_lr=0.0,
|
||||||
|
grad_clip=1.0,
|
||||||
|
weight_decay=0.123,
|
||||||
|
pretrained_ckpt=None,
|
||||||
|
resume_ckpt=None,
|
||||||
|
),
|
||||||
|
data=AttrDict(
|
||||||
|
camera_names=('front',),
|
||||||
|
),
|
||||||
|
agent=AttrDict(
|
||||||
|
_target_='fake.agent',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _group_names(self, agent, optimizer):
|
||||||
|
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
|
||||||
|
return [
|
||||||
|
{names_by_param_id[id(param)] for param in group['params']}
|
||||||
|
for group in optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeTransformerAgent()
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||||
|
module.main(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||||
|
|
||||||
|
optimizer = RecordingAdamW.created[-1]
|
||||||
|
trainable_names = {
|
||||||
|
name for name, param in agent.named_parameters() if param.requires_grad
|
||||||
|
}
|
||||||
|
grouped_names = self._group_names(agent, optimizer)
|
||||||
|
optimizer_names = set().union(*grouped_names)
|
||||||
|
expected_head_names = {
|
||||||
|
'noise_pred_net.proj.weight',
|
||||||
|
'noise_pred_net.proj.bias',
|
||||||
|
'noise_pred_net.norm.weight',
|
||||||
|
'noise_pred_net.norm.bias',
|
||||||
|
}
|
||||||
|
expected_non_head_names = {
|
||||||
|
'backbone.weight',
|
||||||
|
'backbone.bias',
|
||||||
|
'adapter.weight',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(grouped_names[2], expected_non_head_names)
|
||||||
|
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
|
||||||
|
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
|
||||||
|
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
|
||||||
|
self.assertEqual(optimizer_names, trainable_names)
|
||||||
|
|
||||||
|
flattened_param_ids = [
|
||||||
|
id(param)
|
||||||
|
for group in optimizer.param_groups
|
||||||
|
for param in group['params']
|
||||||
|
]
|
||||||
|
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
|
||||||
|
self.assertNotIn('frozen.weight', optimizer_names)
|
||||||
|
self.assertNotIn('frozen.bias', optimizer_names)
|
||||||
|
|
||||||
|
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeTransformerAgent()
|
||||||
|
agent.noise_pred_net.norm.bias.requires_grad = False
|
||||||
|
cfg = self._make_cfg()
|
||||||
|
|
||||||
|
def fake_instantiate(config_node, **_kwargs):
|
||||||
|
if config_node is cfg.data:
|
||||||
|
return FakeDataset()
|
||||||
|
if config_node is cfg.agent:
|
||||||
|
return agent
|
||||||
|
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
previous_cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(tempdir)
|
||||||
|
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||||
|
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
|
||||||
|
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||||
|
mock.patch.object(module, 'AdamW', RecordingAdamW), \
|
||||||
|
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||||
|
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
|
||||||
|
module.main(cfg)
|
||||||
|
finally:
|
||||||
|
os.chdir(previous_cwd)
|
||||||
|
|
||||||
|
optimizer = RecordingAdamW.created[-1]
|
||||||
|
optimizer_names = set().union(*self._group_names(agent, optimizer))
|
||||||
|
trainable_names = {
|
||||||
|
name for name, param in agent.named_parameters() if param.requires_grad
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
|
||||||
|
self.assertEqual(optimizer_names, trainable_names)
|
||||||
|
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
262
tests/test_transformer1d_external_alignment.py
Normal file
262
tests/test_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
import contextlib
|
||||||
|
import importlib.util
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
|
||||||
|
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
|
||||||
|
_TRANSFORMER_WARNING_MESSAGE = (
|
||||||
|
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
|
||||||
|
r'because encoder_layer\.norm_first was True'
|
||||||
|
)
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
|
||||||
|
spec = importlib.util.spec_from_file_location(name, path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
if register:
|
||||||
|
sys.modules[name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_external_module_paths(external_checkout_root: Path):
|
||||||
|
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
|
||||||
|
paths = {
|
||||||
|
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
|
||||||
|
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
|
||||||
|
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
|
||||||
|
}
|
||||||
|
if not all(path.exists() for path in paths.values()):
|
||||||
|
return None
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _temporary_registered_modules():
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def remember(name: str) -> None:
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
|
||||||
|
def ensure_package(name: str) -> None:
|
||||||
|
if not name or name in sys.modules:
|
||||||
|
return
|
||||||
|
remember(name)
|
||||||
|
package = types.ModuleType(name)
|
||||||
|
package.__path__ = []
|
||||||
|
sys.modules[name] = package
|
||||||
|
|
||||||
|
def load(name: str, path: Path):
|
||||||
|
package_parts = name.split('.')[:-1]
|
||||||
|
for idx in range(1, len(package_parts) + 1):
|
||||||
|
ensure_package('.'.join(package_parts[:idx]))
|
||||||
|
|
||||||
|
remember(name)
|
||||||
|
return _load_module_from_path(name, path, register=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield load
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _suppress_nested_tensor_warning():
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
'ignore',
|
||||||
|
message=_TRANSFORMER_WARNING_MESSAGE,
|
||||||
|
category=UserWarning,
|
||||||
|
module=r'torch\.nn\.modules\.transformer',
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_module():
|
||||||
|
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer1DExternalAlignmentTest(unittest.TestCase):
|
||||||
|
def _load_transformer_classes_or_skip(self):
|
||||||
|
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||||
|
if external_paths is None:
|
||||||
|
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||||
|
|
||||||
|
local_module = _load_local_module()
|
||||||
|
with _temporary_registered_modules() as load_external:
|
||||||
|
load_external(
|
||||||
|
'diffusion_policy.model.diffusion.positional_embedding',
|
||||||
|
external_paths['positional_embedding'],
|
||||||
|
)
|
||||||
|
load_external(
|
||||||
|
'diffusion_policy.model.common.module_attr_mixin',
|
||||||
|
external_paths['module_attr_mixin'],
|
||||||
|
)
|
||||||
|
external_module = load_external(
|
||||||
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||||
|
external_paths['transformer_for_diffusion'],
|
||||||
|
)
|
||||||
|
|
||||||
|
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
|
||||||
|
|
||||||
|
def _optim_group_names(self, model, groups):
|
||||||
|
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||||
|
return [
|
||||||
|
{names_by_param[id(param)] for param in group['params']}
|
||||||
|
for group in groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_missing_external_checkout_resolution_returns_none(self):
|
||||||
|
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
|
||||||
|
|
||||||
|
def test_external_loader_restores_injected_sys_modules(self):
|
||||||
|
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
|
||||||
|
if external_paths is None:
|
||||||
|
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
|
||||||
|
|
||||||
|
watched_names = [
|
||||||
|
'diffusion_policy',
|
||||||
|
'diffusion_policy.model',
|
||||||
|
'diffusion_policy.model.common',
|
||||||
|
'diffusion_policy.model.common.module_attr_mixin',
|
||||||
|
'diffusion_policy.model.diffusion',
|
||||||
|
'diffusion_policy.model.diffusion.positional_embedding',
|
||||||
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||||
|
]
|
||||||
|
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||||
|
|
||||||
|
with _temporary_registered_modules() as load_external:
|
||||||
|
load_external(
|
||||||
|
'diffusion_policy.model.diffusion.positional_embedding',
|
||||||
|
external_paths['positional_embedding'],
|
||||||
|
)
|
||||||
|
load_external(
|
||||||
|
'diffusion_policy.model.common.module_attr_mixin',
|
||||||
|
external_paths['module_attr_mixin'],
|
||||||
|
)
|
||||||
|
load_external(
|
||||||
|
'diffusion_policy.model.diffusion.transformer_for_diffusion',
|
||||||
|
external_paths['transformer_for_diffusion'],
|
||||||
|
)
|
||||||
|
|
||||||
|
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
|
||||||
|
self.assertEqual(after, before)
|
||||||
|
|
||||||
|
def test_transformer1d_preserves_local_direct_call_defaults(self):
|
||||||
|
local_module = _load_local_module()
|
||||||
|
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
|
||||||
|
helper = inspect.signature(local_module.create_transformer1d).parameters
|
||||||
|
|
||||||
|
self.assertEqual(ctor['n_layer'].default, 8)
|
||||||
|
self.assertEqual(ctor['n_head'].default, 8)
|
||||||
|
self.assertEqual(ctor['n_emb'].default, 256)
|
||||||
|
self.assertEqual(helper['n_layer'].default, 8)
|
||||||
|
self.assertEqual(helper['n_head'].default, 8)
|
||||||
|
self.assertEqual(helper['n_emb'].default, 256)
|
||||||
|
|
||||||
|
def test_time_as_cond_false_token_accounting_matches_external(self):
|
||||||
|
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||||
|
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
input_dim=4,
|
||||||
|
output_dim=4,
|
||||||
|
horizon=6,
|
||||||
|
n_obs_steps=3,
|
||||||
|
cond_dim=0,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=2,
|
||||||
|
n_emb=8,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=False,
|
||||||
|
time_as_cond=False,
|
||||||
|
obs_as_cond=False,
|
||||||
|
n_cond_layers=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(5)
|
||||||
|
with _suppress_nested_tensor_warning():
|
||||||
|
external_model = TransformerForDiffusion(**config)
|
||||||
|
local_model = Transformer1D(**config)
|
||||||
|
external_model.eval()
|
||||||
|
local_model.eval()
|
||||||
|
|
||||||
|
self.assertEqual(local_model.T, external_model.T)
|
||||||
|
self.assertEqual(local_model.T_cond, external_model.T_cond)
|
||||||
|
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
|
||||||
|
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
|
||||||
|
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
|
||||||
|
|
||||||
|
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
|
||||||
|
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
|
||||||
|
config = dict(
|
||||||
|
input_dim=4,
|
||||||
|
output_dim=4,
|
||||||
|
horizon=6,
|
||||||
|
n_obs_steps=3,
|
||||||
|
cond_dim=5,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=2,
|
||||||
|
n_emb=8,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=False,
|
||||||
|
obs_as_cond=True,
|
||||||
|
n_cond_layers=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(7)
|
||||||
|
with _suppress_nested_tensor_warning():
|
||||||
|
external_model = TransformerForDiffusion(**config)
|
||||||
|
local_model = Transformer1D(**config)
|
||||||
|
external_model.eval()
|
||||||
|
local_model.eval()
|
||||||
|
|
||||||
|
external_state_dict = external_model.state_dict()
|
||||||
|
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||||
|
local_model.load_state_dict(external_state_dict, strict=True)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||||
|
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||||
|
timestep = torch.tensor([11, 17], dtype=torch.long)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
|
||||||
|
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
|
||||||
|
|
||||||
|
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||||
|
self.assertEqual(local_out.shape, external_out.shape)
|
||||||
|
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||||
|
|
||||||
|
weight_decay = 0.123
|
||||||
|
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
|
||||||
|
self.assertEqual(len(local_groups), len(external_groups))
|
||||||
|
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||||
|
self.assertEqual(
|
||||||
|
self._optim_group_names(local_model, local_groups),
|
||||||
|
self._optim_group_names(external_model, external_groups),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user