feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
329
train_xpatch_sparse_multi_datasets.py
Normal file
329
train_xpatch_sparse_multi_datasets.py
Normal file
@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training script for xPatch_SparseChannel model on multiple datasets.
|
||||
Supports Weather, Traffic, Electricity, Exchange, and ILI datasets with sigmoid learning rate adjustment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from train.train import train_forecasting_model
|
||||
from models.xPatch_SparseChannel.xPatch import Model as xPatchSparseChannel
|
||||
|
||||
# Dataset configurations
|
||||
DATASET_CONFIGS = {
|
||||
'Weather': {
|
||||
'csv_file': 'weather.csv',
|
||||
'enc_in': 21,
|
||||
'batch_size': 2048,
|
||||
'learning_rate': 0.0005,
|
||||
'target': 'OT',
|
||||
'data_path': 'weather/weather.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Traffic': {
|
||||
'csv_file': 'traffic.csv',
|
||||
'enc_in': 862,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.002,
|
||||
'target': 'OT',
|
||||
'data_path': 'traffic/traffic.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Electricity': {
|
||||
'csv_file': 'electricity.csv',
|
||||
'enc_in': 321,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.001,
|
||||
'target': 'OT',
|
||||
'data_path': 'electricity/electricity.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'Exchange': {
|
||||
'csv_file': 'exchange_rate.csv',
|
||||
'enc_in': 8,
|
||||
'batch_size': 128,
|
||||
'learning_rate': 0.00005,
|
||||
'target': 'OT',
|
||||
'data_path': 'exchange_rate/exchange_rate.csv',
|
||||
'seq_len': 96,
|
||||
'pred_lengths': [96, 192, 336, 720]
|
||||
},
|
||||
'ILI': {
|
||||
'csv_file': 'national_illness.csv',
|
||||
'enc_in': 7,
|
||||
'batch_size': 32,
|
||||
'learning_rate': 0.01,
|
||||
'target': 'ot',
|
||||
'data_path': 'illness/national_illness.csv',
|
||||
'seq_len': 36,
|
||||
'pred_lengths': [24, 36, 48, 60]
|
||||
}
|
||||
}
|
||||
|
||||
class Args:
|
||||
"""Configuration class for xPatch_SparseChannel model parameters."""
|
||||
def __init__(self, dataset_name, pred_len):
|
||||
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||
|
||||
# Model architecture parameters
|
||||
self.task_name = 'long_term_forecast'
|
||||
self.seq_len = dataset_config['seq_len'] # Use dataset-specific seq_len
|
||||
self.label_len = self.seq_len // 2 # Half of seq_len as label length
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = dataset_config['enc_in']
|
||||
self.c_out = dataset_config['enc_in']
|
||||
|
||||
# xPatch specific parameters from reference
|
||||
self.patch_len = 16 # patch length
|
||||
self.stride = 8 # stride
|
||||
self.padding_patch = 'end' # padding on the end
|
||||
|
||||
# Moving Average parameters
|
||||
self.ma_type = 'ema' # moving average type
|
||||
self.alpha = 0.3 # alpha parameter for EMA
|
||||
self.beta = 0.3 # beta parameter for DEMA
|
||||
|
||||
# RevIN normalization
|
||||
self.revin = 1 # RevIN; True 1 False 0
|
||||
|
||||
# Time features (not used by xPatch but required by data loader)
|
||||
self.embed = 'timeF' # Time feature embedding type
|
||||
self.freq = 'h' # Frequency for time features (hourly)
|
||||
|
||||
# Dataset specific parameters
|
||||
self.data = 'custom'
|
||||
self.root_path = './data/'
|
||||
self.data_path = dataset_config['data_path']
|
||||
self.features = 'M' # Multivariate prediction
|
||||
self.target = dataset_config['target'] # Target column
|
||||
self.train_only = False
|
||||
|
||||
# Required for dataflow - will be set by config
|
||||
self.batch_size = dataset_config['batch_size']
|
||||
self.num_workers = 8 # Will be overridden by config
|
||||
|
||||
print(f"xPatch_SparseChannel Model configuration for {dataset_name}:")
|
||||
print(f" - Input channels (C): {self.enc_in}")
|
||||
print(f" - Patch length: {self.patch_len}")
|
||||
print(f" - Stride: {self.stride}")
|
||||
print(f" - Sequence length: {self.seq_len}") # Now dataset-specific
|
||||
print(f" - Prediction length: {pred_len}")
|
||||
print(f" - Moving average type: {self.ma_type}")
|
||||
print(f" - Alpha: {self.alpha}")
|
||||
print(f" - Beta: {self.beta}")
|
||||
print(f" - RevIN: {self.revin}")
|
||||
print(f" - Target: {self.target}")
|
||||
print(f" - Batch size: {self.batch_size}")
|
||||
|
||||
def create_xpatch_sparse_model(args):
|
||||
"""Create xPatch_SparseChannel model with given configuration."""
|
||||
def model_constructor():
|
||||
return xPatchSparseChannel(args)
|
||||
return model_constructor
|
||||
|
||||
def train_single_dataset(dataset_name, pred_len, model_args, cmd_args, use_ps_loss=True):
|
||||
"""Train xPatch_SparseChannel on specified dataset with given prediction length."""
|
||||
|
||||
dataset_config = DATASET_CONFIGS[dataset_name]
|
||||
|
||||
# Update args for current prediction length
|
||||
model_args.pred_len = pred_len
|
||||
|
||||
# Update dataflow parameters from command line args
|
||||
model_args.num_workers = cmd_args.num_workers
|
||||
|
||||
# Create model constructor
|
||||
model_constructor = create_xpatch_sparse_model(model_args)
|
||||
|
||||
# Training configuration with dataset-specific parameters
|
||||
config = {
|
||||
'learning_rate': dataset_config['learning_rate'], # Dataset-specific learning rate
|
||||
'batch_size': dataset_config['batch_size'], # Dataset-specific batch size
|
||||
'weight_decay': 1e-4,
|
||||
'dataset': dataset_name,
|
||||
'pred_len': pred_len,
|
||||
'seq_len': model_args.seq_len,
|
||||
'patch_len': model_args.patch_len,
|
||||
'stride': model_args.stride,
|
||||
'ma_type': model_args.ma_type,
|
||||
'use_ps_loss': use_ps_loss,
|
||||
'num_workers': cmd_args.num_workers,
|
||||
'pin_memory': True,
|
||||
'persistent_workers': True
|
||||
}
|
||||
|
||||
# Project name for tracking
|
||||
loss_suffix = "_PSLoss" if use_ps_loss else "_MSE"
|
||||
project_name = f"xPatch_SparseChannel_{dataset_name}_pred{pred_len}{loss_suffix}_sigmoid"
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training {dataset_name} with prediction length {pred_len}")
|
||||
print(f"Model: xPatch_SparseChannel")
|
||||
print(f"Loss function: {'PS_Loss' if use_ps_loss else 'MSE'}")
|
||||
print(f"Learning rate: {dataset_config['learning_rate']}")
|
||||
print(f"Batch size: {dataset_config['batch_size']}")
|
||||
print(f"Features: {dataset_config['enc_in']}")
|
||||
print(f"Data path: {model_args.root_path}{model_args.data_path}")
|
||||
print(f"LR adjustment: sigmoid")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Train the model
|
||||
try:
|
||||
model, metrics = train_forecasting_model(
|
||||
model_constructor=model_constructor,
|
||||
data_path=f"{model_args.root_path}{model_args.data_path}",
|
||||
project_name=project_name,
|
||||
config=config,
|
||||
early_stopping_patience=5,
|
||||
max_epochs=100,
|
||||
checkpoint_dir="./checkpoints",
|
||||
log_interval=50,
|
||||
use_x_mark=False, # xPatch_SparseChannel doesn't use time features
|
||||
use_ps_loss=use_ps_loss,
|
||||
ps_lambda=cmd_args.ps_lambda,
|
||||
patch_len_threshold=64,
|
||||
use_gdw=True,
|
||||
dataset_mode="dataflow",
|
||||
dataflow_args=model_args,
|
||||
lr_adjust_strategy="sigmoid" # Use sigmoid learning rate adjustment
|
||||
)
|
||||
|
||||
print(f"Training completed for {project_name}")
|
||||
if use_ps_loss:
|
||||
print(f"Final validation MSE: {metrics.get('final_val_mse', 'N/A'):.6f}")
|
||||
else:
|
||||
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error training {project_name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train xPatch_SparseChannel on multiple datasets with sigmoid LR adjustment')
|
||||
parser.add_argument('--datasets', nargs='+', type=str,
|
||||
default=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||
choices=['Weather', 'Traffic', 'Electricity', 'Exchange', 'ILI'],
|
||||
help='List of datasets to train on')
|
||||
parser.add_argument('--use_ps_loss', action='store_true', default=True,
|
||||
help='Use PS_Loss instead of MSE')
|
||||
parser.add_argument('--ps_lambda', type=float, default=5.0,
|
||||
help='Weight for PS loss component')
|
||||
parser.add_argument('--device', type=str, default=None,
|
||||
help='Device to use for training (cuda/cpu)')
|
||||
parser.add_argument('--num_workers', type=int, default=8,
|
||||
help='Number of data loading workers')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("xPatch_SparseChannel Multi-Dataset Training Script with Sigmoid LR Adjustment")
|
||||
print("=" * 80)
|
||||
print(f"Datasets: {args.datasets}")
|
||||
print(f"Use PS_Loss: {args.use_ps_loss}")
|
||||
print(f"PS_Lambda: {args.ps_lambda}")
|
||||
print(f"Number of workers: {args.num_workers}")
|
||||
print(f"Learning rate adjustment: sigmoid")
|
||||
|
||||
# Display dataset configurations
|
||||
print("\nDataset Configurations:")
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
print(f" {dataset}:")
|
||||
print(f" - Features: {config['enc_in']}")
|
||||
print(f" - Batch size: {config['batch_size']}")
|
||||
print(f" - Learning rate: {config['learning_rate']}")
|
||||
print(f" - Sequence length: {config['seq_len']}")
|
||||
print(f" - Prediction lengths: {config['pred_lengths']}")
|
||||
print(f" - Data path: {config['data_path']}")
|
||||
|
||||
# Check if data files exist
|
||||
missing_datasets = []
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
data_path = f"./data/{config['data_path']}"
|
||||
if not os.path.exists(data_path):
|
||||
missing_datasets.append(f"{dataset}: '{data_path}'")
|
||||
|
||||
if missing_datasets:
|
||||
print(f"\nError: The following dataset files were not found:")
|
||||
for missing in missing_datasets:
|
||||
print(f" - {missing}")
|
||||
print("Please ensure all dataset files are available in the data/ directory.")
|
||||
return
|
||||
|
||||
# Set device
|
||||
if args.device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
print(f"\nUsing device: {device}")
|
||||
|
||||
# Training results storage
|
||||
all_results = {}
|
||||
|
||||
# Train on each dataset
|
||||
for dataset in args.datasets:
|
||||
print(f"\n{'#'*80}")
|
||||
print(f"STARTING TRAINING ON {dataset.upper()} DATASET")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
all_results[dataset] = {}
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
|
||||
# Train on each prediction length for this dataset
|
||||
for pred_len in config['pred_lengths']: # Use dataset-specific prediction lengths
|
||||
# Create model configuration for current dataset
|
||||
model_args = Args(
|
||||
dataset_name=dataset,
|
||||
pred_len=pred_len
|
||||
)
|
||||
|
||||
# Train the model
|
||||
model, metrics = train_single_dataset(
|
||||
dataset_name=dataset,
|
||||
pred_len=pred_len,
|
||||
model_args=model_args,
|
||||
cmd_args=args,
|
||||
use_ps_loss=args.use_ps_loss
|
||||
)
|
||||
|
||||
# Store results
|
||||
all_results[dataset][pred_len] = {
|
||||
'model': model,
|
||||
'metrics': metrics,
|
||||
'data_path': f"./data/{config['data_path']}"
|
||||
}
|
||||
|
||||
# Print comprehensive summary
|
||||
print("\n" + "=" * 100)
|
||||
print("COMPREHENSIVE TRAINING SUMMARY")
|
||||
print("=" * 100)
|
||||
|
||||
for dataset in args.datasets:
|
||||
config = DATASET_CONFIGS[dataset]
|
||||
print(f"\n{dataset} (Features: {config['enc_in']}, Batch: {config['batch_size']}, LR: {config['learning_rate']}, Seq: {config['seq_len']}):")
|
||||
for pred_len in all_results[dataset]:
|
||||
result = all_results[dataset][pred_len]
|
||||
if result['metrics'] is not None:
|
||||
if args.use_ps_loss:
|
||||
mse = result['metrics'].get('final_val_mse', 'N/A')
|
||||
else:
|
||||
mse = result['metrics'].get('final_val_loss', 'N/A')
|
||||
print(f" Pred Length {pred_len}: MSE = {mse}")
|
||||
else:
|
||||
print(f" Pred Length {pred_len}: Training failed")
|
||||
|
||||
print(f"\nAll models saved in: ./checkpoints/")
|
||||
print("All datasets training completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user