402 lines
16 KiB
Python
402 lines
16 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
from collections import deque
|
||
from typing import Dict, Optional, Any, Tuple
|
||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||
from roboimi.vla.models.normalization import NormalizationModule
|
||
|
||
class VLAAgent(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
vision_backbone, # 视觉编码器(ResNet 等)
|
||
state_encoder,
|
||
action_encoder,
|
||
head,
|
||
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||
obs_dim, # 本体感知维度 (例如 关节角度)
|
||
pred_horizon=16, # 预测未来多少步动作
|
||
obs_horizon=4, # 使用多少步历史观测
|
||
diffusion_steps=100, # DDPM 加噪步数
|
||
inference_steps=10, # DDIM 推理步数
|
||
num_cams=3, # 视觉输入的摄像头数量
|
||
dataset_stats=None, # 数据集统计信息,用于归一化
|
||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||
):
|
||
super().__init__()
|
||
# 保存参数
|
||
self.action_dim = action_dim
|
||
self.obs_dim = obs_dim
|
||
self.pred_horizon = pred_horizon
|
||
self.obs_horizon = obs_horizon
|
||
self.num_cams = num_cams
|
||
self.num_action_steps = num_action_steps
|
||
self.inference_steps = inference_steps
|
||
self.head_type = head_type # 'unet' 或 'transformer'
|
||
|
||
|
||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||
self.normalization = NormalizationModule(
|
||
stats=dataset_stats,
|
||
normalization_type=normalization_type
|
||
)
|
||
|
||
self.vision_encoder = vision_backbone
|
||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||
# global_cond_dim: 展平后的总维度(用于UNet)
|
||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||
total_prop_dim = obs_dim * obs_horizon
|
||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||
|
||
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||
|
||
self.noise_scheduler = DDPMScheduler(
|
||
num_train_timesteps=diffusion_steps,
|
||
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||
clip_sample=True,
|
||
prediction_type='epsilon' # 预测噪声
|
||
)
|
||
|
||
# DDIM 调度器用于快速推理
|
||
self.infer_scheduler = DDIMScheduler(
|
||
num_train_timesteps=diffusion_steps,
|
||
beta_schedule='squaredcos_cap_v2',
|
||
clip_sample=True,
|
||
prediction_type='epsilon'
|
||
)
|
||
|
||
# 根据head类型初始化不同的参数
|
||
if head_type == 'transformer':
|
||
# 如果head已经是nn.Module实例,直接使用;否则需要初始化
|
||
if isinstance(head, nn.Module):
|
||
# 已经是实例化的模块(测试时直接传入<E4BCA0><E585A5>
|
||
self.noise_pred_net = head
|
||
else:
|
||
# Hydra部分初始化的对象,调用时传入参数
|
||
self.noise_pred_net = head(
|
||
input_dim=action_dim,
|
||
output_dim=action_dim,
|
||
horizon=pred_horizon,
|
||
n_obs_steps=obs_horizon,
|
||
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||
)
|
||
else: # 'unet' (default)
|
||
# UNet接口: input_dim, global_cond_dim
|
||
self.noise_pred_net = head(
|
||
input_dim=action_dim,
|
||
global_cond_dim=self.global_cond_dim
|
||
)
|
||
|
||
self.state_encoder = state_encoder
|
||
self.action_encoder = action_encoder
|
||
|
||
# 初始化队列(用于在线推理)
|
||
self.reset()
|
||
|
||
def _get_model_device(self) -> torch.device:
|
||
"""获取模型当前所在设备。"""
|
||
return next(self.parameters()).device
|
||
|
||
def _move_to_device(self, data, device: torch.device):
|
||
"""递归地将张量数据移动到指定设备。"""
|
||
if torch.is_tensor(data):
|
||
return data.to(device)
|
||
if isinstance(data, dict):
|
||
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||
if isinstance(data, list):
|
||
return [self._move_to_device(v, device) for v in data]
|
||
if isinstance(data, tuple):
|
||
return tuple(self._move_to_device(v, device) for v in data)
|
||
return data
|
||
|
||
|
||
# ==========================
|
||
# 训练阶段 (Training)
|
||
# ==========================
|
||
def compute_loss(self, batch):
|
||
"""
|
||
计算训练损失
|
||
|
||
Args:
|
||
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
|
||
"""
|
||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
|
||
B = actions.shape[0]
|
||
|
||
# 归一化 states (qpos) 和 actions
|
||
states = self.normalization.normalize_qpos(states)
|
||
actions = self.normalization.normalize_action(actions)
|
||
|
||
state_features = self.state_encoder(states)
|
||
|
||
# 1. 提取视觉特征
|
||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||
action_features = self.action_encoder(actions)
|
||
|
||
# 2. 采样噪声
|
||
noise = torch.randn_like(action_features)
|
||
|
||
# 3. 随机采样时间步 (Timesteps)
|
||
timesteps = torch.randint(
|
||
0, self.noise_scheduler.config.num_train_timesteps,
|
||
(B,), device=action_features.device
|
||
).long()
|
||
|
||
# 4. 给动作加噪 (Forward Diffusion)
|
||
noisy_actions = self.noise_scheduler.add_noise(
|
||
action_features, noise, timesteps
|
||
)
|
||
|
||
# 拼接全局条件并展平
|
||
# visual_features: (B, obs_horizon, vision_dim)
|
||
# state_features: (B, obs_horizon, obs_dim)
|
||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||
global_cond = global_cond.flatten(start_dim=1)
|
||
|
||
# 5. 网络预测噪声(根据head类型选择接口)
|
||
if self.head_type == 'transformer':
|
||
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||
# 将展平的global_cond reshape回序列格式
|
||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||
pred_noise = self.noise_pred_net(
|
||
sample=noisy_actions,
|
||
timestep=timesteps,
|
||
cond=cond
|
||
)
|
||
else: # 'unet'
|
||
pred_noise = self.noise_pred_net(
|
||
sample=noisy_actions,
|
||
timestep=timesteps,
|
||
global_cond=global_cond
|
||
)
|
||
|
||
# 6. 计算 Loss (MSE),支持 padding mask
|
||
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||
|
||
# 如果提供了 action_is_pad,对padding位置进行mask
|
||
if action_is_pad is not None:
|
||
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
|
||
valid_count = mask.sum() * loss.shape[-1]
|
||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||
else:
|
||
loss = loss.mean()
|
||
|
||
return loss
|
||
|
||
# ==========================
|
||
# 队列管理 (Queue Management)
|
||
# ==========================
|
||
def reset(self):
|
||
"""清空观测和动作队列。应在 env.reset() 时调用"""
|
||
self._queues = {
|
||
'qpos': deque(maxlen=self.obs_horizon),
|
||
'images': deque(maxlen=self.obs_horizon),
|
||
'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存
|
||
}
|
||
|
||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||
"""
|
||
将新的观测添加到队列中。
|
||
|
||
Args:
|
||
observation: 包含 'qpos' 和 'images' 的字典
|
||
"""
|
||
# 添加本体感知
|
||
if 'qpos' in observation:
|
||
self._queues['qpos'].append(observation['qpos'].clone())
|
||
|
||
# 添加图像
|
||
if 'images' in observation:
|
||
self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()})
|
||
|
||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||
"""
|
||
从队列中准备用于推理的批量观测。
|
||
如果队列未满(首次调用时),用最新观测重复填充。
|
||
|
||
Returns:
|
||
batch: 包含堆叠后的历史观测的字典
|
||
"""
|
||
# 堆叠历史本体感知
|
||
qpos_list = list(self._queues['qpos'])
|
||
if len(qpos_list) == 0:
|
||
raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测")
|
||
# 如果队列未满,用最后一个观测填充
|
||
while len(qpos_list) < self.obs_horizon:
|
||
qpos_list.append(qpos_list[-1])
|
||
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim)
|
||
|
||
# 堆叠历史图像
|
||
images_list = list(self._queues['images'])
|
||
if len(images_list) == 0:
|
||
raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测")
|
||
# 如果队列未满,用最后一个观测填充
|
||
while len(images_list) < self.obs_horizon:
|
||
images_list.append(images_list[-1])
|
||
|
||
batch_images = {}
|
||
for cam_name in images_list[0].keys():
|
||
batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0)
|
||
|
||
return {'qpos': batch_qpos, 'images': batch_images}
|
||
|
||
# ==========================
|
||
# 在线推理 (Online Inference)
|
||
# ==========================
|
||
@torch.no_grad()
|
||
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||
"""
|
||
根据当前观测选择单个动作。
|
||
|
||
这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程:
|
||
- 缓存 `obs_horizon` 步的历史观测
|
||
- Diffusion 模型生成 `pred_horizon` 步的动作
|
||
- 实际执行 `num_action_steps` 步动作
|
||
|
||
示意图:
|
||
--------------------------------------------------------------
|
||
(图例: o=obs_horizon, h=pred_horizon, a=num_action_steps)
|
||
|时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 |
|
||
|观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 |
|
||
|动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 |
|
||
|动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 |
|
||
--------------------------------------------------------------
|
||
|
||
Args:
|
||
observation: 包含 'qpos' 和 'images' 的字典
|
||
|
||
Returns:
|
||
action: (action_dim,) 单个动作
|
||
"""
|
||
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
|
||
# 避免根据CPU观测把模型错误搬回CPU。
|
||
device = self._get_model_device()
|
||
observation = self._move_to_device(observation, device)
|
||
|
||
# 将新观测添加到队列
|
||
self._populate_queues(observation)
|
||
|
||
# 如果动作队列为空,生成新的动作序列
|
||
if len(self._queues['action']) == 0:
|
||
# 从队列准备批量观测
|
||
batch = self._prepare_observation_batch()
|
||
|
||
# 生成动作块
|
||
actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim)
|
||
|
||
# 提取可执行的动作部分
|
||
# 从 obs_horizon-1 开始,因为前面的动作对应过去的观测
|
||
start = self.obs_horizon - 1
|
||
end = start + self.num_action_steps
|
||
executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim)
|
||
|
||
# 将动作添加到队列
|
||
for i in range(executable_actions.shape[1]):
|
||
self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,)
|
||
|
||
# 从队列中取出一个动作
|
||
action = self._queues['action'].popleft() # (action_dim,)
|
||
|
||
return action
|
||
|
||
@torch.no_grad()
|
||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||
"""
|
||
预测一个动作块(用于在线推理)。
|
||
|
||
Args:
|
||
batch: 包含 'qpos' 和 'images' 的字典
|
||
- qpos: (B, obs_horizon, obs_dim)
|
||
- images: Dict[str, (B, obs_horizon, C, H, W)]
|
||
|
||
Returns:
|
||
actions: (B, pred_horizon, action_dim) 预测的动作序列
|
||
"""
|
||
return self.predict_action(batch['images'], batch['qpos'])
|
||
|
||
# ==========================
|
||
# 批量推理 (Batch Inference - 原有方法)
|
||
# ==========================
|
||
@torch.no_grad()
|
||
def predict_action(self, images, proprioception):
|
||
"""
|
||
批量预测动作序列(用于训练和离线评估)
|
||
|
||
Args:
|
||
images: 图像观测字典
|
||
proprioception: 本体感知观测 (qpos)
|
||
|
||
Returns:
|
||
denormalized_actions: 反归一化后的动作序列
|
||
"""
|
||
B = proprioception.shape[0]
|
||
|
||
# 归一化 proprioception (qpos)
|
||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||
|
||
# 1. 提取当前观测特征(只提取一次)
|
||
visual_features = self.vision_encoder(images)
|
||
state_features = self.state_encoder(proprioception)
|
||
|
||
# 拼接条件(只计算一次)
|
||
# visual_features: (B, obs_horizon, vision_dim)
|
||
# state_features: (B, obs_horizon, obs_dim)
|
||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||
global_cond_flat = global_cond.flatten(start_dim=1)
|
||
if self.head_type == 'transformer':
|
||
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||
else:
|
||
cond = None
|
||
|
||
# 2. 初始化纯高斯噪声动作
|
||
# 形状: (B, pred_horizon, action_dim)
|
||
device = visual_features.device
|
||
current_actions = torch.randn(
|
||
(B, self.pred_horizon, self.action_dim), device=device
|
||
)
|
||
|
||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||
self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数
|
||
|
||
for t in self.infer_scheduler.timesteps:
|
||
model_input = current_actions
|
||
|
||
# 预测噪声(根据head类型选择接口)
|
||
if self.head_type == 'transformer':
|
||
noise_pred = self.noise_pred_net(
|
||
sample=model_input,
|
||
timestep=t,
|
||
cond=cond
|
||
)
|
||
else: # 'unet'
|
||
noise_pred = self.noise_pred_net(
|
||
sample=model_input,
|
||
timestep=t,
|
||
global_cond=global_cond_flat
|
||
)
|
||
|
||
# 移除噪声,更新 current_actions
|
||
current_actions = self.infer_scheduler.step(
|
||
noise_pred, t, current_actions
|
||
).prev_sample
|
||
|
||
# 4. 反归一化动作序列
|
||
denormalized_actions = self.normalization.denormalize_action(current_actions)
|
||
|
||
return denormalized_actions
|
||
|
||
def get_normalization_stats(self):
|
||
"""获取归一化统计信息(用于保存到 checkpoint)"""
|
||
return self.normalization.get_stats()
|