first timesnet try
This commit is contained in:
209
train_test.py
Normal file
209
train_test.py
Normal file
@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training script for TimesNet model on ETT datasets.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from train.train import train_forecasting_model
|
||||
from models.TimesNet.TimesNet import Model as TimesNet
|
||||
|
||||
class Args:
|
||||
"""Configuration class for TimesNet model parameters."""
|
||||
def __init__(self, seq_len, pred_len, enc_in, c_out):
|
||||
# Model architecture parameters
|
||||
self.task_name = 'long_term_forecast'
|
||||
self.seq_len = seq_len
|
||||
self.label_len = seq_len // 2 # Half of seq_len as label length
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.c_out = c_out
|
||||
|
||||
# TimesNet specific parameters
|
||||
self.top_k = 5 # k parameter as specified
|
||||
self.e_layers = 2 # Number of layers as specified
|
||||
self.d_min = 32 # dmin as specified
|
||||
self.d_max = 512 # dmax as specified
|
||||
|
||||
# Calculate d_model based on the formula: min{max{2*⌈log C⌉, dmin}, dmax}
|
||||
log_c = math.ceil(math.log2(enc_in)) if enc_in > 1 else 1
|
||||
# self.d_model = min(max(2 * log_c, self.d_min), self.d_max)
|
||||
self.d_model = 64
|
||||
|
||||
# Other model parameters
|
||||
self.d_ff = 64 # Standard transformer ratio
|
||||
self.num_kernels = 6 # For Inception blocks
|
||||
self.embed = 'timeF' # Time feature embedding type
|
||||
self.freq = 'h' # Frequency for time features (minutely for ETT)
|
||||
self.dropout = 0.1
|
||||
|
||||
print(f"Model configuration:")
|
||||
print(f" - Input channels (C): {enc_in}")
|
||||
print(f" - d_model: {self.d_model} (calculated from 2*⌈log₂({enc_in})⌉ = {2 * log_c})")
|
||||
print(f" - Sequence length: {seq_len}")
|
||||
print(f" - Prediction length: {pred_len}")
|
||||
print(f" - Top-k: {self.top_k}")
|
||||
print(f" - Layers: {self.e_layers}")
|
||||
|
||||
def create_timesnet_model(args):
|
||||
"""Create TimesNet model with given configuration."""
|
||||
def model_constructor():
|
||||
return TimesNet(args)
|
||||
return model_constructor
|
||||
|
||||
def train_single_dataset(data_path, dataset_name, pred_len, args):
|
||||
"""Train TimesNet on a single dataset configuration."""
|
||||
|
||||
# Update args for current prediction length
|
||||
args.pred_len = pred_len
|
||||
|
||||
# Create model constructor
|
||||
model_constructor = create_timesnet_model(args)
|
||||
|
||||
# Training configuration
|
||||
config = {
|
||||
'learning_rate': 1e-5, # LR = 10^-4 as specified
|
||||
'batch_size': 32, # BatchSize 32 as specified
|
||||
'weight_decay': 1e-4,
|
||||
'dataset': dataset_name,
|
||||
'pred_len': pred_len,
|
||||
'seq_len': args.seq_len,
|
||||
'd_model': args.d_model,
|
||||
'top_k': args.top_k,
|
||||
'e_layers': args.e_layers
|
||||
}
|
||||
|
||||
# Project name for tracking
|
||||
project_name = f"TimesNet_{dataset_name}_pred{pred_len}"
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Training {dataset_name} with prediction length {pred_len}")
|
||||
print(f"Data path: {data_path}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Train the model
|
||||
try:
|
||||
model, metrics = train_forecasting_model(
|
||||
model_constructor=model_constructor,
|
||||
data_path=data_path,
|
||||
project_name=project_name,
|
||||
config=config,
|
||||
early_stopping_patience=10,
|
||||
max_epochs=10, # epochs 10 as specified
|
||||
checkpoint_dir="./checkpoints",
|
||||
log_interval=50
|
||||
)
|
||||
|
||||
print(f"Training completed for {project_name}")
|
||||
print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error training {project_name}: {e}")
|
||||
return None, None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train TimesNet on ETT datasets')
|
||||
parser.add_argument('--data_dir', type=str, default='processed_data',
|
||||
help='Directory containing processed NPZ files')
|
||||
parser.add_argument('--datasets', nargs='+', default=['ETTm1', 'ETTm2'],
|
||||
help='List of datasets to train on')
|
||||
parser.add_argument('--pred_lengths', nargs='+', type=int, default=[96, 192, 336, 720],
|
||||
help='List of prediction lengths to train on')
|
||||
parser.add_argument('--seq_len', type=int, default=96,
|
||||
help='Input sequence length')
|
||||
parser.add_argument('--device', type=str, default=None,
|
||||
help='Device to use for training (cuda/cpu)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("TimesNet Training Script")
|
||||
print("=" * 50)
|
||||
print(f"Datasets: {args.datasets}")
|
||||
print(f"Prediction lengths: {args.pred_lengths}")
|
||||
print(f"Input sequence length: {args.seq_len}")
|
||||
print(f"Data directory: {args.data_dir}")
|
||||
|
||||
# Check if data directory exists
|
||||
if not os.path.exists(args.data_dir):
|
||||
print(f"Error: Data directory '{args.data_dir}' not found!")
|
||||
return
|
||||
|
||||
# Set device
|
||||
if args.device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Training results storage
|
||||
all_results = {}
|
||||
|
||||
# Train on each dataset and prediction length combination
|
||||
for dataset in args.datasets:
|
||||
all_results[dataset] = {}
|
||||
|
||||
for pred_len in args.pred_lengths:
|
||||
# Construct data file path
|
||||
data_file = f"{dataset}_input{args.seq_len}_pred{pred_len}.npz"
|
||||
data_path = os.path.join(args.data_dir, data_file)
|
||||
|
||||
# Check if data file exists
|
||||
if not os.path.exists(data_path):
|
||||
print(f"Warning: Data file '{data_path}' not found, skipping...")
|
||||
continue
|
||||
|
||||
# Load data to get input dimensions
|
||||
import numpy as np
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
enc_in = data['train_x'].shape[-1] # Number of features/channels
|
||||
print("输入数据通道数:", enc_in)
|
||||
c_out = enc_in # Output same number of channels
|
||||
|
||||
# Create model configuration
|
||||
model_args = Args(
|
||||
seq_len=args.seq_len,
|
||||
pred_len=pred_len,
|
||||
enc_in=enc_in,
|
||||
c_out=c_out
|
||||
)
|
||||
|
||||
# Train the model
|
||||
model, metrics = train_single_dataset(
|
||||
data_path=data_path,
|
||||
dataset_name=dataset,
|
||||
pred_len=pred_len,
|
||||
args=model_args
|
||||
)
|
||||
|
||||
# Store results
|
||||
all_results[dataset][pred_len] = {
|
||||
'model': model,
|
||||
'metrics': metrics,
|
||||
'data_path': data_path
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print("TRAINING SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
for dataset in all_results:
|
||||
print(f"\n{dataset}:")
|
||||
for pred_len in all_results[dataset]:
|
||||
result = all_results[dataset][pred_len]
|
||||
if result['metrics'] is not None:
|
||||
mse = result['metrics'].get('final_val_mse', 'N/A')
|
||||
print(f" Pred Length {pred_len}: MSE = {mse}")
|
||||
else:
|
||||
print(f" Pred Length {pred_len}: Training failed")
|
||||
|
||||
print(f"\nAll models saved in: ./checkpoints/")
|
||||
print("Training completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user