feat(train): 跑通训练脚本

This commit is contained in:
gouhanke
2026-02-05 14:08:43 +08:00
parent dd2749cb12
commit b0a944f7aa
17 changed files with 1002 additions and 464 deletions

View File

@@ -7,115 +7,27 @@ from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.optim import AdamW
from pathlib import Path
# 确保导入路径正确
# Ensure correct import path
sys.path.append(os.getcwd())
from roboimi.vla.agent import VLAAgent
from hydra.utils import instantiate
log = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
# --- 1. 实例化 Dataset & DataLoader ---
# Hydra 根据 conf/data/custom_hdf5.yaml 实例化类
dataset = instantiate(cfg.data)
dataloader = DataLoader(
dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu")
)
log.info(f"✅ Dataset loaded. Size: {len(dataset)}")
# --- 2. 实例化 Agent ---
agent: VLAAgent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr)
# --- 3. Training Loop ---
# 使用一个无限迭代器或者 epoch 循环
data_iter = iter(dataloader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training")
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
#而在 epoch 结束时重新开始
data_iter = iter(dataloader)
batch = next(data_iter)
# Move to device
# 注意:这里需要递归地将字典里的 tensor 移到 GPU
batch = recursive_to_device(batch, cfg.train.device)
# --- 4. Adapter Layer (适配层) ---
# Dataset 返回的是具体的相机 key (如 'agentview_image' 或 'top')
# Agent 期望的是通用的 'image'
# 我们在这里做一个映射,模拟多模态融合前的处理
# 假设我们只用配置里的第一个 key 作为主视觉
# primary_cam_key = cfg.data.obs_keys[0]
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
# input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
# agent_input = {
# "obs": {
# "image": input_img,
# "text": batch["language"] # 传递语言指令
# },
# "actions": batch["actions"] # (B, Chunk, Dim)
# }
agent_input = {
"action": batch["action"],
"qpos": batch["qpos"],
"images": {}
}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
agent_input["images"][cam_name] = batch[key].squeeze(1)
# --- 5. Forward & Backward ---
outputs = agent(agent_input)
# 处理 Loss 掩码 (如果在真实训练中,需要在这里应用 action_mask)
# 目前 DebugHead 内部直接算了 MSE还没用 mask我们在下一阶段优化 Policy 时加上
loss = outputs['loss']
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % cfg.train.log_freq == 0:
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
log.info("✅ Training Loop with Real HDF5 Finished!")
# --- 6. Save Checkpoint ---
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "vla_model_final.pt")
# 保存整个 Agent 的 state_dict
torch.save(agent.state_dict(), save_path)
log.info(f"💾 Model saved to {save_path}")
log.info("✅ Training Loop Finished!")
def recursive_to_device(data, device):
"""
Recursively move nested dictionaries/lists of tensors to specified device.
Args:
data: Dictionary, list, or tensor
device: Target device (e.g., 'cuda', 'cpu')
Returns:
Data structure with all tensors moved to device
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
@@ -124,5 +36,193 @@ def recursive_to_device(data, device):
return [recursive_to_device(v, device) for v in data]
return data
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Training Script with ResNet Backbone and Diffusion Policy.
This script:
1. Loads dataset from HDF5 files
2. Instantiates VLAAgent with ResNet vision encoder
3. Trains diffusion-based action prediction
4. Saves checkpoints periodically
"""
# Print configuration
print("=" * 80)
print("VLA Training Configuration:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
# Create checkpoint directory
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# =========================================================================
# 1. Instantiate Dataset & DataLoader
# =========================================================================
log.info("📦 Loading dataset...")
try:
dataset = instantiate(cfg.data)
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
except Exception as e:
log.error(f"❌ Failed to load dataset: {e}")
raise
dataloader = DataLoader(
dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training
)
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
# =========================================================================
# 2. Instantiate VLA Agent
# =========================================================================
log.info("🤖 Initializing VLA Agent...")
try:
agent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 Total parameters: {total_params:,}")
log.info(f"📊 Trainable parameters: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Failed to initialize agent: {e}")
raise
# =========================================================================
# 3. Setup Optimizer
# =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
# =========================================================================
# 4. Training Loop
# =========================================================================
log.info("🏋️ Starting training loop...")
data_iter = iter(dataloader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
best_loss = float('inf')
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
# Restart iterator when epoch ends
data_iter = iter(dataloader)
batch = next(data_iter)
# =====================================================================
# Move batch to device
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# Prepare agent input
# =====================================================================
# Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor}
# Extract images into a dictionary
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch:
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
# Prepare agent input
agent_input = {
'images': images, # Dict of camera images
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
'action': batch['action'] # (B, pred_horizon, action_dim)
}
# =====================================================================
# Forward pass & compute loss
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"❌ Forward pass failed at step {step}: {e}")
raise
# =====================================================================
# Backward pass & optimization
# =====================================================================
optimizer.zero_grad()
loss.backward()
# Gradient clipping for stable training
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step()
# =====================================================================
# Logging
# =====================================================================
if step % cfg.train.log_freq == 0:
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"best_loss": f"{best_loss:.4f}"
})
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Checkpoint saving
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
best_model_path = checkpoint_dir / "vla_model_best.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
# =========================================================================
# 5. Save Final Model
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
torch.save({
'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}")
log.info("✅ Training completed successfully!")
log.info(f"📊 Final Loss: {loss.item():.4f}")
log.info(f"📊 Best Loss: {best_loss:.4f}")
if __name__ == "__main__":
main()
main()

View File

@@ -0,0 +1,238 @@
# ResNet VLA Training Guide
This guide explains how to train the VLA agent with ResNet backbone and action_dim=16, obs_dim=16.
## Configuration Overview
### 1. Backbone Configuration
**File**: `roboimi/vla/conf/backbone/resnet.yaml`
- Model: microsoft/resnet-18
- Output dim: 1024 (512 channels × 2 from SpatialSoftmax)
- Frozen by default for faster training
### 2. Agent Configuration
**File**: `roboimi/vla/conf/agent/resnet_diffusion.yaml`
- Vision backbone: ResNet-18 with SpatialSoftmax
- Action dimension: 16
- Observation dimension: 16
- Prediction horizon: 16 steps
- Observation horizon: 2 steps
- Diffusion steps: 100
- Number of cameras: 2
### 3. Dataset Configuration
**File**: `roboimi/vla/conf/data/resnet_dataset.yaml`
- Dataset class: RobotDiffusionDataset
- Prediction horizon: 16
- Observation horizon: 2
- Camera names: [r_vis, top]
- Normalization: gaussian (mean/std)
### 4. Training Configuration
**File**: `roboimi/vla/conf/config.yaml`
- Batch size: 8
- Learning rate: 1e-4
- Max steps: 10000
- Log frequency: 100 steps
- Save frequency: 1000 steps
- Device: cuda
- Num workers: 4
## Prerequisites
### 1. Prepare Dataset
Your dataset should be organized as:
```
/path/to/your/dataset/
├── episode_0.hdf5
├── episode_1.hdf5
├── ...
└── data_stats.pkl
```
Each HDF5 file should contain:
```
episode_N.hdf5
├── action # (T, 16) float32
└── observations/
├── qpos # (T, 16) float32
└── images/
├── r_vis/ # (T, H, W, 3) uint8
└── top/ # (T, H, W, 3) uint8
```
### 2. Generate Dataset Statistics
Create `data_stats.pkl` with:
```python
import pickle
import numpy as np
stats = {
'action': {
'mean': np.zeros(16),
'std': np.ones(16)
},
'qpos': {
'mean': np.zeros(16),
'std': np.ones(16)
}
}
with open('/path/to/your/dataset/data_stats.pkl', 'wb') as f:
pickle.dump(stats, f)
```
Or use the provided script:
```bash
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/your/dataset
```
## Usage
### 1. Update Dataset Path
Edit `roboimi/vla/conf/data/resnet_dataset.yaml`:
```yaml
dataset_dir: "/path/to/your/dataset" # CHANGE THIS
camera_names:
- r_vis # CHANGE TO YOUR CAMERA NAMES
- top
```
### 2. Run Training
```bash
# Basic training
python roboimi/demos/vla_scripts/train_vla.py
# Override configurations
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=16
python roboimi/demos/vla_scripts/train_vla.py train.device=cpu
python roboimi/demos/vla_scripts/train_vla.py train.max_steps=20000
python roboimi/demos/vla_scripts/train_vla.py data.dataset_dir=/custom/path
# Debug mode (CPU, small batch, few steps)
python roboimi/demos/vla_scripts/train_vla.py \
train.device=cpu \
train.batch_size=2 \
train.max_steps=10 \
train.num_workers=0
```
### 3. Monitor Training
Checkpoints are saved to:
- `checkpoints/vla_model_step_1000.pt` - Periodic checkpoints
- `checkpoints/vla_model_best.pt` - Best model (lowest loss)
- `checkpoints/vla_model_final.pt` - Final model
## Architecture Details
### Data Flow
1. **Input**: Images from multiple cameras + proprioception (qpos)
2. **Vision Encoder**: ResNet-18 → SpatialSoftmax → (B, T, 1024) per camera
3. **Feature Concatenation**: All cameras + qpos → Global conditioning
4. **Diffusion Policy**: 1D U-Net predicts noise on action sequences
5. **Output**: Clean action sequence (B, 16, 16)
### Training Process
1. Sample random timestep t from [0, 100]
2. Add noise to ground truth actions
3. Predict noise using vision + proprioception conditioning
4. Compute MSE loss between predicted and actual noise
5. Backpropagate and update weights
### Inference Process
1. Extract visual features from current observation
2. Start with random noise action sequence
3. Iteratively denoise over 10 steps (DDPM scheduler)
4. Return clean action sequence
## Common Issues
### Issue: Out of Memory
**Solution**: Reduce batch size or use CPU
```bash
python train_vla.py train.batch_size=4 train.device=cpu
```
### Issue: Dataset not found
**Solution**: Check dataset_dir path in config
```bash
python train_vla.py data.dataset_dir=/absolute/path/to/dataset
```
### Issue: Camera names mismatch
**Solution**: Update camera_names in data config
```yaml
# roboimi/vla/conf/data/resnet_dataset.yaml
camera_names:
- your_camera_1
- your_camera_2
```
### Issue: data_stats.pkl missing
**Solution**: Generate statistics file
```bash
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/dataset
```
## Model Files Created
```
roboimi/vla/
├── conf/
│ ├── config.yaml (UPDATED)
│ ├── backbone/
│ │ └── resnet.yaml (NEW)
│ ├── agent/
│ │ └── resnet_diffusion.yaml (NEW)
│ └── data/
│ └── resnet_dataset.yaml (NEW)
├── models/
│ └── backbones/
│ ├── __init__.py (UPDATED - added resnet export)
│ └── resnet.py (EXISTING)
└── demos/vla_scripts/
└── train_vla.py (REWRITTEN)
```
## Next Steps
1. **Prepare your dataset** in the required HDF5 format
2. **Update dataset_dir** in `roboimi/vla/conf/data/resnet_dataset.yaml`
3. **Run training** with `python roboimi/demos/vla_scripts/train_vla.py`
4. **Monitor checkpoints** in `checkpoints/` directory
5. **Evaluate** the trained model using the best checkpoint
## Advanced Configuration
### Use Different ResNet Variant
Edit `roboimi/vla/conf/agent/resnet_diffusion.yaml`:
```yaml
vision_backbone:
model_name: "microsoft/resnet-50" # or resnet-34, resnet-101
```
### Adjust Diffusion Steps
```yaml
# More steps = better quality, slower training
diffusion_steps: 200 # default: 100
```
### Change Horizons
```yaml
pred_horizon: 32 # Predict more future steps
obs_horizon: 4 # Use more history
```
### Multi-GPU Training
```bash
# Use CUDA device 1
python train_vla.py train.device=cuda:1
# For multi-GPU, use torch.distributed (requires code modification)
```
## References
- ResNet Paper: https://arxiv.org/abs/1512.03385
- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/
- VLA Framework Documentation: See CLAUDE.md in project root

View File

@@ -2,111 +2,127 @@ import torch
import torch.nn as nn
from typing import Dict, Optional, Any
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
class VLAAgent(nn.Module):
def __init__(
self,
backbone: VLABackbone,
projector: VLAProjector,
head: VLAHead,
state_encoder: nn.Module
vision_backbone, # 你之前定义的 ResNet 类
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
obs_dim, # 本体感知维度 (例如 关节角度)
pred_horizon=16, # 预测未来多少步动作
obs_horizon=4, # 使用多少步历史观测
diffusion_steps=100,
num_cams=2, # 视觉输入的摄像头数量
):
super().__init__()
self.backbone = backbone
self.projector = projector
self.head = head
self.state_encoder = state_encoder
self.vision_encoder = vision_backbone
single_img_feat_dim = self.vision_encoder.output_dim
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
action = batch["action"]
state = batch["qpos"]
images = batch["images"]
state_emb = self.state_encoder(state)
# 2. Project Features
# Shape: (B, Seq, Head_Dim)
embeddings = self.projector(features)
# 3. Compute Action/Loss
# We pass actions if they exist (training mode)
actions = batch.get('actions', None)
outputs = self.head(embeddings=embeddings, actions=actions)
return outputs
# # roboimi/vla/agent.py
# import torch
# import torch.nn as nn
# from typing import Optional, Dict, Union
# class VLAAgent(nn.Module):
# def __init__(self,
# vlm_backbone: nn.Module,
# img_projector: nn.Module,
# action_head: nn.Module,
# state_dim: int,
# embed_dim: int):
# super().__init__()
# self.vlm_backbone = vlm_backbone
# self.img_projector = img_projector
# self.action_head = action_head
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
# # 简单的状态编码器 (通常不需要复杂的 config直接写在这里即可)
# self.state_encoder = nn.Sequential(
# nn.Linear(state_dim, embed_dim),
# nn.Mish(),
# nn.Linear(embed_dim, embed_dim)
# )
self.noise_pred_net = ConditionalUnet1D(
input_dim=action_dim,
global_cond_dim=self.global_cond_dim
)
# def forward(self,
# images: torch.Tensor,
# state: torch.Tensor,
# text: Optional[Union[str, list]] = None,
# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
# """
# Args:
# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
# state: [Batch, Obs_Horizon, State_Dim]
# text: Optional text instructions
# actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
# ==========================
# 训练阶段 (Training)
# ==========================
def compute_loss(self, batch):
"""
batch: 包含 images, qpos (proprioception), action
"""
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
B = gt_actions.shape[0]
images = batch['images']
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
# 1. 提取视觉特征
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
# 2. 融合特征 -> 全局条件 (Global Conditioning)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 3. 采样噪声
noise = torch.randn_like(gt_actions)
# 4. 随机采样时间步 (Timesteps)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=gt_actions.device
).long()
# 5. 给动作加噪 (Forward Diffusion)
noisy_actions = self.noise_scheduler.add_noise(
gt_actions, noise, timesteps
)
# 6. 网络预测噪声
# 注意U-Net 1D 通常期望 channel 在中间: (B, C, T)
# noisy_actions_inp = noisy_actions.permute(0, 2, 1)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
global_cond=global_cond
)
# 还原维度 (B, T, C)
pred_noise = pred_noise.permute(0, 2, 1)
# 7. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise)
return loss
# ==========================
# 推理阶段 (Inference)
# ==========================
@torch.no_grad()
def predict_action(self, images, proprioception):
B = 1 # 假设单次推理
# 1. 提取当前观测特征 (只做一次)
visual_features = self.vision_encoder(images).view(B, -1)
proprioception = proprioception.view(B, -1)
global_cond = torch.cat([visual_features, proprioception], dim=-1)
# 2. 初始化纯高斯噪声动作
# Shape: (B, Horizon, Action_Dim)
current_actions = torch.randn(
(B, 16, 7), device=global_cond.device
)
# 3. 逐步去噪循环 (Reverse Diffusion)
self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM)
for t in self.noise_scheduler.timesteps:
# 调整输入格式适应 1D CNN
model_input = current_actions.permute(0, 2, 1)
# Returns:
# Training: Loss scalar
# Inference: Predicted actions
# """
# B, T, C, H, W = images.shape
# # 1. 图像编码 (Flatten time dimension for efficiency)
# # [B*T, C, H, W] -> [B*T, Vision_Dim]
# flat_images = images.view(B * T, C, H, W)
# vision_feats_dict = self.vlm_backbone(flat_images)
# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
# # 投影并还原时间维度 -> [B, T, Embed_Dim]
# img_emb = self.img_projector(raw_img_emb)
# img_emb = img_emb.view(B, T, -1)
# # 2. 状态编码
# state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
# 预测噪声
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond
)
# noise_pred = noise_pred.permute(0, 2, 1)
# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
# # 假设我们只用最近的一帧图像作为 Context或者将所有历史特征作为 Context
# # 这里演示Context = (Image_History + State_History)
# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
# context = torch.cat([img_emb, state_emb], dim=1)
# # 4. Action Head 分支
# if actions is not None:
# # --- Training Mode ---
# # 必须返回 Loss
# return self.action_head.compute_loss(context, actions)
# else:
# # --- Inference Mode ---
# # 必须返回预测的动作序列
# return self.action_head.predict_action(context)
# 移除噪声,更新 current_actions
current_actions = self.noise_scheduler.step(
noise_pred, t, current_actions
).prev_sample
# 4. 输出最终动作序列
return current_actions # 返回去噪后的干净动作

