This commit is contained in:
gouhanke
2026-02-09 14:41:35 +08:00
parent f833c6d9f1
commit 8b700b6d99
10 changed files with 76 additions and 39 deletions

View File

@@ -5,13 +5,16 @@ from typing import Dict, Optional, Any
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
class VLAAgent(nn.Module):
def __init__(
self,
vision_backbone, # 你之前定义的 ResNet 类
state_encoder,
action_encoder,
head,
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
@@ -32,6 +35,7 @@ class VLAAgent(nn.Module):
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# self.global_cond_dim = total_vision_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
@@ -48,11 +52,16 @@ class VLAAgent(nn.Module):
prediction_type='epsilon'
)
self.noise_pred_net = ConditionalUnet1D(
self.noise_pred_net = head(
input_dim=action_dim,
# input_dim = action_dim + obs_dim,
global_cond_dim=self.global_cond_dim
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
# ==========================
# 训练阶段 (Training)
# ==========================
@@ -60,37 +69,35 @@ class VLAAgent(nn.Module):
"""
batch: 包含 images, qpos (proprioception), action
"""
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
B = gt_actions.shape[0]
images = batch['images']
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
actions, states, images = batch['action'], batch['qpos'], batch['images']
B = actions.shape[0]
state_features = self.state_encoder(states)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
# 2. 融合特征 -> 全局条件 (Global Conditioning)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
action_features = self.action_encoder(actions)
# 3. 采样噪声
noise = torch.randn_like(gt_actions)
noise = torch.randn_like(action_features)
# 4. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=gt_actions.device
(B,), device=action_features.device
).long()
# 5. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
gt_actions, noise, timesteps
action_features, noise, timesteps
)
# 6. 网络预测噪声
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
visual_features=visual_features,
proprioception=state_features
)
# 7. 计算 Loss (MSE)
@@ -102,17 +109,17 @@ class VLAAgent(nn.Module):
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
B = 1 # 假设单次推理
B = proprioception.shape[0]
# 1. 提取当前观测特征 (只做一次)
visual_features = self.vision_encoder(images).view(B, -1)
proprioception = proprioception.view(B, -1)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception)
# 2. 初始化纯高斯噪声动作
# Shape: (B, pred_horizon, action_dim)
device = visual_features.device
current_actions = torch.randn(
(B, self.pred_horizon, self.action_dim), device=global_cond.device
(B, self.pred_horizon, self.action_dim), device=device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
@@ -125,7 +132,8 @@ class VLAAgent(nn.Module):
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond
visual_features=visual_features,
proprioception=state_features
)
# 移除噪声,更新 current_actions

View File

@@ -1,11 +1,12 @@
# @package agent
_target_: roboimi.vla.agent.VLAAgent
defaults:
- /backbone@vision_backbone: resnet
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: conditional_unet1d
- _self_
# Vision Backbone: ResNet-18 with SpatialSoftmax
vision_backbone:
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true
_target_: roboimi.vla.agent.VLAAgent
# Action and Observation Dimensions
action_dim: 16
@@ -16,7 +17,7 @@ pred_horizon: 16
obs_horizon: 2
# Diffusion Parameters
diffusion_steps: 100 # Number of diffusion timesteps for training
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
# Camera Configuration
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取

View File

@@ -1,4 +1,3 @@
# @package agent.backbone
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"

View File

@@ -0,0 +1,5 @@
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
_partial_: true
kernel_size: 3
cond_predict_scale: false

View File

@@ -1,8 +0,0 @@
_target_: roboimi.vla.models.heads.DiffusionActionHead
# 显式声明必填参数
input_dim: ??? # 等待 agent/default.yaml 填充
action_dim: 7
obs_horizon: 2
pred_horizon: 16
denoising_steps: 100

View File

@@ -0,0 +1 @@
_target_: roboimi.vla.modules.encoders.IdentityActionEncoder

View File

@@ -0,0 +1 @@
_target_: roboimi.vla.modules.encoders.IdentityStateEncoder

View File

@@ -1,4 +1,4 @@
# # Action Head models
from .diffusion import ConditionalUnet1D
from .conditional_unet1d import ConditionalUnet1D
__all__ = ["ConditionalUnet1D"]

View File

@@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module):
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None, **kwargs):
local_cond=None, global_cond=None,
visual_features=None, proprioception=None,
**kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
visual_features: (B, T_obs, D_vis)
proprioception: (B, T_obs, D_prop)
output: (B,T,input_dim)
"""
if global_cond is None:
conds = []
if visual_features is not None:
conds.append(visual_features.flatten(start_dim=1))
if proprioception is not None:
conds.append(proprioception.flatten(start_dim=1))
if len(conds) > 0:
global_cond = torch.cat(conds, dim=-1)
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time
@@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module):
x = einops.rearrange(x, 'b t h -> b h t')
return x

View File

@@ -0,0 +1,18 @@
from torch import nn
class IdentityStateEncoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, state):
return state
class IdentityActionEncoder(nn.Module):
def __init__(self):
super().__init__()
def forward(self, action):
return action