diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index 8264b28..a87e991 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -32,6 +32,10 @@ sys.path.append(os.getcwd()) log = logging.getLogger(__name__) +# Register resolver for list length in configs (e.g., ${len:${data.camera_names}}) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + class VLAEvaluator: """ diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 348d8fd..f7c8e57 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -7,7 +7,7 @@ import hydra import torch from tqdm import tqdm from omegaconf import DictConfig, OmegaConf -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from torch.optim import AdamW from pathlib import Path @@ -18,6 +18,10 @@ from hydra.utils import instantiate log = logging.getLogger(__name__) +# Register resolver for list length in configs (e.g., ${len:${data.camera_names}}) +if not OmegaConf.has_resolver("len"): + OmegaConf.register_new_resolver("len", lambda x: len(x)) + def recursive_to_device(data, device): """ @@ -75,15 +79,45 @@ def main(cfg: DictConfig): log.error(f"❌ Failed to load dataset: {e}") raise - dataloader = DataLoader( - dataset, + # Train/Val split + val_split = float(cfg.train.get('val_split', 0.1)) + seed = int(cfg.train.get('seed', 42)) + val_size = int(len(dataset) * val_split) + train_size = len(dataset) - val_size + if val_size > 0: + train_dataset, val_dataset = random_split( + dataset, + [train_size, val_size], + generator=torch.Generator().manual_seed(seed) + ) + log.info(f"✅ Dataset split: train={train_size}, val={val_size} (val_split={val_split})") + else: + train_dataset, val_dataset = dataset, None + log.info("✅ Dataset split: train=all, val=0 (val_split=0)") + + train_loader = DataLoader( + train_dataset, batch_size=cfg.train.batch_size, shuffle=True, num_workers=cfg.train.num_workers, pin_memory=(cfg.train.device != "cpu"), drop_last=True # Drop incomplete batches for stable training ) - log.info(f"✅ DataLoader created. Batches per epoch: {len(dataloader)}") + + val_loader = None + if val_dataset is not None: + val_loader = DataLoader( + val_dataset, + batch_size=cfg.train.batch_size, + shuffle=False, + num_workers=cfg.train.num_workers, + pin_memory=(cfg.train.device != "cpu"), + drop_last=False + ) + + log.info(f"✅ Train loader batches per epoch: {len(train_loader)}") + if val_loader is not None: + log.info(f"✅ Val loader batches per epoch: {len(val_loader)}") # ========================================================================= # 2. Instantiate VLA Agent @@ -149,7 +183,36 @@ def main(cfg: DictConfig): # ========================================================================= log.info("🏋️ Starting training loop...") - data_iter = iter(dataloader) + def build_agent_input(batch_data): + images = {} + for cam_name in cfg.data.camera_names: + key = f"image_{cam_name}" + if key in batch_data: + images[cam_name] = batch_data[key] + + return { + 'images': images, + 'qpos': batch_data['qpos'], + 'action': batch_data['action'] + } + + def run_validation(): + if val_loader is None: + return None + agent.eval() + total_loss = 0.0 + num_batches = 0 + with torch.no_grad(): + for val_batch in val_loader: + val_batch = recursive_to_device(val_batch, cfg.train.device) + val_input = build_agent_input(val_batch) + val_loss = agent.compute_loss(val_input) + total_loss += val_loss.item() + num_batches += 1 + agent.train() + return total_loss / max(num_batches, 1) + + data_iter = iter(train_loader) pbar = tqdm(range(cfg.train.max_steps), desc="Training", ncols=100) best_loss = float('inf') @@ -159,7 +222,7 @@ def main(cfg: DictConfig): batch = next(data_iter) except StopIteration: # Restart iterator when epoch ends - data_iter = iter(dataloader) + data_iter = iter(train_loader) batch = next(data_iter) # ===================================================================== @@ -173,19 +236,8 @@ def main(cfg: DictConfig): # Dataset returns: {action, qpos, image_, ...} # Agent expects: {images: dict, qpos: tensor, action: tensor} - # Extract images into a dictionary - images = {} - for cam_name in cfg.data.camera_names: - key = f"image_{cam_name}" - if key in batch: - images[cam_name] = batch[key] # (B, obs_horizon, C, H, W) - # Prepare agent input - agent_input = { - 'images': images, # Dict of camera images - 'qpos': batch['qpos'], # (B, obs_horizon, obs_dim) - 'action': batch['action'] # (B, pred_horizon, action_dim) - } + agent_input = build_agent_input(batch) # ===================================================================== # Forward pass & compute loss @@ -217,6 +269,15 @@ def main(cfg: DictConfig): }) log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") + # ===================================================================== + # Validation + # ===================================================================== + val_freq = int(cfg.train.get('val_freq', cfg.train.log_freq)) + if val_loader is not None and val_freq > 0 and step % val_freq == 0: + val_loss = run_validation() + if val_loss is not None: + log.info(f"Step {step}/{cfg.train.max_steps} | Val Loss: {val_loss:.4f}") + # ===================================================================== # Checkpoint saving # ===================================================================== diff --git a/roboimi/vla/conf/agent/resnet_diffusion.yaml b/roboimi/vla/conf/agent/resnet_diffusion.yaml index b1b3d8f..0ab1a0c 100644 --- a/roboimi/vla/conf/agent/resnet_diffusion.yaml +++ b/roboimi/vla/conf/agent/resnet_diffusion.yaml @@ -19,4 +19,4 @@ obs_horizon: 2 diffusion_steps: 100 # Number of diffusion timesteps for training # Camera Configuration -num_cams: ${oc.len:data.camera_names} # 自动从 data.camera_names 列表长度获取 \ No newline at end of file +num_cams: ${len:${data.camera_names}} # 自动从 data.camera_names 列表长度获取 \ No newline at end of file