import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import swanlab from typing import Dict, Any, Optional, Callable, Union, Tuple class EarlyStopping: """Early stopping to stop training when validation performance doesn't improve.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): """ Args: patience (int): How long to wait after last improvement. Default: 7 verbose (bool): If True, prints a message for each improvement. Default: False delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0 path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = float('inf') self.delta = delta self.path = path def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): """Save model when validation loss decreases.""" if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') torch.save(model.state_dict(), self.path) self.val_loss_min = val_loss def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders from an NPZ file Args: data_path (str): Path to the NPZ file containing the data batch_size (int): Batch size for the DataLoaders Returns: Dict[str, DataLoader]: Dictionary with train and val DataLoaders """ # Load data from NPZ file data = np.load(data_path, allow_pickle=True) train_x = data['train_x'] train_y = data['train_y'] val_x = data['val_x'] val_y = data['val_y'] # Load time features if available train_x_mark = data.get('train_x_mark', None) train_y_mark = data.get('train_y_mark', None) val_x_mark = data.get('val_x_mark', None) val_y_mark = data.get('val_y_mark', None) # Convert to PyTorch tensors train_x = torch.FloatTensor(train_x) train_y = torch.FloatTensor(train_y) val_x = torch.FloatTensor(val_x) val_y = torch.FloatTensor(val_y) # Create datasets based on whether time features are available if train_x_mark is not None: train_x_mark = torch.FloatTensor(train_x_mark) train_y_mark = torch.FloatTensor(train_y_mark) val_x_mark = torch.FloatTensor(val_x_mark) val_y_mark = torch.FloatTensor(val_y_mark) train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) else: train_dataset = TensorDataset(train_x, train_y) val_dataset = TensorDataset(val_x, val_y) # Create dataloaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) return { 'train': train_loader, 'val': val_loader } def train_forecasting_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10 ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series forecasting model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32) ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer criterion = nn.MSELoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), ) # Add learning rate scheduler to halve LR after each epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() print("1\n") train_loss = 0.0 # 用于记录 log_interval 期间的损失 interval_loss = 0.0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features # Create decoder input (zeros for forecasting) dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() interval_loss += loss.item() if (batch_idx + 1) % log_interval == 0: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") # 计算这一个 interval 的平均损失并记录 avg_interval_loss = interval_loss / log_interval swanlab_run.log({"batch_train_loss": avg_interval_loss}) # 重置 interval loss 以进行下一次计算 interval_loss = 0.0 avg_train_loss = train_loss / len(dataloaders['train']) epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_mse = 0.0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(dataloaders['val']) current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Step the learning rate scheduler scheduler.step() # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Final validation model.eval() final_val_loss = 0.0 final_val_mse = 0.0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) final_val_loss += loss.item() final_val_loss /= len(dataloaders['val']) print(f"Final validation loss: {final_val_loss:.4f}") # Update metrics with final values metrics["final_val_loss"] = final_val_loss # Finish the swanlab run swanlab_run.finish() return model, metrics def train_classification_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10 ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series classification model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32) ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), weight_decay=config.get('weight_decay', 1e-4) ) # Add learning rate scheduler to halve LR after each epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Convert targets to long for classification targets = targets.long() # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() if (batch_idx + 1) % log_interval == 0: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") avg_train_loss = train_loss / len(dataloaders['train']) train_accuracy = 100. * train_correct / train_total epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) val_total += targets.size(0) val_correct += predicted.eq(targets).sum().item() avg_val_loss = val_loss / len(dataloaders['val']) val_accuracy = 100. * val_correct / val_total current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "val_accuracy": val_accuracy, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"Val Accuracy: {val_accuracy:.2f}%, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Step the learning rate scheduler scheduler.step() # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Final validation model.eval() final_val_loss = 0.0 final_val_correct = 0 final_val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) final_val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) final_val_total += targets.size(0) final_val_correct += predicted.eq(targets).sum().item() final_val_loss /= len(dataloaders['val']) final_val_accuracy = 100. * final_val_correct / final_val_total print(f"Final validation loss: {final_val_loss:.4f}") print(f"Final validation accuracy: {final_val_accuracy:.2f}%") # Update metrics with final values metrics["final_val_loss"] = final_val_loss metrics["final_val_accuracy"] = final_val_accuracy # Finish the swanlab run swanlab_run.finish() return model, metrics def main(): # Example usage data_path = 'data/train_data.npz' project_name = 'TimeSeriesForecasting' config = { 'learning_rate': 0.001, 'batch_size': 32, 'weight_decay': 1e-4 } model_constructor = lambda: nn.Sequential( nn.Linear(10, 50), nn.ReLU(), nn.Linear(50, 1) ) model, metrics = train_forecasting_model( model_constructor=model_constructor, data_path=data_path, project_name=project_name, config=config ) if __name__ == "__main__": main()