Files
tsmodel/train/train_diffusion.py

409 lines
14 KiB
Python

import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""Create PyTorch DataLoaders using dataflow data_provider"""
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=train_shuffle,
num_workers=num_workers, drop_last=train_drop_last
)
val_loader = DataLoader(
val_data, batch_size=batch_size, shuffle=val_shuffle,
num_workers=num_workers, drop_last=val_drop_last
)
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=test_shuffle,
num_workers=num_workers, drop_last=test_drop_last
)
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
train_x = data['train_x']
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available and needed
if use_x_mark:
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
train_x_mark = torch.FloatTensor(train_x_mark)
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def train_diffusion_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10,
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a Diffusion time series forecasting model using NPZ data loading
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders using NPZ files (following other models' pattern)
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=False # DiffusionTimeSeries doesn't use time features
)
# Construct the model
model = model_constructor()
model = model.to(device)
print(f"Model created with {model.get_num_params():,} parameters")
# Define optimizer for diffusion training
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-4),
weight_decay=config.get('weight_decay', 1e-4)
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=5, factor=0.5
)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"\nEpoch {epoch+1}/{max_epochs}")
print("-" * 50)
# Training phase
model.train()
train_loss = 0.0
train_samples = 0
interval_loss = 0.0
start_time = time.time()
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
optimizer.zero_grad()
# Diffusion training: model returns loss directly when y is provided
loss = model(seq_x, seq_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
interval_loss += loss.item()
train_samples += 1
# Log at intervals
if (batch_idx + 1) % log_interval == 0:
elapsed_time = time.time() - start_time
avg_interval_loss = interval_loss / log_interval
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
f'Loss: {avg_interval_loss:.6f} '
f'Time: {elapsed_time:.2f}s')
# Log to swanlab
swanlab.log({
'batch_loss': avg_interval_loss,
'batch': epoch * len(dataloaders['train']) + batch_idx,
'learning_rate': optimizer.param_groups[0]['lr']
})
interval_loss = 0.0
start_time = time.time()
avg_train_loss = train_loss / train_samples
# Validation phase - Use faster sampling for validation
model.eval()
val_loss = 0.0
val_samples = 0
criterion = nn.MSELoss()
print(" Validating...")
with torch.no_grad():
# Temporarily reduce diffusion steps for faster validation
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200# Much faster validation
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
# Generate predictions (inference mode with reduced steps)
pred = model(seq_x)
# Compute MSE loss for validation
loss = criterion(pred, seq_y)
val_loss += loss.item()
val_samples += 1
# Print validation progress for first epoch
if epoch == 0 and (batch_idx + 1) % 50 == 0:
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
# Early break for very first epoch to speed up
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
break
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_val_loss = val_loss / val_samples
# Learning rate scheduling
scheduler.step(avg_val_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f" Train Loss: {avg_train_loss:.6f}")
print(f" Val Loss: {avg_val_loss:.6f}")
print(f" Learning Rate: {current_lr:.2e}")
# Log to swanlab
swanlab.log({
'epoch': epoch + 1,
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
'learning_rate': current_lr
})
# Early stopping check
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print(f"Early stopping at epoch {epoch + 1}")
break
# Load best model
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# Final evaluation on test set
print("\nEvaluating on test set...")
model.eval()
test_loss = 0.0
test_samples = 0
all_preds = []
all_targets = []
with torch.no_grad():
# Use reduced timesteps for faster testing
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200 # Faster but still good quality
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
pred = model(seq_x)
loss = criterion(pred, seq_y)
test_loss += loss.item()
test_samples += 1
all_preds.append(pred.cpu().numpy())
all_targets.append(seq_y.cpu().numpy())
# Print progress every 50 batches
if (batch_idx + 1) % 50 == 0:
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_test_loss = test_loss / test_samples
# Calculate additional metrics
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
mse = np.mean((all_preds - all_targets) ** 2)
mae = np.mean(np.abs(all_preds - all_targets))
rmse = np.sqrt(mse)
metrics = {
'test_mse': mse,
'test_mae': mae,
'test_rmse': rmse,
'test_loss': avg_test_loss
}
print(f"Test Results:")
print(f" MSE: {mse:.6f}")
print(f" MAE: {mae:.6f}")
print(f" RMSE: {rmse:.6f}")
# Log final results
swanlab.log(metrics)
swanlab.finish()
return model, metrics