267 lines
10 KiB
Python
267 lines
10 KiB
Python
import sys
|
|
import os
|
|
import logging
|
|
import json
|
|
import pickle
|
|
import hydra
|
|
import torch
|
|
from tqdm import tqdm
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from pathlib import Path
|
|
|
|
# Ensure correct import path
|
|
sys.path.append(os.getcwd())
|
|
|
|
from hydra.utils import instantiate
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def recursive_to_device(data, device):
|
|
"""
|
|
Recursively move nested dictionaries/lists of tensors to specified device.
|
|
|
|
Args:
|
|
data: Dictionary, list, or tensor
|
|
device: Target device (e.g., 'cuda', 'cpu')
|
|
|
|
Returns:
|
|
Data structure with all tensors moved to device
|
|
"""
|
|
if isinstance(data, torch.Tensor):
|
|
return data.to(device)
|
|
elif isinstance(data, dict):
|
|
return {k: recursive_to_device(v, device) for k, v in data.items()}
|
|
elif isinstance(data, list):
|
|
return [recursive_to_device(v, device) for v in data]
|
|
return data
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
|
def main(cfg: DictConfig):
|
|
"""
|
|
VLA Training Script with ResNet Backbone and Diffusion Policy.
|
|
|
|
This script:
|
|
1. Loads dataset from HDF5 files
|
|
2. Instantiates VLAAgent with ResNet vision encoder
|
|
3. Trains diffusion-based action prediction
|
|
4. Saves checkpoints periodically
|
|
"""
|
|
|
|
# Print configuration
|
|
print("=" * 80)
|
|
print("VLA Training Configuration:")
|
|
print("=" * 80)
|
|
print(OmegaConf.to_yaml(cfg))
|
|
print("=" * 80)
|
|
|
|
log.info(f"🚀 Starting VLA Training (Device: {cfg.train.device})")
|
|
|
|
# Create checkpoint directory
|
|
checkpoint_dir = Path("checkpoints")
|
|
checkpoint_dir.mkdir(exist_ok=True)
|
|
|
|
# =========================================================================
|
|
# 1. Instantiate Dataset & DataLoader
|
|
# =========================================================================
|
|
log.info("📦 Loading dataset...")
|
|
try:
|
|
dataset = instantiate(cfg.data)
|
|
log.info(f"✅ Dataset loaded successfully. Total samples: {len(dataset)}")
|
|
except Exception as e:
|
|
log.error(f"❌ Failed to load dataset: {e}")
|
|
raise
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=cfg.train.batch_size,
|
|
shuffle=True,
|
|
num_workers=cfg.train.num_workers,
|
|
pin_memory=(cfg.train.device != "cpu"),
|
|
drop_last=True # Drop incomplete batches for stable training
|
|
)
|
|
log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}")
|
|
|
|
# =========================================================================
|
|
# 2. Instantiate VLA Agent
|
|
# =========================================================================
|
|
log.info("🤖 Initializing VLA Agent...")
|
|
try:
|
|
agent = instantiate(cfg.agent)
|
|
agent.to(cfg.train.device)
|
|
agent.train()
|
|
log.info(f"✅ Agent initialized and moved to {cfg.train.device}")
|
|
|
|
# Count parameters
|
|
total_params = sum(p.numel() for p in agent.parameters())
|
|
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
|
|
log.info(f"📊 Total parameters: {total_params:,}")
|
|
log.info(f"📊 Trainable parameters: {trainable_params:,}")
|
|
|
|
except Exception as e:
|
|
log.error(f"❌ Failed to initialize agent: {e}")
|
|
raise
|
|
|
|
# =========================================================================
|
|
# 2.5. Load Dataset Statistics (will be saved into checkpoints)
|
|
# =========================================================================
|
|
log.info("💾 Loading dataset statistics...")
|
|
dataset_stats = None
|
|
try:
|
|
dataset_dir = cfg.data.get('dataset_dir', 'roboimi/demos/dataset/sim_transfer')
|
|
stats_path = Path(dataset_dir) / 'data_stats.pkl'
|
|
|
|
if stats_path.exists():
|
|
with open(stats_path, 'rb') as f:
|
|
stats = pickle.load(f)
|
|
|
|
dataset_stats = {
|
|
'normalization_type': cfg.data.get('normalization_type', 'gaussian'),
|
|
'action_mean': stats['action']['mean'].tolist(),
|
|
'action_std': stats['action']['std'].tolist(),
|
|
'action_min': stats['action']['min'].tolist(),
|
|
'action_max': stats['action']['max'].tolist(),
|
|
'qpos_mean': stats['qpos']['mean'].tolist(),
|
|
'qpos_std': stats['qpos']['std'].tolist(),
|
|
'qpos_min': stats['qpos']['min'].tolist(),
|
|
'qpos_max': stats['qpos']['max'].tolist(),
|
|
}
|
|
log.info(f"✅ Dataset statistics loaded (normalization: {dataset_stats['normalization_type']})")
|
|
else:
|
|
log.warning(f"⚠️ Statistics file not found: {stats_path}")
|
|
log.warning("⚠️ Actions will not be denormalized during inference!")
|
|
|
|
except Exception as e:
|
|
log.warning(f"⚠️ Failed to load statistics: {e}")
|
|
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
|
|
|
# =========================================================================
|
|
# 3. Setup Optimizer
|
|
# =========================================================================
|
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
|
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
|
|
|
# =========================================================================
|
|
# 4. Training Loop
|
|
# =========================================================================
|
|
log.info("🏋️ Starting training loop...")
|
|
|
|
data_iter = iter(dataloader)
|
|
pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100)
|
|
|
|
best_loss = float('inf')
|
|
|
|
for step in pbar:
|
|
try:
|
|
batch = next(data_iter)
|
|
except StopIteration:
|
|
# Restart iterator when epoch ends
|
|
data_iter = iter(dataloader)
|
|
batch = next(data_iter)
|
|
|
|
# =====================================================================
|
|
# Move batch to device
|
|
# =====================================================================
|
|
batch = recursive_to_device(batch, cfg.train.device)
|
|
|
|
# =====================================================================
|
|
# Prepare agent input
|
|
# =====================================================================
|
|
# Dataset returns: {action, qpos, image_<cam_name>, ...}
|
|
# Agent expects: {images: dict, qpos: tensor, action: tensor}
|
|
|
|
# Extract images into a dictionary
|
|
images = {}
|
|
for cam_name in cfg.data.camera_names:
|
|
key = f"image_{cam_name}"
|
|
if key in batch:
|
|
images[cam_name] = batch[key] # (B, obs_horizon, C, H, W)
|
|
|
|
# Prepare agent input
|
|
agent_input = {
|
|
'images': images, # Dict of camera images
|
|
'qpos': batch['qpos'], # (B, obs_horizon, obs_dim)
|
|
'action': batch['action'] # (B, pred_horizon, action_dim)
|
|
}
|
|
|
|
# =====================================================================
|
|
# Forward pass & compute loss
|
|
# =====================================================================
|
|
try:
|
|
loss = agent.compute_loss(agent_input)
|
|
except Exception as e:
|
|
log.error(f"❌ Forward pass failed at step {step}: {e}")
|
|
raise
|
|
|
|
# =====================================================================
|
|
# Backward pass & optimization
|
|
# =====================================================================
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# Gradient clipping for stable training
|
|
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
|
|
|
optimizer.step()
|
|
|
|
# =====================================================================
|
|
# Logging
|
|
# =====================================================================
|
|
if step % cfg.train.log_freq == 0:
|
|
pbar.set_postfix({
|
|
"loss": f"{loss.item():.4f}",
|
|
"best_loss": f"{best_loss:.4f}"
|
|
})
|
|
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
|
|
|
# =====================================================================
|
|
# Checkpoint saving
|
|
# =====================================================================
|
|
if step > 0 and step % cfg.train.save_freq == 0:
|
|
checkpoint_path = checkpoint_dir / f"vla_model_step_{step}.pt"
|
|
torch.save({
|
|
'step': step,
|
|
'model_state_dict': agent.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': loss.item(),
|
|
'dataset_stats': dataset_stats,
|
|
}, checkpoint_path)
|
|
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
|
|
|
# Save best model
|
|
if loss.item() < best_loss:
|
|
best_loss = loss.item()
|
|
best_model_path = checkpoint_dir / "vla_model_best.pt"
|
|
torch.save({
|
|
'step': step,
|
|
'model_state_dict': agent.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': loss.item(),
|
|
'dataset_stats': dataset_stats,
|
|
}, best_model_path)
|
|
log.info(f"🌟 Best model updated: {best_model_path} (loss: {best_loss:.4f})")
|
|
|
|
# =========================================================================
|
|
# 5. Save Final Model
|
|
# =========================================================================
|
|
final_model_path = checkpoint_dir / "vla_model_final.pt"
|
|
torch.save({
|
|
'step': cfg.train.max_steps,
|
|
'model_state_dict': agent.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': loss.item(),
|
|
'dataset_stats': dataset_stats,
|
|
}, final_model_path)
|
|
log.info(f"💾 Final model saved: {final_model_path}")
|
|
|
|
log.info("✅ Training completed successfully!")
|
|
log.info(f"📊 Final Loss: {loss.item():.4f}")
|
|
log.info(f"📊 Best Loss: {best_loss:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|