feat(train): 跑通训练脚本

This commit is contained in:
gouhanke
2026-02-05 14:08:43 +08:00
parent dd2749cb12
commit b0a944f7aa
17 changed files with 1002 additions and 464 deletions

View File

@@ -7,115 +7,27 @@ from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim import AdamW from torch.optim import AdamW
from pathlib import Path
# 确保导入路径正确 # Ensure correct import path
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from roboimi.vla.agent import VLAAgent
from hydra.utils import instantiate from hydra.utils import instantiate
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
# --- 1. 实例化 Dataset & DataLoader ---
# Hydra 根据 conf/data/custom_hdf5.yaml 实例化类
dataset = instantiate(cfg.data)
dataloader = DataLoader(
dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu")
)
log.info(f"✅ Dataset loaded. Size: {len(dataset)}")
# --- 2. 实例化 Agent ---
agent: VLAAgent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr)
# --- 3. Training Loop ---
# 使用一个无限迭代器或者 epoch 循环
data_iter = iter(dataloader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training")
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
#而在 epoch 结束时重新开始
data_iter = iter(dataloader)
batch = next(data_iter)
# Move to device
# 注意:这里需要递归地将字典里的 tensor 移到 GPU
batch = recursive_to_device(batch, cfg.train.device)
# --- 4. Adapter Layer (适配层) ---
# Dataset 返回的是具体的相机 key (如 'agentview_image' 或 'top')
# Agent 期望的是通用的 'image'
# 我们在这里做一个映射,模拟多模态融合前的处理
# 假设我们只用配置里的第一个 key 作为主视觉
# primary_cam_key = cfg.data.obs_keys[0]
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
# input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
# agent_input = {
# "obs": {
# "image": input_img,
# "text": batch["language"] # 传递语言指令
# },
# "actions": batch["actions"] # (B, Chunk, Dim)
# }
agent_input = {
"action": batch["action"],
"qpos": batch["qpos"],
"images": {}
}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
agent_input["images"][cam_name] = batch[key].squeeze(1)
# --- 5. Forward & Backward ---
outputs = agent(agent_input)
# 处理 Loss 掩码 (如果在真实训练中,需要在这里应用 action_mask)
# 目前 DebugHead 内部直接算了 MSE还没用 mask我们在下一阶段优化 Policy 时加上
loss = outputs['loss']
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % cfg.train.log_freq == 0:
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
log.info("✅ Training Loop with Real HDF5 Finished!")
# --- 6. Save Checkpoint ---
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "vla_model_final.pt")
# 保存整个 Agent 的 state_dict
torch.save(agent.state_dict(), save_path)
log.info(f"💾 Model saved to {save_path}")
log.info("✅ Training Loop Finished!")
def recursive_to_device(data, device): def recursive_to_device(data, device):
"""
Recursively move nested dictionaries/lists of tensors to specified device.
Args:
data: Dictionary, list, or tensor
device: Target device (e.g., 'cuda', 'cpu')
Returns:
Data structure with all tensors moved to device
"""
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.to(device) return data.to(device)
elif isinstance(data, dict): elif isinstance(data, dict):
@@ -124,5 +36,193 @@ def recursive_to_device(data, device):
return [recursive_to_device(v, device) for v in data] return [recursive_to_device(v, device) for v in data]
return data return data
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Training Script with ResNet Backbone and Diffusion Policy.
This script:
1. Loads dataset from HDF5 files
2. Instantiates VLAAgent with ResNet vision encoder
3. Trains diffusion-based action prediction
4. Saves checkpoints periodically
"""
# Print configuration
print("=" * 80)
print("VLA Training Configuration:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
# Create checkpoint directory
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# =========================================================================
# 1. Instantiate Dataset & DataLoader
# =========================================================================
log.info("📦 Loading dataset...")
try:
dataset = instantiate(cfg.data)
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
except Exception as e:
log.error(f"❌ Failed to load dataset: {e}")
raise
dataloader = DataLoader(
dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training
)
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
# =========================================================================
# 2. Instantiate VLA Agent
# =========================================================================
log.info("🤖 Initializing VLA Agent...")
try:
agent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 Total parameters: {total_params:,}")
log.info(f"📊 Trainable parameters: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Failed to initialize agent: {e}")
raise
# =========================================================================
# 3. Setup Optimizer
# =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
# =========================================================================
# 4. Training Loop
# =========================================================================
log.info("🏋️ Starting training loop...")
data_iter = iter(dataloader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
best_loss = float('inf')
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
# Restart iterator when epoch ends
data_iter = iter(dataloader)
batch = next(data_iter)
# =====================================================================
# Move batch to device
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# Prepare agent input
# =====================================================================
# Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor}
# Extract images into a dictionary
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch:
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
# Prepare agent input
agent_input = {
'images': images, # Dict of camera images
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
'action': batch['action'] # (B, pred_horizon, action_dim)
}
# =====================================================================
# Forward pass & compute loss
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"❌ Forward pass failed at step {step}: {e}")
raise
# =====================================================================
# Backward pass & optimization
# =====================================================================
optimizer.zero_grad()
loss.backward()
# Gradient clipping for stable training
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step()
# =====================================================================
# Logging
# =====================================================================
if step % cfg.train.log_freq == 0:
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"best_loss": f"{best_loss:.4f}"
})
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Checkpoint saving
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
best_model_path = checkpoint_dir / "vla_model_best.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
# =========================================================================
# 5. Save Final Model
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
torch.save({
'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}")
log.info("✅ Training completed successfully!")
log.info(f"📊 Final Loss: {loss.item():.4f}")
log.info(f"📊 Best Loss: {best_loss:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -0,0 +1,238 @@
# ResNet VLA Training Guide
This guide explains how to train the VLA agent with ResNet backbone and action_dim=16, obs_dim=16.
## Configuration Overview
### 1. Backbone Configuration
**File**: `roboimi/vla/conf/backbone/resnet.yaml`
- Model: microsoft/resnet-18
- Output dim: 1024 (512 channels × 2 from SpatialSoftmax)
- Frozen by default for faster training
### 2. Agent Configuration
**File**: `roboimi/vla/conf/agent/resnet_diffusion.yaml`
- Vision backbone: ResNet-18 with SpatialSoftmax
- Action dimension: 16
- Observation dimension: 16
- Prediction horizon: 16 steps
- Observation horizon: 2 steps
- Diffusion steps: 100
- Number of cameras: 2
### 3. Dataset Configuration
**File**: `roboimi/vla/conf/data/resnet_dataset.yaml`
- Dataset class: RobotDiffusionDataset
- Prediction horizon: 16
- Observation horizon: 2
- Camera names: [r_vis, top]
- Normalization: gaussian (mean/std)
### 4. Training Configuration
**File**: `roboimi/vla/conf/config.yaml`
- Batch size: 8
- Learning rate: 1e-4
- Max steps: 10000
- Log frequency: 100 steps
- Save frequency: 1000 steps
- Device: cuda
- Num workers: 4
## Prerequisites
### 1. Prepare Dataset
Your dataset should be organized as:
```
/path/to/your/dataset/
├── episode_0.hdf5
├── episode_1.hdf5
├── ...
└── data_stats.pkl
```
Each HDF5 file should contain:
```
episode_N.hdf5
├── action # (T, 16) float32
└── observations/
├── qpos # (T, 16) float32
└── images/
├── r_vis/ # (T, H, W, 3) uint8
└── top/ # (T, H, W, 3) uint8
```
### 2. Generate Dataset Statistics
Create `data_stats.pkl` with:
```python
import pickle
import numpy as np
stats = {
'action': {
'mean': np.zeros(16),
'std': np.ones(16)
},
'qpos': {
'mean': np.zeros(16),
'std': np.ones(16)
}
}
with open('/path/to/your/dataset/data_stats.pkl', 'wb') as f:
pickle.dump(stats, f)
```
Or use the provided script:
```bash
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/your/dataset
```
## Usage
### 1. Update Dataset Path
Edit `roboimi/vla/conf/data/resnet_dataset.yaml`:
```yaml
dataset_dir: "/path/to/your/dataset" # CHANGE THIS
camera_names:
- r_vis # CHANGE TO YOUR CAMERA NAMES
- top
```
### 2. Run Training
```bash
# Basic training
python roboimi/demos/vla_scripts/train_vla.py
# Override configurations
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=16
python roboimi/demos/vla_scripts/train_vla.py train.device=cpu
python roboimi/demos/vla_scripts/train_vla.py train.max_steps=20000
python roboimi/demos/vla_scripts/train_vla.py data.dataset_dir=/custom/path
# Debug mode (CPU, small batch, few steps)
python roboimi/demos/vla_scripts/train_vla.py \
train.device=cpu \
train.batch_size=2 \
train.max_steps=10 \
train.num_workers=0
```
### 3. Monitor Training
Checkpoints are saved to:
- `checkpoints/vla_model_step_1000.pt` - Periodic checkpoints
- `checkpoints/vla_model_best.pt` - Best model (lowest loss)
- `checkpoints/vla_model_final.pt` - Final model
## Architecture Details
### Data Flow
1. **Input**: Images from multiple cameras + proprioception (qpos)
2. **Vision Encoder**: ResNet-18 → SpatialSoftmax → (B, T, 1024) per camera
3. **Feature Concatenation**: All cameras + qpos → Global conditioning
4. **Diffusion Policy**: 1D U-Net predicts noise on action sequences
5. **Output**: Clean action sequence (B, 16, 16)
### Training Process
1. Sample random timestep t from [0, 100]
2. Add noise to ground truth actions
3. Predict noise using vision + proprioception conditioning
4. Compute MSE loss between predicted and actual noise
5. Backpropagate and update weights
### Inference Process
1. Extract visual features from current observation
2. Start with random noise action sequence
3. Iteratively denoise over 10 steps (DDPM scheduler)
4. Return clean action sequence
## Common Issues
### Issue: Out of Memory
**Solution**: Reduce batch size or use CPU
```bash
python train_vla.py train.batch_size=4 train.device=cpu
```
### Issue: Dataset not found
**Solution**: Check dataset_dir path in config
```bash
python train_vla.py data.dataset_dir=/absolute/path/to/dataset
```
### Issue: Camera names mismatch
**Solution**: Update camera_names in data config
```yaml
# roboimi/vla/conf/data/resnet_dataset.yaml
camera_names:
- your_camera_1
- your_camera_2
```
### Issue: data_stats.pkl missing
**Solution**: Generate statistics file
```bash
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/dataset
```
## Model Files Created
```
roboimi/vla/
├── conf/
│ ├── config.yaml (UPDATED)
│ ├── backbone/
│ │ └── resnet.yaml (NEW)
│ ├── agent/
│ │ └── resnet_diffusion.yaml (NEW)
│ └── data/
│ └── resnet_dataset.yaml (NEW)
├── models/
│ └── backbones/
│ ├── __init__.py (UPDATED - added resnet export)
│ └── resnet.py (EXISTING)
└── demos/vla_scripts/
└── train_vla.py (REWRITTEN)
```
## Next Steps
1. **Prepare your dataset** in the required HDF5 format
2. **Update dataset_dir** in `roboimi/vla/conf/data/resnet_dataset.yaml`
3. **Run training** with `python roboimi/demos/vla_scripts/train_vla.py`
4. **Monitor checkpoints** in `checkpoints/` directory
5. **Evaluate** the trained model using the best checkpoint
## Advanced Configuration
### Use Different ResNet Variant
Edit `roboimi/vla/conf/agent/resnet_diffusion.yaml`:
```yaml
vision_backbone:
model_name: "microsoft/resnet-50" # or resnet-34, resnet-101
```
### Adjust Diffusion Steps
```yaml
# More steps = better quality, slower training
diffusion_steps: 200 # default: 100
```
### Change Horizons
```yaml
pred_horizon: 32 # Predict more future steps
obs_horizon: 4 # Use more history
```
### Multi-GPU Training
```bash
# Use CUDA device 1
python train_vla.py train.device=cuda:1
# For multi-GPU, use torch.distributed (requires code modification)
```
## References
- ResNet Paper: https://arxiv.org/abs/1512.03385
- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/
- VLA Framework Documentation: See CLAUDE.md in project root

View File

@@ -2,111 +2,127 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
class VLAAgent(nn.Module): class VLAAgent(nn.Module):
def __init__( def __init__(
self, self,
backbone: VLABackbone, vision_backbone, # 你之前定义的 ResNet 类
projector: VLAProjector, action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
head: VLAHead, obs_dim, # 本体感知维度 (例如 关节角度)
state_encoder: nn.Module pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100,
num_cams=2, # 视觉输入的摄像头数量
): ):
super().__init__() super().__init__()
self.backbone = backbone self.vision_encoder = vision_backbone
self.projector = projector single_img_feat_dim = self.vision_encoder.output_dim
self.head = head total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
self.state_encoder = state_encoder total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
action = batch["action"] beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
state = batch["qpos"] clip_sample=True,
images = batch["images"] prediction_type='epsilon' # 预测噪声
)
state_emb = self.state_encoder(state)
# 2. Project Features
# Shape: (B, Seq, Head_Dim)
embeddings = self.projector(features)
# 3. Compute Action/Loss
# We pass actions if they exist (training mode)
actions = batch.get('actions', None)
outputs = self.head(embeddings=embeddings, actions=actions)
return outputs
# # roboimi/vla/agent.py
# import torch
# import torch.nn as nn
# from typing import Optional, Dict, Union
# class VLAAgent(nn.Module):
# def __init__(self,
# vlm_backbone: nn.Module,
# img_projector: nn.Module,
# action_head: nn.Module,
# state_dim: int,
# embed_dim: int):
# super().__init__()
# self.vlm_backbone = vlm_backbone
# self.img_projector = img_projector
# self.action_head = action_head
# # 简单的状态编码器 (通常不需要复杂的 config直接写在这里即可) self.noise_pred_net = ConditionalUnet1D(
# self.state_encoder = nn.Sequential( input_dim=action_dim,
# nn.Linear(state_dim, embed_dim), global_cond_dim=self.global_cond_dim
# nn.Mish(), )
# nn.Linear(embed_dim, embed_dim)
# )
# def forward(self, # ==========================
# images: torch.Tensor, # 训练阶段 (Training)
# state: torch.Tensor, # ==========================
# text: Optional[Union[str, list]] = None, def compute_loss(self, batch):
# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]: """
# """ batch: 包含 images, qpos (proprioception), action
# Args: """
# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度 gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
# state: [Batch, Obs_Horizon, State_Dim] B = gt_actions.shape[0]
# text: Optional text instructions images = batch['images']
# actions: [Batch, Pred_Horizon, Action_Dim] (Training only) proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
# 2. 融合特征 -> 全局条件 (Global Conditioning)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 3. 采样噪声
noise = torch.randn_like(gt_actions)
# 4. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=gt_actions.device
).long()
# 5. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
gt_actions, noise, timesteps
)
# 6. 网络预测噪声
# 注意U-Net 1D 通常期望 channel 在中间: (B, C, T)
# noisy_actions_inp = noisy_actions.permute(0, 2, 1)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 还原维度 (B, T, C)
pred_noise = pred_noise.permute(0, 2, 1)
# 7. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise)
return loss
# ==========================
# 推理阶段 (Inference)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
B = 1 # 假设单次推理
# 1. 提取当前观测特征 (只做一次)
visual_features = self.vision_encoder(images).view(B, -1)
proprioception = proprioception.view(B, -1)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 2. 初始化纯高斯噪声动作
# Shape: (B, Horizon, Action_Dim)
current_actions = torch.randn(
(B, 16, 7), device=global_cond.device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM)
for t in self.noise_scheduler.timesteps:
# 调整输入格式适应 1D CNN
model_input = current_actions.permute(0, 2, 1)
# Returns: # 预测噪声
# Training: Loss scalar noise_pred = self.noise_pred_net(
# Inference: Predicted actions sample=model_input,
# """ timestep=t,
global_cond=global_cond
# B, T, C, H, W = images.shape )
# noise_pred = noise_pred.permute(0, 2, 1)
# # 1. 图像编码 (Flatten time dimension for efficiency)
# # [B*T, C, H, W] -> [B*T, Vision_Dim]
# flat_images = images.view(B * T, C, H, W)
# vision_feats_dict = self.vlm_backbone(flat_images)
# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
# # 投影并还原时间维度 -> [B, T, Embed_Dim]
# img_emb = self.img_projector(raw_img_emb)
# img_emb = img_emb.view(B, T, -1)
# # 2. 状态编码
# state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例) # 移除噪声,更新 current_actions
# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接 current_actions = self.noise_scheduler.step(
# # 假设我们只用最近的一帧图像作为 Context或者将所有历史特征作为 Context noise_pred, t, current_actions
# # 这里演示Context = (Image_History + State_History) ).prev_sample
# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
# context = torch.cat([img_emb, state_emb], dim=1) # 4. 输出最终动作序列
return current_actions # 返回去噪后的干净动作
# # 4. Action Head 分支
# if actions is not None:
# # --- Training Mode ---
# # 必须返回 Loss
# return self.action_head.compute_loss(context, actions)
# else:
# # --- Inference Mode ---
# # 必须返回预测的动作序列
# return self.action_head.predict_action(context)

