暂存
This commit is contained in:
@@ -5,13 +5,16 @@ from typing import Dict, Optional, Any
|
|||||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
|
from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D
|
||||||
|
|
||||||
class VLAAgent(nn.Module):
|
class VLAAgent(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vision_backbone, # 你之前定义的 ResNet 类
|
vision_backbone, # 你之前定义的 ResNet 类
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
head,
|
||||||
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||||||
obs_dim, # 本体感知维度 (例如 关节角度)
|
obs_dim, # 本体感知维度 (例如 关节角度)
|
||||||
pred_horizon=16, # 预测未来多少步动作
|
pred_horizon=16, # 预测未来多少步动作
|
||||||
@@ -32,6 +35,7 @@ class VLAAgent(nn.Module):
|
|||||||
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
||||||
total_prop_dim = obs_dim * obs_horizon
|
total_prop_dim = obs_dim * obs_horizon
|
||||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||||
|
# self.global_cond_dim = total_vision_dim
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
@@ -48,11 +52,16 @@ class VLAAgent(nn.Module):
|
|||||||
prediction_type='epsilon'
|
prediction_type='epsilon'
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_pred_net = ConditionalUnet1D(
|
self.noise_pred_net = head(
|
||||||
input_dim=action_dim,
|
input_dim=action_dim,
|
||||||
|
# input_dim = action_dim + obs_dim,
|
||||||
global_cond_dim=self.global_cond_dim
|
global_cond_dim=self.global_cond_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.state_encoder = state_encoder
|
||||||
|
self.action_encoder = action_encoder
|
||||||
|
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# 训练阶段 (Training)
|
# 训练阶段 (Training)
|
||||||
# ==========================
|
# ==========================
|
||||||
@@ -60,37 +69,35 @@ class VLAAgent(nn.Module):
|
|||||||
"""
|
"""
|
||||||
batch: 包含 images, qpos (proprioception), action
|
batch: 包含 images, qpos (proprioception), action
|
||||||
"""
|
"""
|
||||||
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
B = gt_actions.shape[0]
|
B = actions.shape[0]
|
||||||
images = batch['images']
|
|
||||||
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
|
|
||||||
|
|
||||||
|
state_features = self.state_encoder(states)
|
||||||
|
|
||||||
# 1. 提取视觉特征
|
# 1. 提取视觉特征
|
||||||
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
|
visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim)
|
||||||
|
action_features = self.action_encoder(actions)
|
||||||
# 2. 融合特征 -> 全局条件 (Global Conditioning)
|
|
||||||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
|
||||||
|
|
||||||
# 3. 采样噪声
|
# 3. 采样噪声
|
||||||
noise = torch.randn_like(gt_actions)
|
noise = torch.randn_like(action_features)
|
||||||
|
|
||||||
# 4. 随机采样时间步 (Timesteps)
|
# 4. 随机采样时间步 (Timesteps)
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.noise_scheduler.config.num_train_timesteps,
|
0, self.noise_scheduler.config.num_train_timesteps,
|
||||||
(B,), device=gt_actions.device
|
(B,), device=action_features.device
|
||||||
).long()
|
).long()
|
||||||
|
|
||||||
# 5. 给动作加噪 (Forward Diffusion)
|
# 5. 给动作加噪 (Forward Diffusion)
|
||||||
noisy_actions = self.noise_scheduler.add_noise(
|
noisy_actions = self.noise_scheduler.add_noise(
|
||||||
gt_actions, noise, timesteps
|
action_features, noise, timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 网络预测噪声
|
# 6. 网络预测噪声
|
||||||
pred_noise = self.noise_pred_net(
|
pred_noise = self.noise_pred_net(
|
||||||
sample=noisy_actions,
|
sample=noisy_actions,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
global_cond=global_cond
|
visual_features=visual_features,
|
||||||
|
proprioception=state_features
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. 计算 Loss (MSE)
|
# 7. 计算 Loss (MSE)
|
||||||
@@ -102,17 +109,17 @@ class VLAAgent(nn.Module):
|
|||||||
# ==========================
|
# ==========================
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action(self, images, proprioception):
|
def predict_action(self, images, proprioception):
|
||||||
B = 1 # 假设单次推理
|
B = proprioception.shape[0]
|
||||||
|
|
||||||
# 1. 提取当前观测特征 (只做一次)
|
# 1. 提取当前观测特征 (只做一次)
|
||||||
visual_features = self.vision_encoder(images).view(B, -1)
|
visual_features = self.vision_encoder(images)
|
||||||
proprioception = proprioception.view(B, -1)
|
state_features = self.state_encoder(proprioception)
|
||||||
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
|
||||||
|
|
||||||
# 2. 初始化纯高斯噪声动作
|
# 2. 初始化纯高斯噪声动作
|
||||||
# Shape: (B, pred_horizon, action_dim)
|
# Shape: (B, pred_horizon, action_dim)
|
||||||
|
device = visual_features.device
|
||||||
current_actions = torch.randn(
|
current_actions = torch.randn(
|
||||||
(B, self.pred_horizon, self.action_dim), device=global_cond.device
|
(B, self.pred_horizon, self.action_dim), device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 逐步去噪循环 (Reverse Diffusion)
|
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||||||
@@ -125,7 +132,8 @@ class VLAAgent(nn.Module):
|
|||||||
noise_pred = self.noise_pred_net(
|
noise_pred = self.noise_pred_net(
|
||||||
sample=model_input,
|
sample=model_input,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
global_cond=global_cond
|
visual_features=visual_features,
|
||||||
|
proprioception=state_features
|
||||||
)
|
)
|
||||||
|
|
||||||
# 移除噪声,更新 current_actions
|
# 移除噪声,更新 current_actions
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
# @package agent
|
# @package agent
|
||||||
_target_: roboimi.vla.agent.VLAAgent
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: conditional_unet1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
# Vision Backbone: ResNet-18 with SpatialSoftmax
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
vision_backbone:
|
|
||||||
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
|
||||||
model_name: "microsoft/resnet-18"
|
|
||||||
freeze: true
|
|
||||||
|
|
||||||
# Action and Observation Dimensions
|
# Action and Observation Dimensions
|
||||||
action_dim: 16
|
action_dim: 16
|
||||||
@@ -16,7 +17,7 @@ pred_horizon: 16
|
|||||||
obs_horizon: 2
|
obs_horizon: 2
|
||||||
|
|
||||||
# Diffusion Parameters
|
# Diffusion Parameters
|
||||||
diffusion_steps: 100 # Number of diffusion timesteps for training
|
# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递)
|
||||||
|
|
||||||
# Camera Configuration
|
# Camera Configuration
|
||||||
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
# @package agent.backbone
|
|
||||||
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||||
|
|
||||||
model_name: "microsoft/resnet-18"
|
model_name: "microsoft/resnet-18"
|
||||||
|
|||||||
5
roboimi/vla/conf/head/conditional_unet1d.yaml
Normal file
5
roboimi/vla/conf/head/conditional_unet1d.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
kernel_size: 3
|
||||||
|
cond_predict_scale: false
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
_target_: roboimi.vla.models.heads.DiffusionActionHead
|
|
||||||
|
|
||||||
# 显式声明必填参数
|
|
||||||
input_dim: ??? # 等待 agent/default.yaml 填充
|
|
||||||
action_dim: 7
|
|
||||||
obs_horizon: 2
|
|
||||||
pred_horizon: 16
|
|
||||||
denoising_steps: 100
|
|
||||||
1
roboimi/vla/conf/modules/identity_action_encoder.yaml
Normal file
1
roboimi/vla/conf/modules/identity_action_encoder.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
_target_: roboimi.vla.modules.encoders.IdentityActionEncoder
|
||||||
1
roboimi/vla/conf/modules/identity_state_encoder.yaml
Normal file
1
roboimi/vla/conf/modules/identity_state_encoder.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
_target_: roboimi.vla.modules.encoders.IdentityStateEncoder
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# # Action Head models
|
# # Action Head models
|
||||||
from .diffusion import ConditionalUnet1D
|
from .conditional_unet1d import ConditionalUnet1D
|
||||||
|
|
||||||
__all__ = ["ConditionalUnet1D"]
|
__all__ = ["ConditionalUnet1D"]
|
||||||
|
|||||||
@@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
local_cond=None, global_cond=None, **kwargs):
|
local_cond=None, global_cond=None,
|
||||||
|
visual_features=None, proprioception=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
x: (B,T,input_dim)
|
x: (B,T,input_dim)
|
||||||
timestep: (B,) or int, diffusion step
|
timestep: (B,) or int, diffusion step
|
||||||
local_cond: (B,T,local_cond_dim)
|
local_cond: (B,T,local_cond_dim)
|
||||||
global_cond: (B,global_cond_dim)
|
global_cond: (B,global_cond_dim)
|
||||||
|
visual_features: (B, T_obs, D_vis)
|
||||||
|
proprioception: (B, T_obs, D_prop)
|
||||||
output: (B,T,input_dim)
|
output: (B,T,input_dim)
|
||||||
"""
|
"""
|
||||||
|
if global_cond is None:
|
||||||
|
conds = []
|
||||||
|
if visual_features is not None:
|
||||||
|
conds.append(visual_features.flatten(start_dim=1))
|
||||||
|
if proprioception is not None:
|
||||||
|
conds.append(proprioception.flatten(start_dim=1))
|
||||||
|
if len(conds) > 0:
|
||||||
|
global_cond = torch.cat(conds, dim=-1)
|
||||||
|
|
||||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||||
|
|
||||||
# 1. time
|
# 1. time
|
||||||
@@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
x = einops.rearrange(x, 'b t h -> b h t')
|
x = einops.rearrange(x, 'b t h -> b h t')
|
||||||
return x
|
return x
|
||||||
|
|
||||||
18
roboimi/vla/modules/encoders.py
Normal file
18
roboimi/vla/modules/encoders.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class IdentityStateEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityActionEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, action):
|
||||||
|
return action
|
||||||
Reference in New Issue
Block a user