Files
tsmodel/train_xpatch_sparse_multi_datasets.py

330 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Training script for xPatch_SparseChannel model on multiple datasets.
Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment.
"""
import os
import math
import argparse
import torch
import torch.nn as nn
from train.train import train_forecasting_model
from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel
# Dataset configurations
DATASET_CONFIGS = {
'Weather': {
'csv_file': 'weather.csv',
'enc_in': 21,
'batch_size': 2048,
'learning_rate': 0.0005,
'target': 'OT',
'data_path': 'weather/weather.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Traffic': {
'csv_file': 'traffic.csv',
'enc_in': 862,
'batch_size': 32,
'learning_rate': 0.002,
'target': 'OT',
'data_path': 'traffic/traffic.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Electricity': {
'csv_file': 'electricity.csv',
'enc_in': 321,
'batch_size': 32,
'learning_rate': 0.001,
'target': 'OT',
'data_path': 'electricity/electricity.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'Exchange': {
'csv_file': 'exchange_rate.csv',
'enc_in': 8,
'batch_size': 128,
'learning_rate': 0.00005,
'target': 'OT',
'data_path': 'exchange_rate/exchange_rate.csv',
'seq_len': 96,
'pred_lengths': [96, 192, 336, 720]
},
'ILI': {
'csv_file': 'national_illness.csv',
'enc_in': 7,
'batch_size': 32,
'learning_rate': 0.01,
'target': 'ot',
'data_path': 'illness/national_illness.csv',
'seq_len': 36,
'pred_lengths': [24, 36, 48, 60]
}
}
class Args:
"""Configuration class for xPatch_SparseChannel model parameters."""
def __init__(self, dataset_name, pred_len):
dataset_config = DATASET_CONFIGS[dataset_name]
# Model architecture parameters
self.task_name = 'long_term_forecast'
self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len
self.label_len = self.seq_len // 2 # Half of seq_len as label length
self.pred_len = pred_len
self.enc_in = dataset_config['enc_in']
self.c_out = dataset_config['enc_in']
# xPatch specific parameters from reference
self.patch_len = 16 # patch length
self.stride = 8 # stride
self.padding_patch = 'end' # padding on the end
# Moving Average parameters
self.ma_type = 'ema' # moving average type
self.alpha = 0.3 # alpha parameter for EMA
self.beta = 0.3 # beta parameter for DEMA
# RevIN normalization
self.revin = 1 # RevIN; True 1 False 0
# Time features (not used by xPatch but required by data loader)
self.embed = 'timeF' # Time feature embedding type
self.freq = 'h' # Frequency for time features (hourly)
# Dataset specific parameters
self.data = 'custom'
self.root_path = './data/'
self.data_path = dataset_config['data_path']
self.features = 'M' # Multivariate prediction
self.target = dataset_config['target'] # Target column
self.train_only = False
# Required for dataflow - will be set by config
self.batch_size = dataset_config['batch_size']
self.num_workers = 8 # Will be overridden by config
print(f"xPatch_SparseChannel Model configuration for {dataset_name}:")
print(f" - Input channels (C): {self.enc_in}")
print(f" - Patch length: {self.patch_len}")
print(f" - Stride: {self.stride}")
print(f" - Sequence length: {self.seq_len}") # Now dataset-specific
print(f" - Prediction length: {pred_len}")
print(f" - Moving average type: {self.ma_type}")
print(f" - Alpha: {self.alpha}")
print(f" - Beta: {self.beta}")
print(f" - RevIN: {self.revin}")
print(f" - Target: {self.target}")
print(f" - Batch size: {self.batch_size}")
def create_xpatch_sparse_model(args):
"""Create xPatch_SparseChannel model with given configuration."""
def model_constructor():
return xPatchSparseChannel(args)
return model_constructor
def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True):
"""Train xPatch_SparseChannel on specified dataset with given prediction length."""
dataset_config = DATASET_CONFIGS[dataset_name]
# Update args for current prediction length
model_args.pred_len = pred_len
# Update dataflow parameters from command line args
model_args.num_workers = cmd_args.num_workers
# Create model constructor
model_constructor = create_xpatch_sparse_model(model_args)
# Training configuration with dataset-specific parameters
config = {
'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate
'batch_size': dataset_config['batch_size'], # Dataset-specific batch size
'weight_decay': 1e-4,
'dataset': dataset_name,
'pred_len': pred_len,
'seq_len': model_args.seq_len,
'patch_len': model_args.patch_len,
'stride': model_args.stride,
'ma_type': model_args.ma_type,
'use_ps_loss': use_ps_loss,
'num_workers': cmd_args.num_workers,
'pin_memory': True,
'persistent_workers': True
}
# Project name for tracking
loss_suffix = "_PSLoss" if use_ps_loss else "_MSE"
project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid"
print(f"\n{'='*60}")
print(f"Training {dataset_name} with prediction length {pred_len}")
print(f"Model: xPatch_SparseChannel")
print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}")
print(f"Learning rate: {dataset_config['learning_rate']}")
print(f"Batch size: {dataset_config['batch_size']}")
print(f"Features: {dataset_config['enc_in']}")
print(f"Data path: {model_args.root_path}{model_args.data_path}")
print(f"LR adjustment: sigmoid")
print(f"{'='*60}")
# Train the model
try:
model, metrics = train_forecasting_model(
model_constructor=model_constructor,
data_path=f"{model_args.root_path}{model_args.data_path}",
project_name=project_name,
config=config,
early_stopping_patience=5,
max_epochs=100,
checkpoint_dir="./checkpoints",
log_interval=50,
use_x_mark=False, # xPatch_SparseChannel doesn't use time features
use_ps_loss=use_ps_loss,
ps_lambda=cmd_args.ps_lambda,
patch_len_threshold=64,
use_gdw=True,
dataset_mode="dataflow",
dataflow_args=model_args,
lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment
)
print(f"Training completed for {project_name}")
if use_ps_loss:
print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}")
else:
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
return model, metrics
except Exception as e:
print(f"Error training {project_name}: {e}")
import traceback
traceback.print_exc()
return None, None
def main():
parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment')
parser.add_argument('--datasets', nargs='+', type=str,
default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
help='List of datasets to train on')
parser.add_argument('--use_ps_loss', action='store_true', default=True,
help='Use PS_Loss instead of MSE')
parser.add_argument('--ps_lambda', type=float, default=5.0,
help='Weight for PS loss component')
parser.add_argument('--device', type=str, default=None,
help='Device to use for training (cuda/cpu)')
parser.add_argument('--num_workers', type=int, default=8,
help='Number of data loading workers')
args = parser.parse_args()
print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment")
print("=" * 80)
print(f"Datasets: {args.datasets}")
print(f"Use PS_Loss: {args.use_ps_loss}")
print(f"PS_Lambda: {args.ps_lambda}")
print(f"Number of workers: {args.num_workers}")
print(f"Learning rate adjustment: sigmoid")
# Display dataset configurations
print("\nDataset Configurations:")
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
print(f" {dataset}:")
print(f" - Features: {config['enc_in']}")
print(f" - Batch size: {config['batch_size']}")
print(f" - Learning rate: {config['learning_rate']}")
print(f" - Sequence length: {config['seq_len']}")
print(f" - Prediction lengths: {config['pred_lengths']}")
print(f" - Data path: {config['data_path']}")
# Check if data files exist
missing_datasets = []
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
data_path = f"./data/{config['data_path']}"
if not os.path.exists(data_path):
missing_datasets.append(f"{dataset}: '{data_path}'")
if missing_datasets:
print(f"\nError: The following dataset files were not found:")
for missing in missing_datasets:
print(f" - {missing}")
print("Please ensure all dataset files are available in the data/ directory.")
return
# Set device
if args.device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
device = args.device
print(f"\nUsing device: {device}")
# Training results storage
all_results = {}
# Train on each dataset
for dataset in args.datasets:
print(f"\n{'#'*80}")
print(f"STARTING TRAINING ON {dataset.upper()} DATASET")
print(f"{'#'*80}")
all_results[dataset] = {}
config = DATASET_CONFIGS[dataset]
# Train on each prediction length for this dataset
for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths
# Create model configuration for current dataset
model_args = Args(
dataset_name=dataset,
pred_len=pred_len
)
# Train the model
model, metrics = train_single_dataset(
dataset_name=dataset,
pred_len=pred_len,
model_args=model_args,
cmd_args=args,
use_ps_loss=args.use_ps_loss
)
# Store results
all_results[dataset][pred_len] = {
'model': model,
'metrics': metrics,
'data_path': f"./data/{config['data_path']}"
}
# Print comprehensive summary
print("\n" + "=" * 100)
print("COMPREHENSIVE TRAINING SUMMARY")
print("=" * 100)
for dataset in args.datasets:
config = DATASET_CONFIGS[dataset]
print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):")
for pred_len in all_results[dataset]:
result = all_results[dataset][pred_len]
if result['metrics'] is not None:
if args.use_ps_loss:
mse = result['metrics'].get('final_val_mse', 'N/A')
else:
mse = result['metrics'].get('final_val_loss', 'N/A')
print(f" Pred Length {pred_len}: MSE = {mse}")
else:
print(f" Pred Length {pred_len}: Training failed")
print(f"\nAll models saved in: ./checkpoints/")
print("All datasets training completed!")
if __name__ == "__main__":
main()