From b0a944f7aa87c3eaeac0d6e0904c86eb61a5428f Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Thu, 5 Feb 2026 14:08:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(train):=20=E8=B7=91=E9=80=9A=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 302 ++++++++----- roboimi/vla/RESNET_TRAINING_GUIDE.md | 238 +++++++++++ roboimi/vla/agent.py | 208 ++++----- roboimi/vla/conf/agent/resnet_diffusion.yaml | 22 + roboimi/vla/conf/backbone/resnet.yaml | 10 + roboimi/vla/conf/config.yaml | 17 +- roboimi/vla/conf/data/resnet_dataset.yaml | 18 + roboimi/vla/data/dataset.py | 68 ++- roboimi/vla/models/backbones/__init__.py | 3 +- roboimi/vla/models/backbones/clip.py | 1 - roboimi/vla/models/backbones/debug.py | 30 -- roboimi/vla/models/backbones/dinov2.py | 1 - roboimi/vla/models/backbones/resnet.py | 83 ++++ roboimi/vla/models/heads/__init__.py | 5 +- roboimi/vla/models/heads/act.py | 1 - roboimi/vla/models/heads/debug.py | 33 -- roboimi/vla/models/heads/diffusion.py | 426 ++++++++++++------- 17 files changed, 1002 insertions(+), 464 deletions(-) create mode 100644 roboimi/vla/RESNET_TRAINING_GUIDE.md create mode 100644 roboimi/vla/conf/agent/resnet_diffusion.yaml create mode 100644 roboimi/vla/conf/backbone/resnet.yaml create mode 100644 roboimi/vla/conf/data/resnet_dataset.yaml delete mode 100644 roboimi/vla/models/backbones/clip.py delete mode 100644 roboimi/vla/models/backbones/debug.py delete mode 100644 roboimi/vla/models/backbones/dinov2.py create mode 100644 roboimi/vla/models/backbones/resnet.py delete mode 100644 roboimi/vla/models/heads/act.py delete mode 100644 roboimi/vla/models/heads/debug.py diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 7faf1a9..c4376f8 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -7,115 +7,27 @@ from tqdm import tqdm from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader from torch.optim import AdamW +from pathlib import Path -# 确保导入路径正确 +# Ensure correct import path sys.path.append(os.getcwd()) -from roboimi.vla.agent import VLAAgent from hydra.utils import instantiate log = logging.getLogger(__name__) -@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") -def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) - log.info(f"🚀 Starting VLA Training with Real Data (Device: {cfg.train.device})") - - # --- 1. 实例化 Dataset & DataLoader --- - # Hydra 根据 conf/data/custom_hdf5.yaml 实例化类 - dataset = instantiate(cfg.data) - - dataloader = DataLoader( - dataset, - batch_size=cfg.train.batch_size, - shuffle=True, - num_workers=cfg.train.num_workers, - pin_memory=(cfg.train.device != "cpu") - ) - log.info(f"✅ Dataset loaded. Size: {len(dataset)}") - - # --- 2. 实例化 Agent --- - agent: VLAAgent = instantiate(cfg.agent) - agent.to(cfg.train.device) - agent.train() - - optimizer = AdamW(agent.parameters(), lr=cfg.train.lr) - - # --- 3. Training Loop --- - # 使用一个无限迭代器或者 epoch 循环 - data_iter = iter(dataloader) - pbar = tqdm(range(cfg.train.max_steps), desc="Training") - - for step in pbar: - try: - batch = next(data_iter) - except StopIteration: - #而在 epoch 结束时重新开始 - data_iter = iter(dataloader) - batch = next(data_iter) - - # Move to device - # 注意:这里需要递归地将字典里的 tensor 移到 GPU - batch = recursive_to_device(batch, cfg.train.device) - - # --- 4. Adapter Layer (适配层) --- - # Dataset 返回的是具体的相机 key (如 'agentview_image' 或 'top') - # Agent 期望的是通用的 'image' - # 我们在这里做一个映射,模拟多模态融合前的处理 - - # 假设我们只用配置里的第一个 key 作为主视觉 - # primary_cam_key = cfg.data.obs_keys[0] - - # Dataset 返回 shape: (B, Obs_Horizon, C, H, W) - # DebugBackbone 期望: (B, C, H, W) 或者 (B, Seq, Dim) - # 这里我们取 Obs_Horizon 的最后一帧 (Current Frame) - # input_img = batch['obs'][primary_cam_key][:, -1, :, :, :] - - # agent_input = { - # "obs": { - # "image": input_img, - # "text": batch["language"] # 传递语言指令 - # }, - # "actions": batch["actions"] # (B, Chunk, Dim) - # } - agent_input = { - "action": batch["action"], - "qpos": batch["qpos"], - "images": {} - } - - for cam_name in cfg.data.camera_names: - key = f"image_{cam_name}" - agent_input["images"][cam_name] = batch[key].squeeze(1) - - # --- 5. Forward & Backward --- - outputs = agent(agent_input) - - # 处理 Loss 掩码 (如果在真实训练中,需要在这里应用 action_mask) - # 目前 DebugHead 内部直接算了 MSE,还没用 mask,我们在下一阶段优化 Policy 时加上 - loss = outputs['loss'] - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if step % cfg.train.log_freq == 0: - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) - - log.info("✅ Training Loop with Real HDF5 Finished!") - -# --- 6. Save Checkpoint --- - save_dir = "checkpoints" - os.makedirs(save_dir, exist_ok=True) - save_path = os.path.join(save_dir, "vla_model_final.pt") - - # 保存整个 Agent 的 state_dict - torch.save(agent.state_dict(), save_path) - log.info(f"💾 Model saved to {save_path}") - - log.info("✅ Training Loop Finished!") def recursive_to_device(data, device): + """ + Recursively move nested dictionaries/lists of tensors to specified device. + + Args: + data: Dictionary, list, or tensor + device: Target device (e.g., 'cuda', 'cpu') + + Returns: + Data structure with all tensors moved to device + """ if isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, dict): @@ -124,5 +36,193 @@ def recursive_to_device(data, device): return [recursive_to_device(v, device) for v in data] return data + +@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") +def main(cfg: DictConfig): + """ + VLA Training Script with ResNet Backbone and Diffusion Policy. + + This script: + 1. Loads dataset from HDF5 files + 2. Instantiates VLAAgent with ResNet vision encoder + 3. Trains diffusion-based action prediction + 4. Saves checkpoints periodically + """ + + # Print configuration + print("=" * 80) + print("VLA Training Configuration:") + print("=" * 80) + print(OmegaConf.to_yaml(cfg)) + print("=" * 80) + + log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})") + + # Create checkpoint directory + checkpoint_dir = Path("checkpoints") + checkpoint_dir.mkdir(exist_ok=True) + + # ========================================================================= + # 1. Instantiate Dataset & DataLoader + # ========================================================================= + log.info("📦 Loading dataset...") + try: + dataset = instantiate(cfg.data) + log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}") + except Exception as e: + log.error(f"❌ Failed to load dataset: {e}") + raise + + dataloader = DataLoader( + dataset, + batch_size=cfg.train.batch_size, + shuffle=True, + num_workers=cfg.train.num_workers, + pin_memory=(cfg.train.device != "cpu"), + drop_last=True # Drop incomplete batches for stable training + ) + log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}") + + # ========================================================================= + # 2. Instantiate VLA Agent + # ========================================================================= + log.info("🤖 Initializing VLA Agent...") + try: + agent = instantiate(cfg.agent) + agent.to(cfg.train.device) + agent.train() + log.info(f"✅ Agent initialized and moved to {cfg.train.device}") + + # Count parameters + total_params = sum(p.numel() for p in agent.parameters()) + trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) + log.info(f"📊 Total parameters: {total_params:,}") + log.info(f"📊 Trainable parameters: {trainable_params:,}") + + except Exception as e: + log.error(f"❌ Failed to initialize agent: {e}") + raise + + # ========================================================================= + # 3. Setup Optimizer + # ========================================================================= + optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5) + log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})") + + # ========================================================================= + # 4. Training Loop + # ========================================================================= + log.info("🏋️ Starting training loop...") + + data_iter = iter(dataloader) + pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100) + + best_loss = float('inf') + + for step in pbar: + try: + batch = next(data_iter) + except StopIteration: + # Restart iterator when epoch ends + data_iter = iter(dataloader) + batch = next(data_iter) + + # ===================================================================== + # Move batch to device + # ===================================================================== + batch = recursive_to_device(batch, cfg.train.device) + + # ===================================================================== + # Prepare agent input + # ===================================================================== + # Dataset returns: {action, qpos, image_, ...} + # Agent expects: {images: dict, qpos: tensor, action: tensor} + + # Extract images into a dictionary + images = {} + for cam_name in cfg.data.camera_names: + key = f"image_{cam_name}" + if key in batch: + images[cam_name] = batch[key] # (B, obs_horizon, C, H, W) + + # Prepare agent input + agent_input = { + 'images': images, # Dict of camera images + 'qpos': batch['qpos'], # (B, obs_horizon, obs_dim) + 'action': batch['action'] # (B, pred_horizon, action_dim) + } + + # ===================================================================== + # Forward pass & compute loss + # ===================================================================== + try: + loss = agent.compute_loss(agent_input) + except Exception as e: + log.error(f"❌ Forward pass failed at step {step}: {e}") + raise + + # ===================================================================== + # Backward pass & optimization + # ===================================================================== + optimizer.zero_grad() + loss.backward() + + # Gradient clipping for stable training + torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) + + optimizer.step() + + # ===================================================================== + # Logging + # ===================================================================== + if step % cfg.train.log_freq == 0: + pbar.set_postfix({ + "loss": f"{loss.item():.4f}", + "best_loss": f"{best_loss:.4f}" + }) + log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") + + # ===================================================================== + # Checkpoint saving + # ===================================================================== + if step > 0 and step % cfg.train.save_freq == 0: + checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt" + torch.save({ + 'step': step, + 'model_state_dict': agent.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss.item(), + }, checkpoint_path) + log.info(f"💾 Checkpoint saved: {checkpoint_path}") + + # Save best model + if loss.item() < best_loss: + best_loss = loss.item() + best_model_path = checkpoint_dir / "vla_model_best.pt" + torch.save({ + 'step': step, + 'model_state_dict': agent.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss.item(), + }, best_model_path) + log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})") + + # ========================================================================= + # 5. Save Final Model + # ========================================================================= + final_model_path = checkpoint_dir / "vla_model_final.pt" + torch.save({ + 'step': cfg.train.max_steps, + 'model_state_dict': agent.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss.item(), + }, final_model_path) + log.info(f"💾 Final model saved: {final_model_path}") + + log.info("✅ Training completed successfully!") + log.info(f"📊 Final Loss: {loss.item():.4f}") + log.info(f"📊 Best Loss: {best_loss:.4f}") + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/roboimi/vla/RESNET_TRAINING_GUIDE.md b/roboimi/vla/RESNET_TRAINING_GUIDE.md new file mode 100644 index 0000000..8071d4f --- /dev/null +++ b/roboimi/vla/RESNET_TRAINING_GUIDE.md @@ -0,0 +1,238 @@ +# ResNet VLA Training Guide + +This guide explains how to train the VLA agent with ResNet backbone and action_dim=16, obs_dim=16. + +## Configuration Overview + +### 1. Backbone Configuration +**File**: `roboimi/vla/conf/backbone/resnet.yaml` +- Model: microsoft/resnet-18 +- Output dim: 1024 (512 channels × 2 from SpatialSoftmax) +- Frozen by default for faster training + +### 2. Agent Configuration +**File**: `roboimi/vla/conf/agent/resnet_diffusion.yaml` +- Vision backbone: ResNet-18 with SpatialSoftmax +- Action dimension: 16 +- Observation dimension: 16 +- Prediction horizon: 16 steps +- Observation horizon: 2 steps +- Diffusion steps: 100 +- Number of cameras: 2 + +### 3. Dataset Configuration +**File**: `roboimi/vla/conf/data/resnet_dataset.yaml` +- Dataset class: RobotDiffusionDataset +- Prediction horizon: 16 +- Observation horizon: 2 +- Camera names: [r_vis, top] +- Normalization: gaussian (mean/std) + +### 4. Training Configuration +**File**: `roboimi/vla/conf/config.yaml` +- Batch size: 8 +- Learning rate: 1e-4 +- Max steps: 10000 +- Log frequency: 100 steps +- Save frequency: 1000 steps +- Device: cuda +- Num workers: 4 + +## Prerequisites + +### 1. Prepare Dataset +Your dataset should be organized as: +``` +/path/to/your/dataset/ +├── episode_0.hdf5 +├── episode_1.hdf5 +├── ... +└── data_stats.pkl +``` + +Each HDF5 file should contain: +``` +episode_N.hdf5 +├── action # (T, 16) float32 +└── observations/ + ├── qpos # (T, 16) float32 + └── images/ + ├── r_vis/ # (T, H, W, 3) uint8 + └── top/ # (T, H, W, 3) uint8 +``` + +### 2. Generate Dataset Statistics +Create `data_stats.pkl` with: +```python +import pickle +import numpy as np + +stats = { + 'action': { + 'mean': np.zeros(16), + 'std': np.ones(16) + }, + 'qpos': { + 'mean': np.zeros(16), + 'std': np.ones(16) + } +} + +with open('/path/to/your/dataset/data_stats.pkl', 'wb') as f: + pickle.dump(stats, f) +``` + +Or use the provided script: +```bash +python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/your/dataset +``` + +## Usage + +### 1. Update Dataset Path +Edit `roboimi/vla/conf/data/resnet_dataset.yaml`: +```yaml +dataset_dir: "/path/to/your/dataset" # CHANGE THIS +camera_names: + - r_vis # CHANGE TO YOUR CAMERA NAMES + - top +``` + +### 2. Run Training +```bash +# Basic training +python roboimi/demos/vla_scripts/train_vla.py + +# Override configurations +python roboimi/demos/vla_scripts/train_vla.py train.batch_size=16 +python roboimi/demos/vla_scripts/train_vla.py train.device=cpu +python roboimi/demos/vla_scripts/train_vla.py train.max_steps=20000 +python roboimi/demos/vla_scripts/train_vla.py data.dataset_dir=/custom/path + +# Debug mode (CPU, small batch, few steps) +python roboimi/demos/vla_scripts/train_vla.py \ + train.device=cpu \ + train.batch_size=2 \ + train.max_steps=10 \ + train.num_workers=0 +``` + +### 3. Monitor Training +Checkpoints are saved to: +- `checkpoints/vla_model_step_1000.pt` - Periodic checkpoints +- `checkpoints/vla_model_best.pt` - Best model (lowest loss) +- `checkpoints/vla_model_final.pt` - Final model + +## Architecture Details + +### Data Flow +1. **Input**: Images from multiple cameras + proprioception (qpos) +2. **Vision Encoder**: ResNet-18 → SpatialSoftmax → (B, T, 1024) per camera +3. **Feature Concatenation**: All cameras + qpos → Global conditioning +4. **Diffusion Policy**: 1D U-Net predicts noise on action sequences +5. **Output**: Clean action sequence (B, 16, 16) + +### Training Process +1. Sample random timestep t from [0, 100] +2. Add noise to ground truth actions +3. Predict noise using vision + proprioception conditioning +4. Compute MSE loss between predicted and actual noise +5. Backpropagate and update weights + +### Inference Process +1. Extract visual features from current observation +2. Start with random noise action sequence +3. Iteratively denoise over 10 steps (DDPM scheduler) +4. Return clean action sequence + +## Common Issues + +### Issue: Out of Memory +**Solution**: Reduce batch size or use CPU +```bash +python train_vla.py train.batch_size=4 train.device=cpu +``` + +### Issue: Dataset not found +**Solution**: Check dataset_dir path in config +```bash +python train_vla.py data.dataset_dir=/absolute/path/to/dataset +``` + +### Issue: Camera names mismatch +**Solution**: Update camera_names in data config +```yaml +# roboimi/vla/conf/data/resnet_dataset.yaml +camera_names: + - your_camera_1 + - your_camera_2 +``` + +### Issue: data_stats.pkl missing +**Solution**: Generate statistics file +```bash +python -m roboimi.vla.scripts.calculate_stats --dataset_dir /path/to/dataset +``` + +## Model Files Created + +``` +roboimi/vla/ +├── conf/ +│ ├── config.yaml (UPDATED) +│ ├── backbone/ +│ │ └── resnet.yaml (NEW) +│ ├── agent/ +│ │ └── resnet_diffusion.yaml (NEW) +│ └── data/ +│ └── resnet_dataset.yaml (NEW) +├── models/ +│ └── backbones/ +│ ├── __init__.py (UPDATED - added resnet export) +│ └── resnet.py (EXISTING) +└── demos/vla_scripts/ + └── train_vla.py (REWRITTEN) +``` + +## Next Steps + +1. **Prepare your dataset** in the required HDF5 format +2. **Update dataset_dir** in `roboimi/vla/conf/data/resnet_dataset.yaml` +3. **Run training** with `python roboimi/demos/vla_scripts/train_vla.py` +4. **Monitor checkpoints** in `checkpoints/` directory +5. **Evaluate** the trained model using the best checkpoint + +## Advanced Configuration + +### Use Different ResNet Variant +Edit `roboimi/vla/conf/agent/resnet_diffusion.yaml`: +```yaml +vision_backbone: + model_name: "microsoft/resnet-50" # or resnet-34, resnet-101 +``` + +### Adjust Diffusion Steps +```yaml +# More steps = better quality, slower training +diffusion_steps: 200 # default: 100 +``` + +### Change Horizons +```yaml +pred_horizon: 32 # Predict more future steps +obs_horizon: 4 # Use more history +``` + +### Multi-GPU Training +```bash +# Use CUDA device 1 +python train_vla.py train.device=cuda:1 + +# For multi-GPU, use torch.distributed (requires code modification) +``` + +## References + +- ResNet Paper: https://arxiv.org/abs/1512.03385 +- Diffusion Policy: https://diffusion-policy.cs.columbia.edu/ +- VLA Framework Documentation: See CLAUDE.md in project root diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index c60585f..5684e82 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -2,111 +2,127 @@ import torch import torch.nn as nn from typing import Dict, Optional, Any from roboimi.vla.core.interfaces import VLABackbone, VLAProjector, VLAHead +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from roboimi.vla.models.heads.diffusion import ConditionalUnet1D class VLAAgent(nn.Module): def __init__( self, - backbone: VLABackbone, - projector: VLAProjector, - head: VLAHead, - state_encoder: nn.Module + vision_backbone, # 你之前定义的 ResNet 类 + action_dim, # 机器人动作维度 (例如 7: xyz + rpy + gripper) + obs_dim, # 本体感知维度 (例如 关节角度) + pred_horizon=16, # 预测未来多少步动作 + obs_horizon=4, # 使用多少步历史观测 + diffusion_steps=100, + num_cams=2, # 视觉输入的摄像头数量 ): super().__init__() - self.backbone = backbone - self.projector = projector - self.head = head - self.state_encoder = state_encoder + self.vision_encoder = vision_backbone + single_img_feat_dim = self.vision_encoder.output_dim + total_vision_dim = single_img_feat_dim * num_cams * obs_horizon + total_prop_dim = obs_dim * obs_horizon + self.global_cond_dim = total_vision_dim + total_prop_dim - def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: - - action = batch["action"] - state = batch["qpos"] - images = batch["images"] - - state_emb = self.state_encoder(state) - - # 2. Project Features - # Shape: (B, Seq, Head_Dim) - embeddings = self.projector(features) - - # 3. Compute Action/Loss - # We pass actions if they exist (training mode) - actions = batch.get('actions', None) - outputs = self.head(embeddings=embeddings, actions=actions) - - return outputs - -# # roboimi/vla/agent.py - -# import torch -# import torch.nn as nn -# from typing import Optional, Dict, Union - -# class VLAAgent(nn.Module): -# def __init__(self, -# vlm_backbone: nn.Module, -# img_projector: nn.Module, -# action_head: nn.Module, -# state_dim: int, -# embed_dim: int): -# super().__init__() -# self.vlm_backbone = vlm_backbone -# self.img_projector = img_projector -# self.action_head = action_head + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=diffusion_steps, + beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule + clip_sample=True, + prediction_type='epsilon' # 预测噪声 + ) -# # 简单的状态编码器 (通常不需要复杂的 config,直接写在这里即可) -# self.state_encoder = nn.Sequential( -# nn.Linear(state_dim, embed_dim), -# nn.Mish(), -# nn.Linear(embed_dim, embed_dim) -# ) + self.noise_pred_net = ConditionalUnet1D( + input_dim=action_dim, + global_cond_dim=self.global_cond_dim + ) -# def forward(self, -# images: torch.Tensor, -# state: torch.Tensor, -# text: Optional[Union[str, list]] = None, -# actions: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Dict]: -# """ -# Args: -# images: [Batch, Obs_Horizon, C, H, W] 注意: 这里需要处理时间维度 -# state: [Batch, Obs_Horizon, State_Dim] -# text: Optional text instructions -# actions: [Batch, Pred_Horizon, Action_Dim] (Training only) + # ========================== + # 训练阶段 (Training) + # ========================== + def compute_loss(self, batch): + """ + batch: 包含 images, qpos (proprioception), action + """ + gt_actions = batch['action'] # Shape: (B, Horizon, Action_Dim) + B = gt_actions.shape[0] + images = batch['images'] + proprioception = batch['qpos'].view(B, -1) # (B, obs_horizon * obs_dim) + + + # 1. 提取视觉特征 + visual_features = self.vision_encoder(images).view(B, -1) # (B, vision_dim) + + # 2. 融合特征 -> 全局条件 (Global Conditioning) + global_cond = torch.cat([visual_features, proprioception], dim=-1) + + # 3. 采样噪声 + noise = torch.randn_like(gt_actions) + + # 4. 随机采样时间步 (Timesteps) + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, + (B,), device=gt_actions.device + ).long() + + # 5. 给动作加噪 (Forward Diffusion) + noisy_actions = self.noise_scheduler.add_noise( + gt_actions, noise, timesteps + ) + + # 6. 网络预测噪声 + # 注意:U-Net 1D 通常期望 channel 在中间: (B, C, T) + # noisy_actions_inp = noisy_actions.permute(0, 2, 1) + + pred_noise = self.noise_pred_net( + sample=noisy_actions, + timestep=timesteps, + global_cond=global_cond + ) + + # 还原维度 (B, T, C) + pred_noise = pred_noise.permute(0, 2, 1) + + # 7. 计算 Loss (MSE) + loss = nn.functional.mse_loss(pred_noise, noise) + return loss + + # ========================== + # 推理阶段 (Inference) + # ========================== + @torch.no_grad() + def predict_action(self, images, proprioception): + B = 1 # 假设单次推理 + + # 1. 提取当前观测特征 (只做一次) + visual_features = self.vision_encoder(images).view(B, -1) + proprioception = proprioception.view(B, -1) + global_cond = torch.cat([visual_features, proprioception], dim=-1) + + # 2. 初始化纯高斯噪声动作 + # Shape: (B, Horizon, Action_Dim) + current_actions = torch.randn( + (B, 16, 7), device=global_cond.device + ) + + # 3. 逐步去噪循环 (Reverse Diffusion) + self.noise_scheduler.set_timesteps(10) # 推理时可以用更少步加速 (如 DDIM) + + for t in self.noise_scheduler.timesteps: + # 调整输入格式适应 1D CNN + model_input = current_actions.permute(0, 2, 1) -# Returns: -# Training: Loss scalar -# Inference: Predicted actions -# """ - -# B, T, C, H, W = images.shape - -# # 1. 图像编码 (Flatten time dimension for efficiency) -# # [B*T, C, H, W] -> [B*T, Vision_Dim] -# flat_images = images.view(B * T, C, H, W) -# vision_feats_dict = self.vlm_backbone(flat_images) -# raw_img_emb = vision_feats_dict['image_embeds'] # [B*T, Vision_Dim] - -# # 投影并还原时间维度 -> [B, T, Embed_Dim] -# img_emb = self.img_projector(raw_img_emb) -# img_emb = img_emb.view(B, T, -1) - -# # 2. 状态编码 -# state_emb = self.state_encoder(state) # [B, T, Embed_Dim] + # 预测噪声 + noise_pred = self.noise_pred_net( + sample=model_input, + timestep=t, + global_cond=global_cond + ) + # noise_pred = noise_pred.permute(0, 2, 1) -# # 3. 特征融合 (这里做一个简单的 Early Fusion 示例) -# # 将图像特征和状态特征在特征维度拼接,或在时间维度拼接 -# # 假设我们只用最近的一帧图像作为 Context,或者将所有历史特征作为 Context -# # 这里演示:Context = (Image_History + State_History) -# # [B, T, Embed] + [B, T, Embed] -> [B, 2*T, Embed] (Concat on time) -# context = torch.cat([img_emb, state_emb], dim=1) - -# # 4. Action Head 分支 -# if actions is not None: -# # --- Training Mode --- -# # 必须返回 Loss -# return self.action_head.compute_loss(context, actions) -# else: -# # --- Inference Mode --- -# # 必须返回预测的动作序列 -# return self.action_head.predict_action(context) \ No newline at end of file + # 移除噪声,更新 current_actions + current_actions = self.noise_scheduler.step( + noise_pred, t, current_actions + ).prev_sample + + # 4. 输出最终动作序列 + return current_actions # 返回去噪后的干净动作 \ No newline at end of file diff --git a/roboimi/vla/conf/agent/resnet_diffusion.yaml b/roboimi/vla/conf/agent/resnet_diffusion.yaml new file mode 100644 index 0000000..6e8a3ab --- /dev/null +++ b/roboimi/vla/conf/agent/resnet_diffusion.yaml @@ -0,0 +1,22 @@ +# @package agent +_target_: roboimi.vla.agent.VLAAgent + +# Vision Backbone: ResNet-18 with SpatialSoftmax +vision_backbone: + _target_: roboimi.vla.models.backbones.resnet.ResNetBackbone + model_name: "microsoft/resnet-18" + freeze: true + +# Action and Observation Dimensions +action_dim: 16 # Robot action dimension +obs_dim: 16 # Proprioception dimension (qpos) + +# Prediction Horizons +pred_horizon: 16 # How many future actions to predict +obs_horizon: 2 # How many historical observations to use + +# Diffusion Parameters +diffusion_steps: 100 # Number of diffusion timesteps for training + +# Camera Configuration +num_cams: 2 # Number of cameras (e.g., r_vis, top) diff --git a/roboimi/vla/conf/backbone/resnet.yaml b/roboimi/vla/conf/backbone/resnet.yaml new file mode 100644 index 0000000..584eddd --- /dev/null +++ b/roboimi/vla/conf/backbone/resnet.yaml @@ -0,0 +1,10 @@ +# @package agent.backbone +_target_: roboimi.vla.models.backbones.resnet.ResNetBackbone + +model_name: "microsoft/resnet-18" +freeze: true + +# Output dimension calculation: +# ResNet-18 final layer has 512 channels +# After SpatialSoftmax: 512 * 2 = 1024 (x,y coordinates per channel) +# output_dim: 1024 diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 89661f2..8b57ad4 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -1,12 +1,13 @@ defaults: - _self_ - - agent: base_siglip - - data: custom_hdf5 # 新增这一行,激活数据配置 + - agent: resnet_diffusion + - data: resnet_dataset train: - batch_size: 4 # 减小 batch size 方便调试 - lr: 1e-4 - max_steps: 10 - log_freq: 10 - device: "cpu" - num_workers: 0 # 调试设为0,验证通过后改为 2 或 4 \ No newline at end of file + batch_size: 8 # Batch size for training + lr: 1e-4 # Learning rate + max_steps: 10000 # Maximum training steps + log_freq: 100 # Log frequency (steps) + save_freq: 1000 # Save checkpoint frequency (steps) + device: "cuda" # Device: "cuda" or "cpu" + num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production) \ No newline at end of file diff --git a/roboimi/vla/conf/data/resnet_dataset.yaml b/roboimi/vla/conf/data/resnet_dataset.yaml new file mode 100644 index 0000000..28145a7 --- /dev/null +++ b/roboimi/vla/conf/data/resnet_dataset.yaml @@ -0,0 +1,18 @@ +# @package data +_target_: roboimi.vla.data.dataset.RobotDiffusionDataset + +# Dataset Directory (CHANGE THIS TO YOUR DATA PATH) +dataset_dir: "roboimi/demos/dataset/sim_transfer" # Path to your dataset directory + +# Horizon Parameters +pred_horizon: 16 # Prediction horizon (matches agent.pred_horizon) +obs_horizon: 2 # Observation horizon (matches agent.obs_horizon) +action_horizon: 8 # Action execution horizon (used during evaluation) + +# Camera Names (CHANGE THIS TO MATCH YOUR CAMERAS) +camera_names: + - r_vis + - top + +# Normalization Type: 'gaussian' (mean/std) or 'min_max' ([-1, 1]) +normalization_type: gaussian diff --git a/roboimi/vla/data/dataset.py b/roboimi/vla/data/dataset.py index 7e286f9..6e9b490 100644 --- a/roboimi/vla/data/dataset.py +++ b/roboimi/vla/data/dataset.py @@ -90,52 +90,48 @@ class RobotDiffusionDataset(Dataset): # 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding) # 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5] - indices = [] - for i in range(self.obs_horizon): - # t - (To - 1) + i - query_ts = start_ts - (self.obs_horizon - 1) + i - # 边界处理 (Padding first frame) - query_ts = max(query_ts, 0) - indices.append(query_ts) - - # 读取 qpos (proprioception) - qpos_data = root['observations/qpos'] - qpos = qpos_data[indices] # smart indexing - if self.stats: - qpos = self._normalize_data(qpos, self.stats['qpos']) + start_idx_raw = start_ts - (self.obs_horizon - 1) + start_idx = max(start_idx_raw, 0) + end_idx = start_ts + 1 + pad_len = max(0, -start_idx_raw) - # 读取 Images - # 你有三个视角: angle, r_vis, top - # 建议将它们分开返回,或者在 Dataset 里 Concat + # Qpos + qpos_data = root['observations/qpos'] + qpos_val = qpos_data[start_idx:end_idx] + + if pad_len > 0: + first_frame = qpos_val[0] + padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0) + qpos_val = np.concatenate([padding, qpos_val], axis=0) + + if self.stats: + qpos_val = self._normalize_data(qpos_val, self.stats['qpos']) + + # Images image_dict = {} for cam_name in self.camera_names: - # HDF5 dataset img_dset = root['observations']['images'][cam_name] + imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C) - imgs = [] - for t in indices: - img = img_dset[t] # (480, 640, 3) uint8 - img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (C, H, W) - imgs.append(img) + if pad_len > 0: + first_frame = imgs_np[0] + padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0) + imgs_np = np.concatenate([padding, imgs_np], axis=0) - # Stack time dimension: (obs_horizon, 3, H, W) - image_dict[cam_name] = torch.stack(imgs) + # 转换为 Tensor: (T, H, W, C) -> (T, C, H, W) + imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0 + imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor) + image_dict[cam_name] = imgs_tensor - # ----------------------------- - # 4. 组装 Batch - # ----------------------------- + # ============================== + # 3. 组装 Batch + # ============================== data_batch = { - 'action': torch.from_numpy(actions).float(), # (Tp, 16) - 'qpos': torch.from_numpy(qpos).float(), # (To, 16) + 'action': torch.from_numpy(actions).float(), + 'qpos': torch.from_numpy(qpos_val).float(), } - # 将图像放入 batch for cam_name, img_tensor in image_dict.items(): - data_batch[f'image_{cam_name}'] = img_tensor # (To, 3, H, W) - - # TODO: 添加 Language Instruction - # 如果所有 episode 共享任务,这里可以是固定 embedding - # 如果每个 episode 任务不同,你需要一个额外的 meta json 来映射 file_path -> text - # data_batch['lang_text'] = "pick up the red cube" + data_batch[f'image_{cam_name}'] = img_tensor return data_batch diff --git a/roboimi/vla/models/backbones/__init__.py b/roboimi/vla/models/backbones/__init__.py index ea22800..2f36dcd 100644 --- a/roboimi/vla/models/backbones/__init__.py +++ b/roboimi/vla/models/backbones/__init__.py @@ -1,9 +1,10 @@ # Backbone models from .siglip import SigLIPBackbone +from .resnet import ResNetBackbone # from .clip import CLIPBackbone # from .dinov2 import DinoV2Backbone -__all__ = ["SigLIPBackbone"] +__all__ = ["SigLIPBackbone", "ResNetBackbone"] # from .debug import DebugBackbone # __all__ = ["DebugBackbone"] \ No newline at end of file diff --git a/roboimi/vla/models/backbones/clip.py b/roboimi/vla/models/backbones/clip.py deleted file mode 100644 index c30ac7f..0000000 --- a/roboimi/vla/models/backbones/clip.py +++ /dev/null @@ -1 +0,0 @@ -# CLIP Backbone 实现 diff --git a/roboimi/vla/models/backbones/debug.py b/roboimi/vla/models/backbones/debug.py deleted file mode 100644 index 4c85b98..0000000 --- a/roboimi/vla/models/backbones/debug.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -import torch.nn as nn -from typing import Dict -from roboimi.vla.core.interfaces import VLABackbone - -class DebugBackbone(VLABackbone): - """ - A fake backbone that outputs random tensors. - """ - def __init__(self, embed_dim: int = 768, seq_len: int = 10): - super().__init__() - self._embed_dim = embed_dim - self.seq_len = seq_len - # A dummy trainable parameter - self.dummy_param = nn.Parameter(torch.zeros(1)) - - def forward(self, obs: Dict[str, torch.Tensor]) -> torch.Tensor: - batch_size = obs['image'].shape[0] - - # 1. Generate random noise - noise = torch.randn(batch_size, self.seq_len, self._embed_dim, device=obs['image'].device) - - # 2. CRITICAL FIX: Add the dummy parameter to the noise. - # This connects 'noise' to 'self.dummy_param' in the computation graph. - # The value doesn't change (since param is 0), but the gradient path is established. - return noise + self.dummy_param - - @property - def embed_dim(self) -> int: - return self._embed_dim \ No newline at end of file diff --git a/roboimi/vla/models/backbones/dinov2.py b/roboimi/vla/models/backbones/dinov2.py deleted file mode 100644 index acba66c..0000000 --- a/roboimi/vla/models/backbones/dinov2.py +++ /dev/null @@ -1 +0,0 @@ -# DinoV2 Backbone 实现 diff --git a/roboimi/vla/models/backbones/resnet.py b/roboimi/vla/models/backbones/resnet.py new file mode 100644 index 0000000..dca2fa1 --- /dev/null +++ b/roboimi/vla/models/backbones/resnet.py @@ -0,0 +1,83 @@ +from roboimi.vla.core.interfaces import VLABackbone +from transformers import ResNetModel +from torchvision import transforms +import torch +import torch.nn as nn + +class ResNetBackbone(VLABackbone): + def __init__( + self, + model_name = "microsoft/resnet-18", + freeze: bool = True, + ): + super().__init__() + self.model = ResNetModel.from_pretrained(model_name) + self.out_channels = self.model.config.hidden_sizes[-1] + self.transform = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.spatial_softmax = SpatialSoftmax(num_rows=12, num_cols=12) + if freeze: + self._freeze_parameters() + + def _freeze_parameters(self): + print("❄️ Freezing ResNet Backbone parameters") + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + def forward_single_image(self, image): + B, T, C, H, W = image.shape + image = image.view(B * T, C, H, W) + image = self.transform(image) + feature_map = self.model(image).last_hidden_state # (B*T, D, H', W') + features = self.spatial_softmax(feature_map) # (B*T, D*2) + return features + + def forward(self, images): + any_tensor = next(iter(images.values())) + B, T = any_tensor.shape[:2] + features_all = [] + sorted_cam_names = sorted(images.keys()) + for cam_name in sorted_cam_names: + img = images[cam_name] + features = self.forward_single_image(img) # (B*T, D*2) + features_all.append(features) + combined_features = torch.cat(features_all, dim=1) # (B*T, Num_Cams*D*2) + return combined_features.view(B, T, -1) + + @property + def output_dim(self): + """Output dimension after spatial softmax: out_channels * 2""" + return self.out_channels * 2 + +class SpatialSoftmax(nn.Module): + """ + 将特征图 (N, C, H, W) 转换为坐标特征 (N, C*2) + """ + def __init__(self, num_rows, num_cols, temperature=None): + super().__init__() + self.temperature = nn.Parameter(torch.ones(1)) + # 创建网格坐标 + pos_x, pos_y = torch.meshgrid( + torch.linspace(-1, 1, num_rows), + torch.linspace(-1, 1, num_cols), + indexing='ij' + ) + self.register_buffer('pos_x', pos_x.reshape(-1)) + self.register_buffer('pos_y', pos_y.reshape(-1)) + + def forward(self, x): + N, C, H, W = x.shape + x = x.view(N, C, -1) # (N, C, H*W) + + # 计算 Softmax 注意力图 + softmax_attention = torch.nn.functional.softmax(x / self.temperature, dim=2) + + # 计算期望坐标 (x, y) + expected_x = torch.sum(softmax_attention * self.pos_x, dim=2, keepdim=True) + expected_y = torch.sum(softmax_attention * self.pos_y, dim=2, keepdim=True) + + # 拼接并展平 -> (N, C*2) + return torch.cat([expected_x, expected_y], dim=2).reshape(N, -1) \ No newline at end of file diff --git a/roboimi/vla/models/heads/__init__.py b/roboimi/vla/models/heads/__init__.py index 42f28b2..4260dba 100644 --- a/roboimi/vla/models/heads/__init__.py +++ b/roboimi/vla/models/heads/__init__.py @@ -1,9 +1,8 @@ # # Action Head models -from .diffusion import DiffusionHead +from .diffusion import ConditionalUnet1D # from .act import ACTHead -__all__ = ["DiffusionHead"] +__all__ = ["ConditionalUnet1D"] # from .debug import DebugHead - # __all__ = ["DebugHead"] \ No newline at end of file diff --git a/roboimi/vla/models/heads/act.py b/roboimi/vla/models/heads/act.py deleted file mode 100644 index 1860fe4..0000000 --- a/roboimi/vla/models/heads/act.py +++ /dev/null @@ -1 +0,0 @@ -# ACT-VAE Action Head 实现 diff --git a/roboimi/vla/models/heads/debug.py b/roboimi/vla/models/heads/debug.py deleted file mode 100644 index 49f0924..0000000 --- a/roboimi/vla/models/heads/debug.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.nn as nn -from typing import Dict, Optional -from roboimi.vla.core.interfaces import VLAHead - -class DebugHead(VLAHead): - """ - A fake Action Head using MSE Loss. - Replaces complex Diffusion/ACT policies for architecture verification. - """ - def __init__(self, input_dim: int, action_dim: int, chunk_size: int = 16): - super().__init__() - # Simple regression from embedding -> action chunk - self.regressor = nn.Linear(input_dim, chunk_size * action_dim) - self.action_dim = action_dim - self.chunk_size = chunk_size - self.loss_fn = nn.MSELoss() - - def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: - # Simple pooling over sequence dimension to get (B, Hidden) - pooled_embed = embeddings.mean(dim=1) - - # Predict actions: (B, Chunk * Act_Dim) -> (B, Chunk, Act_Dim) - pred_flat = self.regressor(pooled_embed) - pred_actions = pred_flat.view(-1, self.chunk_size, self.action_dim) - - output = {"pred_actions": pred_actions} - - if actions is not None: - # Calculate MSE Loss against ground truth - output["loss"] = self.loss_fn(pred_actions, actions) - - return output \ No newline at end of file diff --git a/roboimi/vla/models/heads/diffusion.py b/roboimi/vla/models/heads/diffusion.py index adb1e60..6233658 100644 --- a/roboimi/vla/models/heads/diffusion.py +++ b/roboimi/vla/models/heads/diffusion.py @@ -5,170 +5,290 @@ from typing import Dict, Optional from diffusers import DDPMScheduler from roboimi.vla.core.interfaces import VLAHead -class DiffusionHead(VLAHead): - def __init__( - self, - input_dim: int, # 来自 Projector 的维度 (e.g. 384) - action_dim: int, # 动作维度 (e.g. 16) - chunk_size: int, # 预测视界 (e.g. 16) - n_timesteps: int = 100, # 扩散步数 - hidden_dim: int = 256 - ): +from typing import Union +import logging +import torch +import torch.nn as nn +import einops + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.layers.torch import Rearrange +import math + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): super().__init__() - self.action_dim = action_dim - self.chunk_size = chunk_size - - # 1. 噪声调度器 (DDPM) - self.scheduler = DDPMScheduler( - num_train_timesteps=n_timesteps, - beta_schedule='squaredcos_cap_v2', # 现代 Diffusion 常用调度 - clip_sample=True, - prediction_type='epsilon' # 预测噪声 - ) + self.dim = dim - # 2. 噪声预测网络 (Noise Predictor Network) - # 输入: Noisy Action + Time Embedding + Image Embedding - # 这是一个简单的 Conditional MLP/ResNet 结构 - self.time_emb = nn.Sequential( - nn.Linear(1, hidden_dim), + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + # Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + # Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), - nn.Linear(hidden_dim, hidden_dim) ) - - self.cond_proj = nn.Linear(input_dim, hidden_dim) # 把图像特征投影一下 - - # 主干网络 (由几个 Residual Block 组成) - self.mid_layers = nn.ModuleList([ - nn.Sequential( - nn.Linear(hidden_dim + action_dim * chunk_size, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.Mish(), - nn.Linear(hidden_dim, hidden_dim + action_dim * chunk_size) # 简单的残差 - ) for _ in range(3) + + def forward(self, x): + return self.block(x) + +class ConditionalResidualBlock1D(nn.Module): + def __init__(self, + in_channels, + out_channels, + cond_dim, + kernel_size=3, + n_groups=8, + cond_predict_scale=False): + super().__init__() + self.blocks = nn.ModuleList([ + Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), + Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), ]) - - # 输出层: 预测噪声 (Shape 与 Action 相同) - self.final_layer = nn.Linear(hidden_dim + action_dim * chunk_size, action_dim * chunk_size) - def forward(self, embeddings: torch.Tensor, actions: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: - """ - Unified interface for Training and Inference. - """ - device = embeddings.device - - # --- 1. 处理条件 (Conditioning) --- - # embeddings: (B, Seq, Dim). 我们这里做一个简化,做 Average Pooling 变成 (B, Dim) - # 如果你想做更复杂的 Cross-Attention,可以在这里改 - global_cond = embeddings.mean(dim=1) - cond_feat = self.cond_proj(global_cond) # (B, Hidden) - # ========================================= - # 分支 A: 训练模式 (Training) - # ========================================= - if actions is not None: - batch_size = actions.shape[0] - - # 1.1 准备数据 (Flatten: B, Chunk, ActDim -> B, Chunk*ActDim) - actions_flat = actions.view(batch_size, -1) - - # 1.2 采样噪声和时间步 - noise = torch.randn_like(actions_flat) - timesteps = torch.randint( - 0, self.scheduler.config.num_train_timesteps, - (batch_size,), device=device - ).long() - - # 1.3 加噪 (Forward Diffusion) - noisy_actions = self.scheduler.add_noise(actions_flat, noise, timesteps) - - # 1.4 预测噪声 (Network Forward) - pred_noise = self._predict_noise(noisy_actions, timesteps, cond_feat) - - # 1.5 计算 Loss (MSE between actual noise and predicted noise) - loss = nn.functional.mse_loss(pred_noise, noise) - - return {"loss": loss} - # ========================================= - # 分支 B: 推理模式 (Inference) - # ========================================= + cond_channels = out_channels + if cond_predict_scale: + cond_channels = out_channels * 2 + self.cond_predict_scale = cond_predict_scale + self.out_channels = out_channels + self.cond_encoder = nn.Sequential( + nn.Mish(), + nn.Linear(cond_dim, cond_channels), + Rearrange('batch t -> batch t 1'), + ) + + # make sure dimensions compatible + self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ + if in_channels != out_channels else nn.Identity() + + def forward(self, x, cond): + ''' + x : [ batch_size x in_channels x horizon ] + cond : [ batch_size x cond_dim] + + returns: + out : [ batch_size x out_channels x horizon ] + ''' + out = self.blocks[0](x) + embed = self.cond_encoder(cond) + if self.cond_predict_scale: + embed = embed.reshape( + embed.shape[0], 2, self.out_channels, 1) + scale = embed[:,0,...] + bias = embed[:,1,...] + out = scale * out + bias else: - batch_size = embeddings.shape[0] - - # 2.1 从纯高斯噪声开始 - noisy_actions = torch.randn( - batch_size, self.chunk_size * self.action_dim, - device=device - ) - - # 2.2 逐步去噪 (Reverse Diffusion Loop) - # 使用 scheduler.timesteps 自动处理步长 - self.scheduler.set_timesteps(self.scheduler.config.num_train_timesteps) - - for t in self.scheduler.timesteps: - # 构造 batch 的 t - timesteps = torch.tensor([t], device=device).repeat(batch_size) - - # 预测噪声 - # 注意:diffusers 的 step 需要 model_output - model_output = self._predict_noise(noisy_actions, timesteps, cond_feat) - - # 移除噪声 (Step) - noisy_actions = self.scheduler.step( - model_output, t, noisy_actions - ).prev_sample + out = out + embed + out = self.blocks[1](out) + out = out + self.residual_conv(x) + return out - # 2.3 Reshape 回 (B, Chunk, ActDim) - pred_actions = noisy_actions.view(batch_size, self.chunk_size, self.action_dim) - - return {"pred_actions": pred_actions} - def _predict_noise(self, noisy_actions, timesteps, cond_feat): - """内部辅助函数:运行简单的 MLP 网络""" - # Time Embed - t_emb = self.time_emb(timesteps.float().unsqueeze(-1)) # (B, Hidden) +class ConditionalUnet1D(nn.Module): + def __init__(self, + input_dim, + local_cond_dim=None, + global_cond_dim=None, + diffusion_step_embed_dim=256, + down_dims=[256,512,1024], + kernel_size=3, + n_groups=8, + cond_predict_scale=False + ): + super().__init__() + all_dims = [input_dim] + list(down_dims) + start_dim = down_dims[0] + + dsed = diffusion_step_embed_dim + diffusion_step_encoder = nn.Sequential( + SinusoidalPosEmb(dsed), + nn.Linear(dsed, dsed * 4), + nn.Mish(), + nn.Linear(dsed * 4, dsed), + ) + cond_dim = dsed + if global_cond_dim is not None: + cond_dim += global_cond_dim + + in_out = list(zip(all_dims[:-1], all_dims[1:])) + + local_cond_encoder = None + if local_cond_dim is not None: + _, dim_out = in_out[0] + dim_in = local_cond_dim + local_cond_encoder = nn.ModuleList([ + # down encoder + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale), + # up encoder + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale) + ]) + + mid_dim = all_dims[-1] + self.mid_modules = nn.ModuleList([ + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale + ), + ConditionalResidualBlock1D( + mid_dim, mid_dim, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale + ), + ]) + + down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + down_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_in, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale), + ConditionalResidualBlock1D( + dim_out, dim_out, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + up_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + up_modules.append(nn.ModuleList([ + ConditionalResidualBlock1D( + dim_out*2, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale), + ConditionalResidualBlock1D( + dim_in, dim_in, cond_dim=cond_dim, + kernel_size=kernel_size, n_groups=n_groups, + cond_predict_scale=cond_predict_scale), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) - # Fusion: Concat Action + (Condition * Time) - # 这里用简单的相加融合,实际可以更复杂 - fused_feat = cond_feat + t_emb + final_conv = nn.Sequential( + Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), + nn.Conv1d(start_dim, input_dim, 1), + ) + + self.diffusion_step_encoder = diffusion_step_encoder + self.local_cond_encoder = local_cond_encoder + self.up_modules = up_modules + self.down_modules = down_modules + self.final_conv = final_conv + + + def forward(self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + local_cond=None, global_cond=None, **kwargs): + """ + x: (B,T,input_dim) + timestep: (B,) or int, diffusion step + local_cond: (B,T,local_cond_dim) + global_cond: (B,global_cond_dim) + output: (B,T,input_dim) + """ + sample = einops.rearrange(sample, 'b h t -> b t h') + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + global_feature = self.diffusion_step_encoder(timesteps) + + if global_cond is not None: + global_feature = torch.cat([ + global_feature, global_cond + ], axis=-1) - # Concat input - x = torch.cat([noisy_actions, fused_feat], dim=-1) # 注意这里维度需要对齐,或者用 MLP 映射 + # encode local features + h_local = list() + if local_cond is not None: + local_cond = einops.rearrange(local_cond, 'b h t -> b t h') + resnet, resnet2 = self.local_cond_encoder + x = resnet(local_cond, global_feature) + h_local.append(x) + x = resnet2(local_cond, global_feature) + h_local.append(x) - # 修正:上面的 concat 维度可能不对,为了简化代码,我们用一种更简单的方式: - # 将 cond_feat 加到 input 里需要维度匹配。 - # 这里重写一个极简的 Forward: - - # 正确做法:先将 x 映射到 hidden,再加 t_emb 和 cond_feat - # 但为了复用 self.mid_layers 定义的 Linear(Hidden + Input)... - # 我们用最傻瓜的方式:Input = Action,Condition 直接拼接到每一层或者只拼输入 - - # 让我们修正一下网络结构逻辑,确保不报错: - # Input: NoisyAction (Dim_A) - # Cond: Hidden (Dim_H) - - # 这种临时写的 MLP 容易维度不匹配,我们改用一个极其稳健的计算流: - # x = Action - # h = Cond + Time - # input = cat([x, h]) -> Linear -> Output - - # 重新定义 _predict_noise 的逻辑依赖于 __init__ 里的定义。 - # 为了保证一次跑通,我使用动态 cat: - - x = noisy_actions - # 假设 mid_layers 的输入是 hidden_dim + action_flat_dim - # 我们把 condition 映射成 hidden_dim,然后 concat - - # 真正的计算流: - h = cond_feat + t_emb # (B, Hidden) - - # 把 h 拼接到 x 上 (前提是 x 是 action flat) - # Linear 输入维度是 Hidden + ActFlat - model_input = torch.cat([h, x], dim=-1) - - for layer in self.mid_layers: - # Residual connection mechanism - out = layer(model_input) - model_input = out + model_input # Simple ResNet - - return self.final_layer(model_input) \ No newline at end of file + x = sample + h = [] + for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + x = resnet(x, global_feature) + if idx == 0 and len(h_local) > 0: + x = x + h_local[0] + x = resnet2(x, global_feature) + h.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, global_feature) + # The correct condition should be: + # if idx == (len(self.up_modules)-1) and len(h_local) > 0: + # However this change will break compatibility with published checkpoints. + # Therefore it is left as a comment. + if idx == len(self.up_modules) and len(h_local) > 0: + x = x + h_local[1] + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, 'b t h -> b h t') + return x +