feat(train): 跑通训练脚本
This commit is contained in:
@@ -7,115 +7,27 @@ from tqdm import tqdm
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# 确保导入路径正确
|
# Ensure correct import path
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
from roboimi.vla.agent import VLAAgent
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
|
||||||
def main(cfg: DictConfig):
|
|
||||||
print(OmegaConf.to_yaml(cfg))
|
|
||||||
log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})")
|
|
||||||
|
|
||||||
# --- 1. 实例化 Dataset & DataLoader ---
|
|
||||||
# Hydra 根据 conf/data/custom_hdf5.yaml 实例化类
|
|
||||||
dataset = instantiate(cfg.data)
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=cfg.train.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=cfg.train.num_workers,
|
|
||||||
pin_memory=(cfg.train.device != "cpu")
|
|
||||||
)
|
|
||||||
log.info(f"✅ Dataset loaded. Size: {len(dataset)}")
|
|
||||||
|
|
||||||
# --- 2. 实例化 Agent ---
|
|
||||||
agent: VLAAgent = instantiate(cfg.agent)
|
|
||||||
agent.to(cfg.train.device)
|
|
||||||
agent.train()
|
|
||||||
|
|
||||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr)
|
|
||||||
|
|
||||||
# --- 3. Training Loop ---
|
|
||||||
# 使用一个无限迭代器或者 epoch 循环
|
|
||||||
data_iter = iter(dataloader)
|
|
||||||
pbar = tqdm(range(cfg.train.max_steps), desc="Training")
|
|
||||||
|
|
||||||
for step in pbar:
|
|
||||||
try:
|
|
||||||
batch = next(data_iter)
|
|
||||||
except StopIteration:
|
|
||||||
#而在 epoch 结束时重新开始
|
|
||||||
data_iter = iter(dataloader)
|
|
||||||
batch = next(data_iter)
|
|
||||||
|
|
||||||
# Move to device
|
|
||||||
# 注意:这里需要递归地将字典里的 tensor 移到 GPU
|
|
||||||
batch = recursive_to_device(batch, cfg.train.device)
|
|
||||||
|
|
||||||
# --- 4. Adapter Layer (适配层) ---
|
|
||||||
# Dataset 返回的是具体的相机 key (如 'agentview_image' 或 'top')
|
|
||||||
# Agent 期望的是通用的 'image'
|
|
||||||
# 我们在这里做一个映射,模拟多模态融合前的处理
|
|
||||||
|
|
||||||
# 假设我们只用配置里的第一个 key 作为主视觉
|
|
||||||
# primary_cam_key = cfg.data.obs_keys[0]
|
|
||||||
|
|
||||||
# Dataset 返回 shape: (B, Obs_Horizon, C, H, W)
|
|
||||||
# DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim)
|
|
||||||
# 这里我们取 Obs_Horizon 的最后一帧 (Current Frame)
|
|
||||||
# input_img = batch['obs'][primary_cam_key][:, -1, :, :, :]
|
|
||||||
|
|
||||||
# agent_input = {
|
|
||||||
# "obs": {
|
|
||||||
# "image": input_img,
|
|
||||||
# "text": batch["language"] # 传递语言指令
|
|
||||||
# },
|
|
||||||
# "actions": batch["actions"] # (B, Chunk, Dim)
|
|
||||||
# }
|
|
||||||
agent_input = {
|
|
||||||
"action": batch["action"],
|
|
||||||
"qpos": batch["qpos"],
|
|
||||||
"images": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for cam_name in cfg.data.camera_names:
|
|
||||||
key = f"image_{cam_name}"
|
|
||||||
agent_input["images"][cam_name] = batch[key].squeeze(1)
|
|
||||||
|
|
||||||
# --- 5. Forward & Backward ---
|
|
||||||
outputs = agent(agent_input)
|
|
||||||
|
|
||||||
# 处理 Loss 掩码 (如果在真实训练中,需要在这里应用 action_mask)
|
|
||||||
# 目前 DebugHead 内部直接算了 MSE,还没用 mask,我们在下一阶段优化 Policy 时加上
|
|
||||||
loss = outputs['loss']
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if step % cfg.train.log_freq == 0:
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
log.info("✅ Training Loop with Real HDF5 Finished!")
|
|
||||||
|
|
||||||
# --- 6. Save Checkpoint ---
|
|
||||||
save_dir = "checkpoints"
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
save_path = os.path.join(save_dir, "vla_model_final.pt")
|
|
||||||
|
|
||||||
# 保存整个 Agent 的 state_dict
|
|
||||||
torch.save(agent.state_dict(), save_path)
|
|
||||||
log.info(f"💾 Model saved to {save_path}")
|
|
||||||
|
|
||||||
log.info("✅ Training Loop Finished!")
|
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
|
"""
|
||||||
|
Recursively move nested dictionaries/lists of tensors to specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary, list, or tensor
|
||||||
|
device: Target device (e.g., 'cuda', 'cpu')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data structure with all tensors moved to device
|
||||||
|
"""
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.to(device)
|
return data.to(device)
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
@@ -124,5 +36,193 @@ def recursive_to_device(data, device):
|
|||||||
return [recursive_to_device(v, device) for v in data]
|
return [recursive_to_device(v, device) for v in data]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||||
|
def main(cfg: DictConfig):
|
||||||
|
"""
|
||||||
|
VLA Training Script with ResNet Backbone and Diffusion Policy.
|
||||||
|
|
||||||
|
This script:
|
||||||
|
1. Loads dataset from HDF5 files
|
||||||
|
2. Instantiates VLAAgent with ResNet vision encoder
|
||||||
|
3. Trains diffusion-based action prediction
|
||||||
|
4. Saves checkpoints periodically
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Print configuration
|
||||||
|
print("=" * 80)
|
||||||
|
print("VLA Training Configuration:")
|
||||||
|
print("=" * 80)
|
||||||
|
print(OmegaConf.to_yaml(cfg))
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
|
||||||
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
checkpoint_dir = Path("checkpoints")
|
||||||
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 1. Instantiate Dataset & DataLoader
|
||||||
|
# =========================================================================
|
||||||
|
log.info("📦 Loading dataset...")
|
||||||
|
try:
|
||||||
|
dataset = instantiate(cfg.data)
|
||||||
|
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ Failed to load dataset: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=cfg.train.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=cfg.train.num_workers,
|
||||||
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
drop_last=True # Drop incomplete batches for stable training
|
||||||
|
)
|
||||||
|
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 2. Instantiate VLA Agent
|
||||||
|
# =========================================================================
|
||||||
|
log.info("🤖 Initializing VLA Agent...")
|
||||||
|
try:
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
agent.to(cfg.train.device)
|
||||||
|
agent.train()
|
||||||
|
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
|
||||||
|
|
||||||
|
# Count parameters
|
||||||
|
total_params = sum(p.numel() for p in agent.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
||||||
|
log.info(f"📊 Total parameters: {total_params:,}")
|
||||||
|
log.info(f"📊 Trainable parameters: {trainable_params:,}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ Failed to initialize agent: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 3. Setup Optimizer
|
||||||
|
# =========================================================================
|
||||||
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
||||||
|
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 4. Training Loop
|
||||||
|
# =========================================================================
|
||||||
|
log.info("🏋️ Starting training loop...")
|
||||||
|
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
||||||
|
|
||||||
|
best_loss = float('inf')
|
||||||
|
|
||||||
|
for step in pbar:
|
||||||
|
try:
|
||||||
|
batch = next(data_iter)
|
||||||
|
except StopIteration:
|
||||||
|
# Restart iterator when epoch ends
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
batch = next(data_iter)
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Move batch to device
|
||||||
|
# =====================================================================
|
||||||
|
batch = recursive_to_device(batch, cfg.train.device)
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Prepare agent input
|
||||||
|
# =====================================================================
|
||||||
|
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
||||||
|
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
||||||
|
|
||||||
|
# Extract images into a dictionary
|
||||||
|
images = {}
|
||||||
|
for cam_name in cfg.data.camera_names:
|
||||||
|
key = f"image_{cam_name}"
|
||||||
|
if key in batch:
|
||||||
|
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
|
||||||
|
|
||||||
|
# Prepare agent input
|
||||||
|
agent_input = {
|
||||||
|
'images': images, # Dict of camera images
|
||||||
|
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
|
||||||
|
'action': batch['action'] # (B, pred_horizon, action_dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Forward pass & compute loss
|
||||||
|
# =====================================================================
|
||||||
|
try:
|
||||||
|
loss = agent.compute_loss(agent_input)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ Forward pass failed at step {step}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Backward pass & optimization
|
||||||
|
# =====================================================================
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping for stable training
|
||||||
|
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Logging
|
||||||
|
# =====================================================================
|
||||||
|
if step % cfg.train.log_freq == 0:
|
||||||
|
pbar.set_postfix({
|
||||||
|
"loss": f"{loss.item():.4f}",
|
||||||
|
"best_loss": f"{best_loss:.4f}"
|
||||||
|
})
|
||||||
|
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Checkpoint saving
|
||||||
|
# =====================================================================
|
||||||
|
if step > 0 and step % cfg.train.save_freq == 0:
|
||||||
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
||||||
|
torch.save({
|
||||||
|
'step': step,
|
||||||
|
'model_state_dict': agent.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': loss.item(),
|
||||||
|
}, checkpoint_path)
|
||||||
|
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if loss.item() < best_loss:
|
||||||
|
best_loss = loss.item()
|
||||||
|
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
||||||
|
torch.save({
|
||||||
|
'step': step,
|
||||||
|
'model_state_dict': agent.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': loss.item(),
|
||||||
|
}, best_model_path)
|
||||||
|
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 5. Save Final Model
|
||||||
|
# =========================================================================
|
||||||
|
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
||||||
|
torch.save({
|
||||||
|
'step': cfg.train.max_steps,
|
||||||
|
'model_state_dict': agent.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': loss.item(),
|
||||||
|
}, final_model_path)
|
||||||
|
log.info(f"💾 Final model saved: {final_model_path}")
|
||||||
|
|
||||||
|
log.info("✅ Training completed successfully!")
|
||||||
|
log.info(f"📊 Final Loss: {loss.item():.4f}")
|
||||||
|
log.info(f"📊 Best Loss: {best_loss:.4f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
238
roboimi/vla/RESNET_TRAINING_GUIDE.md
Normal file
238
roboimi/vla/RESNET_TRAINING_GUIDE.md
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
# ResNet VLA Training Guide
|
||||||
|
|
||||||
|
This guide explains how to train the VLA agent with ResNet backbone and action_dim=16, obs_dim=16.
|
||||||
|
|
||||||
|
## Configuration Overview
|
||||||
|
|
||||||
|
### 1. Backbone Configuration
|
||||||
|
**File**: `roboimi/vla/conf/backbone/resnet.yaml`
|
||||||
|
- Model: microsoft/resnet-18
|
||||||
|
- Output dim: 1024 (512 channels × 2 from SpatialSoftmax)
|
||||||
|
- Frozen by default for faster training
|
||||||
|
|
||||||
|
### 2. Agent Configuration
|
||||||
|
**File**: `roboimi/vla/conf/agent/resnet_diffusion.yaml`
|
||||||
|
- Vision backbone: ResNet-18 with SpatialSoftmax
|
||||||
|
- Action dimension: 16
|
||||||
|
- Observation dimension: 16
|
||||||
|
- Prediction horizon: 16 steps
|
||||||
|
- Observation horizon: 2 steps
|
||||||
|
- Diffusion steps: 100
|
||||||
|
- Number of cameras: 2
|
||||||
|
|
||||||
|
### 3. Dataset Configuration
|
||||||
|
**File**: `roboimi/vla/conf/data/resnet_dataset.yaml`
|
||||||
|
- Dataset class: RobotDiffusionDataset
|
||||||
|
- Prediction horizon: 16
|
||||||
|
- Observation horizon: 2
|
||||||
|
- Camera names: [r_vis, top]
|
||||||
|
- Normalization: gaussian (mean/std)
|
||||||
|
|
||||||
|
### 4. Training Configuration
|
||||||
|
**File**: `roboimi/vla/conf/config.yaml`
|
||||||
|
- Batch size: 8
|
||||||
|
- Learning rate: 1e-4
|
||||||
|
- Max steps: 10000
|
||||||
|
- Log frequency: 100 steps
|
||||||
|
- Save frequency: 1000 steps
|
||||||
|
- Device: cuda
|
||||||
|
- Num workers: 4
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
### 1. Prepare Dataset
|
||||||
|
Your dataset should be organized as:
|
||||||
|
```
|
||||||
|
/path/to/your/dataset/
|
||||||
|
├── episode_0.hdf5
|
||||||
|
├── episode_1.hdf5
|
||||||
|
├── ...
|
||||||
|
└── data_stats.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
Each HDF5 file should contain:
|
||||||
|
```
|
||||||
|
episode_N.hdf5
|
||||||
|
├── action # (T, 16) float32
|
||||||
|
└── observations/
|
||||||
|
├── qpos # (T, 16) float32
|
||||||
|
└── images/
|
||||||
|
├── r_vis/ # (T, H, W, 3) uint8
|
||||||
|
└── top/ # (T, H, W, 3) uint8
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Generate Dataset Statistics
|
||||||
|
Create `data_stats.pkl` with:
|
||||||
|
```python
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
'action': {
|
||||||
|
'mean': np.zeros(16),
|
||||||
|
'std': np.ones(16)
|
||||||
|
},
|
||||||
|
'qpos': {
|
||||||
|
'mean': np.zeros(16),
|
||||||
|
'std': np.ones(16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open('/path/to/your/dataset/data_stats.pkl', 'wb') as f:
|
||||||
|
pickle.dump(stats, f)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use the provided script:
|
||||||
|
```bash
|
||||||
|
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/your/dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Update Dataset Path
|
||||||
|
Edit `roboimi/vla/conf/data/resnet_dataset.yaml`:
|
||||||
|
```yaml
|
||||||
|
dataset_dir: "/path/to/your/dataset" # CHANGE THIS
|
||||||
|
camera_names:
|
||||||
|
- r_vis # CHANGE TO YOUR CAMERA NAMES
|
||||||
|
- top
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Run Training
|
||||||
|
```bash
|
||||||
|
# Basic training
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py
|
||||||
|
|
||||||
|
# Override configurations
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py train.batch_size=16
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py train.device=cpu
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py train.max_steps=20000
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py data.dataset_dir=/custom/path
|
||||||
|
|
||||||
|
# Debug mode (CPU, small batch, few steps)
|
||||||
|
python roboimi/demos/vla_scripts/train_vla.py \
|
||||||
|
train.device=cpu \
|
||||||
|
train.batch_size=2 \
|
||||||
|
train.max_steps=10 \
|
||||||
|
train.num_workers=0
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Monitor Training
|
||||||
|
Checkpoints are saved to:
|
||||||
|
- `checkpoints/vla_model_step_1000.pt` - Periodic checkpoints
|
||||||
|
- `checkpoints/vla_model_best.pt` - Best model (lowest loss)
|
||||||
|
- `checkpoints/vla_model_final.pt` - Final model
|
||||||
|
|
||||||
|
## Architecture Details
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
1. **Input**: Images from multiple cameras + proprioception (qpos)
|
||||||
|
2. **Vision Encoder**: ResNet-18 → SpatialSoftmax → (B, T, 1024) per camera
|
||||||
|
3. **Feature Concatenation**: All cameras + qpos → Global conditioning
|
||||||
|
4. **Diffusion Policy**: 1D U-Net predicts noise on action sequences
|
||||||
|
5. **Output**: Clean action sequence (B, 16, 16)
|
||||||
|
|
||||||
|
### Training Process
|
||||||
|
1. Sample random timestep t from [0, 100]
|
||||||
|
2. Add noise to ground truth actions
|
||||||
|
3. Predict noise using vision + proprioception conditioning
|
||||||
|
4. Compute MSE loss between predicted and actual noise
|
||||||
|
5. Backpropagate and update weights
|
||||||
|
|
||||||
|
### Inference Process
|
||||||
|
1. Extract visual features from current observation
|
||||||
|
2. Start with random noise action sequence
|
||||||
|
3. Iteratively denoise over 10 steps (DDPM scheduler)
|
||||||
|
4. Return clean action sequence
|
||||||
|
|
||||||
|
## Common Issues
|
||||||
|
|
||||||
|
### Issue: Out of Memory
|
||||||
|
**Solution**: Reduce batch size or use CPU
|
||||||
|
```bash
|
||||||
|
python train_vla.py train.batch_size=4 train.device=cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: Dataset not found
|
||||||
|
**Solution**: Check dataset_dir path in config
|
||||||
|
```bash
|
||||||
|
python train_vla.py data.dataset_dir=/absolute/path/to/dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: Camera names mismatch
|
||||||
|
**Solution**: Update camera_names in data config
|
||||||
|
```yaml
|
||||||
|
# roboimi/vla/conf/data/resnet_dataset.yaml
|
||||||
|
camera_names:
|
||||||
|
- your_camera_1
|
||||||
|
- your_camera_2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue: data_stats.pkl missing
|
||||||
|
**Solution**: Generate statistics file
|
||||||
|
```bash
|
||||||
|
python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Files Created
|
||||||
|
|
||||||
|
```
|
||||||
|
roboimi/vla/
|
||||||
|
├── conf/
|
||||||
|
│ ├── config.yaml (UPDATED)
|
||||||
|
│ ├── backbone/
|
||||||
|
│ │ └── resnet.yaml (NEW)
|
||||||
|
│ ├── agent/
|
||||||
|
│ │ └── resnet_diffusion.yaml (NEW)
|
||||||
|
│ └── data/
|
||||||
|
│ └── resnet_dataset.yaml (NEW)
|
||||||
|
├── models/
|
||||||
|
│ └── backbones/
|
||||||
|
│ ├── __init__.py (UPDATED - added resnet export)
|
||||||
|
│ └── resnet.py (EXISTING)
|
||||||
|
└── demos/vla_scripts/
|
||||||
|
└── train_vla.py (REWRITTEN)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. **Prepare your dataset** in the required HDF5 format
|
||||||
|
2. **Update dataset_dir** in `roboimi/vla/conf/data/resnet_dataset.yaml`
|
||||||
|
3. **Run training** with `python roboimi/demos/vla_scripts/train_vla.py`
|
||||||
|
4. **Monitor checkpoints** in `checkpoints/` directory
|
||||||
|
5. **Evaluate** the trained model using the best checkpoint
|
||||||
|
|
||||||
|
## Advanced Configuration
|
||||||
|
|
||||||
|
### Use Different ResNet Variant
|
||||||
|
Edit `roboimi/vla/conf/agent/resnet_diffusion.yaml`:
|
||||||
|
```yaml
|
||||||
|
vision_backbone:
|
||||||
|
model_name: "microsoft/resnet-50" # or resnet-34, resnet-101
|
||||||
|
```
|
||||||
|
|
||||||
|
### Adjust Diffusion Steps
|
||||||
|
```yaml
|
||||||
|
# More steps = better quality, slower training
|
||||||
|
diffusion_steps: 200 # default: 100
|
||||||
|
```
|
||||||
|
|
||||||
|
### Change Horizons
|
||||||
|
```yaml
|
||||||
|
pred_horizon: 32 # Predict more future steps
|
||||||
|
obs_horizon: 4 # Use more history
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-GPU Training
|
||||||
|
```bash
|
||||||
|
# Use CUDA device 1
|
||||||
|
python train_vla.py train.device=cuda:1
|
||||||
|
|
||||||
|
# For multi-GPU, use torch.distributed (requires code modification)
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- ResNet Paper: https://arxiv.org/abs/1512.03385
|
||||||
|
- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/
|
||||||
|
- VLA Framework Documentation: See CLAUDE.md in project root
|
||||||
@@ -2,111 +2,127 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any
|
||||||
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from roboimi.vla.models.heads.diffusion import ConditionalUnet1D
|
||||||
|
|
||||||
class VLAAgent(nn.Module):
|
class VLAAgent(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone: VLABackbone,
|
vision_backbone, # 你之前定义的 ResNet 类
|
||||||
projector: VLAProjector,
|
action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper)
|
||||||
head: VLAHead,
|
obs_dim, # 本体感知维度 (例如 关节角度)
|
||||||
state_encoder: nn.Module
|
pred_horizon=16, # 预测未来多少步动作
|
||||||
|
obs_horizon=4, # 使用多少步历史观测
|
||||||
|
diffusion_steps=100,
|
||||||
|
num_cams=2, # 视觉输入的摄像头数量
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = backbone
|
self.vision_encoder = vision_backbone
|
||||||
self.projector = projector
|
single_img_feat_dim = self.vision_encoder.output_dim
|
||||||
self.head = head
|
total_vision_dim = single_img_feat_dim * num_cams * obs_horizon
|
||||||
self.state_encoder = state_encoder
|
total_prop_dim = obs_dim * obs_horizon
|
||||||
|
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||||
|
|
||||||
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
self.noise_scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=diffusion_steps,
|
||||||
|
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||||||
|
clip_sample=True,
|
||||||
|
prediction_type='epsilon' # 预测噪声
|
||||||
|
)
|
||||||
|
|
||||||
action = batch["action"]
|
self.noise_pred_net = ConditionalUnet1D(
|
||||||
state = batch["qpos"]
|
input_dim=action_dim,
|
||||||
images = batch["images"]
|
global_cond_dim=self.global_cond_dim
|
||||||
|
)
|
||||||
|
|
||||||
state_emb = self.state_encoder(state)
|
# ==========================
|
||||||
|
# 训练阶段 (Training)
|
||||||
|
# ==========================
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
"""
|
||||||
|
batch: 包含 images, qpos (proprioception), action
|
||||||
|
"""
|
||||||
|
gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim)
|
||||||
|
B = gt_actions.shape[0]
|
||||||
|
images = batch['images']
|
||||||
|
proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim)
|
||||||
|
|
||||||
# 2. Project Features
|
|
||||||
# Shape: (B, Seq, Head_Dim)
|
|
||||||
embeddings = self.projector(features)
|
|
||||||
|
|
||||||
# 3. Compute Action/Loss
|
# 1. 提取视觉特征
|
||||||
# We pass actions if they exist (training mode)
|
visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim)
|
||||||
actions = batch.get('actions', None)
|
|
||||||
outputs = self.head(embeddings=embeddings, actions=actions)
|
|
||||||
|
|
||||||
return outputs
|
# 2. 融合特征 -> 全局条件 (Global Conditioning)
|
||||||
|
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||||||
|
|
||||||
# # roboimi/vla/agent.py
|
# 3. 采样噪声
|
||||||
|
noise = torch.randn_like(gt_actions)
|
||||||
|
|
||||||
# import torch
|
# 4. 随机采样时间步 (Timesteps)
|
||||||
# import torch.nn as nn
|
timesteps = torch.randint(
|
||||||
# from typing import Optional, Dict, Union
|
0, self.noise_scheduler.config.num_train_timesteps,
|
||||||
|
(B,), device=gt_actions.device
|
||||||
|
).long()
|
||||||
|
|
||||||
# class VLAAgent(nn.Module):
|
# 5. 给动作加噪 (Forward Diffusion)
|
||||||
# def __init__(self,
|
noisy_actions = self.noise_scheduler.add_noise(
|
||||||
# vlm_backbone: nn.Module,
|
gt_actions, noise, timesteps
|
||||||
# img_projector: nn.Module,
|
)
|
||||||
# action_head: nn.Module,
|
|
||||||
# state_dim: int,
|
|
||||||
# embed_dim: int):
|
|
||||||
# super().__init__()
|
|
||||||
# self.vlm_backbone = vlm_backbone
|
|
||||||
# self.img_projector = img_projector
|
|
||||||
# self.action_head = action_head
|
|
||||||
|
|
||||||
# # 简单的状态编码器 (通常不需要复杂的 config,直接写在这里即可)
|
# 6. 网络预测噪声
|
||||||
# self.state_encoder = nn.Sequential(
|
# 注意:U-Net 1D 通常期望 channel 在中间: (B, C, T)
|
||||||
# nn.Linear(state_dim, embed_dim),
|
# noisy_actions_inp = noisy_actions.permute(0, 2, 1)
|
||||||
# nn.Mish(),
|
|
||||||
# nn.Linear(embed_dim, embed_dim)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def forward(self,
|
pred_noise = self.noise_pred_net(
|
||||||
# images: torch.Tensor,
|
sample=noisy_actions,
|
||||||
# state: torch.Tensor,
|
timestep=timesteps,
|
||||||
# text: Optional[Union[str, list]] = None,
|
global_cond=global_cond
|
||||||
# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]:
|
)
|
||||||
# """
|
|
||||||
# Args:
|
|
||||||
# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度
|
|
||||||
# state: [Batch, Obs_Horizon, State_Dim]
|
|
||||||
# text: Optional text instructions
|
|
||||||
# actions: [Batch, Pred_Horizon, Action_Dim] (Training only)
|
|
||||||
|
|
||||||
# Returns:
|
# 还原维度 (B, T, C)
|
||||||
# Training: Loss scalar
|
pred_noise = pred_noise.permute(0, 2, 1)
|
||||||
# Inference: Predicted actions
|
|
||||||
# """
|
|
||||||
|
|
||||||
# B, T, C, H, W = images.shape
|
# 7. 计算 Loss (MSE)
|
||||||
|
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||||
|
return loss
|
||||||
|
|
||||||
# # 1. 图像编码 (Flatten time dimension for efficiency)
|
# ==========================
|
||||||
# # [B*T, C, H, W] -> [B*T, Vision_Dim]
|
# 推理阶段 (Inference)
|
||||||
# flat_images = images.view(B * T, C, H, W)
|
# ==========================
|
||||||
# vision_feats_dict = self.vlm_backbone(flat_images)
|
@torch.no_grad()
|
||||||
# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim]
|
def predict_action(self, images, proprioception):
|
||||||
|
B = 1 # 假设单次推理
|
||||||
|
|
||||||
# # 投影并还原时间维度 -> [B, T, Embed_Dim]
|
# 1. 提取当前观测特征 (只做一次)
|
||||||
# img_emb = self.img_projector(raw_img_emb)
|
visual_features = self.vision_encoder(images).view(B, -1)
|
||||||
# img_emb = img_emb.view(B, T, -1)
|
proprioception = proprioception.view(B, -1)
|
||||||
|
global_cond = torch.cat([visual_features, proprioception], dim=-1)
|
||||||
|
|
||||||
# # 2. 状态编码
|
# 2. 初始化纯高斯噪声动作
|
||||||
# state_emb = self.state_encoder(state) # [B, T, Embed_Dim]
|
# Shape: (B, Horizon, Action_Dim)
|
||||||
|
current_actions = torch.randn(
|
||||||
|
(B, 16, 7), device=global_cond.device
|
||||||
|
)
|
||||||
|
|
||||||
# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例)
|
# 3. 逐步去噪循环 (Reverse Diffusion)
|
||||||
# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接
|
self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM)
|
||||||
# # 假设我们只用最近的一帧图像作为 Context,或者将所有历史特征作为 Context
|
|
||||||
# # 这里演示:Context = (Image_History + State_History)
|
|
||||||
# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time)
|
|
||||||
# context = torch.cat([img_emb, state_emb], dim=1)
|
|
||||||
|
|
||||||
# # 4. Action Head 分支
|
for t in self.noise_scheduler.timesteps:
|
||||||
# if actions is not None:
|
# 调整输入格式适应 1D CNN
|
||||||
# # --- Training Mode ---
|
model_input = current_actions.permute(0, 2, 1)
|
||||||
# # 必须返回 Loss
|
|
||||||
# return self.action_head.compute_loss(context, actions)
|
# 预测噪声
|
||||||
# else:
|
noise_pred = self.noise_pred_net(
|
||||||
# # --- Inference Mode ---
|
sample=model_input,
|
||||||
# # 必须返回预测的动作序列
|
timestep=t,
|
||||||
# return self.action_head.predict_action(context)
|
global_cond=global_cond
|
||||||
|
)
|
||||||
|
# noise_pred = noise_pred.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# 移除噪声,更新 current_actions
|
||||||
|
current_actions = self.noise_scheduler.step(
|
||||||
|
noise_pred, t, current_actions
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# 4. 输出最终动作序列
|
||||||
|
return current_actions # 返回去噪后的干净动作
|
||||||
22
roboimi/vla/conf/agent/resnet_diffusion.yaml
Normal file
22
roboimi/vla/conf/agent/resnet_diffusion.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# @package agent
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# Vision Backbone: ResNet-18 with SpatialSoftmax
|
||||||
|
vision_backbone:
|
||||||
|
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||||
|
model_name: "microsoft/resnet-18"
|
||||||
|
freeze: true
|
||||||
|
|
||||||
|
# Action and Observation Dimensions
|
||||||
|
action_dim: 16 # Robot action dimension
|
||||||
|
obs_dim: 16 # Proprioception dimension (qpos)
|
||||||
|
|
||||||
|
# Prediction Horizons
|
||||||
|
pred_horizon: 16 # How many future actions to predict
|
||||||
|
obs_horizon: 2 # How many historical observations to use
|
||||||
|
|
||||||
|
# Diffusion Parameters
|
||||||
|
diffusion_steps: 100 # Number of diffusion timesteps for training
|
||||||
|
|
||||||
|
# Camera Configuration
|
||||||
|
num_cams: 2 # Number of cameras (e.g., r_vis, top)
|
||||||
10
roboimi/vla/conf/backbone/resnet.yaml
Normal file
10
roboimi/vla/conf/backbone/resnet.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# @package agent.backbone
|
||||||
|
_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone
|
||||||
|
|
||||||
|
model_name: "microsoft/resnet-18"
|
||||||
|
freeze: true
|
||||||
|
|
||||||
|
# Output dimension calculation:
|
||||||
|
# ResNet-18 final layer has 512 channels
|
||||||
|
# After SpatialSoftmax: 512 * 2 = 1024 (x,y coordinates per channel)
|
||||||
|
# output_dim: 1024
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- _self_
|
- _self_
|
||||||
- agent: base_siglip
|
- agent: resnet_diffusion
|
||||||
- data: custom_hdf5 # 新增这一行,激活数据配置
|
- data: resnet_dataset
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 4 # 减小 batch size 方便调试
|
batch_size: 8 # Batch size for training
|
||||||
lr: 1e-4
|
lr: 1e-4 # Learning rate
|
||||||
max_steps: 10
|
max_steps: 10000 # Maximum training steps
|
||||||
log_freq: 10
|
log_freq: 100 # Log frequency (steps)
|
||||||
device: "cpu"
|
save_freq: 1000 # Save checkpoint frequency (steps)
|
||||||
num_workers: 0 # 调试设为0,验证通过后改为 2 或 4
|
device: "cuda" # Device: "cuda" or "cpu"
|
||||||
|
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
|
||||||
18
roboimi/vla/conf/data/resnet_dataset.yaml
Normal file
18
roboimi/vla/conf/data/resnet_dataset.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# @package data
|
||||||
|
_target_: roboimi.vla.data.dataset.RobotDiffusionDataset
|
||||||
|
|
||||||
|
# Dataset Directory (CHANGE THIS TO YOUR DATA PATH)
|
||||||
|
dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory
|
||||||
|
|
||||||
|
# Horizon Parameters
|
||||||
|
pred_horizon: 16 # Prediction horizon (matches agent.pred_horizon)
|
||||||
|
obs_horizon: 2 # Observation horizon (matches agent.obs_horizon)
|
||||||
|
action_horizon: 8 # Action execution horizon (used during evaluation)
|
||||||
|
|
||||||
|
# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS)
|
||||||
|
camera_names:
|
||||||
|
- r_vis
|
||||||
|
- top
|
||||||
|
|
||||||
|
# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1])
|
||||||
|
normalization_type: gaussian
|
||||||
@@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset):
|
|||||||
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
||||||
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
||||||
|
|
||||||
indices = []
|
start_idx_raw = start_ts - (self.obs_horizon - 1)
|
||||||
for i in range(self.obs_horizon):
|
start_idx = max(start_idx_raw, 0)
|
||||||
# t - (To - 1) + i
|
end_idx = start_ts + 1
|
||||||
query_ts = start_ts - (self.obs_horizon - 1) + i
|
pad_len = max(0, -start_idx_raw)
|
||||||
# 边界处理 (Padding first frame)
|
|
||||||
query_ts = max(query_ts, 0)
|
|
||||||
indices.append(query_ts)
|
|
||||||
|
|
||||||
# 读取 qpos (proprioception)
|
# Qpos
|
||||||
qpos_data = root['observations/qpos']
|
qpos_data = root['observations/qpos']
|
||||||
qpos = qpos_data[indices] # smart indexing
|
qpos_val = qpos_data[start_idx:end_idx]
|
||||||
if self.stats:
|
|
||||||
qpos = self._normalize_data(qpos, self.stats['qpos'])
|
|
||||||
|
|
||||||
# 读取 Images
|
if pad_len > 0:
|
||||||
# 你有三个视角: angle, r_vis, top
|
first_frame = qpos_val[0]
|
||||||
# 建议将它们分开返回,或者在 Dataset 里 Concat
|
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
|
||||||
|
qpos_val = np.concatenate([padding, qpos_val], axis=0)
|
||||||
|
|
||||||
|
if self.stats:
|
||||||
|
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
|
||||||
|
|
||||||
|
# Images
|
||||||
image_dict = {}
|
image_dict = {}
|
||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
# HDF5 dataset
|
|
||||||
img_dset = root['observations']['images'][cam_name]
|
img_dset = root['observations']['images'][cam_name]
|
||||||
|
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
|
||||||
|
|
||||||
imgs = []
|
if pad_len > 0:
|
||||||
for t in indices:
|
first_frame = imgs_np[0]
|
||||||
img = img_dset[t] # (480, 640, 3) uint8
|
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
|
||||||
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W)
|
imgs_np = np.concatenate([padding, imgs_np], axis=0)
|
||||||
imgs.append(img)
|
|
||||||
|
|
||||||
# Stack time dimension: (obs_horizon, 3, H, W)
|
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
|
||||||
image_dict[cam_name] = torch.stack(imgs)
|
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
|
||||||
|
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
|
||||||
|
image_dict[cam_name] = imgs_tensor
|
||||||
|
|
||||||
# -----------------------------
|
# ==============================
|
||||||
# 4. 组装 Batch
|
# 3. 组装 Batch
|
||||||
# -----------------------------
|
# ==============================
|
||||||
data_batch = {
|
data_batch = {
|
||||||
'action': torch.from_numpy(actions).float(), # (Tp, 16)
|
'action': torch.from_numpy(actions).float(),
|
||||||
'qpos': torch.from_numpy(qpos).float(), # (To, 16)
|
'qpos': torch.from_numpy(qpos_val).float(),
|
||||||
}
|
}
|
||||||
# 将图像放入 batch
|
|
||||||
for cam_name, img_tensor in image_dict.items():
|
for cam_name, img_tensor in image_dict.items():
|
||||||
data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W)
|
data_batch[f'image_{cam_name}'] = img_tensor
|
||||||
|
|
||||||
# TODO: 添加 Language Instruction
|
|
||||||
# 如果所有 episode 共享任务,这里可以是固定 embedding
|
|
||||||
# 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text
|
|
||||||
# data_batch['lang_text'] = "pick up the red cube"
|
|
||||||
|
|
||||||
return data_batch
|
return data_batch
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# Backbone models
|
# Backbone models
|
||||||
from .siglip import SigLIPBackbone
|
from .siglip import SigLIPBackbone
|
||||||
|
from .resnet import ResNetBackbone
|
||||||
# from .clip import CLIPBackbone
|
# from .clip import CLIPBackbone
|
||||||
# from .dinov2 import DinoV2Backbone
|
# from .dinov2 import DinoV2Backbone
|
||||||
|
|
||||||
__all__ = ["SigLIPBackbone"]
|
__all__ = ["SigLIPBackbone", "ResNetBackbone"]
|
||||||
|
|
||||||
# from .debug import DebugBackbone
|
# from .debug import DebugBackbone
|
||||||
# __all__ = ["DebugBackbone"]
|
# __all__ = ["DebugBackbone"]
|
||||||
@@ -1 +0,0 @@
|
|||||||
# CLIP Backbone 实现
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import Dict
|
|
||||||
from roboimi.vla.core.interfaces import VLABackbone
|
|
||||||
|
|
||||||
class DebugBackbone(VLABackbone):
|
|
||||||
"""
|
|
||||||
A fake backbone that outputs random tensors.
|
|
||||||
"""
|
|
||||||
def __init__(self, embed_dim: int = 768, seq_len: int = 10):
|
|
||||||
super().__init__()
|
|
||||||
self._embed_dim = embed_dim
|
|
||||||
self.seq_len = seq_len
|
|
||||||
# A dummy trainable parameter
|
|
||||||
self.dummy_param = nn.Parameter(torch.zeros(1))
|
|
||||||
|
|
||||||
def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
batch_size = obs['image'].shape[0]
|
|
||||||
|
|
||||||
# 1. Generate random noise
|
|
||||||
noise = torch.randn(batch_size, self.seq_len, self._embed_dim, device=obs['image'].device)
|
|
||||||
|
|
||||||
# 2. CRITICAL FIX: Add the dummy parameter to the noise.
|
|
||||||
# This connects 'noise' to 'self.dummy_param' in the computation graph.
|
|
||||||
# The value doesn't change (since param is 0), but the gradient path is established.
|
|
||||||
return noise + self.dummy_param
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embed_dim(self) -> int:
|
|
||||||
return self._embed_dim
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# DinoV2 Backbone 实现
|
|
||||||
83
roboimi/vla/models/backbones/resnet.py
Normal file
83
roboimi/vla/models/backbones/resnet.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
from roboimi.vla.core.interfaces import VLABackbone
|
||||||
|
from transformers import ResNetModel
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class ResNetBackbone(VLABackbone):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name = "microsoft/resnet-18",
|
||||||
|
freeze: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = ResNetModel.from_pretrained(model_name)
|
||||||
|
self.out_channels = self.model.config.hidden_sizes[-1]
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((384, 384)),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12)
|
||||||
|
if freeze:
|
||||||
|
self._freeze_parameters()
|
||||||
|
|
||||||
|
def _freeze_parameters(self):
|
||||||
|
print("❄️ Freezing ResNet Backbone parameters")
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def forward_single_image(self, image):
|
||||||
|
B, T, C, H, W = image.shape
|
||||||
|
image = image.view(B * T, C, H, W)
|
||||||
|
image = self.transform(image)
|
||||||
|
feature_map = self.model(image).last_hidden_state # (B*T, D, H', W')
|
||||||
|
features = self.spatial_softmax(feature_map) # (B*T, D*2)
|
||||||
|
return features
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
any_tensor = next(iter(images.values()))
|
||||||
|
B, T = any_tensor.shape[:2]
|
||||||
|
features_all = []
|
||||||
|
sorted_cam_names = sorted(images.keys())
|
||||||
|
for cam_name in sorted_cam_names:
|
||||||
|
img = images[cam_name]
|
||||||
|
features = self.forward_single_image(img) # (B*T, D*2)
|
||||||
|
features_all.append(features)
|
||||||
|
combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2)
|
||||||
|
return combined_features.view(B, T, -1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self):
|
||||||
|
"""Output dimension after spatial softmax: out_channels * 2"""
|
||||||
|
return self.out_channels * 2
|
||||||
|
|
||||||
|
class SpatialSoftmax(nn.Module):
|
||||||
|
"""
|
||||||
|
将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2)
|
||||||
|
"""
|
||||||
|
def __init__(self, num_rows, num_cols, temperature=None):
|
||||||
|
super().__init__()
|
||||||
|
self.temperature = nn.Parameter(torch.ones(1))
|
||||||
|
# 创建网格坐标
|
||||||
|
pos_x, pos_y = torch.meshgrid(
|
||||||
|
torch.linspace(-1, 1, num_rows),
|
||||||
|
torch.linspace(-1, 1, num_cols),
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
self.register_buffer('pos_x', pos_x.reshape(-1))
|
||||||
|
self.register_buffer('pos_y', pos_y.reshape(-1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N, C, H, W = x.shape
|
||||||
|
x = x.view(N, C, -1) # (N, C, H*W)
|
||||||
|
|
||||||
|
# 计算 Softmax 注意力图
|
||||||
|
softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2)
|
||||||
|
|
||||||
|
# 计算期望坐标 (x, y)
|
||||||
|
expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True)
|
||||||
|
expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True)
|
||||||
|
|
||||||
|
# 拼接并展平 -> (N, C*2)
|
||||||
|
return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1)
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
# # Action Head models
|
# # Action Head models
|
||||||
from .diffusion import DiffusionHead
|
from .diffusion import ConditionalUnet1D
|
||||||
# from .act import ACTHead
|
# from .act import ACTHead
|
||||||
|
|
||||||
__all__ = ["DiffusionHead"]
|
__all__ = ["ConditionalUnet1D"]
|
||||||
|
|
||||||
# from .debug import DebugHead
|
# from .debug import DebugHead
|
||||||
|
|
||||||
# __all__ = ["DebugHead"]
|
# __all__ = ["DebugHead"]
|
||||||
@@ -1 +0,0 @@
|
|||||||
# ACT-VAE Action Head 实现
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from roboimi.vla.core.interfaces import VLAHead
|
|
||||||
|
|
||||||
class DebugHead(VLAHead):
|
|
||||||
"""
|
|
||||||
A fake Action Head using MSE Loss.
|
|
||||||
Replaces complex Diffusion/ACT policies for architecture verification.
|
|
||||||
"""
|
|
||||||
def __init__(self, input_dim: int, action_dim: int, chunk_size: int = 16):
|
|
||||||
super().__init__()
|
|
||||||
# Simple regression from embedding -> action chunk
|
|
||||||
self.regressor = nn.Linear(input_dim, chunk_size * action_dim)
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.loss_fn = nn.MSELoss()
|
|
||||||
|
|
||||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
||||||
# Simple pooling over sequence dimension to get (B, Hidden)
|
|
||||||
pooled_embed = embeddings.mean(dim=1)
|
|
||||||
|
|
||||||
# Predict actions: (B, Chunk * Act_Dim) -> (B, Chunk, Act_Dim)
|
|
||||||
pred_flat = self.regressor(pooled_embed)
|
|
||||||
pred_actions = pred_flat.view(-1, self.chunk_size, self.action_dim)
|
|
||||||
|
|
||||||
output = {"pred_actions": pred_actions}
|
|
||||||
|
|
||||||
if actions is not None:
|
|
||||||
# Calculate MSE Loss against ground truth
|
|
||||||
output["loss"] = self.loss_fn(pred_actions, actions)
|
|
||||||
|
|
||||||
return output
|
|
||||||
@@ -5,170 +5,290 @@ from typing import Dict, Optional
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from roboimi.vla.core.interfaces import VLAHead
|
from roboimi.vla.core.interfaces import VLAHead
|
||||||
|
|
||||||
class DiffusionHead(VLAHead):
|
from typing import Union
|
||||||
def __init__(
|
import logging
|
||||||
self,
|
import torch
|
||||||
input_dim: int, # 来自 Projector 的维度 (e.g. 384)
|
import torch.nn as nn
|
||||||
action_dim: int, # 动作维度 (e.g. 16)
|
import einops
|
||||||
chunk_size: int, # 预测视界 (e.g. 16)
|
|
||||||
n_timesteps: int = 100, # 扩散步数
|
import torch
|
||||||
hidden_dim: int = 256
|
import torch.nn as nn
|
||||||
):
|
import torch.nn.functional as F
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.action_dim = action_dim
|
self.dim = dim
|
||||||
self.chunk_size = chunk_size
|
|
||||||
|
|
||||||
# 1. 噪声调度器 (DDPM)
|
def forward(self, x):
|
||||||
self.scheduler = DDPMScheduler(
|
device = x.device
|
||||||
num_train_timesteps=n_timesteps,
|
half_dim = self.dim // 2
|
||||||
beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
clip_sample=True,
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
prediction_type='epsilon' # 预测噪声
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class Downsample1d(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class Upsample1d(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class Conv1dBlock(nn.Module):
|
||||||
|
'''
|
||||||
|
Conv1d --> GroupNorm --> Mish
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||||
|
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||||
|
nn.GroupNorm(n_groups, out_channels),
|
||||||
|
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||||
|
nn.Mish(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 噪声预测网络 (Noise Predictor Network)
|
def forward(self, x):
|
||||||
# 输入: Noisy Action + Time Embedding + Image Embedding
|
return self.block(x)
|
||||||
# 这是一个简单的 Conditional MLP/ResNet 结构
|
|
||||||
self.time_emb = nn.Sequential(
|
|
||||||
nn.Linear(1, hidden_dim),
|
|
||||||
nn.Mish(),
|
|
||||||
nn.Linear(hidden_dim, hidden_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下
|
class ConditionalResidualBlock1D(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
# 主干网络 (由几个 Residual Block 组成)
|
in_channels,
|
||||||
self.mid_layers = nn.ModuleList([
|
out_channels,
|
||||||
nn.Sequential(
|
cond_dim,
|
||||||
nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim),
|
kernel_size=3,
|
||||||
nn.LayerNorm(hidden_dim),
|
n_groups=8,
|
||||||
nn.Mish(),
|
cond_predict_scale=False):
|
||||||
nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差
|
super().__init__()
|
||||||
) for _ in range(3)
|
self.blocks = nn.ModuleList([
|
||||||
|
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
|
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||||
])
|
])
|
||||||
|
|
||||||
# 输出层: 预测噪声 (Shape 与 Action 相同)
|
|
||||||
self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size)
|
|
||||||
|
|
||||||
def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Unified interface for Training and Inference.
|
|
||||||
"""
|
|
||||||
device = embeddings.device
|
|
||||||
|
|
||||||
# --- 1. 处理条件 (Conditioning) ---
|
cond_channels = out_channels
|
||||||
# embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim)
|
if cond_predict_scale:
|
||||||
# 如果你想做更复杂的 Cross-Attention,可以在这里改
|
cond_channels = out_channels * 2
|
||||||
global_cond = embeddings.mean(dim=1)
|
self.cond_predict_scale = cond_predict_scale
|
||||||
cond_feat = self.cond_proj(global_cond) # (B, Hidden)
|
self.out_channels = out_channels
|
||||||
|
self.cond_encoder = nn.Sequential(
|
||||||
# =========================================
|
nn.Mish(),
|
||||||
# 分支 A: 训练模式 (Training)
|
nn.Linear(cond_dim, cond_channels),
|
||||||
# =========================================
|
Rearrange('batch t -> batch t 1'),
|
||||||
if actions is not None:
|
|
||||||
batch_size = actions.shape[0]
|
|
||||||
|
|
||||||
# 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim)
|
|
||||||
actions_flat = actions.view(batch_size, -1)
|
|
||||||
|
|
||||||
# 1.2 采样噪声和时间步
|
|
||||||
noise = torch.randn_like(actions_flat)
|
|
||||||
timesteps = torch.randint(
|
|
||||||
0, self.scheduler.config.num_train_timesteps,
|
|
||||||
(batch_size,), device=device
|
|
||||||
).long()
|
|
||||||
|
|
||||||
# 1.3 加噪 (Forward Diffusion)
|
|
||||||
noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps)
|
|
||||||
|
|
||||||
# 1.4 预测噪声 (Network Forward)
|
|
||||||
pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
|
||||||
|
|
||||||
# 1.5 计算 Loss (MSE between actual noise and predicted noise)
|
|
||||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
|
||||||
|
|
||||||
return {"loss": loss}
|
|
||||||
|
|
||||||
# =========================================
|
|
||||||
# 分支 B: 推理模式 (Inference)
|
|
||||||
# =========================================
|
|
||||||
else:
|
|
||||||
batch_size = embeddings.shape[0]
|
|
||||||
|
|
||||||
# 2.1 从纯高斯噪声开始
|
|
||||||
noisy_actions = torch.randn(
|
|
||||||
batch_size, self.chunk_size * self.action_dim,
|
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2.2 逐步去噪 (Reverse Diffusion Loop)
|
# make sure dimensions compatible
|
||||||
# 使用 scheduler.timesteps 自动处理步长
|
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||||
self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps)
|
if in_channels != out_channels else nn.Identity()
|
||||||
|
|
||||||
for t in self.scheduler.timesteps:
|
def forward(self, x, cond):
|
||||||
# 构造 batch 的 t
|
'''
|
||||||
timesteps = torch.tensor([t], device=device).repeat(batch_size)
|
x : [ batch_size x in_channels x horizon ]
|
||||||
|
cond : [ batch_size x cond_dim]
|
||||||
|
|
||||||
# 预测噪声
|
returns:
|
||||||
# 注意:diffusers 的 step 需要 model_output
|
out : [ batch_size x out_channels x horizon ]
|
||||||
model_output = self._predict_noise(noisy_actions, timesteps, cond_feat)
|
'''
|
||||||
|
out = self.blocks[0](x)
|
||||||
|
embed = self.cond_encoder(cond)
|
||||||
|
if self.cond_predict_scale:
|
||||||
|
embed = embed.reshape(
|
||||||
|
embed.shape[0], 2, self.out_channels, 1)
|
||||||
|
scale = embed[:,0,...]
|
||||||
|
bias = embed[:,1,...]
|
||||||
|
out = scale * out + bias
|
||||||
|
else:
|
||||||
|
out = out + embed
|
||||||
|
out = self.blocks[1](out)
|
||||||
|
out = out + self.residual_conv(x)
|
||||||
|
return out
|
||||||
|
|
||||||
# 移除噪声 (Step)
|
|
||||||
noisy_actions = self.scheduler.step(
|
|
||||||
model_output, t, noisy_actions
|
|
||||||
).prev_sample
|
|
||||||
|
|
||||||
# 2.3 Reshape 回 (B, Chunk, ActDim)
|
class ConditionalUnet1D(nn.Module):
|
||||||
pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim)
|
def __init__(self,
|
||||||
|
input_dim,
|
||||||
|
local_cond_dim=None,
|
||||||
|
global_cond_dim=None,
|
||||||
|
diffusion_step_embed_dim=256,
|
||||||
|
down_dims=[256,512,1024],
|
||||||
|
kernel_size=3,
|
||||||
|
n_groups=8,
|
||||||
|
cond_predict_scale=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
all_dims = [input_dim] + list(down_dims)
|
||||||
|
start_dim = down_dims[0]
|
||||||
|
|
||||||
return {"pred_actions": pred_actions}
|
dsed = diffusion_step_embed_dim
|
||||||
|
diffusion_step_encoder = nn.Sequential(
|
||||||
|
SinusoidalPosEmb(dsed),
|
||||||
|
nn.Linear(dsed, dsed * 4),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(dsed * 4, dsed),
|
||||||
|
)
|
||||||
|
cond_dim = dsed
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
cond_dim += global_cond_dim
|
||||||
|
|
||||||
def _predict_noise(self, noisy_actions, timesteps, cond_feat):
|
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||||
"""内部辅助函数:运行简单的 MLP 网络"""
|
|
||||||
# Time Embed
|
|
||||||
t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden)
|
|
||||||
|
|
||||||
# Fusion: Concat Action + (Condition * Time)
|
local_cond_encoder = None
|
||||||
# 这里用简单的相加融合,实际可以更复杂
|
if local_cond_dim is not None:
|
||||||
fused_feat = cond_feat + t_emb
|
_, dim_out = in_out[0]
|
||||||
|
dim_in = local_cond_dim
|
||||||
|
local_cond_encoder = nn.ModuleList([
|
||||||
|
# down encoder
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in, dim_out, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale),
|
||||||
|
# up encoder
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in, dim_out, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale)
|
||||||
|
])
|
||||||
|
|
||||||
# Concat input
|
mid_dim = all_dims[-1]
|
||||||
x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射
|
self.mid_modules = nn.ModuleList([
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale
|
||||||
|
),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
mid_dim, mid_dim, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
# 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式:
|
down_modules = nn.ModuleList([])
|
||||||
# 将 cond_feat 加到 input 里需要维度匹配。
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||||
# 这里重写一个极简的 Forward:
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
down_modules.append(nn.ModuleList([
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in, dim_out, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_out, dim_out, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale),
|
||||||
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||||
|
]))
|
||||||
|
|
||||||
# 正确做法:先将 x 映射到 hidden,再加 t_emb 和 cond_feat
|
up_modules = nn.ModuleList([])
|
||||||
# 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)...
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
# 我们用最傻瓜的方式:Input = Action,Condition 直接拼接到每一层或者只拼输入
|
is_last = ind >= (len(in_out) - 1)
|
||||||
|
up_modules.append(nn.ModuleList([
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_out*2, dim_in, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale),
|
||||||
|
ConditionalResidualBlock1D(
|
||||||
|
dim_in, dim_in, cond_dim=cond_dim,
|
||||||
|
kernel_size=kernel_size, n_groups=n_groups,
|
||||||
|
cond_predict_scale=cond_predict_scale),
|
||||||
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||||
|
]))
|
||||||
|
|
||||||
# 让我们修正一下网络结构逻辑,确保不报错:
|
final_conv = nn.Sequential(
|
||||||
# Input: NoisyAction (Dim_A)
|
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||||
# Cond: Hidden (Dim_H)
|
nn.Conv1d(start_dim, input_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
# 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流:
|
self.diffusion_step_encoder = diffusion_step_encoder
|
||||||
# x = Action
|
self.local_cond_encoder = local_cond_encoder
|
||||||
# h = Cond + Time
|
self.up_modules = up_modules
|
||||||
# input = cat([x, h]) -> Linear -> Output
|
self.down_modules = down_modules
|
||||||
|
self.final_conv = final_conv
|
||||||
|
|
||||||
# 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。
|
|
||||||
# 为了保证一次跑通,我使用动态 cat:
|
|
||||||
|
|
||||||
x = noisy_actions
|
def forward(self,
|
||||||
# 假设 mid_layers 的输入是 hidden_dim + action_flat_dim
|
sample: torch.Tensor,
|
||||||
# 我们把 condition 映射成 hidden_dim,然后 concat
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
local_cond=None, global_cond=None, **kwargs):
|
||||||
|
"""
|
||||||
|
x: (B,T,input_dim)
|
||||||
|
timestep: (B,) or int, diffusion step
|
||||||
|
local_cond: (B,T,local_cond_dim)
|
||||||
|
global_cond: (B,global_cond_dim)
|
||||||
|
output: (B,T,input_dim)
|
||||||
|
"""
|
||||||
|
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||||
|
|
||||||
# 真正的计算流:
|
# 1. time
|
||||||
h = cond_feat + t_emb # (B, Hidden)
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
# 把 h 拼接到 x 上 (前提是 x 是 action flat)
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
# Linear 输入维度是 Hidden + ActFlat
|
|
||||||
model_input = torch.cat([h, x], dim=-1)
|
|
||||||
|
|
||||||
for layer in self.mid_layers:
|
if global_cond is not None:
|
||||||
# Residual connection mechanism
|
global_feature = torch.cat([
|
||||||
out = layer(model_input)
|
global_feature, global_cond
|
||||||
model_input = out + model_input # Simple ResNet
|
], axis=-1)
|
||||||
|
|
||||||
|
# encode local features
|
||||||
|
h_local = list()
|
||||||
|
if local_cond is not None:
|
||||||
|
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
|
||||||
|
resnet, resnet2 = self.local_cond_encoder
|
||||||
|
x = resnet(local_cond, global_feature)
|
||||||
|
h_local.append(x)
|
||||||
|
x = resnet2(local_cond, global_feature)
|
||||||
|
h_local.append(x)
|
||||||
|
|
||||||
|
x = sample
|
||||||
|
h = []
|
||||||
|
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
if idx == 0 and len(h_local) > 0:
|
||||||
|
x = x + h_local[0]
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
h.append(x)
|
||||||
|
x = downsample(x)
|
||||||
|
|
||||||
|
for mid_module in self.mid_modules:
|
||||||
|
x = mid_module(x, global_feature)
|
||||||
|
|
||||||
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
|
x = resnet(x, global_feature)
|
||||||
|
# The correct condition should be:
|
||||||
|
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||||
|
# However this change will break compatibility with published checkpoints.
|
||||||
|
# Therefore it is left as a comment.
|
||||||
|
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||||
|
x = x + h_local[1]
|
||||||
|
x = resnet2(x, global_feature)
|
||||||
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
x = einops.rearrange(x, 'b t h -> b h t')
|
||||||
|
return x
|
||||||
|
|
||||||
return self.final_layer(model_input)
|
|
||||||
Reference in New Issue
Block a user