View File

@@ -0,0 +1,22 @@
# @package agent
_target_: roboimi.vla.agent.VLAAgent
# Vision Backbone: ResNet-18 with SpatialSoftmax
vision_backbone:
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true
# Action and Observation Dimensions
action_dim: 16 # Robot action dimension
obs_dim: 16 # Proprioception dimension (qpos)
# Prediction Horizons
pred_horizon: 16 # How many future actions to predict
obs_horizon: 2 # How many historical observations to use
# Diffusion Parameters
diffusion_steps: 100 # Number of diffusion timesteps for training
# Camera Configuration
num_cams: 2 # Number of cameras (e.g., r_vis, top)

View File

@@ -0,0 +1,10 @@
# @package agent.backbone
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
model_name: "microsoft/resnet-18"
freeze: true
# Output dimension calculation:
# ResNet-18 final layer has 512 channels
# After SpatialSoftmax: 512 * 2 = 1024 (x,y coordinates per channel)
# output_dim: 1024

View File

@@ -1,12 +1,13 @@
defaults:
- _self_
- agent: base_siglip
- data: custom_hdf5 # 新增这一行,激活数据配置
- agent: resnet_diffusion
- data: resnet_dataset
train:
batch_size: 4 # 减小 batch size 方便调试
lr: 1e-4
max_steps: 10
log_freq: 10
device: "cpu"
num_workers: 0 # 调试设为0验证通过后改为 2 或 4
batch_size: 8 # Batch size for training
lr: 1e-4 # Learning rate
max_steps: 10000 # Maximum training steps
log_freq: 100 # Log frequency (steps)
save_freq: 1000 # Save checkpoint frequency (steps)
device: "cuda" # Device: "cuda" or "cpu"
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)