View File

@@ -0,0 +1,22 @@
# @package agent
_target_: roboimi.vla.agent.VLAAgent
# Vision Backbone: ResNet-18 with SpatialSoftmax
vision_backbone:
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true
# Action and Observation Dimensions
action_dim: 16 # Robot action dimension
obs_dim: 16 # Proprioception dimension (qpos)
# Prediction Horizons
pred_horizon: 16 # How many future actions to predict
obs_horizon: 2 # How many historical observations to use
# Diffusion Parameters
diffusion_steps: 100 # Number of diffusion timesteps for training
# Camera Configuration
num_cams: 2 # Number of cameras (e.g., r_vis, top)

View File

@@ -0,0 +1,10 @@
# @package agent.backbone
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true
# Output dimension calculation:
# ResNet-18 final layer has 512 channels
# After SpatialSoftmax: 512 * 2 = 1024 (x,y coordinates per channel)
# output_dim: 1024

View File

@@ -1,12 +1,13 @@
defaults: defaults:
- _self_ - _self_
- agent: base_siglip - agent: resnet_diffusion
- data: custom_hdf5 # 新增这一行,激活数据配置 - data: resnet_dataset
train: train:
batch_size: 4 # 减小 batch size 方便调试 batch_size: 8 # Batch size for training
lr: 1e-4 lr: 1e-4 # Learning rate
max_steps: 10 max_steps: 10000 # Maximum training steps
log_freq: 10 log_freq: 100 # Log frequency (steps)
device: "cpu" save_freq: 1000 # Save checkpoint frequency (steps)
num_workers: 0 # 调试设为0验证通过后改为 2 或 4 device: "cuda" # Device: "cuda" or "cpu"
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)

