import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import swanlab from typing import Dict, Any, Optional, Callable, Union, Tuple from dataflow import data_provider class EarlyStopping: """Early stopping to stop training when validation performance doesn't improve.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): """ Args: patience (int): How long to wait after last improvement. Default: 7 verbose (bool): If True, prints a message for each improvement. Default: False delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0 path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = float('inf') self.delta = delta self.path = path def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): """Save model when validation loss decreases.""" if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') torch.save(model.state_dict(), self.path) self.val_loss_min = val_loss class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset): """Wrapper to remove time features from dataflow datasets when use_x_mark=False""" def __init__(self, original_dataset): self.original_dataset = original_dataset def __getitem__(self, index): # Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark) seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index] # Return only seq_x and seq_y (remove time features) return seq_x, seq_y def __len__(self): return len(self.original_dataset) def inverse_transform(self, data): if hasattr(self.original_dataset, 'inverse_transform'): return self.original_dataset.inverse_transform(data) return data def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders using dataflow data_provider Args: args: Arguments object containing dataset configuration Required attributes: data, root_path, data_path, seq_len, label_len, pred_len, features, target, embed, freq, batch_size, num_workers, train_only use_x_mark (bool): Whether to use time features (x_mark and y_mark) Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders """ # Create datasets and dataloaders for each split train_data, _ = data_provider(args, flag='train') val_data, _ = data_provider(args, flag='val') test_data, _ = data_provider(args, flag='test') # Wrap datasets to respect use_x_mark parameter if not use_x_mark: train_data = DatasetWrapperWithoutTimeFeatures(train_data) val_data = DatasetWrapperWithoutTimeFeatures(val_data) test_data = DatasetWrapperWithoutTimeFeatures(test_data) # Determine batch size and other parameters based on flag train_shuffle = True val_shuffle = False test_shuffle = False train_drop_last = True val_drop_last = True test_drop_last = True batch_size = args.batch_size num_workers = args.num_workers # Create new dataloaders with wrapped datasets train_loader = DataLoader( train_data, batch_size=batch_size, shuffle=train_shuffle, num_workers=num_workers, drop_last=train_drop_last ) val_loader = DataLoader( val_data, batch_size=batch_size, shuffle=val_shuffle, num_workers=num_workers, drop_last=val_drop_last ) test_loader = DataLoader( test_data, batch_size=batch_size, shuffle=test_shuffle, num_workers=num_workers, drop_last=test_drop_last ) return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders from an NPZ file Args: data_path (str): Path to the NPZ file containing the data batch_size (int): Batch size for the DataLoaders use_x_mark (bool): Whether to use time features (x_mark) from the data file Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders """ # Load data from NPZ file data = np.load(data_path, allow_pickle=True) train_x = data['train_x'] train_y = data['train_y'] val_x = data['val_x'] val_y = data['val_y'] test_x = data['test_x'] test_y = data['test_y'] # Load time features if available and needed if use_x_mark: train_x_mark = data.get('train_x_mark', None) train_y_mark = data.get('train_y_mark', None) val_x_mark = data.get('val_x_mark', None) val_y_mark = data.get('val_y_mark', None) test_x_mark = data.get('test_x_mark', None) test_y_mark = data.get('test_y_mark', None) else: train_x_mark = None train_y_mark = None val_x_mark = None val_y_mark = None test_x_mark = None test_y_mark = None # Convert to PyTorch tensors train_x = torch.FloatTensor(train_x) train_y = torch.FloatTensor(train_y) val_x = torch.FloatTensor(val_x) val_y = torch.FloatTensor(val_y) test_x = torch.FloatTensor(test_x) test_y = torch.FloatTensor(test_y) # Create datasets based on whether time features are available if train_x_mark is not None: train_x_mark = torch.FloatTensor(train_x_mark) train_y_mark = torch.FloatTensor(train_y_mark) val_x_mark = torch.FloatTensor(val_x_mark) val_y_mark = torch.FloatTensor(val_y_mark) test_x_mark = torch.FloatTensor(test_x_mark) test_y_mark = torch.FloatTensor(test_y_mark) train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark) else: train_dataset = TensorDataset(train_x, train_y) val_dataset = TensorDataset(val_x, val_y) test_dataset = TensorDataset(test_x, test_y) # Create dataloaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def train_forecasting_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", dataflow_args = None ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series forecasting model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data (for npz mode) project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders based on dataset_mode if dataset_mode == "dataflow": if dataflow_args is None: raise ValueError("dataflow_args is required when dataset_mode='dataflow'") dataloaders = create_data_loaders_from_dataflow( dataflow_args, use_x_mark=use_x_mark ) else: # Default to "npz" mode dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), use_x_mark=use_x_mark ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer criterion = nn.MSELoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), ) # Add learning rate scheduler to halve LR after each epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() print("1\n") train_loss = 0.0 # 用于记录 log_interval 期间的损失 interval_loss = 0.0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features # Create decoder input (zeros for forecasting) outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() interval_loss += loss.item() if (batch_idx + 1) % log_interval == 0: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") # 计算这一个 interval 的平均损失并记录 avg_interval_loss = interval_loss / log_interval swanlab_run.log({"batch_train_loss": avg_interval_loss}) # 重置 interval loss 以进行下一次计算 interval_loss = 0.0 avg_train_loss = train_loss / len(dataloaders['train']) epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_mse = 0.0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(dataloaders['val']) current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Step the learning rate scheduler scheduler.step() # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Test evaluation on the best model model.eval() test_loss = 0.0 test_mse = 0.0 print("Evaluating on test set...") with torch.no_grad(): for batch_data in dataloaders['test']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) test_loss += loss.item() test_loss /= len(dataloaders['test']) print(f"Test evaluation completed!") print(f"Test Loss (MSE): {test_loss:.6f}") # Final validation for consistency model.eval() final_val_loss = 0.0 final_val_mse = 0.0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) final_val_loss += loss.item() final_val_loss /= len(dataloaders['val']) print(f"Final validation loss: {final_val_loss:.6f}") # Log final test results to swanlab final_metrics = { "final_test_loss": test_loss, "final_val_loss": final_val_loss } swanlab_run.log(final_metrics) # Update metrics with final values metrics["final_val_loss"] = final_val_loss metrics["final_test_loss"] = test_loss # Finish the swanlab run swanlab_run.finish() return model, metrics def train_classification_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", dataflow_args = None ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series classification model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data (for npz mode) project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders based on dataset_mode if dataset_mode == "dataflow": if dataflow_args is None: raise ValueError("dataflow_args is required when dataset_mode='dataflow'") dataloaders = create_data_loaders_from_dataflow( dataflow_args, use_x_mark=use_x_mark ) else: # Default to "npz" mode dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), use_x_mark=use_x_mark ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), weight_decay=config.get('weight_decay', 1e-4) ) # Add learning rate scheduler to halve LR after each epoch scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Convert targets to long for classification targets = targets.long() # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() if (batch_idx + 1) % log_interval == 0: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") avg_train_loss = train_loss / len(dataloaders['train']) train_accuracy = 100. * train_correct / train_total epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) val_total += targets.size(0) val_correct += predicted.eq(targets).sum().item() avg_val_loss = val_loss / len(dataloaders['val']) val_accuracy = 100. * val_correct / val_total current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "val_accuracy": val_accuracy, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"Val Accuracy: {val_accuracy:.2f}%, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Step the learning rate scheduler scheduler.step() # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Final validation model.eval() final_val_loss = 0.0 final_val_correct = 0 final_val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) final_val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) final_val_total += targets.size(0) final_val_correct += predicted.eq(targets).sum().item() final_val_loss /= len(dataloaders['val']) final_val_accuracy = 100. * final_val_correct / final_val_total print(f"Final validation loss: {final_val_loss:.4f}") print(f"Final validation accuracy: {final_val_accuracy:.2f}%") # Update metrics with final values metrics["final_val_loss"] = final_val_loss metrics["final_val_accuracy"] = final_val_accuracy # Finish the swanlab run swanlab_run.finish() return model, metrics def main(): # Example usage data_path = 'data/train_data.npz' project_name = 'TimeSeriesForecasting' config = { 'learning_rate': 0.001, 'batch_size': 32, 'weight_decay': 1e-4 } model_constructor = lambda: nn.Sequential( nn.Linear(10, 50), nn.ReLU(), nn.Linear(50, 1) ) model, metrics = train_forecasting_model( model_constructor=model_constructor, data_path=data_path, project_name=project_name, config=config ) if __name__ == "__main__": main()