View File

@@ -0,0 +1,18 @@
# @package data
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
# Horizon Parameters
pred_horizon: 16 # Prediction horizon (matches agent.pred_horizon)
obs_horizon: 2 # Observation horizon (matches agent.obs_horizon)
action_horizon: 8 # Action execution horizon (used during evaluation)
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
camera_names:
- r_vis
- top
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
normalization_type: gaussian

View File

@@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset):
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
indices = []
for i in range(self.obs_horizon):
# t - (To - 1) + i
query_ts = start_ts - (self.obs_horizon - 1) + i
# 边界处理 (Padding first frame)
query_ts = max(query_ts, 0)
indices.append(query_ts)
# 读取 qpos (proprioception)
qpos_data = root['observations/qpos']
qpos = qpos_data[indices] # smart indexing
if self.stats:
qpos = self._normalize_data(qpos, self.stats['qpos'])
start_idx_raw = start_ts - (self.obs_horizon - 1)
start_idx = max(start_idx_raw, 0)
end_idx = start_ts + 1
pad_len = max(0, -start_idx_raw)
# 读取 Images
# 你有三个视角: angle, r_vis, top
# 建议将它们分开返回,或者在 Dataset 里 Concat
# Qpos
qpos_data = root['observations/qpos']
qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {}
for cam_name in self.camera_names:
# HDF5 dataset
img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
imgs = []
for t in indices:
img = img_dset[t] # (480, 640, 3) uint8
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
imgs.append(img)
if pad_len > 0:
first_frame = imgs_np[0]
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
imgs_np = np.concatenate([padding, imgs_np], axis=0)
# Stack time dimension: (obs_horizon, 3, H, W)
image_dict[cam_name] = torch.stack(imgs)
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# -----------------------------
# 4. 组装 Batch
# -----------------------------
# ==============================
# 3. 组装 Batch
# ==============================
data_batch = {
'action': torch.from_numpy(actions).float(), # (Tp, 16)
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos_val).float(),
}
# 将图像放入 batch
for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
# TODO: 添加 Language Instruction
# 如果所有 episode 共享任务,这里可以是固定 embedding
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
# data_batch['lang_text'] = "pick up the red cube"
data_batch[f'image_{cam_name}'] = img_tensor
return data_batch