View File

@@ -0,0 +1,18 @@
# @package data
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
# Horizon Parameters
pred_horizon: 16 # Prediction horizon (matches agent.pred_horizon)
obs_horizon: 2 # Observation horizon (matches agent.obs_horizon)
action_horizon: 8 # Action execution horizon (used during evaluation)
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
camera_names:
- r_vis
- top
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
normalization_type: gaussian

View File

@@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset):
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding) # 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5] # 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
indices = [] start_idx_raw = start_ts - (self.obs_horizon - 1)
for i in range(self.obs_horizon): start_idx = max(start_idx_raw, 0)
# t - (To - 1) + i end_idx = start_ts + 1
query_ts = start_ts - (self.obs_horizon - 1) + i pad_len = max(0, -start_idx_raw)
# 边界处理 (Padding first frame)
query_ts = max(query_ts, 0)
indices.append(query_ts)
# 读取 qpos (proprioception)
qpos_data = root['observations/qpos']
qpos = qpos_data[indices] # smart indexing
if self.stats:
qpos = self._normalize_data(qpos, self.stats['qpos'])
# 读取 Images # Qpos
# 你有三个视角: angle, r_vis, top qpos_data = root['observations/qpos']
# 建议将它们分开返回,或者在 Dataset 里 Concat qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {} image_dict = {}
for cam_name in self.camera_names: for cam_name in self.camera_names:
# HDF5 dataset
img_dset = root['observations']['images'][cam_name] img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
imgs = [] if pad_len > 0:
for t in indices: first_frame = imgs_np[0]
img = img_dset[t] # (480, 640, 3) uint8 padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W) imgs_np = np.concatenate([padding, imgs_np], axis=0)
imgs.append(img)
# Stack time dimension: (obs_horizon, 3, H, W) # 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
image_dict[cam_name] = torch.stack(imgs) imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# ----------------------------- # ==============================
# 4. 组装 Batch # 3. 组装 Batch
# ----------------------------- # ==============================
data_batch = { data_batch = {
'action': torch.from_numpy(actions).float(), # (Tp, 16) 'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos).float(), # (To, 16) 'qpos': torch.from_numpy(qpos_val).float(),
} }
# 将图像放入 batch
for cam_name, img_tensor in image_dict.items(): for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W) data_batch[f'image_{cam_name}'] = img_tensor
# TODO: 添加 Language Instruction
# 如果所有 episode 共享任务,这里可以是固定 embedding
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
# data_batch['lang_text'] = "pick up the red cube"
return data_batch return data_batch

