diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 7df889b..b04faec 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -9,6 +9,7 @@ from tqdm import tqdm from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader, random_split from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR from pathlib import Path # Ensure correct import path @@ -43,6 +44,43 @@ def recursive_to_device(data, device): return data +def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0): + """ + Create a learning rate scheduler with warmup. + + Args: + optimizer: PyTorch optimizer + warmup_steps: Number of warmup steps + max_steps: Total training steps + scheduler_type: Type of scheduler after warmup ('cosine' or 'constant') + min_lr: Minimum learning rate (for cosine decay) + + Returns: + LambdaLR scheduler + """ + import math + # Capture initial lr before LambdaLR modifies it + base_lr = optimizer.param_groups[0]['lr'] + min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0 + + def lr_lambda(step): + # Warmup phase: linear increase from 0 to 1 + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + + # Post-warmup phase + if scheduler_type == 'cosine': + # Cosine annealing from 1 to min_lr_ratio + progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps)) + cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) + return max(min_lr_ratio, cosine_decay) + else: + # Constant learning rate + return 1.0 + + return LambdaLR(optimizer, lr_lambda) + + @hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") def main(cfg: DictConfig): """ @@ -173,11 +211,25 @@ def main(cfg: DictConfig): log.warning("⚠️ Training will continue, but inference may not work correctly") # ========================================================================= - # 3. Setup Optimizer + # 3. Setup Optimizer & LR Scheduler # ========================================================================= optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5) log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})") + # Setup learning rate scheduler with warmup + warmup_steps = int(cfg.train.get('warmup_steps', 500)) + scheduler_type = cfg.train.get('scheduler_type', 'cosine') + min_lr = float(cfg.train.get('min_lr', 1e-6)) + + scheduler = get_lr_schedule_with_warmup( + optimizer, + warmup_steps=warmup_steps, + max_steps=cfg.train.max_steps, + scheduler_type=scheduler_type, + min_lr=min_lr + ) + log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})") + # ========================================================================= # 4. Training Loop # ========================================================================= @@ -265,16 +317,19 @@ def main(cfg: DictConfig): torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) optimizer.step() + scheduler.step() # ===================================================================== # Logging # ===================================================================== if step % cfg.train.log_freq == 0: + current_lr = optimizer.param_groups[0]['lr'] pbar.set_postfix({ "loss": f"{loss.item():.4f}", + "lr": f"{current_lr:.2e}", "best_loss": f"{best_loss:.4f}" }) - log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") + log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}") # ===================================================================== # Checkpoint saving & Validation @@ -290,9 +345,11 @@ def main(cfg: DictConfig): 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), 'val_loss': val_loss, 'dataset_stats': dataset_stats, + 'current_lr': optimizer.param_groups[0]['lr'], }, checkpoint_path) log.info(f"💾 Checkpoint saved: {checkpoint_path}") @@ -305,9 +362,11 @@ def main(cfg: DictConfig): 'step': step, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), 'val_loss': val_loss, 'dataset_stats': dataset_stats, + 'current_lr': optimizer.param_groups[0]['lr'], }, best_model_path) log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})") @@ -319,8 +378,10 @@ def main(cfg: DictConfig): 'step': cfg.train.max_steps, 'model_state_dict': agent.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss.item(), 'dataset_stats': dataset_stats, + 'current_lr': optimizer.param_groups[0]['lr'], }, final_model_path) log.info(f"💾 Final model saved: {final_model_path}") diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index d724b77..f1a9c14 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -11,4 +11,9 @@ train: log_freq: 100 # Log frequency (steps) save_freq: 2000 # Save checkpoint frequency (steps) device: "cuda" # Device: "cuda" or "cpu" - num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production) \ No newline at end of file + num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production) + + # Learning rate scheduler with warmup + warmup_steps: 500 # Number of warmup steps + scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine" + min_lr: 1e-6 # Minimum learning rate (for cosine decay) \ No newline at end of file