Files
tsmodel/train/train.py

910 lines
34 KiB
Python

import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
from layers.ps_loss import PSLoss
from utils.tools import adjust_learning_rate, dotdict
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): How long to wait after last improvement. Default: 7
verbose (bool): If True, prints a message for each improvement. Default: False
delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0
path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
# Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark)
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
# Return only seq_x and seq_y (remove time features)
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders using dataflow data_provider
Args:
args: Arguments object containing dataset configuration
Required attributes: data, root_path, data_path, seq_len, label_len, pred_len,
features, target, embed, freq, batch_size, num_workers, train_only
use_x_mark (bool): Whether to use time features (x_mark and y_mark)
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Create datasets and dataloaders for each split
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
# Wrap datasets to respect use_x_mark parameter
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
# Determine batch size and other parameters based on flag
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
# Create new dataloaders with wrapped datasets
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=train_shuffle,
num_workers=num_workers,
drop_last=train_drop_last
)
val_loader = DataLoader(
val_data,
batch_size=batch_size,
shuffle=val_shuffle,
num_workers=num_workers,
drop_last=val_drop_last
)
test_loader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=test_shuffle,
num_workers=num_workers,
drop_last=test_drop_last
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
num_workers: int = 4, pin_memory: bool = True,
persistent_workers: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
num_workers (int): Number of worker processes for data loading
pin_memory (bool): Whether to pin memory for faster GPU transfer
persistent_workers (bool): Whether to keep workers alive between epochs
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
train_x = data['train_x']
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available and needed
if use_x_mark:
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
train_x_mark = torch.FloatTensor(train_x_mark)
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders with performance optimizations
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=True # Drop incomplete batches for training
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def train_forecasting_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None,
use_ps_loss: bool = False,
ps_lambda: float = 5.0,
patch_len_threshold: int = 64,
use_gdw: bool = True,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
early_stopping_patience (int): Number of epochs to wait before early stopping
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
ps_lambda (float): Weight for PS loss component when combined with MSE
patch_len_threshold (int): Maximum patch length for adaptive patching
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders based on dataset_mode
if dataset_mode == "dataflow":
if dataflow_args is None:
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
model = model_constructor()
model = model.to(device)
# Define loss function and optimizer
if use_ps_loss:
criterion = PSLoss(
patch_len_threshold=patch_len_threshold,
lambda_ps=ps_lambda,
use_gdw=use_gdw
)
else:
criterion = nn.MSELoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"Epoch {epoch+1}/{max_epochs}")
# Training phase
model.train()
print("1\n")
train_loss = 0.0
# 用于记录 log_interval 期间的损失
interval_loss = 0.0
start_time = time.time()
for batch_idx, batch_data in enumerate(dataloaders['train']):
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
# Create decoder input (zeros for forecasting)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
if use_ps_loss:
loss, loss_dict = criterion(outputs, targets, model)
else:
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Update statistics
train_loss += loss.item()
interval_loss += loss.item()
if (batch_idx + 1) % log_interval == 0:
if use_ps_loss and 'loss_dict' in locals():
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
f"Total Loss: {loss.item():.4f}, "
f"MSE: {loss_dict['mse_loss']:.4f}, "
f"PS: {loss_dict['ps_loss']:.4f}")
# Log detailed loss components
swanlab_run.log({
"batch_total_loss": loss.item(),
"batch_mse_loss": loss_dict['mse_loss'],
"batch_ps_loss": loss_dict['ps_loss'],
"batch_corr_loss": loss_dict['corr_loss'],
"batch_var_loss": loss_dict['var_loss'],
"batch_mean_loss": loss_dict['mean_loss'],
"alpha": loss_dict['alpha'],
"beta": loss_dict['beta'],
"gamma": loss_dict['gamma']
})
else:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
swanlab_run.log({"batch_train_loss": loss.item()})
# 重置 interval loss 以进行下一次计算
interval_loss = 0.0
avg_train_loss = train_loss / len(dataloaders['train'])
epoch_time = time.time() - start_time
# Validation phase
model.eval()
val_loss = 0.0
val_mse = 0.0
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate training loss (PS or MSE)
if use_ps_loss:
loss, _ = criterion(outputs, targets, model)
val_loss += loss.item()
else:
loss = criterion(outputs, targets)
val_loss += loss.item()
# Always calculate MSE for validation metrics
mse_loss = val_mse_criterion(outputs, targets)
val_mse += mse_loss.item()
avg_val_loss = val_loss / len(dataloaders['val'])
avg_val_mse = val_mse / len(dataloaders['val'])
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_mse": avg_val_mse,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
swanlab_run.log(metrics_dict)
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val MSE: {avg_val_mse:.4f}, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
# Check if we should save the model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
metrics = metrics_dict
# Early stopping
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered")
break
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Test evaluation on the best model - Always use MSE for final evaluation
model.eval()
test_loss = 0.0
test_mse = 0.0
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
print("Evaluating on test set...")
with torch.no_grad():
for batch_data in dataloaders['test']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Always calculate MSE for test evaluation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
test_loss += mse_loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency - Always use MSE for final metrics
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Always calculate MSE for final validation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
final_val_loss += mse_loss.item()
final_val_loss /= len(dataloaders['val'])
print(f"Final validation MSE: {final_val_loss:.6f}")
print(f"Final test MSE: {test_loss:.6f}")
if use_ps_loss:
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
# Log final test results to swanlab
final_metrics = {
"final_test_mse": test_loss,
"final_val_mse": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values (always MSE for comparison)
metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
# Finish the swanlab run
swanlab_run.finish()
return model, metrics
def train_classification_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
early_stopping_patience (int): Number of epochs to wait before early stopping
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders based on dataset_mode
if dataset_mode == "dataflow":
if dataflow_args is None:
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
model = model_constructor()
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
weight_decay=config.get('weight_decay', 1e-4)
)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"Epoch {epoch+1}/{max_epochs}")
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
start_time = time.time()
for batch_idx, batch_data in enumerate(dataloaders['train']):
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Convert targets to long for classification
targets = targets.long()
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Update statistics
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
avg_train_loss = train_loss / len(dataloaders['train'])
train_accuracy = 100. * train_correct / train_total
epoch_time = time.time() - start_time
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
targets = targets.long()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate accuracy
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
avg_val_loss = val_loss / len(dataloaders['val'])
val_accuracy = 100. * val_correct / val_total
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_accuracy": val_accuracy,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
swanlab_run.log(metrics_dict)
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val Accuracy: {val_accuracy:.2f}%, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
# Check if we should save the model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
metrics = metrics_dict
# Early stopping
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered")
break
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Final validation
model.eval()
final_val_loss = 0.0
final_val_correct = 0
final_val_total = 0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
targets = targets.long()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
# Calculate accuracy
_, predicted = outputs.max(1)
final_val_total += targets.size(0)
final_val_correct += predicted.eq(targets).sum().item()
final_val_loss /= len(dataloaders['val'])
final_val_accuracy = 100. * final_val_correct / final_val_total
print(f"Final validation loss: {final_val_loss:.4f}")
print(f"Final validation accuracy: {final_val_accuracy:.2f}%")
# Update metrics with final values
metrics["final_val_loss"] = final_val_loss
metrics["final_val_accuracy"] = final_val_accuracy
# Finish the swanlab run
swanlab_run.finish()
return model, metrics
def main():
# Example usage
data_path = 'data/train_data.npz'
project_name = 'TimeSeriesForecasting'
config = {
'learning_rate': 0.001,
'batch_size': 32,
'weight_decay': 1e-4
}
model_constructor = lambda: nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
model, metrics = train_forecasting_model(
model_constructor=model_constructor,
data_path=data_path,
project_name=project_name,
config=config
)
if __name__ == "__main__":
main()