View File

@@ -1,9 +1,10 @@
# Backbone models
from .siglip import SigLIPBackbone
from .resnet import ResNetBackbone
# from .clip import CLIPBackbone
# from .dinov2 import DinoV2Backbone
__all__ = ["SigLIPBackbone"]
__all__ = ["SigLIPBackbone", "ResNetBackbone"]
# from .debug import DebugBackbone
# __all__ = ["DebugBackbone"]

View File

@@ -1 +0,0 @@
# CLIP Backbone 实现

View File

@@ -1,30 +0,0 @@
import torch
import torch.nn as nn
from typing import Dict
from roboimi.vla.core.interfaces import VLABackbone
class DebugBackbone(VLABackbone):
"""
A fake backbone that outputs random tensors.
"""
def __init__(self, embed_dim: int = 768, seq_len: int = 10):
super().__init__()
self._embed_dim = embed_dim
self.seq_len = seq_len
# A dummy trainable parameter
self.dummy_param = nn.Parameter(torch.zeros(1))
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
batch_size = obs['image'].shape[0]
# 1. Generate random noise
noise = torch.randn(batch_size, self.seq_len, self._embed_dim, device=obs['image'].device)
# 2. CRITICAL FIX: Add the dummy parameter to the noise.
# This connects 'noise' to 'self.dummy_param' in the computation graph.
# The value doesn't change (since param is 0), but the gradient path is established.
return noise + self.dummy_param
@property
def embed_dim(self) -> int:
return self._embed_dim

