暂存
This commit is contained in:
@@ -5,13 +5,16 @@ from typing import Dict, Optional, Any
|
||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
|
||||
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||
|
||||
class VLAAgent(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone, # 你之前定义的 ResNet 类
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
head,
|
||||
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||||
obs_dim, # 本体感知维度 (例如 关节角度)
|
||||
pred_horizon=16, # 预测未来多少步动作
|
||||
@@ -32,6 +35,7 @@ class VLAAgent(nn.Module):
|
||||
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
||||
total_prop_dim = obs_dim * obs_horizon
|
||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||
# self.global_cond_dim = total_vision_dim
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
@@ -48,11 +52,16 @@ class VLAAgent(nn.Module):
|
||||
prediction_type='epsilon'
|
||||
)
|
||||
|
||||
self.noise_pred_net = ConditionalUnet1D(
|
||||
self.noise_pred_net = head(
|
||||
input_dim=action_dim,
|
||||
# input_dim = action_dim + obs_dim,
|
||||
global_cond_dim=self.global_cond_dim
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
|
||||
|
||||
# ==========================
|
||||
# 训练阶段 (Training)
|
||||
# ==========================
|
||||
@@ -60,37 +69,35 @@ class VLAAgent(nn.Module):
|
||||
"""
|
||||
batch: 包含 images, qpos (proprioception), action
|
||||
"""
|
||||
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
|
||||
B = gt_actions.shape[0]
|
||||
images = batch['images']
|
||||
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
B = actions.shape[0]
|
||||
|
||||
state_features = self.state_encoder(states)
|
||||
|
||||
# 1. 提取视觉特征
|
||||
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
|
||||
|
||||
# 2. 融合特征 -> 全局条件 (Global Conditioning)
|
||||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||||
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||
action_features = self.action_encoder(actions)
|
||||
|
||||
# 3. 采样噪声
|
||||
noise = torch.randn_like(gt_actions)
|
||||
noise = torch.randn_like(action_features)
|
||||
|
||||
# 4. 随机采样时间步 (Timesteps)
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,), device=gt_actions.device
|
||||
(B,), device=action_features.device
|
||||
).long()
|
||||
|
||||
# 5. 给动作加噪 (Forward Diffusion)
|
||||
noisy_actions = self.noise_scheduler.add_noise(
|
||||
gt_actions, noise, timesteps
|
||||
action_features, noise, timesteps
|
||||
)
|
||||
|
||||
# 6. 网络预测噪声
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
global_cond=global_cond
|
||||
visual_features=visual_features,
|
||||
proprioception=state_features
|
||||
)
|
||||
|
||||
# 7. 计算 Loss (MSE)
|
||||
@@ -102,17 +109,17 @@ class VLAAgent(nn.Module):
|
||||
# ==========================
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
B = 1 # 假设单次推理
|
||||
B = proprioception.shape[0]
|
||||
|
||||
# 1. 提取当前观测特征 (只做一次)
|
||||
visual_features = self.vision_encoder(images).view(B, -1)
|
||||
proprioception = proprioception.view(B, -1)
|
||||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(proprioception)
|
||||
|
||||
# 2. 初始化纯高斯噪声动作
|
||||
# Shape: (B, pred_horizon, action_dim)
|
||||
device = visual_features.device
|
||||
current_actions = torch.randn(
|
||||
(B, self.pred_horizon, self.action_dim), device=global_cond.device
|
||||
(B, self.pred_horizon, self.action_dim), device=device
|
||||
)
|
||||
|
||||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||||
@@ -125,7 +132,8 @@ class VLAAgent(nn.Module):
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=model_input,
|
||||
timestep=t,
|
||||
global_cond=global_cond
|
||||
visual_features=visual_features,
|
||||
proprioception=state_features
|
||||
)
|
||||
|
||||
# 移除噪声,更新 current_actions
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# @package agent
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: conditional_unet1d
|
||||
- _self_
|
||||
|
||||
# Vision Backbone: ResNet-18 with SpatialSoftmax
|
||||
vision_backbone:
|
||||
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||
model_name: "microsoft/resnet-18"
|
||||
freeze: true
|
||||
_target_: roboimi.vla.agent.VLAAgent
|
||||
|
||||
# Action and Observation Dimensions
|
||||
action_dim: 16
|
||||
@@ -16,7 +17,7 @@ pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
|
||||
# Diffusion Parameters
|
||||
diffusion_steps: 100 # Number of diffusion timesteps for training
|
||||
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
|
||||
|
||||
# Camera Configuration
|
||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||
@@ -1,4 +1,3 @@
|
||||
# @package agent.backbone
|
||||
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||
|
||||
model_name: "microsoft/resnet-18"
|
||||
|
||||
5
roboimi/vla/conf/head/conditional_unet1d.yaml
Normal file
5
roboimi/vla/conf/head/conditional_unet1d.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
|
||||
_partial_: true
|
||||
|
||||
kernel_size: 3
|
||||
cond_predict_scale: false
|
||||
@@ -1,8 +0,0 @@
|
||||
_target_: roboimi.vla.models.heads.DiffusionActionHead
|
||||
|
||||
# 显式声明必填参数
|
||||
input_dim: ??? # 等待 agent/default.yaml 填充
|
||||
action_dim: 7
|
||||
obs_horizon: 2
|
||||
pred_horizon: 16
|
||||
denoising_steps: 100
|
||||
1
roboimi/vla/conf/modules/identity_action_encoder.yaml
Normal file
1
roboimi/vla/conf/modules/identity_action_encoder.yaml
Normal file
@@ -0,0 +1 @@
|
||||
_target_: roboimi.vla.modules.encoders.IdentityActionEncoder
|
||||
1
roboimi/vla/conf/modules/identity_state_encoder.yaml
Normal file
1
roboimi/vla/conf/modules/identity_state_encoder.yaml
Normal file
@@ -0,0 +1 @@
|
||||
_target_: roboimi.vla.modules.encoders.IdentityStateEncoder
|
||||
@@ -1,4 +1,4 @@
|
||||
# # Action Head models
|
||||
from .diffusion import ConditionalUnet1D
|
||||
from .conditional_unet1d import ConditionalUnet1D
|
||||
|
||||
__all__ = ["ConditionalUnet1D"]
|
||||
|
||||
@@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module):
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None, global_cond=None, **kwargs):
|
||||
local_cond=None, global_cond=None,
|
||||
visual_features=None, proprioception=None,
|
||||
**kwargs):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
visual_features: (B, T_obs, D_vis)
|
||||
proprioception: (B, T_obs, D_prop)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
if global_cond is None:
|
||||
conds = []
|
||||
if visual_features is not None:
|
||||
conds.append(visual_features.flatten(start_dim=1))
|
||||
if proprioception is not None:
|
||||
conds.append(proprioception.flatten(start_dim=1))
|
||||
if len(conds) > 0:
|
||||
global_cond = torch.cat(conds, dim=-1)
|
||||
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
|
||||
# 1. time
|
||||
@@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module):
|
||||
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
|
||||
18
roboimi/vla/modules/encoders.py
Normal file
18
roboimi/vla/modules/encoders.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from torch import nn
|
||||
|
||||
class IdentityStateEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, state):
|
||||
return state
|
||||
|
||||
|
||||
class IdentityActionEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, action):
|
||||
return action
|
||||
Reference in New Issue
Block a user