View File

@@ -1,9 +1,10 @@
# Backbone models # Backbone models
from .siglip import SigLIPBackbone from .siglip import SigLIPBackbone
from .resnet import ResNetBackbone
# from .clip import CLIPBackbone # from .clip import CLIPBackbone
# from .dinov2 import DinoV2Backbone # from .dinov2 import DinoV2Backbone
__all__ = ["SigLIPBackbone"] __all__ = ["SigLIPBackbone", "ResNetBackbone"]
# from .debug import DebugBackbone # from .debug import DebugBackbone
# __all__ = ["DebugBackbone"] # __all__ = ["DebugBackbone"]

View File

@@ -1 +0,0 @@
# CLIP Backbone 实现

View File

@@ -1,30 +0,0 @@
import torch
import torch.nn as nn
from typing import Dict
from roboimi.vla.core.interfaces import VLABackbone
class DebugBackbone(VLABackbone):
"""
A fake backbone that outputs random tensors.
"""
def __init__(self, embed_dim: int = 768, seq_len: int = 10):
super().__init__()
self._embed_dim = embed_dim
self.seq_len = seq_len
# A dummy trainable parameter
self.dummy_param = nn.Parameter(torch.zeros(1))
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
batch_size = obs['image'].shape[0]
# 1. Generate random noise
noise = torch.randn(batch_size, self.seq_len, self._embed_dim, device=obs['image'].device)
# 2. CRITICAL FIX: Add the dummy parameter to the noise.
# This connects 'noise' to 'self.dummy_param' in the computation graph.
# The value doesn't change (since param is 0), but the gradient path is established.
return noise + self.dummy_param
@property
def embed_dim(self) -> int:
return self._embed_dim