View File

@@ -1 +0,0 @@
# DinoV2 Backbone 实现

View File

@@ -0,0 +1,83 @@
from roboimi.vla.core.interfaces import VLABackbone
from transformers import ResNetModel
from torchvision import transforms
import torch
import torch.nn as nn
class ResNetBackbone(VLABackbone):
def __init__(
self,
model_name = "microsoft/resnet-18",
freeze: bool = True,
):
super().__init__()
self.model = ResNetModel.from_pretrained(model_name)
self.out_channels = self.model.config.hidden_sizes[-1]
self.transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
if freeze:
self._freeze_parameters()
def _freeze_parameters(self):
print("❄️ Freezing ResNet Backbone parameters")
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
def forward_single_image(self, image):
B, T, C, H, W = image.shape
image = image.view(B * T, C, H, W)
image = self.transform(image)
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
features = self.spatial_softmax(feature_map) # (B*T, D*2)
return features
def forward(self, images):
any_tensor = next(iter(images.values()))
B, T = any_tensor.shape[:2]
features_all = []
sorted_cam_names = sorted(images.keys())
for cam_name in sorted_cam_names:
img = images[cam_name]
features = self.forward_single_image(img) # (B*T, D*2)
features_all.append(features)
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
return combined_features.view(B, T, -1)
@property
def output_dim(self):
"""Output dimension after spatial softmax: out_channels * 2"""
return self.out_channels * 2
class SpatialSoftmax(nn.Module):
"""
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
"""
def __init__(self, num_rows, num_cols, temperature=None):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
# 创建网格坐标
pos_x, pos_y = torch.meshgrid(
torch.linspace(-1, 1, num_rows),
torch.linspace(-1, 1, num_cols),
indexing='ij'
)
self.register_buffer('pos_x', pos_x.reshape(-1))
self.register_buffer('pos_y', pos_y.reshape(-1))
def forward(self, x):
N, C, H, W = x.shape
x = x.view(N, C, -1) # (N, C, H*W)
# 计算 Softmax 注意力图
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
# 计算期望坐标 (x, y)
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
# 拼接并展平 -> (N, C*2)
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)

