Files
roboimi/roboimi/vla/agent.py
2026-02-05 14:08:43 +08:00

128 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
from typing import Dict, Optional, Any
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 你之前定义的 ResNet 类
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100,
num_cams=2, # 视觉输入的摄像头数量
):
super().__init__()
self.vision_encoder = vision_backbone
single_img_feat_dim = self.vision_encoder.output_dim
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
self.noise_pred_net = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
batch: 包含 images, qpos (proprioception), action
"""
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
B = gt_actions.shape[0]
images = batch['images']
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
# 2. 融合特征 -> 全局条件 (Global Conditioning)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 3. 采样噪声
noise = torch.randn_like(gt_actions)
# 4. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=gt_actions.device
).long()
# 5. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
gt_actions, noise, timesteps
)
# 6. 网络预测噪声
# 注意U-Net 1D 通常期望 channel 在中间: (B, C, T)
# noisy_actions_inp = noisy_actions.permute(0, 2, 1)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 还原维度 (B, T, C)
pred_noise = pred_noise.permute(0, 2, 1)
# 7. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise)
return loss
# ==========================
# 推理阶段 (Inference)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
B = 1 # 假设单次推理
# 1. 提取当前观测特征 (只做一次)
visual_features = self.vision_encoder(images).view(B, -1)
proprioception = proprioception.view(B, -1)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 2. 初始化纯高斯噪声动作
# Shape: (B, Horizon, Action_Dim)
current_actions = torch.randn(
(B, 16, 7), device=global_cond.device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM)
for t in self.noise_scheduler.timesteps:
# 调整输入格式适应 1D CNN
model_input = current_actions.permute(0, 2, 1)
# 预测噪声
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond
)
# noise_pred = noise_pred.permute(0, 2, 1)
# 移除噪声,更新 current_actions
current_actions = self.noise_scheduler.step(
noise_pred, t, current_actions
).prev_sample
# 4. 输出最终动作序列
return current_actions # 返回去噪后的干净动作