View File

@@ -1 +0,0 @@
# DinoV2 Backbone 实现

View File

@@ -0,0 +1,83 @@
from roboimi.vla.core.interfaces import VLABackbone
from transformers import ResNetModel
from torchvision import transforms
import torch
import torch.nn as nn
class ResNetBackbone(VLABackbone):
def __init__(
self,
model_name = "microsoft/resnet-18",
freeze: bool = True,
):
super().__init__()
self.model = ResNetModel.from_pretrained(model_name)
self.out_channels = self.model.config.hidden_sizes[-1]
self.transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing ResNet Backbone parameters")
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def forward_single_image(self, image):
B, T, C, H, W = image.shape
image = image.view(B * T, C, H, W)
image = self.transform(image)
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
features = self.spatial_softmax(feature_map) # (B*T, D*2)
return features
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
sorted_cam_names = sorted(images.keys())
for cam_name in sorted_cam_names:
img = images[cam_name]
features = self.forward_single_image(img) # (B*T, D*2)
features_all.append(features)
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
return combined_features.view(B, T, -1)
@property
def output_dim(self):
"""Output dimension after spatial softmax: out_channels * 2"""
return self.out_channels * 2
class SpatialSoftmax(nn.Module):
"""
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
"""
def __init__(self, num_rows, num_cols, temperature=None):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
# 创建网格坐标
pos_x, pos_y = torch.meshgrid(
torch.linspace(-1, 1, num_rows),
torch.linspace(-1, 1, num_cols),
indexing='ij'
)
self.register_buffer('pos_x', pos_x.reshape(-1))
self.register_buffer('pos_y', pos_y.reshape(-1))
def forward(self, x):
N, C, H, W = x.shape
x = x.view(N, C, -1) # (N, C, H*W)
# 计算 Softmax 注意力图
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
# 计算期望坐标 (x, y)
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
# 拼接并展平 -> (N, C*2)
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)

