128 lines
4.6 KiB
Python
128 lines
4.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from typing import Dict, Optional, Any
|
||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
|
||
|
||
class VLAAgent(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
vision_backbone, # 你之前定义的 ResNet 类
|
||
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||
obs_dim, # 本体感知维度 (例如 关节角度)
|
||
pred_horizon=16, # 预测未来多少步动作
|
||
obs_horizon=4, # 使用多少步历史观测
|
||
diffusion_steps=100,
|
||
num_cams=2, # 视觉输入的摄像头数量
|
||
):
|
||
super().__init__()
|
||
self.vision_encoder = vision_backbone
|
||
single_img_feat_dim = self.vision_encoder.output_dim
|
||
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
||
total_prop_dim = obs_dim * obs_horizon
|
||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||
|
||
self.noise_scheduler = DDPMScheduler(
|
||
num_train_timesteps=diffusion_steps,
|
||
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||
clip_sample=True,
|
||
prediction_type='epsilon' # 预测噪声
|
||
)
|
||
|
||
self.noise_pred_net = ConditionalUnet1D(
|
||
input_dim=action_dim,
|
||
global_cond_dim=self.global_cond_dim
|
||
)
|
||
|
||
# ==========================
|
||
# 训练阶段 (Training)
|
||
# ==========================
|
||
def compute_loss(self, batch):
|
||
"""
|
||
batch: 包含 images, qpos (proprioception), action
|
||
"""
|
||
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
|
||
B = gt_actions.shape[0]
|
||
images = batch['images']
|
||
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
|
||
|
||
|
||
# 1. 提取视觉特征
|
||
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
|
||
|
||
# 2. 融合特征 -> 全局条件 (Global Conditioning)
|
||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||
|
||
# 3. 采样噪声
|
||
noise = torch.randn_like(gt_actions)
|
||
|
||
# 4. 随机采样时间步 (Timesteps)
|
||
timesteps = torch.randint(
|
||
0, self.noise_scheduler.config.num_train_timesteps,
|
||
(B,), device=gt_actions.device
|
||
).long()
|
||
|
||
# 5. 给动作加噪 (Forward Diffusion)
|
||
noisy_actions = self.noise_scheduler.add_noise(
|
||
gt_actions, noise, timesteps
|
||
)
|
||
|
||
# 6. 网络预测噪声
|
||
# 注意:U-Net 1D 通常期望 channel 在中间: (B, C, T)
|
||
# noisy_actions_inp = noisy_actions.permute(0, 2, 1)
|
||
|
||
pred_noise = self.noise_pred_net(
|
||
sample=noisy_actions,
|
||
timestep=timesteps,
|
||
global_cond=global_cond
|
||
)
|
||
|
||
# 还原维度 (B, T, C)
|
||
pred_noise = pred_noise.permute(0, 2, 1)
|
||
|
||
# 7. 计算 Loss (MSE)
|
||
loss = nn.functional.mse_loss(pred_noise, noise)
|
||
return loss
|
||
|
||
# ==========================
|
||
# 推理阶段 (Inference)
|
||
# ==========================
|
||
@torch.no_grad()
|
||
def predict_action(self, images, proprioception):
|
||
B = 1 # 假设单次推理
|
||
|
||
# 1. 提取当前观测特征 (只做一次)
|
||
visual_features = self.vision_encoder(images).view(B, -1)
|
||
proprioception = proprioception.view(B, -1)
|
||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||
|
||
# 2. 初始化纯高斯噪声动作
|
||
# Shape: (B, Horizon, Action_Dim)
|
||
current_actions = torch.randn(
|
||
(B, 16, 7), device=global_cond.device
|
||
)
|
||
|
||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||
self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM)
|
||
|
||
for t in self.noise_scheduler.timesteps:
|
||
# 调整输入格式适应 1D CNN
|
||
model_input = current_actions.permute(0, 2, 1)
|
||
|
||
# 预测噪声
|
||
noise_pred = self.noise_pred_net(
|
||
sample=model_input,
|
||
timestep=t,
|
||
global_cond=global_cond
|
||
)
|
||
# noise_pred = noise_pred.permute(0, 2, 1)
|
||
|
||
# 移除噪声,更新 current_actions
|
||
current_actions = self.noise_scheduler.step(
|
||
noise_pred, t, current_actions
|
||
).prev_sample
|
||
|
||
# 4. 输出最终动作序列
|
||
return current_actions # 返回去噪后的干净动作 |