From 8b700b6d99df3029f33d2b74a8f5cdfa90f16cf9 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Mon, 9 Feb 2026 14:41:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/vla/agent.py | 48 +++++++++++-------- roboimi/vla/conf/agent/resnet_diffusion.yaml | 15 +++--- roboimi/vla/conf/backbone/resnet.yaml | 1 - roboimi/vla/conf/head/conditional_unet1d.yaml | 5 ++ roboimi/vla/conf/head/diffusion.yaml | 8 ---- .../conf/modules/identity_action_encoder.yaml | 1 + .../conf/modules/identity_state_encoder.yaml | 1 + roboimi/vla/models/heads/__init__.py | 2 +- .../{diffusion.py => conditional_unet1d.py} | 16 ++++++- roboimi/vla/modules/encoders.py | 18 +++++++ 10 files changed, 76 insertions(+), 39 deletions(-) create mode 100644 roboimi/vla/conf/head/conditional_unet1d.yaml delete mode 100644 roboimi/vla/conf/head/diffusion.yaml create mode 100644 roboimi/vla/conf/modules/identity_action_encoder.yaml create mode 100644 roboimi/vla/conf/modules/identity_state_encoder.yaml rename roboimi/vla/models/heads/{diffusion.py => conditional_unet1d.py} (94%) create mode 100644 roboimi/vla/modules/encoders.py diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index ac1371e..81ae588 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -5,13 +5,16 @@ from typing import Dict, Optional, Any from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from roboimi.vla.models.heads.diffusion import ConditionalUnet1D +from roboimi.vla.models.heads.conditional_unet1d import ConditionalUnet1D class VLAAgent(nn.Module): def __init__( self, vision_backbone, # 你之前定义的 ResNet 类 + state_encoder, + action_encoder, + head, action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper) obs_dim, # 本体感知维度 (例如 关节角度) pred_horizon=16, # 预测未来多少步动作 @@ -32,6 +35,7 @@ class VLAAgent(nn.Module): total_vision_dim = single_img_feat_dim * num_cams * obs_horizon total_prop_dim = obs_dim * obs_horizon self.global_cond_dim = total_vision_dim + total_prop_dim + # self.global_cond_dim = total_vision_dim self.noise_scheduler = DDPMScheduler( num_train_timesteps=diffusion_steps, @@ -48,11 +52,16 @@ class VLAAgent(nn.Module): prediction_type='epsilon' ) - self.noise_pred_net = ConditionalUnet1D( + self.noise_pred_net = head( input_dim=action_dim, + # input_dim = action_dim + obs_dim, global_cond_dim=self.global_cond_dim ) + self.state_encoder = state_encoder + self.action_encoder = action_encoder + + # ========================== # 训练阶段 (Training) # ========================== @@ -60,37 +69,35 @@ class VLAAgent(nn.Module): """ batch: 包含 images, qpos (proprioception), action """ - gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim) - B = gt_actions.shape[0] - images = batch['images'] - proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim) + actions, states, images = batch['action'], batch['qpos'], batch['images'] + B = actions.shape[0] + state_features = self.state_encoder(states) # 1. 提取视觉特征 - visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim) - - # 2. 融合特征 -> 全局条件 (Global Conditioning) - global_cond = torch.cat([visual_features, proprioception], dim=-1) + visual_features = self.vision_encoder(images) # (B, obs_horizon, vision_dim) + action_features = self.action_encoder(actions) # 3. 采样噪声 - noise = torch.randn_like(gt_actions) + noise = torch.randn_like(action_features) # 4. 随机采样时间步 (Timesteps) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, - (B,), device=gt_actions.device + (B,), device=action_features.device ).long() # 5. 给动作加噪 (Forward Diffusion) noisy_actions = self.noise_scheduler.add_noise( - gt_actions, noise, timesteps + action_features, noise, timesteps ) # 6. 网络预测噪声 pred_noise = self.noise_pred_net( sample=noisy_actions, timestep=timesteps, - global_cond=global_cond + visual_features=visual_features, + proprioception=state_features ) # 7. 计算 Loss (MSE) @@ -102,17 +109,17 @@ class VLAAgent(nn.Module): # ========================== @torch.no_grad() def predict_action(self, images, proprioception): - B = 1 # 假设单次推理 + B = proprioception.shape[0] # 1. 提取当前观测特征 (只做一次) - visual_features = self.vision_encoder(images).view(B, -1) - proprioception = proprioception.view(B, -1) - global_cond = torch.cat([visual_features, proprioception], dim=-1) + visual_features = self.vision_encoder(images) + state_features = self.state_encoder(proprioception) # 2. 初始化纯高斯噪声动作 # Shape: (B, pred_horizon, action_dim) + device = visual_features.device current_actions = torch.randn( - (B, self.pred_horizon, self.action_dim), device=global_cond.device + (B, self.pred_horizon, self.action_dim), device=device ) # 3. 逐步去噪循环 (Reverse Diffusion) @@ -125,7 +132,8 @@ class VLAAgent(nn.Module): noise_pred = self.noise_pred_net( sample=model_input, timestep=t, - global_cond=global_cond + visual_features=visual_features, + proprioception=state_features ) # 移除噪声,更新 current_actions diff --git a/roboimi/vla/conf/agent/resnet_diffusion.yaml b/roboimi/vla/conf/agent/resnet_diffusion.yaml index 0ab1a0c..2874672 100644 --- a/roboimi/vla/conf/agent/resnet_diffusion.yaml +++ b/roboimi/vla/conf/agent/resnet_diffusion.yaml @@ -1,11 +1,12 @@ # @package agent -_target_: roboimi.vla.agent.VLAAgent +defaults: + - /backbone@vision_backbone: resnet + - /modules@state_encoder: identity_state_encoder + - /modules@action_encoder: identity_action_encoder + - /head: conditional_unet1d + - _self_ -# Vision Backbone: ResNet-18 with SpatialSoftmax -vision_backbone: - _target_: roboimi.vla.models.backbones.resnet.ResNetBackbone - model_name: "microsoft/resnet-18" - freeze: true +_target_: roboimi.vla.agent.VLAAgent # Action and Observation Dimensions action_dim: 16 @@ -16,7 +17,7 @@ pred_horizon: 16 obs_horizon: 2 # Diffusion Parameters -diffusion_steps: 100 # Number of diffusion timesteps for training +# diffusion_steps: 100 (这些参数应该移到 head 配置中,或者通过变量传递) # Camera Configuration num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取 \ No newline at end of file diff --git a/roboimi/vla/conf/backbone/resnet.yaml b/roboimi/vla/conf/backbone/resnet.yaml index 487577d..4fb178b 100644 --- a/roboimi/vla/conf/backbone/resnet.yaml +++ b/roboimi/vla/conf/backbone/resnet.yaml @@ -1,4 +1,3 @@ -# @package agent.backbone _target_: roboimi.vla.models.backbones.resnet.ResNetBackbone model_name: "microsoft/resnet-18" diff --git a/roboimi/vla/conf/head/conditional_unet1d.yaml b/roboimi/vla/conf/head/conditional_unet1d.yaml new file mode 100644 index 0000000..fb3cc1a --- /dev/null +++ b/roboimi/vla/conf/head/conditional_unet1d.yaml @@ -0,0 +1,5 @@ +_target_: roboimi.vla.models.heads.conditional_unet1d.ConditionalUnet1D +_partial_: true + +kernel_size: 3 +cond_predict_scale: false diff --git a/roboimi/vla/conf/head/diffusion.yaml b/roboimi/vla/conf/head/diffusion.yaml deleted file mode 100644 index 2934c94..0000000 --- a/roboimi/vla/conf/head/diffusion.yaml +++ /dev/null @@ -1,8 +0,0 @@ -_target_: roboimi.vla.models.heads.DiffusionActionHead - -# 显式声明必填参数 -input_dim: ??? # 等待 agent/default.yaml 填充 -action_dim: 7 -obs_horizon: 2 -pred_horizon: 16 -denoising_steps: 100 \ No newline at end of file diff --git a/roboimi/vla/conf/modules/identity_action_encoder.yaml b/roboimi/vla/conf/modules/identity_action_encoder.yaml new file mode 100644 index 0000000..4f18b51 --- /dev/null +++ b/roboimi/vla/conf/modules/identity_action_encoder.yaml @@ -0,0 +1 @@ +_target_: roboimi.vla.modules.encoders.IdentityActionEncoder diff --git a/roboimi/vla/conf/modules/identity_state_encoder.yaml b/roboimi/vla/conf/modules/identity_state_encoder.yaml new file mode 100644 index 0000000..fba00d5 --- /dev/null +++ b/roboimi/vla/conf/modules/identity_state_encoder.yaml @@ -0,0 +1 @@ +_target_: roboimi.vla.modules.encoders.IdentityStateEncoder diff --git a/roboimi/vla/models/heads/__init__.py b/roboimi/vla/models/heads/__init__.py index 7a32179..601a467 100644 --- a/roboimi/vla/models/heads/__init__.py +++ b/roboimi/vla/models/heads/__init__.py @@ -1,4 +1,4 @@ # # Action Head models -from .diffusion import ConditionalUnet1D +from .conditional_unet1d import ConditionalUnet1D __all__ = ["ConditionalUnet1D"] diff --git a/roboimi/vla/models/heads/diffusion.py b/roboimi/vla/models/heads/conditional_unet1d.py similarity index 94% rename from roboimi/vla/models/heads/diffusion.py rename to roboimi/vla/models/heads/conditional_unet1d.py index 6233658..f468120 100644 --- a/roboimi/vla/models/heads/diffusion.py +++ b/roboimi/vla/models/heads/conditional_unet1d.py @@ -225,14 +225,27 @@ class ConditionalUnet1D(nn.Module): def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], - local_cond=None, global_cond=None, **kwargs): + local_cond=None, global_cond=None, + visual_features=None, proprioception=None, + **kwargs): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step local_cond: (B,T,local_cond_dim) global_cond: (B,global_cond_dim) + visual_features: (B, T_obs, D_vis) + proprioception: (B, T_obs, D_prop) output: (B,T,input_dim) """ + if global_cond is None: + conds = [] + if visual_features is not None: + conds.append(visual_features.flatten(start_dim=1)) + if proprioception is not None: + conds.append(proprioception.flatten(start_dim=1)) + if len(conds) > 0: + global_cond = torch.cat(conds, dim=-1) + sample = einops.rearrange(sample, 'b h t -> b t h') # 1. time @@ -291,4 +304,3 @@ class ConditionalUnet1D(nn.Module): x = einops.rearrange(x, 'b t h -> b h t') return x - diff --git a/roboimi/vla/modules/encoders.py b/roboimi/vla/modules/encoders.py new file mode 100644 index 0000000..0fa0970 --- /dev/null +++ b/roboimi/vla/modules/encoders.py @@ -0,0 +1,18 @@ +from torch import nn + +class IdentityStateEncoder(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, state): + return state + + +class IdentityActionEncoder(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, action): + return action \ No newline at end of file