View File

@@ -1,9 +1,8 @@
# # Action Head models
from .diffusion import DiffusionHead
from .diffusion import ConditionalUnet1D
# from .act import ACTHead
__all__ = ["DiffusionHead"]
__all__ = ["ConditionalUnet1D"]
# from .debug import DebugHead
# __all__ = ["DebugHead"]

View File

@@ -1 +0,0 @@
# ACT-VAE Action Head 实现

View File

@@ -1,33 +0,0 @@
import torch
import torch.nn as nn
from typing import Dict, Optional
from roboimi.vla.core.interfaces import VLAHead
class DebugHead(VLAHead):
"""
A fake Action Head using MSE Loss.
Replaces complex Diffusion/ACT policies for architecture verification.
"""
def __init__(self, input_dim: int, action_dim: int, chunk_size: int = 16):
super().__init__()
# Simple regression from embedding -> action chunk
self.regressor = nn.Linear(input_dim, chunk_size * action_dim)
self.action_dim = action_dim
self.chunk_size = chunk_size
self.loss_fn = nn.MSELoss()
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
# Simple pooling over sequence dimension to get (B, Hidden)
pooled_embed = embeddings.mean(dim=1)
# Predict actions: (B, Chunk * Act_Dim) -> (B, Chunk, Act_Dim)
pred_flat = self.regressor(pooled_embed)
pred_actions = pred_flat.view(-1, self.chunk_size, self.action_dim)
output = {"pred_actions": pred_actions}
if actions is not None:
# Calculate MSE Loss against ground truth
output["loss"] = self.loss_fn(pred_actions, actions)
return output

