feat(train): 添加warmup学习率调度器

This commit is contained in:
gouhanke
2026-02-06 22:54:34 +08:00
parent 456056347f
commit 4332530a5f
2 changed files with 69 additions and 3 deletions

View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from pathlib import Path from pathlib import Path
# Ensure correct import path # Ensure correct import path
@@ -43,6 +44,43 @@ def recursive_to_device(data, device):
return data return data
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
Create a learning rate scheduler with warmup.
Args:
optimizer: PyTorch optimizer
warmup_steps: Number of warmup steps
max_steps: Total training steps
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant')
min_lr: Minimum learning rate (for cosine decay)
Returns:
LambdaLR scheduler
"""
import math
# Capture initial lr before LambdaLR modifies it
base_lr = optimizer.param_groups[0]['lr']
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
def lr_lambda(step):
# Warmup phase: linear increase from 0 to 1
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
# Post-warmup phase
if scheduler_type == 'cosine':
# Cosine annealing from 1 to min_lr_ratio
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, cosine_decay)
else:
# Constant learning rate
return 1.0
return LambdaLR(optimizer, lr_lambda)
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config") @hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
def main(cfg: DictConfig): def main(cfg: DictConfig):
""" """
@@ -173,11 +211,25 @@ def main(cfg: DictConfig):
log.warning("⚠️ Training will continue, but inference may not work correctly") log.warning("⚠️ Training will continue, but inference may not work correctly")
# ========================================================================= # =========================================================================
# 3. Setup Optimizer # 3. Setup Optimizer & LR Scheduler
# ========================================================================= # =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5) optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})") log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
# Setup learning rate scheduler with warmup
warmup_steps = int(cfg.train.get('warmup_steps', 500))
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
min_lr = float(cfg.train.get('min_lr', 1e-6))
scheduler = get_lr_schedule_with_warmup(
optimizer,
warmup_steps=warmup_steps,
max_steps=cfg.train.max_steps,
scheduler_type=scheduler_type,
min_lr=min_lr
)
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})")
# ========================================================================= # =========================================================================
# 4. Training Loop # 4. Training Loop
# ========================================================================= # =========================================================================
@@ -265,16 +317,19 @@ def main(cfg: DictConfig):
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step()
# ===================================================================== # =====================================================================
# Logging # Logging
# ===================================================================== # =====================================================================
if step % cfg.train.log_freq == 0: if step % cfg.train.log_freq == 0:
current_lr = optimizer.param_groups[0]['lr']
pbar.set_postfix({ pbar.set_postfix({
"loss": f"{loss.item():.4f}", "loss": f"{loss.item():.4f}",
"lr": f"{current_lr:.2e}",
"best_loss": f"{best_loss:.4f}" "best_loss": f"{best_loss:.4f}"
}) })
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}") log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
# ===================================================================== # =====================================================================
# Checkpoint saving & Validation # Checkpoint saving & Validation
@@ -290,9 +345,11 @@ def main(cfg: DictConfig):
'step': step, 'step': step,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'val_loss': val_loss, 'val_loss': val_loss,
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, checkpoint_path) }, checkpoint_path)
log.info(f"💾 Checkpoint saved: {checkpoint_path}") log.info(f"💾 Checkpoint saved: {checkpoint_path}")
@@ -305,9 +362,11 @@ def main(cfg: DictConfig):
'step': step, 'step': step,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'val_loss': val_loss, 'val_loss': val_loss,
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, best_model_path) }, best_model_path)
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})") log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
@@ -319,8 +378,10 @@ def main(cfg: DictConfig):
'step': cfg.train.max_steps, 'step': cfg.train.max_steps,
'model_state_dict': agent.state_dict(), 'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(), 'loss': loss.item(),
'dataset_stats': dataset_stats, 'dataset_stats': dataset_stats,
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path) }, final_model_path)
log.info(f"💾 Final model saved: {final_model_path}") log.info(f"💾 Final model saved: {final_model_path}")

View File

@@ -11,4 +11,9 @@ train:
log_freq: 100 # Log frequency (steps) log_freq: 100 # Log frequency (steps)
save_freq: 2000 # Save checkpoint frequency (steps) save_freq: 2000 # Save checkpoint frequency (steps)
device: "cuda" # Device: "cuda" or "cpu" device: "cuda" # Device: "cuda" or "cpu"
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production) num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
# Learning rate scheduler with warmup
warmup_steps: 500 # Number of warmup steps
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine"
min_lr: 1e-6 # Minimum learning rate (for cosine decay)