View File

@@ -1,9 +1,8 @@
# # Action Head models # # Action Head models
from .diffusion import DiffusionHead from .diffusion import ConditionalUnet1D
# from .act import ACTHead # from .act import ACTHead
__all__ = ["DiffusionHead"] __all__ = ["ConditionalUnet1D"]
# from .debug import DebugHead # from .debug import DebugHead
# __all__ = ["DebugHead"] # __all__ = ["DebugHead"]

View File

@@ -1 +0,0 @@
# ACT-VAE Action Head 实现

View File

@@ -1,33 +0,0 @@
import torch
import torch.nn as nn
from typing import Dict, Optional
from roboimi.vla.core.interfaces import VLAHead
class DebugHead(VLAHead):
"""
A fake Action Head using MSE Loss.
Replaces complex Diffusion/ACT policies for architecture verification.
"""
def __init__(self, input_dim: int, action_dim: int, chunk_size: int = 16):
super().__init__()
# Simple regression from embedding -> action chunk
self.regressor = nn.Linear(input_dim, chunk_size * action_dim)
self.action_dim = action_dim
self.chunk_size = chunk_size
self.loss_fn = nn.MSELoss()
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Simple pooling over sequence dimension to get (B, Hidden)
pooled_embed = embeddings.mean(dim=1)
# Predict actions: (B, Chunk * Act_Dim) -> (B, Chunk, Act_Dim)
pred_flat = self.regressor(pooled_embed)
pred_actions = pred_flat.view(-1, self.chunk_size, self.action_dim)
output = {"pred_actions": pred_actions}
if actions is not None:
# Calculate MSE Loss against ground truth
output["loss"] = self.loss_fn(pred_actions, actions)
return output

