import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import swanlab from typing import Dict, Any, Optional, Callable, Union, Tuple from dataflow import data_provider from layers.ps_loss import PSLoss from utils.tools import adjust_learning_rate, dotdict class EarlyStopping: """Early stopping to stop training when validation performance doesn't improve.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): """ Args: patience (int): How long to wait after last improvement. Default: 7 verbose (bool): If True, prints a message for each improvement. Default: False delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0 path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = float('inf') self.delta = delta self.path = path def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): """Save model when validation loss decreases.""" if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') torch.save(model.state_dict(), self.path) self.val_loss_min = val_loss class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset): """Wrapper to remove time features from dataflow datasets when use_x_mark=False""" def __init__(self, original_dataset): self.original_dataset = original_dataset def __getitem__(self, index): # Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark) seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index] # Return only seq_x and seq_y (remove time features) return seq_x, seq_y def __len__(self): return len(self.original_dataset) def inverse_transform(self, data): if hasattr(self.original_dataset, 'inverse_transform'): return self.original_dataset.inverse_transform(data) return data def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders using dataflow data_provider Args: args: Arguments object containing dataset configuration Required attributes: data, root_path, data_path, seq_len, label_len, pred_len, features, target, embed, freq, batch_size, num_workers, train_only use_x_mark (bool): Whether to use time features (x_mark and y_mark) Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders """ # Create datasets and dataloaders for each split train_data, _ = data_provider(args, flag='train') val_data, _ = data_provider(args, flag='val') test_data, _ = data_provider(args, flag='test') # Wrap datasets to respect use_x_mark parameter if not use_x_mark: train_data = DatasetWrapperWithoutTimeFeatures(train_data) val_data = DatasetWrapperWithoutTimeFeatures(val_data) test_data = DatasetWrapperWithoutTimeFeatures(test_data) # Determine batch size and other parameters based on flag train_shuffle = True val_shuffle = False test_shuffle = False train_drop_last = True val_drop_last = True test_drop_last = True batch_size = args.batch_size num_workers = args.num_workers # Create new dataloaders with wrapped datasets train_loader = DataLoader( train_data, batch_size=batch_size, shuffle=train_shuffle, num_workers=num_workers, drop_last=train_drop_last ) val_loader = DataLoader( val_data, batch_size=batch_size, shuffle=val_shuffle, num_workers=num_workers, drop_last=val_drop_last ) test_loader = DataLoader( test_data, batch_size=batch_size, shuffle=test_shuffle, num_workers=num_workers, drop_last=test_drop_last ) return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders from an NPZ file Args: data_path (str): Path to the NPZ file containing the data batch_size (int): Batch size for the DataLoaders use_x_mark (bool): Whether to use time features (x_mark) from the data file num_workers (int): Number of worker processes for data loading pin_memory (bool): Whether to pin memory for faster GPU transfer persistent_workers (bool): Whether to keep workers alive between epochs Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders """ # Load data from NPZ file data = np.load(data_path, allow_pickle=True) train_x = data['train_x'] train_y = data['train_y'] val_x = data['val_x'] val_y = data['val_y'] test_x = data['test_x'] test_y = data['test_y'] # Load time features if available and needed if use_x_mark: train_x_mark = data.get('train_x_mark', None) train_y_mark = data.get('train_y_mark', None) val_x_mark = data.get('val_x_mark', None) val_y_mark = data.get('val_y_mark', None) test_x_mark = data.get('test_x_mark', None) test_y_mark = data.get('test_y_mark', None) else: train_x_mark = None train_y_mark = None val_x_mark = None val_y_mark = None test_x_mark = None test_y_mark = None # Convert to PyTorch tensors train_x = torch.FloatTensor(train_x) train_y = torch.FloatTensor(train_y) val_x = torch.FloatTensor(val_x) val_y = torch.FloatTensor(val_y) test_x = torch.FloatTensor(test_x) test_y = torch.FloatTensor(test_y) # Create datasets based on whether time features are available if train_x_mark is not None: train_x_mark = torch.FloatTensor(train_x_mark) train_y_mark = torch.FloatTensor(train_y_mark) val_x_mark = torch.FloatTensor(val_x_mark) val_y_mark = torch.FloatTensor(val_y_mark) test_x_mark = torch.FloatTensor(test_x_mark) test_y_mark = torch.FloatTensor(test_y_mark) train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark) else: train_dataset = TensorDataset(train_x, train_y) val_dataset = TensorDataset(val_x, val_y) test_dataset = TensorDataset(test_x, test_y) # Create dataloaders with performance optimizations train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers if num_workers > 0 else False, drop_last=True # Drop incomplete batches for training ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers if num_workers > 0 else False, drop_last=False ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers if num_workers > 0 else False, drop_last=False ) return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def train_forecasting_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", dataflow_args = None, use_ps_loss: bool = False, ps_lambda: float = 5.0, patch_len_threshold: int = 64, use_gdw: bool = True, lr_adjust_strategy: str = "type1" ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series forecasting model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data (for npz mode) project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE ps_lambda (float): Weight for PS loss component when combined with MSE patch_len_threshold (int): Maximum patch length for adaptive patching use_gdw (bool): Whether to use Gradient-based Dynamic Weighting lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6' Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders based on dataset_mode if dataset_mode == "dataflow": if dataflow_args is None: raise ValueError("dataflow_args is required when dataset_mode='dataflow'") dataloaders = create_data_loaders_from_dataflow( dataflow_args, use_x_mark=use_x_mark ) else: # Default to "npz" mode dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), use_x_mark=use_x_mark, num_workers=config.get('num_workers', 4), pin_memory=config.get('pin_memory', True), persistent_workers=config.get('persistent_workers', True) ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer if use_ps_loss: criterion = PSLoss( patch_len_threshold=patch_len_threshold, lambda_ps=ps_lambda, use_gdw=use_gdw ) else: criterion = nn.MSELoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), ) # Create args object for learning rate adjustment lr_args = dotdict({ 'learning_rate': config.get('learning_rate', 1e-3), 'lradj': lr_adjust_strategy }) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() print("1\n") train_loss = 0.0 # 用于记录 log_interval 期间的损失 interval_loss = 0.0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = x_mark.to(device), y_mark.to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.to(device), targets.to(device) x_mark, y_mark = None, None # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features # Create decoder input (zeros for forecasting) outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss if use_ps_loss: loss, loss_dict = criterion(outputs, targets, model) else: loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() interval_loss += loss.item() if (batch_idx + 1) % log_interval == 0: if use_ps_loss and 'loss_dict' in locals(): print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, " f"Total Loss: {loss.item():.4f}, " f"MSE: {loss_dict['mse_loss']:.4f}, " f"PS: {loss_dict['ps_loss']:.4f}") # Log detailed loss components swanlab_run.log({ "batch_total_loss": loss.item(), "batch_mse_loss": loss_dict['mse_loss'], "batch_ps_loss": loss_dict['ps_loss'], "batch_corr_loss": loss_dict['corr_loss'], "batch_var_loss": loss_dict['var_loss'], "batch_mean_loss": loss_dict['mean_loss'], "alpha": loss_dict['alpha'], "beta": loss_dict['beta'], "gamma": loss_dict['gamma'] }) else: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") swanlab_run.log({"batch_train_loss": loss.item()}) # 重置 interval loss 以进行下一次计算 interval_loss = 0.0 avg_train_loss = train_loss / len(dataloaders['train']) epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_mse = 0.0 val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Calculate training loss (PS or MSE) if use_ps_loss: loss, _ = criterion(outputs, targets, model) val_loss += loss.item() else: loss = criterion(outputs, targets) val_loss += loss.item() # Always calculate MSE for validation metrics mse_loss = val_mse_criterion(outputs, targets) val_mse += mse_loss.item() avg_val_loss = val_loss / len(dataloaders['val']) avg_val_mse = val_mse / len(dataloaders['val']) current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "val_mse": avg_val_mse, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"Val MSE: {avg_val_mse:.4f}, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Adjust learning rate using utils.tools function adjust_learning_rate(optimizer, epoch, lr_args) # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Test evaluation on the best model - Always use MSE for final evaluation model.eval() test_loss = 0.0 test_mse = 0.0 mse_criterion = nn.MSELoss() # Always use MSE for test evaluation print("Evaluating on test set...") with torch.no_grad(): for batch_data in dataloaders['test']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Always calculate MSE for test evaluation (for fair comparison) mse_loss = mse_criterion(outputs, targets) test_loss += mse_loss.item() test_loss /= len(dataloaders['test']) print(f"Test evaluation completed!") print(f"Test Loss (MSE): {test_loss:.6f}") # Final validation for consistency - Always use MSE for final metrics model.eval() final_val_loss = 0.0 final_val_mse = 0.0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features outputs = model(inputs, x_mark) else: # For simple models without time features outputs = model(inputs) # Always calculate MSE for final validation (for fair comparison) mse_loss = mse_criterion(outputs, targets) final_val_loss += mse_loss.item() final_val_loss /= len(dataloaders['val']) print(f"Final validation MSE: {final_val_loss:.6f}") print(f"Final test MSE: {test_loss:.6f}") if use_ps_loss: print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison") # Log final test results to swanlab final_metrics = { "final_test_mse": test_loss, "final_val_mse": final_val_loss } swanlab_run.log(final_metrics) # Update metrics with final values (always MSE for comparison) metrics["final_val_loss"] = final_val_loss metrics["final_test_loss"] = test_loss metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE # Finish the swanlab run swanlab_run.finish() return model, metrics def train_classification_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10, use_x_mark: bool = True, dataset_mode: str = "npz", dataflow_args = None, lr_adjust_strategy: str = "type1" ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a time series classification model Args: model_constructor (Callable): Function that constructs and returns the model data_path (str): Path to the NPZ file containing the processed data (for npz mode) project_name (str): Name of the project for swanlab tracking config (Dict[str, Any]): Configuration dictionary for the experiment device (Optional[str]): Device to use for training ('cpu' or 'cuda') early_stopping_patience (int): Number of epochs to wait before early stopping max_epochs (int): Maximum number of epochs to train for checkpoint_dir (str): Directory to save model checkpoints log_interval (int): How often to log metrics during training use_x_mark (bool): Whether to use time features (x_mark) from the data file dataset_mode (str): Dataset construction mode - "npz" or "dataflow" dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow") lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6' Returns: Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders based on dataset_mode if dataset_mode == "dataflow": if dataflow_args is None: raise ValueError("dataflow_args is required when dataset_mode='dataflow'") dataloaders = create_data_loaders_from_dataflow( dataflow_args, use_x_mark=use_x_mark ) else: # Default to "npz" mode dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), use_x_mark=use_x_mark, num_workers=config.get('num_workers', 4), pin_memory=config.get('pin_memory', True), persistent_workers=config.get('persistent_workers', True) ) # Construct the model model = model_constructor() model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-3), weight_decay=config.get('weight_decay', 1e-4) ) # Create args object for learning rate adjustment lr_args = dotdict({ 'learning_rate': config.get('learning_rate', 1e-3), 'lradj': lr_adjust_strategy }) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"Epoch {epoch+1}/{max_epochs}") # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 start_time = time.time() for batch_idx, batch_data in enumerate(dataloaders['train']): # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None # Convert targets to long for classification targets = targets.long() # Zero the parameter gradients optimizer.zero_grad() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) loss = criterion(outputs, targets) # Backward pass and optimize loss.backward() optimizer.step() # Update statistics train_loss += loss.item() _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() if (batch_idx + 1) % log_interval == 0: print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") avg_train_loss = train_loss / len(dataloaders['train']) train_accuracy = 100. * train_correct / train_total epoch_time = time.time() - start_time # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) val_total += targets.size(0) val_correct += predicted.eq(targets).sum().item() avg_val_loss = val_loss / len(dataloaders['val']) val_accuracy = 100. * val_correct / val_total current_lr = optimizer.param_groups[0]['lr'] # Log metrics metrics_dict = { "train_loss": avg_train_loss, "val_loss": avg_val_loss, "val_accuracy": val_accuracy, "learning_rate": current_lr, "epoch_time": epoch_time } swanlab_run.log(metrics_dict) print(f"Epoch {epoch+1}/{max_epochs}, " f"Train Loss: {avg_train_loss:.4f}, " f"Val Loss: {avg_val_loss:.4f}, " f"Val Accuracy: {val_accuracy:.2f}%, " f"LR: {current_lr:.6f}, " f"Time: {epoch_time:.2f}s") # Check if we should save the model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss metrics = metrics_dict # Early stopping early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping triggered") break # Adjust learning rate using utils.tools function adjust_learning_rate(optimizer, epoch, lr_args) # Load the best model model.load_state_dict(torch.load(checkpoint_path)) # Final validation model.eval() final_val_loss = 0.0 final_val_correct = 0 final_val_total = 0 with torch.no_grad(): for batch_data in dataloaders['val']: # Handle both cases: with and without time features if len(batch_data) == 4: # With time features inputs, targets, x_mark, y_mark = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device) else: # Without time features inputs, targets = batch_data inputs, targets = inputs.float().to(device), targets.float().to(device) x_mark, y_mark = None, None targets = targets.long() # Forward pass - handle both cases if x_mark is not None: # For TimesNet model with time features dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark, dec_inp, y_mark) else: # For simple models without time features outputs = model(inputs) # Calculate loss loss = criterion(outputs, targets) final_val_loss += loss.item() # Calculate accuracy _, predicted = outputs.max(1) final_val_total += targets.size(0) final_val_correct += predicted.eq(targets).sum().item() final_val_loss /= len(dataloaders['val']) final_val_accuracy = 100. * final_val_correct / final_val_total print(f"Final validation loss: {final_val_loss:.4f}") print(f"Final validation accuracy: {final_val_accuracy:.2f}%") # Update metrics with final values metrics["final_val_loss"] = final_val_loss metrics["final_val_accuracy"] = final_val_accuracy # Finish the swanlab run swanlab_run.finish() return model, metrics def main(): # Example usage data_path = 'data/train_data.npz' project_name = 'TimeSeriesForecasting' config = { 'learning_rate': 0.001, 'batch_size': 32, 'weight_decay': 1e-4 } model_constructor = lambda: nn.Sequential( nn.Linear(10, 50), nn.ReLU(), nn.Linear(50, 1) ) model, metrics = train_forecasting_model( model_constructor=model_constructor, data_path=data_path, project_name=project_name, config=config ) if __name__ == "__main__": main()