feat(train): 添加warmup学习率调度器
This commit is contained in:
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Ensure correct import path
|
# Ensure correct import path
|
||||||
@@ -43,6 +44,43 @@ def recursive_to_device(data, device):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||||
|
"""
|
||||||
|
Create a learning rate scheduler with warmup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: PyTorch optimizer
|
||||||
|
warmup_steps: Number of warmup steps
|
||||||
|
max_steps: Total training steps
|
||||||
|
scheduler_type: Type of scheduler after warmup ('cosine' or 'constant')
|
||||||
|
min_lr: Minimum learning rate (for cosine decay)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LambdaLR scheduler
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
# Capture initial lr before LambdaLR modifies it
|
||||||
|
base_lr = optimizer.param_groups[0]['lr']
|
||||||
|
min_lr_ratio = min_lr / base_lr if base_lr > 0 else 0.0
|
||||||
|
|
||||||
|
def lr_lambda(step):
|
||||||
|
# Warmup phase: linear increase from 0 to 1
|
||||||
|
if step < warmup_steps:
|
||||||
|
return float(step) / float(max(1, warmup_steps))
|
||||||
|
|
||||||
|
# Post-warmup phase
|
||||||
|
if scheduler_type == 'cosine':
|
||||||
|
# Cosine annealing from 1 to min_lr_ratio
|
||||||
|
progress = float(step - warmup_steps) / float(max(1, max_steps - warmup_steps))
|
||||||
|
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||||
|
return max(min_lr_ratio, cosine_decay)
|
||||||
|
else:
|
||||||
|
# Constant learning rate
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return LambdaLR(optimizer, lr_lambda)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
@hydra.main(version_base=None, config_path="../../vla/conf", config_name="config")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
"""
|
"""
|
||||||
@@ -173,11 +211,25 @@ def main(cfg: DictConfig):
|
|||||||
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
log.warning("⚠️ Training will continue, but inference may not work correctly")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 3. Setup Optimizer
|
# 3. Setup Optimizer & LR Scheduler
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
||||||
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
log.info(f"🔧 Optimizer: AdamW (lr={cfg.train.lr})")
|
||||||
|
|
||||||
|
# Setup learning rate scheduler with warmup
|
||||||
|
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||||
|
scheduler_type = cfg.train.get('scheduler_type', 'cosine')
|
||||||
|
min_lr = float(cfg.train.get('min_lr', 1e-6))
|
||||||
|
|
||||||
|
scheduler = get_lr_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
max_steps=cfg.train.max_steps,
|
||||||
|
scheduler_type=scheduler_type,
|
||||||
|
min_lr=min_lr
|
||||||
|
)
|
||||||
|
log.info(f"📈 LR Scheduler: {scheduler_type} with {warmup_steps} warmup steps (min_lr={min_lr})")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 4. Training Loop
|
# 4. Training Loop
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -265,16 +317,19 @@ def main(cfg: DictConfig):
|
|||||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Logging
|
# Logging
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
if step % cfg.train.log_freq == 0:
|
if step % cfg.train.log_freq == 0:
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
"loss": f"{loss.item():.4f}",
|
"loss": f"{loss.item():.4f}",
|
||||||
|
"lr": f"{current_lr:.2e}",
|
||||||
"best_loss": f"{best_loss:.4f}"
|
"best_loss": f"{best_loss:.4f}"
|
||||||
})
|
})
|
||||||
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f}")
|
log.info(f"Step {step}/{cfg.train.max_steps} | Loss: {loss.item():.4f} | LR: {current_lr:.2e}")
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# Checkpoint saving & Validation
|
# Checkpoint saving & Validation
|
||||||
@@ -290,9 +345,11 @@ def main(cfg: DictConfig):
|
|||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'val_loss': val_loss,
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
log.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
||||||
|
|
||||||
@@ -305,9 +362,11 @@ def main(cfg: DictConfig):
|
|||||||
'step': step,
|
'step': step,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'val_loss': val_loss,
|
'val_loss': val_loss,
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, best_model_path)
|
}, best_model_path)
|
||||||
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
log.info(f"🌟 Best model updated: {best_model_path} (val_loss: {best_loss:.4f})")
|
||||||
|
|
||||||
@@ -319,8 +378,10 @@ def main(cfg: DictConfig):
|
|||||||
'step': cfg.train.max_steps,
|
'step': cfg.train.max_steps,
|
||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': loss.item(),
|
||||||
'dataset_stats': dataset_stats,
|
'dataset_stats': dataset_stats,
|
||||||
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, final_model_path)
|
}, final_model_path)
|
||||||
log.info(f"💾 Final model saved: {final_model_path}")
|
log.info(f"💾 Final model saved: {final_model_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -12,3 +12,8 @@ train:
|
|||||||
save_freq: 2000 # Save checkpoint frequency (steps)
|
save_freq: 2000 # Save checkpoint frequency (steps)
|
||||||
device: "cuda" # Device: "cuda" or "cpu"
|
device: "cuda" # Device: "cuda" or "cpu"
|
||||||
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
|
num_workers: 8 # DataLoader workers (set to 0 for debugging, 8 for production)
|
||||||
|
|
||||||
|
# Learning rate scheduler with warmup
|
||||||
|
warmup_steps: 500 # Number of warmup steps
|
||||||
|
scheduler_type: "cosine" # Scheduler after warmup: "constant" or "cosine"
|
||||||
|
min_lr: 1e-6 # Minimum learning rate (for cosine decay)
|
||||||
Reference in New Issue
Block a user