View File

@@ -5,170 +5,290 @@ from typing import Dict, Optional
from diffusers import DDPMScheduler
from roboimi.vla.core.interfaces import VLAHead
class DiffusionHead(VLAHead):
def __init__(
self,
input_dim: int, # 来自 Projector 的维度 (e.g. 384)
action_dim: int, # 动作维度 (e.g. 16)
chunk_size: int, # 预测视界 (e.g. 16)
n_timesteps: int = 100, # 扩散步数
hidden_dim: int = 256
):
from typing import Union
import logging
import torch
import torch.nn as nn
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.action_dim = action_dim
self.chunk_size = chunk_size
# 1. 噪声调度器 (DDPM)
self.scheduler = DDPMScheduler(
num_train_timesteps=n_timesteps,
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
clip_sample=True,
prediction_type='epsilon' # 预测噪声
)
self.dim = dim
# 2. 噪声预测网络 (Noise Predictor Network)
# 输入: Noisy Action + Time Embedding + Image Embedding
# 这是一个简单的 Conditional MLP/ResNet 结构
self.time_emb = nn.Sequential(
nn.Linear(1, hidden_dim),
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim)
)
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下
# 主干网络 (由几个 Residual Block 组成)
self.mid_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差
) for _ in range(3)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8,
cond_predict_scale=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
# 输出层: 预测噪声 (Shape 与 Action 相同)
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Unified interface for Training and Inference.
"""
device = embeddings.device
# --- 1. 处理条件 (Conditioning) ---
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
# 如果你想做更复杂的 Cross-Attention可以在这里改
global_cond = embeddings.mean(dim=1)
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
# =========================================
# 分支 A: 训练模式 (Training)
# =========================================
if actions is not None:
batch_size = actions.shape[0]
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
actions_flat = actions.view(batch_size, -1)
# 1.2 采样噪声和时间步
noise = torch.randn_like(actions_flat)
timesteps = torch.randint(
0, self.scheduler.config.num_train_timesteps,
(batch_size,), device=device
).long()
# 1.3 加噪 (Forward Diffusion)
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
# 1.4 预测噪声 (Network Forward)
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
loss = nn.functional.mse_loss(pred_noise, noise)
return {"loss": loss}
# =========================================
# 分支 B: 推理模式 (Inference)
# =========================================
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange('batch t -> batch t 1'),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:,0,...]
bias = embed[:,1,...]
out = scale * out + bias
else:
batch_size = embeddings.shape[0]
# 2.1 从纯高斯噪声开始
noisy_actions = torch.randn(
batch_size, self.chunk_size * self.action_dim,
device=device
)
# 2.2 逐步去噪 (Reverse Diffusion Loop)
# 使用 scheduler.timesteps 自动处理步长
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
for t in self.scheduler.timesteps:
# 构造 batch 的 t
timesteps = torch.tensor([t], device=device).repeat(batch_size)
# 预测噪声
# 注意diffusers 的 step 需要 model_output
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
# 移除噪声 (Step)
noisy_actions = self.scheduler.step(
model_output, t, noisy_actions
).prev_sample
out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
# 2.3 Reshape 回 (B, Chunk, ActDim)
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
return {"pred_actions": pred_actions}
def _predict_noise(self, noisy_actions, timesteps, cond_feat):
"""内部辅助函数:运行简单的 MLP 网络"""
# Time Embed
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden)
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False
):
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList([
# down encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
# up encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale)
])
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
# Fusion: Concat Action + (Condition * Time)
# 这里用简单的相加融合,实际可以更复杂
fused_feat = cond_feat + t_emb
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None, **kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
sample = einops.rearrange(sample, 'b h t -> b t h')
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
# Concat input
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射
# encode local features
h_local = list()
if local_cond is not None:
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式:
# 将 cond_feat 加到 input 里需要维度匹配。
# 这里重写一个极简的 Forward:
# 正确做法:先将 x 映射到 hidden再加 t_emb 和 cond_feat
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)...
# 我们用最傻瓜的方式Input = ActionCondition 直接拼接到每一层或者只拼输入
# 让我们修正一下网络结构逻辑,确保不报错:
# Input: NoisyAction (Dim_A)
# Cond: Hidden (Dim_H)
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
# x = Action
# h = Cond + Time
# input = cat([x, h]) -> Linear -> Output
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。
# 为了保证一次跑通,我使用动态 cat:
x = noisy_actions
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim
# 我们把 condition 映射成 hidden_dim然后 concat
# 真正的计算流:
h = cond_feat + t_emb # (B, Hidden)
# 把 h 拼接到 x 上 (前提是 x 是 action flat)
# Linear 输入维度是 Hidden + ActFlat
model_input = torch.cat([h, x], dim=-1)
for layer in self.mid_layers:
# Residual connection mechanism
out = layer(model_input)
model_input = out + model_input # Simple ResNet
return self.final_layer(model_input)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x