#!/usr/bin/env python3 """ Training script for xPatch_SparseChannel model on multiple datasets. Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment. """ import os import math import argparse import torch import torch.nn as nn from train.train import train_forecasting_model from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel # Dataset configurations DATASET_CONFIGS = { 'Weather': { 'csv_file': 'weather.csv', 'enc_in': 21, 'batch_size': 2048, 'learning_rate': 0.0005, 'target': 'OT', 'data_path': 'weather/weather.csv', 'seq_len': 96, 'pred_lengths': [96, 192, 336, 720] }, 'Traffic': { 'csv_file': 'traffic.csv', 'enc_in': 862, 'batch_size': 32, 'learning_rate': 0.002, 'target': 'OT', 'data_path': 'traffic/traffic.csv', 'seq_len': 96, 'pred_lengths': [96, 192, 336, 720] }, 'Electricity': { 'csv_file': 'electricity.csv', 'enc_in': 321, 'batch_size': 32, 'learning_rate': 0.001, 'target': 'OT', 'data_path': 'electricity/electricity.csv', 'seq_len': 96, 'pred_lengths': [96, 192, 336, 720] }, 'Exchange': { 'csv_file': 'exchange_rate.csv', 'enc_in': 8, 'batch_size': 128, 'learning_rate': 0.00005, 'target': 'OT', 'data_path': 'exchange_rate/exchange_rate.csv', 'seq_len': 96, 'pred_lengths': [96, 192, 336, 720] }, 'ILI': { 'csv_file': 'national_illness.csv', 'enc_in': 7, 'batch_size': 32, 'learning_rate': 0.01, 'target': 'ot', 'data_path': 'illness/national_illness.csv', 'seq_len': 36, 'pred_lengths': [24, 36, 48, 60] } } class Args: """Configuration class for xPatch_SparseChannel model parameters.""" def __init__(self, dataset_name, pred_len): dataset_config = DATASET_CONFIGS[dataset_name] # Model architecture parameters self.task_name = 'long_term_forecast' self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len self.label_len = self.seq_len // 2 # Half of seq_len as label length self.pred_len = pred_len self.enc_in = dataset_config['enc_in'] self.c_out = dataset_config['enc_in'] # xPatch specific parameters from reference self.patch_len = 16 # patch length self.stride = 8 # stride self.padding_patch = 'end' # padding on the end # Moving Average parameters self.ma_type = 'ema' # moving average type self.alpha = 0.3 # alpha parameter for EMA self.beta = 0.3 # beta parameter for DEMA # RevIN normalization self.revin = 1 # RevIN; True 1 False 0 # Time features (not used by xPatch but required by data loader) self.embed = 'timeF' # Time feature embedding type self.freq = 'h' # Frequency for time features (hourly) # Dataset specific parameters self.data = 'custom' self.root_path = './data/' self.data_path = dataset_config['data_path'] self.features = 'M' # Multivariate prediction self.target = dataset_config['target'] # Target column self.train_only = False # Required for dataflow - will be set by config self.batch_size = dataset_config['batch_size'] self.num_workers = 8 # Will be overridden by config print(f"xPatch_SparseChannel Model configuration for {dataset_name}:") print(f" - Input channels (C): {self.enc_in}") print(f" - Patch length: {self.patch_len}") print(f" - Stride: {self.stride}") print(f" - Sequence length: {self.seq_len}") # Now dataset-specific print(f" - Prediction length: {pred_len}") print(f" - Moving average type: {self.ma_type}") print(f" - Alpha: {self.alpha}") print(f" - Beta: {self.beta}") print(f" - RevIN: {self.revin}") print(f" - Target: {self.target}") print(f" - Batch size: {self.batch_size}") def create_xpatch_sparse_model(args): """Create xPatch_SparseChannel model with given configuration.""" def model_constructor(): return xPatchSparseChannel(args) return model_constructor def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True): """Train xPatch_SparseChannel on specified dataset with given prediction length.""" dataset_config = DATASET_CONFIGS[dataset_name] # Update args for current prediction length model_args.pred_len = pred_len # Update dataflow parameters from command line args model_args.num_workers = cmd_args.num_workers # Create model constructor model_constructor = create_xpatch_sparse_model(model_args) # Training configuration with dataset-specific parameters config = { 'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate 'batch_size': dataset_config['batch_size'], # Dataset-specific batch size 'weight_decay': 1e-4, 'dataset': dataset_name, 'pred_len': pred_len, 'seq_len': model_args.seq_len, 'patch_len': model_args.patch_len, 'stride': model_args.stride, 'ma_type': model_args.ma_type, 'use_ps_loss': use_ps_loss, 'num_workers': cmd_args.num_workers, 'pin_memory': True, 'persistent_workers': True } # Project name for tracking loss_suffix = "_PSLoss" if use_ps_loss else "_MSE" project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid" print(f"\n{'='*60}") print(f"Training {dataset_name} with prediction length {pred_len}") print(f"Model: xPatch_SparseChannel") print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}") print(f"Learning rate: {dataset_config['learning_rate']}") print(f"Batch size: {dataset_config['batch_size']}") print(f"Features: {dataset_config['enc_in']}") print(f"Data path: {model_args.root_path}{model_args.data_path}") print(f"LR adjustment: sigmoid") print(f"{'='*60}") # Train the model try: model, metrics = train_forecasting_model( model_constructor=model_constructor, data_path=f"{model_args.root_path}{model_args.data_path}", project_name=project_name, config=config, early_stopping_patience=5, max_epochs=100, checkpoint_dir="./checkpoints", log_interval=50, use_x_mark=False, # xPatch_SparseChannel doesn't use time features use_ps_loss=use_ps_loss, ps_lambda=cmd_args.ps_lambda, patch_len_threshold=64, use_gdw=True, dataset_mode="dataflow", dataflow_args=model_args, lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment ) print(f"Training completed for {project_name}") if use_ps_loss: print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}") else: print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}") return model, metrics except Exception as e: print(f"Error training {project_name}: {e}") import traceback traceback.print_exc() return None, None def main(): parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment') parser.add_argument('--datasets', nargs='+', type=str, default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'], choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'], help='List of datasets to train on') parser.add_argument('--use_ps_loss', action='store_true', default=True, help='Use PS_Loss instead of MSE') parser.add_argument('--ps_lambda', type=float, default=5.0, help='Weight for PS loss component') parser.add_argument('--device', type=str, default=None, help='Device to use for training (cuda/cpu)') parser.add_argument('--num_workers', type=int, default=8, help='Number of data loading workers') args = parser.parse_args() print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment") print("=" * 80) print(f"Datasets: {args.datasets}") print(f"Use PS_Loss: {args.use_ps_loss}") print(f"PS_Lambda: {args.ps_lambda}") print(f"Number of workers: {args.num_workers}") print(f"Learning rate adjustment: sigmoid") # Display dataset configurations print("\nDataset Configurations:") for dataset in args.datasets: config = DATASET_CONFIGS[dataset] print(f" {dataset}:") print(f" - Features: {config['enc_in']}") print(f" - Batch size: {config['batch_size']}") print(f" - Learning rate: {config['learning_rate']}") print(f" - Sequence length: {config['seq_len']}") print(f" - Prediction lengths: {config['pred_lengths']}") print(f" - Data path: {config['data_path']}") # Check if data files exist missing_datasets = [] for dataset in args.datasets: config = DATASET_CONFIGS[dataset] data_path = f"./data/{config['data_path']}" if not os.path.exists(data_path): missing_datasets.append(f"{dataset}: '{data_path}'") if missing_datasets: print(f"\nError: The following dataset files were not found:") for missing in missing_datasets: print(f" - {missing}") print("Please ensure all dataset files are available in the data/ directory.") return # Set device if args.device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device print(f"\nUsing device: {device}") # Training results storage all_results = {} # Train on each dataset for dataset in args.datasets: print(f"\n{'#'*80}") print(f"STARTING TRAINING ON {dataset.upper()} DATASET") print(f"{'#'*80}") all_results[dataset] = {} config = DATASET_CONFIGS[dataset] # Train on each prediction length for this dataset for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths # Create model configuration for current dataset model_args = Args( dataset_name=dataset, pred_len=pred_len ) # Train the model model, metrics = train_single_dataset( dataset_name=dataset, pred_len=pred_len, model_args=model_args, cmd_args=args, use_ps_loss=args.use_ps_loss ) # Store results all_results[dataset][pred_len] = { 'model': model, 'metrics': metrics, 'data_path': f"./data/{config['data_path']}" } # Print comprehensive summary print("\n" + "=" * 100) print("COMPREHENSIVE TRAINING SUMMARY") print("=" * 100) for dataset in args.datasets: config = DATASET_CONFIGS[dataset] print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):") for pred_len in all_results[dataset]: result = all_results[dataset][pred_len] if result['metrics'] is not None: if args.use_ps_loss: mse = result['metrics'].get('final_val_mse', 'N/A') else: mse = result['metrics'].get('final_val_loss', 'N/A') print(f" Pred Length {pred_len}: MSE = {mse}") else: print(f" Pred Length {pred_len}: Training failed") print(f"\nAll models saved in: ./checkpoints/") print("All datasets training completed!") if __name__ == "__main__": main()