Files
roboimi/roboimi/demos/vla_scripts/train_vla.py

267 lines
10 KiB
Python

import sys
import os
import logging
import json
import pickle
import hydra
import torch
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.optim import AdamW
from pathlib import Path
# Ensure correct import path
sys.path.append(os.getcwd())
from hydra.utils import instantiate
log = logging.getLogger(__name__)
def recursive_to_device(data, device):
"""
Recursively move nested dictionaries/lists of tensors to specified device.
Args:
data: Dictionary, list, or tensor
device: Target device (e.g., 'cuda', 'cpu')
Returns:
Data structure with all tensors moved to device
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: recursive_to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [recursive_to_device(v, device) for v in data]
return data
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig):
"""
VLA Training Script with ResNet Backbone and Diffusion Policy.
This script:
1. Loads dataset from HDF5 files
2. Instantiates VLAAgent with ResNet vision encoder
3. Trains diffusion-based action prediction
4. Saves checkpoints periodically
"""
# Print configuration
print("=" * 80)
print("VLA Training Configuration:")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
# Create checkpoint directory
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
# =========================================================================
# 1. Instantiate Dataset & DataLoader
# =========================================================================
log.info("📦 Loading dataset...")
try:
dataset = instantiate(cfg.data)
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
except Exception as e:
log.error(f"❌ Failed to load dataset: {e}")
raise
dataloader = DataLoader(
dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
drop_last=True # Drop incomplete batches for stable training
)
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
# =========================================================================
# 2. Instantiate VLA Agent
# =========================================================================
log.info("🤖 Initializing VLA Agent...")
try:
agent = instantiate(cfg.agent)
agent.to(cfg.train.device)
agent.train()
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
log.info(f"📊 Total parameters: {total_params:,}")
log.info(f"📊 Trainable parameters: {trainable_params:,}")
except Exception as e:
log.error(f"❌ Failed to initialize agent: {e}")
raise
# =========================================================================
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
# =========================================================================
log.info("💾 Loading dataset statistics...")
dataset_stats = None
try:
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
stats_path = Path(dataset_dir) / 'data_stats.pkl'
if stats_path.exists():
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
dataset_stats = {
'normalization_type': cfg.data.get('normalization_type', 'gaussian'),
'action_mean': stats['action']['mean'].tolist(),
'action_std': stats['action']['std'].tolist(),
'action_min': stats['action']['min'].tolist(),
'action_max': stats['action']['max'].tolist(),
'qpos_mean': stats['qpos']['mean'].tolist(),
'qpos_std': stats['qpos']['std'].tolist(),
'qpos_min': stats['qpos']['min'].tolist(),
'qpos_max': stats['qpos']['max'].tolist(),
}
log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
else:
log.warning(f"⚠️ Statistics file not found: {stats_path}")
log.warning("⚠️ Actions will not be denormalized during inference!")
except Exception as e:
log.warning(f"⚠️ Failed to load statistics: {e}")
log.warning("⚠️ Training will continue, but inference may not work correctly")
# =========================================================================
# 3. Setup Optimizer
# =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
# =========================================================================
# 4. Training Loop
# =========================================================================
log.info("🏋️ Starting training loop...")
data_iter = iter(dataloader)
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
best_loss = float('inf')
for step in pbar:
try:
batch = next(data_iter)
except StopIteration:
# Restart iterator when epoch ends
data_iter = iter(dataloader)
batch = next(data_iter)
# =====================================================================
# Move batch to device
# =====================================================================
batch = recursive_to_device(batch, cfg.train.device)
# =====================================================================
# Prepare agent input
# =====================================================================
# Dataset returns: {action, qpos, image_<cam_name>, ...}
# Agent expects: {images: dict, qpos: tensor, action: tensor}
# Extract images into a dictionary
images = {}
for cam_name in cfg.data.camera_names:
key = f"image_{cam_name}"
if key in batch:
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
# Prepare agent input
agent_input = {
'images': images, # Dict of camera images
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
'action': batch['action'] # (B, pred_horizon, action_dim)
}
# =====================================================================
# Forward pass & compute loss
# =====================================================================
try:
loss = agent.compute_loss(agent_input)
except Exception as e:
log.error(f"❌ Forward pass failed at step {step}: {e}")
raise
# =====================================================================
# Backward pass & optimization
# =====================================================================
optimizer.zero_grad()
loss.backward()
# Gradient clipping for stable training
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step()
# =====================================================================
# Logging
# =====================================================================
if step % cfg.train.log_freq == 0:
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"best_loss": f"{best_loss:.4f}"
})
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
# =====================================================================
# Checkpoint saving
# =====================================================================
if step > 0 and step % cfg.train.save_freq == 0:
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
# Save best model
if loss.item() < best_loss:
best_loss = loss.item()
best_model_path = checkpoint_dir / "vla_model_best.pt"
torch.save({
'step': step,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
# =========================================================================
# 5. Save Final Model
# =========================================================================
final_model_path = checkpoint_dir / "vla_model_final.pt"
torch.save({
'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'dataset_stats': dataset_stats,
}, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}")
log.info("✅ Training completed successfully!")
log.info(f"📊 Final Loss: {loss.item():.4f}")
log.info(f"📊 Best Loss: {best_loss:.4f}")
if __name__ == "__main__":
main()