Files
roboimi/roboimi/vla/agent.py

402 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import numpy as np
from collections import deque
from typing import Dict, Optional, Any, Tuple
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 视觉编码器ResNet 等)
state_encoder,
action_encoder,
head,
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100, # DDPM 加噪步数
inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
):
super().__init__()
# 保存参数
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
# 归一化模块 - 统一训练和推理的归一化逻辑
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
# global_cond_dim: 展平后的总维度用于UNet
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# per_step_cond_dim: 每步的条件维度用于Transformer
# 注意这里不乘以obs_horizon因为Transformer的输入是序列形式
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
# DDIM 调度器用于快速推理
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2',
clip_sample=True,
prediction_type='epsilon'
)
# 根据head类型初始化不同的参数
if head_type == 'transformer':
# 如果head已经是nn.Module实例直接使用否则需要初始化
if isinstance(head, nn.Module):
# 已经是实例化的模块测试时直接传入<E4BCA0><E585A5>
self.noise_pred_net = head
else:
# Hydra部分初始化的对象调用时传入参数
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim # 每步的条件维度
)
else: # 'unet' (default)
# UNet接口: input_dim, global_cond_dim
self.noise_pred_net = head(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# 初始化队列(用于在线推理)
self.reset()
def _get_model_device(self) -> torch.device:
"""获取模型当前所在设备。"""
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
"""递归地将张量数据移动到指定设备。"""
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
"""
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
B = actions.shape[0]
# 归一化 states (qpos) 和 actions
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
state_features = self.state_encoder(states)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
action_features = self.action_encoder(actions)
# 2. 采样噪声
noise = torch.randn_like(action_features)
# 3. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=action_features.device
).long()
# 4. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
action_features, noise, timesteps
)
# 拼接全局条件并展平
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1)
# 5. 网络预测噪声根据head类型选择接口
if self.head_type == 'transformer':
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
# 将展平的global_cond reshape回序列格式
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond
)
else: # 'unet'
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 6. 计算 Loss (MSE),支持 padding mask
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
# 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
# ==========================
# 队列管理 (Queue Management)
# ==========================
def reset(self):
"""清空观测和动作队列。应在 env.reset() 时调用"""
self._queues = {
'qpos': deque(maxlen=self.obs_horizon),
'images': deque(maxlen=self.obs_horizon),
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
"""
将新的观测添加到队列中。
Args:
observation: 包含 'qpos''images' 的字典
"""
# 添加本体感知
if 'qpos' in observation:
self._queues['qpos'].append(observation['qpos'].clone())
# 添加图像
if 'images' in observation:
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
"""
从队列中准备用于推理的批量观测。
如果队列未满(首次调用时),用最新观测重复填充。
Returns:
batch: 包含堆叠后的历史观测的字典
"""
# 堆叠历史本体感知
qpos_list = list(self._queues['qpos'])
if len(qpos_list) == 0:
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
# 堆叠历史图像
images_list = list(self._queues['images'])
if len(images_list) == 0:
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
# 如果队列未满,用最后一个观测填充
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
return {'qpos': batch_qpos, 'images': batch_images}
# ==========================
# 在线推理 (Online Inference)
# ==========================
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
根据当前观测选择单个动作。
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
- 缓存 `obs_horizon` 步的历史观测
- Diffusion 模型生成 `pred_horizon` 步的动作
- 实际执行 `num_action_steps` 步动作
示意图:
--------------------------------------------------------------
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
--------------------------------------------------------------
Args:
observation: 包含 'qpos''images' 的字典
Returns:
action: (action_dim,) 单个动作
"""
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
# 避免根据CPU观测把模型错误搬回CPU。
device = self._get_model_device()
observation = self._move_to_device(observation, device)
# 将新观测添加到队列
self._populate_queues(observation)
# 如果动作队列为空,生成新的动作序列
if len(self._queues['action']) == 0:
# 从队列准备批量观测
batch = self._prepare_observation_batch()
# 生成动作块
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
# 提取可执行的动作部分
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
# 将动作添加到队列
for i in range(executable_actions.shape[1]):
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
# 从队列中取出一个动作
action = self._queues['action'].popleft() # (action_dim,)
return action
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
预测一个动作块(用于在线推理)。
Args:
batch: 包含 'qpos''images' 的字典
- qpos: (B, obs_horizon, obs_dim)
- images: Dict[str, (B, obs_horizon, C, H, W)]
Returns:
actions: (B, pred_horizon, action_dim) 预测的动作序列
"""
return self.predict_action(batch['images'], batch['qpos'])
# ==========================
# 批量推理 (Batch Inference - 原有方法)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
"""
批量预测动作序列(用于训练和离线评估)
Args:
images: 图像观测字典
proprioception: 本体感知观测 (qpos)
Returns:
denormalized_actions: 反归一化后的动作序列
"""
B = proprioception.shape[0]
# 归一化 proprioception (qpos)
proprioception = self.normalization.normalize_qpos(proprioception)
# 1. 提取当前观测特征(只提取一次)
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception)
# 拼接条件(只计算一次)
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond_flat = global_cond.flatten(start_dim=1)
if self.head_type == 'transformer':
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
else:
cond = None
# 2. 初始化纯高斯噪声动作
# 形状: (B, pred_horizon, action_dim)
device = visual_features.device
current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
for t in self.infer_scheduler.timesteps:
model_input = current_actions
# 预测噪声根据head类型选择接口
if self.head_type == 'transformer':
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
cond=cond
)
else: # 'unet'
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond_flat
)
# 移除噪声,更新 current_actions
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
# 4. 反归一化动作序列
denormalized_actions = self.normalization.denormalize_action(current_actions)
return denormalized_actions
def get_normalization_stats(self):
"""获取归一化统计信息(用于保存到 checkpoint"""
return self.normalization.get_stats()