210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training script for TimesNet model on ETT datasets.
|
|
"""
|
|
|
|
import os
|
|
import math
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
from train.train import train_forecasting_model
|
|
from models.TimesNet.TimesNet import Model as TimesNet
|
|
|
|
class Args:
|
|
"""Configuration class for TimesNet model parameters."""
|
|
def __init__(self, seq_len, pred_len, enc_in, c_out):
|
|
# Model architecture parameters
|
|
self.task_name = 'long_term_forecast'
|
|
self.seq_len = seq_len
|
|
self.label_len = seq_len // 2 # Half of seq_len as label length
|
|
self.pred_len = pred_len
|
|
self.enc_in = enc_in
|
|
self.c_out = c_out
|
|
|
|
# TimesNet specific parameters
|
|
self.top_k = 5 # k parameter as specified
|
|
self.e_layers = 2 # Number of layers as specified
|
|
self.d_min = 32 # dmin as specified
|
|
self.d_max = 512 # dmax as specified
|
|
|
|
# Calculate d_model based on the formula: min{max{2*⌈log C⌉, dmin}, dmax}
|
|
log_c = math.ceil(math.log2(enc_in)) if enc_in > 1 else 1
|
|
# self.d_model = min(max(2 * log_c, self.d_min), self.d_max)
|
|
self.d_model = 64
|
|
|
|
# Other model parameters
|
|
self.d_ff = 64 # Standard transformer ratio
|
|
self.num_kernels = 6 # For Inception blocks
|
|
self.embed = 'timeF' # Time feature embedding type
|
|
self.freq = 'h' # Frequency for time features (minutely for ETT)
|
|
self.dropout = 0.1
|
|
|
|
print(f"Model configuration:")
|
|
print(f" - Input channels (C): {enc_in}")
|
|
print(f" - d_model: {self.d_model} (calculated from 2*⌈log₂({enc_in})⌉ = {2 * log_c})")
|
|
print(f" - Sequence length: {seq_len}")
|
|
print(f" - Prediction length: {pred_len}")
|
|
print(f" - Top-k: {self.top_k}")
|
|
print(f" - Layers: {self.e_layers}")
|
|
|
|
def create_timesnet_model(args):
|
|
"""Create TimesNet model with given configuration."""
|
|
def model_constructor():
|
|
return TimesNet(args)
|
|
return model_constructor
|
|
|
|
def train_single_dataset(data_path, dataset_name, pred_len, args):
|
|
"""Train TimesNet on a single dataset configuration."""
|
|
|
|
# Update args for current prediction length
|
|
args.pred_len = pred_len
|
|
|
|
# Create model constructor
|
|
model_constructor = create_timesnet_model(args)
|
|
|
|
# Training configuration
|
|
config = {
|
|
'learning_rate': 1e-5, # LR = 10^-4 as specified
|
|
'batch_size': 32, # BatchSize 32 as specified
|
|
'weight_decay': 1e-4,
|
|
'dataset': dataset_name,
|
|
'pred_len': pred_len,
|
|
'seq_len': args.seq_len,
|
|
'd_model': args.d_model,
|
|
'top_k': args.top_k,
|
|
'e_layers': args.e_layers
|
|
}
|
|
|
|
# Project name for tracking
|
|
project_name = f"TimesNet_{dataset_name}_pred{pred_len}"
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Training {dataset_name} with prediction length {pred_len}")
|
|
print(f"Data path: {data_path}")
|
|
print(f"{'='*60}")
|
|
|
|
# Train the model
|
|
try:
|
|
model, metrics = train_forecasting_model(
|
|
model_constructor=model_constructor,
|
|
data_path=data_path,
|
|
project_name=project_name,
|
|
config=config,
|
|
early_stopping_patience=10,
|
|
max_epochs=10, # epochs 10 as specified
|
|
checkpoint_dir="./checkpoints",
|
|
log_interval=50
|
|
)
|
|
|
|
print(f"Training completed for {project_name}")
|
|
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
|
|
|
|
return model, metrics
|
|
|
|
except Exception as e:
|
|
print(f"Error training {project_name}: {e}")
|
|
return None, None
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train TimesNet on ETT datasets')
|
|
parser.add_argument('--data_dir', type=str, default='processed_data',
|
|
help='Directory containing processed NPZ files')
|
|
parser.add_argument('--datasets', nargs='+', default=['ETTm1', 'ETTm2'],
|
|
help='List of datasets to train on')
|
|
parser.add_argument('--pred_lengths', nargs='+', type=int, default=[96, 192, 336, 720],
|
|
help='List of prediction lengths to train on')
|
|
parser.add_argument('--seq_len', type=int, default=96,
|
|
help='Input sequence length')
|
|
parser.add_argument('--device', type=str, default=None,
|
|
help='Device to use for training (cuda/cpu)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("TimesNet Training Script")
|
|
print("=" * 50)
|
|
print(f"Datasets: {args.datasets}")
|
|
print(f"Prediction lengths: {args.pred_lengths}")
|
|
print(f"Input sequence length: {args.seq_len}")
|
|
print(f"Data directory: {args.data_dir}")
|
|
|
|
# Check if data directory exists
|
|
if not os.path.exists(args.data_dir):
|
|
print(f"Error: Data directory '{args.data_dir}' not found!")
|
|
return
|
|
|
|
# Set device
|
|
if args.device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
else:
|
|
device = args.device
|
|
print(f"Using device: {device}")
|
|
|
|
# Training results storage
|
|
all_results = {}
|
|
|
|
# Train on each dataset and prediction length combination
|
|
for dataset in args.datasets:
|
|
all_results[dataset] = {}
|
|
|
|
for pred_len in args.pred_lengths:
|
|
# Construct data file path
|
|
data_file = f"{dataset}_input{args.seq_len}_pred{pred_len}.npz"
|
|
data_path = os.path.join(args.data_dir, data_file)
|
|
|
|
# Check if data file exists
|
|
if not os.path.exists(data_path):
|
|
print(f"Warning: Data file '{data_path}' not found, skipping...")
|
|
continue
|
|
|
|
# Load data to get input dimensions
|
|
import numpy as np
|
|
data = np.load(data_path, allow_pickle=True)
|
|
enc_in = data['train_x'].shape[-1] # Number of features/channels
|
|
print("输入数据通道数:", enc_in)
|
|
c_out = enc_in # Output same number of channels
|
|
|
|
# Create model configuration
|
|
model_args = Args(
|
|
seq_len=args.seq_len,
|
|
pred_len=pred_len,
|
|
enc_in=enc_in,
|
|
c_out=c_out
|
|
)
|
|
|
|
# Train the model
|
|
model, metrics = train_single_dataset(
|
|
data_path=data_path,
|
|
dataset_name=dataset,
|
|
pred_len=pred_len,
|
|
args=model_args
|
|
)
|
|
|
|
# Store results
|
|
all_results[dataset][pred_len] = {
|
|
'model': model,
|
|
'metrics': metrics,
|
|
'data_path': data_path
|
|
}
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 80)
|
|
print("TRAINING SUMMARY")
|
|
print("=" * 80)
|
|
|
|
for dataset in all_results:
|
|
print(f"\n{dataset}:")
|
|
for pred_len in all_results[dataset]:
|
|
result = all_results[dataset][pred_len]
|
|
if result['metrics'] is not None:
|
|
mse = result['metrics'].get('final_val_mse', 'N/A')
|
|
print(f" Pred Length {pred_len}: MSE = {mse}")
|
|
else:
|
|
print(f" Pred Length {pred_len}: Training failed")
|
|
|
|
print(f"\nAll models saved in: ./checkpoints/")
|
|
print("Training completed!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|