import torch import torch.nn as nn import numpy as np from collections import deque from typing import Dict, Optional, Any, Tuple from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D from roboimi.vla.models.normalization import NormalizationModule class VLAAgent(nn.Module): def __init__( self, vision_backbone, # 视觉编码器(ResNet 等) state_encoder, action_encoder, head, action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper) obs_dim, # 本体感知维度 (例如 关节角度) pred_horizon=16, # 预测未来多少步动作 obs_horizon=4, # 使用多少步历史观测 diffusion_steps=100, # DDPM 加噪步数 inference_steps=10, # DDIM 推理步数 num_cams=3, # 视觉输入的摄像头数量 dataset_stats=None, # 数据集统计信息,用于归一化 normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max' num_action_steps=8, # 每次推理实际执行多少步动作 head_type='unet', # Policy head类型: 'unet' 或 'transformer' ): super().__init__() # 保存参数 self.action_dim = action_dim self.obs_dim = obs_dim self.pred_horizon = pred_horizon self.obs_horizon = obs_horizon self.num_cams = num_cams self.num_action_steps = num_action_steps self.inference_steps = inference_steps self.head_type = head_type # 'unet' 或 'transformer' # 归一化模块 - 统一训练和推理的归一化逻辑 self.normalization = NormalizationModule( stats=dataset_stats, normalization_type=normalization_type ) self.vision_encoder = vision_backbone single_cam_feat_dim = self.vision_encoder.output_dim # global_cond_dim: 展平后的总维度(用于UNet) total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon total_prop_dim = obs_dim * obs_horizon self.global_cond_dim = total_vision_dim + total_prop_dim # per_step_cond_dim: 每步的条件维度(用于Transformer) # 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式 self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim self.noise_scheduler = DDPMScheduler( num_train_timesteps=diffusion_steps, beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule clip_sample=True, prediction_type='epsilon' # 预测噪声 ) # DDIM 调度器用于快速推理 self.infer_scheduler = DDIMScheduler( num_train_timesteps=diffusion_steps, beta_schedule='squaredcos_cap_v2', clip_sample=True, prediction_type='epsilon' ) # 根据head类型初始化不同的参数 if head_type == 'transformer': # 如果head已经是nn.Module实例,直接使用;否则需要初始化 if isinstance(head, nn.Module): # 已经是实例化的模块(测试时直接传入�� self.noise_pred_net = head else: # Hydra部分初始化的对象,调用时传入参数 self.noise_pred_net = head( input_dim=action_dim, output_dim=action_dim, horizon=pred_horizon, n_obs_steps=obs_horizon, cond_dim=self.per_step_cond_dim # 每步的条件维度 ) else: # 'unet' (default) # UNet接口: input_dim, global_cond_dim self.noise_pred_net = head( input_dim=action_dim, global_cond_dim=self.global_cond_dim ) self.state_encoder = state_encoder self.action_encoder = action_encoder # 初始化队列(用于在线推理) self.reset() def _get_model_device(self) -> torch.device: """获取模型当前所在设备。""" return next(self.parameters()).device def _move_to_device(self, data, device: torch.device): """递归地将张量数据移动到指定设备。""" if torch.is_tensor(data): return data.to(device) if isinstance(data, dict): return {k: self._move_to_device(v, device) for k, v in data.items()} if isinstance(data, list): return [self._move_to_device(v, device) for v in data] if isinstance(data, tuple): return tuple(self._move_to_device(v, device) for v in data) return data # ========================== # 训练阶段 (Training) # ========================== def compute_loss(self, batch): """ 计算训练损失 Args: batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典 """ actions, states, images = batch['action'], batch['qpos'], batch['images'] action_is_pad = batch.get('action_is_pad', None) # 获取padding mask B = actions.shape[0] # 归一化 states (qpos) 和 actions states = self.normalization.normalize_qpos(states) actions = self.normalization.normalize_action(actions) state_features = self.state_encoder(states) # 1. 提取视觉特征 visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim) action_features = self.action_encoder(actions) # 2. 采样噪声 noise = torch.randn_like(action_features) # 3. 随机采样时间步 (Timesteps) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=action_features.device ).long() # 4. 给动作加噪 (Forward Diffusion) noisy_actions = self.noise_scheduler.add_noise( action_features, noise, timesteps ) # 拼接全局条件并展平 # visual_features: (B, obs_horizon, vision_dim) # state_features: (B, obs_horizon, obs_dim) # 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim)) global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond = global_cond.flatten(start_dim=1) # 5. 网络预测噪声(根据head类型选择接口) if self.head_type == 'transformer': # Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step) # 将展平的global_cond reshape回序列格式 cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) pred_noise = self.noise_pred_net( sample=noisy_actions, timestep=timesteps, cond=cond ) else: # 'unet' pred_noise = self.noise_pred_net( sample=noisy_actions, timestep=timesteps, global_cond=global_cond ) # 6. 计算 Loss (MSE),支持 padding mask loss = nn.functional.mse_loss(pred_noise, noise, reduction='none') # 如果提供了 action_is_pad,对padding位置进行mask if action_is_pad is not None: # action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim) mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据 valid_count = mask.sum() * loss.shape[-1] loss = (loss * mask).sum() / valid_count.clamp_min(1.0) else: loss = loss.mean() return loss # ========================== # 队列管理 (Queue Management) # ========================== def reset(self): """清空观测和动作队列。应在 env.reset() 时调用""" self._queues = { 'qpos': deque(maxlen=self.obs_horizon), 'images': deque(maxlen=self.obs_horizon), 'action': deque(maxlen=self.pred_horizon - self.obs_horizon + 1), # 可执行的动作缓存 } def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None: """ 将新的观测添加到队列中。 Args: observation: 包含 'qpos' 和 'images' 的字典 """ # 添加本体感知 if 'qpos' in observation: self._queues['qpos'].append(observation['qpos'].clone()) # 添加图像 if 'images' in observation: self._queues['images'].append({k: v.clone() for k, v in observation['images'].items()}) def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]: """ 从队列中准备用于推理的批量观测。 如果队列未满(首次调用时),用最新观测重复填充。 Returns: batch: 包含堆叠后的历史观测的字典 """ # 堆叠历史本体感知 qpos_list = list(self._queues['qpos']) if len(qpos_list) == 0: raise ValueError("观测队列为空,请先调用 _populate_queues 添加观测") # 如果队列未满,用最后一个观测填充 while len(qpos_list) < self.obs_horizon: qpos_list.append(qpos_list[-1]) batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) # (1, obs_horizon, obs_dim) # 堆叠历史图像 images_list = list(self._queues['images']) if len(images_list) == 0: raise ValueError("图像队列为空,请先调用 _populate_queues 添加观测") # 如果队列未满,用最后一个观测填充 while len(images_list) < self.obs_horizon: images_list.append(images_list[-1]) batch_images = {} for cam_name in images_list[0].keys(): batch_images[cam_name] = torch.stack([img[cam_name] for img in images_list], dim=0).unsqueeze(0) return {'qpos': batch_qpos, 'images': batch_images} # ========================== # 在线推理 (Online Inference) # ========================== @torch.no_grad() def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor: """ 根据当前观测选择单个动作。 这个方法维护一个历史观测和生成动作轨迹的缓存。工作流程: - 缓存 `obs_horizon` 步的历史观测 - Diffusion 模型生成 `pred_horizon` 步的动作 - 实际执行 `num_action_steps` 步动作 示意图: -------------------------------------------------------------- (图例: o=obs_horizon, h=pred_horizon, a=num_action_steps) |时间步 | 0 | 1 | ... | o-1 | o | ... | h-1 | |观测是否使用 | 是 | 是 | 是 | 是 | 否 | 否 | 否 | |动作是否生成 | 是 | 是 | 是 | 是 | 是 | 是 | 是 | |动作是否执行 | 否 | 否 | 否 | 否 | 是 | 是 | 是 | -------------------------------------------------------------- Args: observation: 包含 'qpos' 和 'images' 的字典 Returns: action: (action_dim,) 单个动作 """ # 使用模型当前设备作为唯一真值,将输入移动到模型设备 # 避免根据CPU观测把模型错误搬回CPU。 device = self._get_model_device() observation = self._move_to_device(observation, device) # 将新观测添加到队列 self._populate_queues(observation) # 如果动作队列为空,生成新的动作序列 if len(self._queues['action']) == 0: # 从队列准备批量观测 batch = self._prepare_observation_batch() # 生成动作块 actions = self.predict_action_chunk(batch) # (1, pred_horizon, action_dim) # 提取可执行的动作部分 # 从 obs_horizon-1 开始,因为前面的动作对应过去的观测 start = self.obs_horizon - 1 end = start + self.num_action_steps executable_actions = actions[:, start:end] # (1, num_action_steps, action_dim) # 将动作添加到队列 for i in range(executable_actions.shape[1]): self._queues['action'].append(executable_actions[:, i].squeeze(0)) # (action_dim,) # 从队列中取出一个动作 action = self._queues['action'].popleft() # (action_dim,) return action @torch.no_grad() def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ 预测一个动作块(用于在线推理)。 Args: batch: 包含 'qpos' 和 'images' 的字典 - qpos: (B, obs_horizon, obs_dim) - images: Dict[str, (B, obs_horizon, C, H, W)] Returns: actions: (B, pred_horizon, action_dim) 预测的动作序列 """ return self.predict_action(batch['images'], batch['qpos']) # ========================== # 批量推理 (Batch Inference - 原有方法) # ========================== @torch.no_grad() def predict_action(self, images, proprioception): """ 批量预测动作序列(用于训练和离线评估) Args: images: 图像观测字典 proprioception: 本体感知观测 (qpos) Returns: denormalized_actions: 反归一化后的动作序列 """ B = proprioception.shape[0] # 归一化 proprioception (qpos) proprioception = self.normalization.normalize_qpos(proprioception) # 1. 提取当前观测特征(只提取一次) visual_features = self.vision_encoder(images) state_features = self.state_encoder(proprioception) # 拼接条件(只计算一次) # visual_features: (B, obs_horizon, vision_dim) # state_features: (B, obs_horizon, obs_dim) global_cond = torch.cat([visual_features, state_features], dim=-1) global_cond_flat = global_cond.flatten(start_dim=1) if self.head_type == 'transformer': cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim) else: cond = None # 2. 初始化纯高斯噪声动作 # 形状: (B, pred_horizon, action_dim) device = visual_features.device current_actions = torch.randn( (B, self.pred_horizon, self.action_dim), device=device ) # 3. 逐步去噪循环 (Reverse Diffusion) self.infer_scheduler.set_timesteps(self.inference_steps) # DDIM 推理步数 for t in self.infer_scheduler.timesteps: model_input = current_actions # 预测噪声(根据head类型选择接口) if self.head_type == 'transformer': noise_pred = self.noise_pred_net( sample=model_input, timestep=t, cond=cond ) else: # 'unet' noise_pred = self.noise_pred_net( sample=model_input, timestep=t, global_cond=global_cond_flat ) # 移除噪声,更新 current_actions current_actions = self.infer_scheduler.step( noise_pred, t, current_actions ).prev_sample # 4. 反归一化动作序列 denormalized_actions = self.normalization.denormalize_action(current_actions) return denormalized_actions def get_normalization_stats(self): """获取归一化统计信息(用于保存到 checkpoint)""" return self.normalization.get_stats()