Files
roboimi/roboimi/vla/agent.py

447 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import numpy as np
from collections import deque
from typing import Dict, Optional, Any, Tuple
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 视觉编码器ResNet 等)
state_encoder,
action_encoder,
head,
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100, # DDPM 加噪步数
inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量
camera_names: Optional[Tuple[str, ...]] = None, # 条件相机顺序
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
):
super().__init__()
# 保存参数
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
agent_camera_names = tuple(camera_names) if camera_names is not None else None
backbone_camera_names = getattr(vision_backbone, 'camera_names', None)
backbone_camera_names = tuple(backbone_camera_names) if backbone_camera_names is not None else None
backbone_num_cameras = getattr(vision_backbone, 'num_cameras', None)
if backbone_num_cameras is not None and backbone_num_cameras != self.num_cams:
raise ValueError(
f"agent.num_cams({self.num_cams}) 与 "
f"vision_backbone.num_cameras({backbone_num_cameras}) 不一致"
)
if (
agent_camera_names is not None
and backbone_camera_names is not None
and agent_camera_names != backbone_camera_names
):
raise ValueError(
f"agent.camera_names({list(agent_camera_names)}) 与 "
f"vision_backbone.camera_names({list(backbone_camera_names)}) 不一致"
)
self.camera_names = (
agent_camera_names if agent_camera_names is not None else backbone_camera_names
)
if self.camera_names is not None and len(self.camera_names) != self.num_cams:
raise ValueError(
f"camera_names 长度({len(self.camera_names)})与 num_cams({self.num_cams})不一致"
)
# 归一化模块 - 统一训练和推理的归一化逻辑
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type
)
self.vision_encoder = vision_backbone
if self.camera_names is not None:
self.vision_encoder.camera_names = self.camera_names
single_cam_feat_dim = self.vision_encoder.output_dim
# global_cond_dim: 展平后的总维度用于UNet
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# per_step_cond_dim: 每步的条件维度用于Transformer
# 注意这里不乘以obs_horizon因为Transformer的输入是序列形式
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
# DDIM 调度器用于快速推理
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2',
clip_sample=True,
prediction_type='epsilon'
)
# 根据head类型初始化不同的参数
if head_type == 'transformer':
# 如果head已经是nn.Module实例直接使用否则需要初始化
if isinstance(head, nn.Module):
# 已经是实例化的模块测试时直接传入<E4BCA0><E585A5>
self.noise_pred_net = head
else:
# Hydra部分初始化的对象调用时传入参数
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim # 每步的条件维度
)
else: # 'unet' (default)
# UNet接口: input_dim, global_cond_dim
self.noise_pred_net = head(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
self.reset()
def _get_model_device(self) -> torch.device:
"""获取模型当前所在设备。"""
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
"""递归地将张量数据移动到指定设备。"""
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
def _order_images(self, images: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""按显式配置的相机顺序返回图像字典。"""
if self.camera_names is None:
camera_names = tuple(sorted(images.keys()))
if len(camera_names) != self.num_cams:
raise ValueError(
f"图像条件相机数量({len(camera_names)})与 num_cams({self.num_cams})不一致"
)
return {cam_name: images[cam_name] for cam_name in camera_names}
missing = [cam_name for cam_name in self.camera_names if cam_name not in images]
if missing:
raise ValueError(
f"图像条件缺少必需相机。missing={missing}, expected={list(self.camera_names)}"
)
return {cam_name: images[cam_name] for cam_name in self.camera_names}
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
"""构造每步条件,确保图像条件顺序稳定。"""
ordered_images = self._order_images(images)
visual_features = self.vision_encoder(ordered_images)
state_features = self.state_encoder(states)
cond = torch.cat([visual_features, state_features], dim=-1)
if cond.shape[-1] != self.per_step_cond_dim:
raise RuntimeError(
f"条件维度不匹配: got {cond.shape[-1]}, expected {self.per_step_cond_dim}"
)
return cond
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
"""
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
B = actions.shape[0]
# 归一化 states (qpos) 和 actions
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
# 1. 提取视觉特征
per_step_cond = self._build_cond(images, states)
action_features = self.action_encoder(actions)
# 2. 采样噪声
noise = torch.randn_like(action_features)
# 3. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=action_features.device
).long()
# 4. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
action_features, noise, timesteps
)
# 拼接全局条件并展平
# per_step_cond: (B, obs_horizon, vision_dim * num_cams + obs_dim)
# 展平后用于 UNet全序列形式用于 Transformer
global_cond = per_step_cond.flatten(start_dim=1)
# 5. 网络预测噪声根据head类型选择接口
if self.head_type == 'transformer':
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=per_step_cond
)
else: # 'unet'
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 6. 计算 Loss (MSE),支持 padding mask
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
# 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
# ==========================
# 队列管理 (Queue Management)
# ==========================
def reset(self):
"""清空观测和动作队列。应在 env.reset() 时调用"""
self._queues = {
'qpos': deque(maxlen=self.obs_horizon),
'images': deque(maxlen=self.obs_horizon),
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
"""
将新的观测添加到队列中。
Args:
observation: 包含 'qpos''images' 的字典
"""
# 添加本体感知
if 'qpos' in observation:
self._queues['qpos'].append(observation['qpos'].clone())
# 添加图像
if 'images' in observation:
ordered_images = self._order_images(observation['images'])
self._queues['images'].append({k: v.clone() for k, v in ordered_images.items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
"""
从队列中准备用于推理的批量观测。
如果队列未满(首次调用时),用最新观测重复填充。
Returns:
batch: 包含堆叠后的历史观测的字典
"""
# 堆叠历史本体感知
qpos_list = list(self._queues['qpos'])
if len(qpos_list) == 0:
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# 堆叠历史图像
images_list = list(self._queues['images'])
if len(images_list) == 0:
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
camera_names = self.camera_names if self.camera_names is not None else tuple(sorted(images_list[0].keys()))
for cam_name in camera_names:
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
return {'qpos': batch_qpos, 'images': batch_images}
# ==========================
# 在线推理 (Online Inference)
# ==========================
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
根据当前观测选择单个动作。
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
- 缓存 `obs_horizon` 步的历史观测
- Diffusion 模型生成 `pred_horizon` 步的动作
- 实际执行 `num_action_steps` 步动作
示意图:
--------------------------------------------------------------
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
--------------------------------------------------------------
Args:
observation: 包含 'qpos''images' 的字典
Returns:
action: (action_dim,) 单个动作
"""
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
# 避免根据CPU观测把模型错误搬回CPU。
device = self._get_model_device()
observation = self._move_to_device(observation, device)
# 将新观测添加到队列
self._populate_queues(observation)
# 如果动作队列为空,生成新的动作序列
if len(self._queues['action']) == 0:
# 从队列准备批量观测
batch = self._prepare_observation_batch()
# 生成动作块
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
# 提取可执行的动作部分
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
# 将动作添加到队列
for i in range(executable_actions.shape[1]):
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
# 从队列中取出一个动作
action = self._queues['action'].popleft() # (action_dim,)
return action
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
预测一个动作块(用于在线推理)。
Args:
batch: 包含 'qpos''images' 的字典
- qpos: (B, obs_horizon, obs_dim)
- images: Dict[str, (B, obs_horizon, C, H, W)]
Returns:
actions: (B, pred_horizon, action_dim) 预测的动作序列
"""
return self.predict_action(batch['images'], batch['qpos'])
# ==========================
# 批量推理 (Batch Inference - 原有方法)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
"""
批量预测动作序列(用于训练和离线评估)
Args:
images: 图像观测字典
proprioception: 本体感知观测 (qpos)
Returns:
denormalized_actions: 反归一化后的动作序列
"""
B = proprioception.shape[0]
# 归一化 proprioception (qpos)
proprioception = self.normalization.normalize_qpos(proprioception)
# 1. 提取当前观测特征(只提取一次)
per_step_cond = self._build_cond(images, proprioception)
# 拼接条件(只计算一次)
global_cond_flat = per_step_cond.flatten(start_dim=1)
if self.head_type == 'transformer':
cond = per_step_cond
else:
cond = None
# 2. 初始化纯高斯噪声动作
# 形状: (B, pred_horizon, action_dim)
device = per_step_cond.device
current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
for t in self.infer_scheduler.timesteps:
model_input = current_actions
# 预测噪声根据head类型选择接口
if self.head_type == 'transformer':
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
cond=cond
)
else: # 'unet'
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond_flat
)
# 移除噪声,更新 current_actions
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
# 4. 反归一化动作序列
denormalized_actions = self.normalization.denormalize_action(current_actions)
return denormalized_actions
def get_normalization_stats(self):
"""获取归一化统计信息(用于保存到 checkpoint"""
return self.normalization.get_stats()