View File

@@ -5,170 +5,290 @@ from typing import Dict, Optional
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from roboimi.vla.core.interfaces import VLAHead from roboimi.vla.core.interfaces import VLAHead
class DiffusionHead(VLAHead): from typing import Union
def __init__( import logging
self, import torch
input_dim: int, # 来自 Projector 的维度 (e.g. 384) import torch.nn as nn
action_dim: int, # 动作维度 (e.g. 16) import einops
chunk_size: int, # 预测视界 (e.g. 16)
n_timesteps: int = 100, # 扩散步数 import torch
hidden_dim: int = 256 import torch.nn as nn
): import torch.nn.functional as F
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__() super().__init__()
self.action_dim = action_dim self.dim = dim
self.chunk_size = chunk_size
# 1. 噪声调度器 (DDPM)
self.scheduler = DDPMScheduler(
num_train_timesteps=n_timesteps,
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
# 2. 噪声预测网络 (Noise Predictor Network) def forward(self, x):
# 输入: Noisy Action + Time Embedding + Image Embedding device = x.device
# 这是一个简单的 Conditional MLP/ResNet 结构 half_dim = self.dim // 2
self.time_emb = nn.Sequential( emb = math.log(10000) / (half_dim - 1)
nn.Linear(1, hidden_dim), emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(), nn.Mish(),
nn.Linear(hidden_dim, hidden_dim)
) )
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下 def forward(self, x):
return self.block(x)
# 主干网络 (由几个 Residual Block 组成)
self.mid_layers = nn.ModuleList([ class ConditionalResidualBlock1D(nn.Module):
nn.Sequential( def __init__(self,
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim), in_channels,
nn.LayerNorm(hidden_dim), out_channels,
nn.Mish(), cond_dim,
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差 kernel_size=3,
) for _ in range(3) n_groups=8,
cond_predict_scale=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]) ])
# 输出层: 预测噪声 (Shape 与 Action 相同)
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Unified interface for Training and Inference.
"""
device = embeddings.device
# --- 1. 处理条件 (Conditioning) ---
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
# 如果你想做更复杂的 Cross-Attention可以在这里改
global_cond = embeddings.mean(dim=1)
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
# =========================================
# 分支 A: 训练模式 (Training)
# =========================================
if actions is not None:
batch_size = actions.shape[0]
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
actions_flat = actions.view(batch_size, -1)
# 1.2 采样噪声和时间步
noise = torch.randn_like(actions_flat)
timesteps = torch.randint(
0, self.scheduler.config.num_train_timesteps,
(batch_size,), device=device
).long()
# 1.3 加噪 (Forward Diffusion)
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
# 1.4 预测噪声 (Network Forward)
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
loss = nn.functional.mse_loss(pred_noise, noise)
return {"loss": loss}
# ========================================= cond_channels = out_channels
# 分支 B: 推理模式 (Inference) if cond_predict_scale:
# ========================================= cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange('batch t -> batch t 1'),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:,0,...]
bias = embed[:,1,...]
out = scale * out + bias
else: else:
batch_size = embeddings.shape[0] out = out + embed
out = self.blocks[1](out)
# 2.1 从纯高斯噪声开始 out = out + self.residual_conv(x)
noisy_actions = torch.randn( return out
batch_size, self.chunk_size * self.action_dim,
device=device
)
# 2.2 逐步去噪 (Reverse Diffusion Loop)
# 使用 scheduler.timesteps 自动处理步长
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
for t in self.scheduler.timesteps:
# 构造 batch 的 t
timesteps = torch.tensor([t], device=device).repeat(batch_size)
# 预测噪声
# 注意diffusers 的 step 需要 model_output
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
# 移除噪声 (Step)
noisy_actions = self.scheduler.step(
model_output, t, noisy_actions
).prev_sample
# 2.3 Reshape 回 (B, Chunk, ActDim)
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
return {"pred_actions": pred_actions}
def _predict_noise(self, noisy_actions, timesteps, cond_feat): class ConditionalUnet1D(nn.Module):
"""内部辅助函数:运行简单的 MLP 网络""" def __init__(self,
# Time Embed input_dim,
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden) local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False
):
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList([
# down encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
# up encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale)
])
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
# Fusion: Concat Action + (Condition * Time) final_conv = nn.Sequential(
# 这里用简单的相加融合,实际可以更复杂 Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
fused_feat = cond_feat + t_emb nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None, **kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
# Concat input # encode local features
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射 h_local = list()
if local_cond is not None:
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式: x = sample
# 将 cond_feat 加到 input 里需要维度匹配。 h = []
# 这里重写一个极简的 Forward: for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
# 正确做法:先将 x 映射到 hidden再加 t_emb 和 cond_feat if idx == 0 and len(h_local) > 0:
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)... x = x + h_local[0]
# 我们用最傻瓜的方式Input = ActionCondition 直接拼接到每一层或者只拼输入 x = resnet2(x, global_feature)
h.append(x)
# 让我们修正一下网络结构逻辑,确保不报错: x = downsample(x)
# Input: NoisyAction (Dim_A)
# Cond: Hidden (Dim_H) for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
# x = Action for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
# h = Cond + Time x = torch.cat((x, h.pop()), dim=1)
# input = cat([x, h]) -> Linear -> Output x = resnet(x, global_feature)
# The correct condition should be:
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。 # if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# 为了保证一次跑通,我使用动态 cat: # However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
x = noisy_actions if idx == len(self.up_modules) and len(h_local) > 0:
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim x = x + h_local[1]
# 我们把 condition 映射成 hidden_dim然后 concat x = resnet2(x, global_feature)
x = upsample(x)
# 真正的计算流:
h = cond_feat + t_emb # (B, Hidden) x = self.final_conv(x)
# 把 h 拼接到 x 上 (前提是 x 是 action flat) x = einops.rearrange(x, 'b t h -> b h t')
# Linear 输入维度是 Hidden + ActFlat return x
model_input = torch.cat([h, x], dim=-1)
for layer in self.mid_layers:
# Residual connection mechanism
out = layer(model_input)
model_input = out + model_input # Simple ResNet
return self.final_layer(model_input)