first timesnet try

This commit is contained in:
game-loader
2025-07-30 21:18:46 +08:00
parent dc8c9f1f09
commit 6ee5c769c4
17 changed files with 2918 additions and 0 deletions

630
train/train.py Normal file
View File

@ -0,0 +1,630 @@
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): How long to wait after last improvement. Default: 7
verbose (bool): If True, prints a message for each improvement. Default: False
delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0
path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
Returns:
Dict[str, DataLoader]: Dictionary with train and val DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
train_x = data['train_x']
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
# Load time features if available
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
train_x_mark = torch.FloatTensor(train_x_mark)
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
return {
'train': train_loader,
'val': val_loader
}
def train_forecasting_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
early_stopping_patience (int): Number of epochs to wait before early stopping
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32)
)
# Construct the model
model = model_constructor()
model = model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"Epoch {epoch+1}/{max_epochs}")
# Training phase
model.train()
print("1\n")
train_loss = 0.0
# 用于记录 log_interval 期间的损失
interval_loss = 0.0
start_time = time.time()
for batch_idx, batch_data in enumerate(dataloaders['train']):
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
# Create decoder input (zeros for forecasting)
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Update statistics
train_loss += loss.item()
interval_loss += loss.item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
# 计算这一个 interval 的平均损失并记录
avg_interval_loss = interval_loss / log_interval
swanlab_run.log({"batch_train_loss": avg_interval_loss})
# 重置 interval loss 以进行下一次计算
interval_loss = 0.0
avg_train_loss = train_loss / len(dataloaders['train'])
epoch_time = time.time() - start_time
# Validation phase
model.eval()
val_loss = 0.0
val_mse = 0.0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
avg_val_loss = val_loss / len(dataloaders['val'])
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
swanlab_run.log(metrics_dict)
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
# Check if we should save the model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
metrics = metrics_dict
# Early stopping
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Final validation
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.4f}")
# Update metrics with final values
metrics["final_val_loss"] = final_val_loss
# Finish the swanlab run
swanlab_run.finish()
return model, metrics
def train_classification_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
early_stopping_patience (int): Number of epochs to wait before early stopping
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32)
)
# Construct the model
model = model_constructor()
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
weight_decay=config.get('weight_decay', 1e-4)
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"Epoch {epoch+1}/{max_epochs}")
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
start_time = time.time()
for batch_idx, batch_data in enumerate(dataloaders['train']):
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
# Convert targets to long for classification
targets = targets.long()
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Update statistics
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
avg_train_loss = train_loss / len(dataloaders['train'])
train_accuracy = 100. * train_correct / train_total
epoch_time = time.time() - start_time
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
targets = targets.long()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate accuracy
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
avg_val_loss = val_loss / len(dataloaders['val'])
val_accuracy = 100. * val_correct / val_total
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_accuracy": val_accuracy,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
swanlab_run.log(metrics_dict)
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val Accuracy: {val_accuracy:.2f}%, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
# Check if we should save the model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
metrics = metrics_dict
# Early stopping
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Final validation
model.eval()
final_val_loss = 0.0
final_val_correct = 0
final_val_total = 0
with torch.no_grad():
for batch_data in dataloaders['val']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = None, None
targets = targets.long()
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
# Calculate accuracy
_, predicted = outputs.max(1)
final_val_total += targets.size(0)
final_val_correct += predicted.eq(targets).sum().item()
final_val_loss /= len(dataloaders['val'])
final_val_accuracy = 100. * final_val_correct / final_val_total
print(f"Final validation loss: {final_val_loss:.4f}")
print(f"Final validation accuracy: {final_val_accuracy:.2f}%")
# Update metrics with final values
metrics["final_val_loss"] = final_val_loss
metrics["final_val_accuracy"] = final_val_accuracy
# Finish the swanlab run
swanlab_run.finish()
return model, metrics
def main():
# Example usage
data_path = 'data/train_data.npz'
project_name = 'TimeSeriesForecasting'
config = {
'learning_rate': 0.001,
'batch_size': 32,
'weight_decay': 1e-4
}
model_constructor = lambda: nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
model, metrics = train_forecasting_model(
model_constructor=model_constructor,
data_path=data_path,
project_name=project_name,
config=config
)
if __name__ == "__main__":
main()