feat(train): 添加warmup学习率调度器
This commit is contained in:
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure correct import path
|
||||
@@ -43,6 +44,43 @@ def recursive_to_device(data, device):
|
||||
return data
|
||||
|
||||
|
||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||
"""
|
||||
Create a learning rate scheduler with warmup.
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
warmup_steps: Number of warmup steps
|
||||
max_steps: Total training steps
|
||||
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant')
|
||||
min_lr: Minimum learning rate (for cosine decay)
|
||||
|
||||
Returns:
|
||||
LambdaLR scheduler
|
||||
"""
|
||||
import math
|
||||
# Capture initial lr before LambdaLR modifies it
|
||||
base_lr = optimizer.param_groups[0]['lr']
|
||||
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
||||
|
||||
def lr_lambda(step):
|
||||
# Warmup phase: linear increase from 0 to 1
|
||||
if step < warmup_steps:
|
||||
return float(step) / float(max(1, warmup_steps))
|
||||
|
||||
# Post-warmup phase
|
||||
if scheduler_type == 'cosine':
|
||||
# Cosine annealing from 1 to min_lr_ratio
|
||||
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
||||
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
return max(min_lr_ratio, cosine_decay)
|
||||
else:
|
||||
# Constant learning rate
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||
def main(cfg: DictConfig):
|
||||
"""
|
||||
@@ -173,11 +211,25 @@ def main(cfg: DictConfig):
|
||||
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
||||
|
||||
# =========================================================================
|
||||
# 3. Setup Optimizer
|
||||
# 3. Setup Optimizer & LR Scheduler
|
||||
# =========================================================================
|
||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
||||
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
||||
|
||||
# Setup learning rate scheduler with warmup
|
||||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
||||
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
||||
|
||||
scheduler = get_lr_schedule_with_warmup(
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
max_steps=cfg.train.max_steps,
|
||||
scheduler_type=scheduler_type,
|
||||
min_lr=min_lr
|
||||
)
|
||||
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})")
|
||||
|
||||
# =========================================================================
|
||||
# 4. Training Loop
|
||||
# =========================================================================
|
||||
@@ -265,16 +317,19 @@ def main(cfg: DictConfig):
|
||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# =====================================================================
|
||||
# Logging
|
||||
# =====================================================================
|
||||
if step % cfg.train.log_freq == 0:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
pbar.set_postfix({
|
||||
"loss": f"{loss.item():.4f}",
|
||||
"lr": f"{current_lr:.2e}",
|
||||
"best_loss": f"{best_loss:.4f}"
|
||||
})
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
|
||||
|
||||
# =====================================================================
|
||||
# Checkpoint saving & Validation
|
||||
@@ -290,9 +345,11 @@ def main(cfg: DictConfig):
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, checkpoint_path)
|
||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
@@ -305,9 +362,11 @@ def main(cfg: DictConfig):
|
||||
'step': step,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'val_loss': val_loss,
|
||||
'dataset_stats': dataset_stats,
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, best_model_path)
|
||||
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
||||
|
||||
@@ -319,8 +378,10 @@ def main(cfg: DictConfig):
|
||||
'step': cfg.train.max_steps,
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'loss': loss.item(),
|
||||
'dataset_stats': dataset_stats,
|
||||
'current_lr': optimizer.param_groups[0]['lr'],
|
||||
}, final_model_path)
|
||||
log.info(f"💾 Final model saved: {final_model_path}")
|
||||
|
||||
|
||||
@@ -12,3 +12,8 @@ train:
|
||||
save_freq: 2000 # Save checkpoint frequency (steps)
|
||||
device: "cuda" # Device: "cuda" or "cpu"
|
||||
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
|
||||
|
||||
# Learning rate scheduler with warmup
|
||||
warmup_steps: 500 # Number of warmup steps
|
||||
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine"
|
||||
min_lr: 1e-6 # Minimum learning rate (for cosine decay)
|
||||
Reference in New Issue
Block a user