feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel
This commit is contained in:
177
train/train.py
177
train/train.py
@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
from dataflow import data_provider
|
||||
from layers.ps_loss import PSLoss
|
||||
from utils.tools import adjust_learning_rate, dotdict
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
|
||||
num_workers: int = 4, pin_memory: bool = True,
|
||||
persistent_workers: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
num_workers (int): Number of worker processes for data loading
|
||||
pin_memory (bool): Whether to pin memory for faster GPU transfer
|
||||
persistent_workers (bool): Whether to keep workers alive between epochs
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
test_dataset = TensorDataset(test_x, test_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
# Create dataloaders with performance optimizations
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=True # Drop incomplete batches for training
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=False
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers if num_workers > 0 else False,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
@ -223,7 +254,12 @@ def train_forecasting_model(
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
dataflow_args = None,
|
||||
use_ps_loss: bool = False,
|
||||
ps_lambda: float = 5.0,
|
||||
patch_len_threshold: int = 64,
|
||||
use_gdw: bool = True,
|
||||
lr_adjust_strategy: str = "type1"
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series forecasting model
|
||||
@ -241,6 +277,11 @@ def train_forecasting_model(
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
|
||||
ps_lambda (float): Weight for PS loss component when combined with MSE
|
||||
patch_len_threshold (int): Maximum patch length for adaptive patching
|
||||
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
|
||||
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -271,7 +312,10 @@ def train_forecasting_model(
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
use_x_mark=use_x_mark,
|
||||
num_workers=config.get('num_workers', 4),
|
||||
pin_memory=config.get('pin_memory', True),
|
||||
persistent_workers=config.get('persistent_workers', True)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -279,14 +323,24 @@ def train_forecasting_model(
|
||||
model = model.to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.MSELoss()
|
||||
if use_ps_loss:
|
||||
criterion = PSLoss(
|
||||
patch_len_threshold=patch_len_threshold,
|
||||
lambda_ps=ps_lambda,
|
||||
use_gdw=use_gdw
|
||||
)
|
||||
else:
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-3),
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
||||
# Create args object for learning rate adjustment
|
||||
lr_args = dotdict({
|
||||
'learning_rate': config.get('learning_rate', 1e-3),
|
||||
'lradj': lr_adjust_strategy
|
||||
})
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
@ -334,7 +388,11 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
# Calculate loss
|
||||
if use_ps_loss:
|
||||
loss, loss_dict = criterion(outputs, targets, model)
|
||||
else:
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
@ -345,10 +403,26 @@ def train_forecasting_model(
|
||||
interval_loss += loss.item()
|
||||
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
# 计算这一个 interval 的平均损失并记录
|
||||
avg_interval_loss = interval_loss / log_interval
|
||||
swanlab_run.log({"batch_train_loss": avg_interval_loss})
|
||||
if use_ps_loss and 'loss_dict' in locals():
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
|
||||
f"Total Loss: {loss.item():.4f}, "
|
||||
f"MSE: {loss_dict['mse_loss']:.4f}, "
|
||||
f"PS: {loss_dict['ps_loss']:.4f}")
|
||||
# Log detailed loss components
|
||||
swanlab_run.log({
|
||||
"batch_total_loss": loss.item(),
|
||||
"batch_mse_loss": loss_dict['mse_loss'],
|
||||
"batch_ps_loss": loss_dict['ps_loss'],
|
||||
"batch_corr_loss": loss_dict['corr_loss'],
|
||||
"batch_var_loss": loss_dict['var_loss'],
|
||||
"batch_mean_loss": loss_dict['mean_loss'],
|
||||
"alpha": loss_dict['alpha'],
|
||||
"beta": loss_dict['beta'],
|
||||
"gamma": loss_dict['gamma']
|
||||
})
|
||||
else:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
swanlab_run.log({"batch_train_loss": loss.item()})
|
||||
|
||||
# 重置 interval loss 以进行下一次计算
|
||||
interval_loss = 0.0
|
||||
@ -360,6 +434,7 @@ def train_forecasting_model(
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_mse = 0.0
|
||||
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
@ -381,18 +456,28 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
# Calculate training loss (PS or MSE)
|
||||
if use_ps_loss:
|
||||
loss, _ = criterion(outputs, targets, model)
|
||||
val_loss += loss.item()
|
||||
else:
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
|
||||
# Always calculate MSE for validation metrics
|
||||
mse_loss = val_mse_criterion(outputs, targets)
|
||||
val_mse += mse_loss.item()
|
||||
|
||||
|
||||
avg_val_loss = val_loss / len(dataloaders['val'])
|
||||
avg_val_mse = val_mse / len(dataloaders['val'])
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log metrics
|
||||
metrics_dict = {
|
||||
"train_loss": avg_train_loss,
|
||||
"val_loss": avg_val_loss,
|
||||
"val_mse": avg_val_mse,
|
||||
"learning_rate": current_lr,
|
||||
"epoch_time": epoch_time
|
||||
}
|
||||
@ -402,6 +487,7 @@ def train_forecasting_model(
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, "
|
||||
f"Train Loss: {avg_train_loss:.4f}, "
|
||||
f"Val Loss: {avg_val_loss:.4f}, "
|
||||
f"Val MSE: {avg_val_mse:.4f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
@ -416,16 +502,17 @@ def train_forecasting_model(
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
# Adjust learning rate using utils.tools function
|
||||
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
# Test evaluation on the best model
|
||||
# Test evaluation on the best model - Always use MSE for final evaluation
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_mse = 0.0
|
||||
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
|
||||
|
||||
print("Evaluating on test set...")
|
||||
with torch.no_grad():
|
||||
@ -448,16 +535,16 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
# Always calculate MSE for test evaluation (for fair comparison)
|
||||
mse_loss = mse_criterion(outputs, targets)
|
||||
test_loss += mse_loss.item()
|
||||
|
||||
test_loss /= len(dataloaders['test'])
|
||||
|
||||
print(f"Test evaluation completed!")
|
||||
print(f"Test Loss (MSE): {test_loss:.6f}")
|
||||
|
||||
# Final validation for consistency
|
||||
# Final validation for consistency - Always use MSE for final metrics
|
||||
model.eval()
|
||||
final_val_loss = 0.0
|
||||
final_val_mse = 0.0
|
||||
@ -482,25 +569,31 @@ def train_forecasting_model(
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
final_val_loss += loss.item()
|
||||
# Always calculate MSE for final validation (for fair comparison)
|
||||
mse_loss = mse_criterion(outputs, targets)
|
||||
final_val_loss += mse_loss.item()
|
||||
|
||||
|
||||
final_val_loss /= len(dataloaders['val'])
|
||||
|
||||
print(f"Final validation loss: {final_val_loss:.6f}")
|
||||
print(f"Final validation MSE: {final_val_loss:.6f}")
|
||||
print(f"Final test MSE: {test_loss:.6f}")
|
||||
|
||||
if use_ps_loss:
|
||||
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
|
||||
|
||||
# Log final test results to swanlab
|
||||
final_metrics = {
|
||||
"final_test_loss": test_loss,
|
||||
"final_val_loss": final_val_loss
|
||||
"final_test_mse": test_loss,
|
||||
"final_val_mse": final_val_loss
|
||||
}
|
||||
swanlab_run.log(final_metrics)
|
||||
|
||||
# Update metrics with final values
|
||||
# Update metrics with final values (always MSE for comparison)
|
||||
metrics["final_val_loss"] = final_val_loss
|
||||
metrics["final_test_loss"] = test_loss
|
||||
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
|
||||
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
|
||||
|
||||
# Finish the swanlab run
|
||||
swanlab_run.finish()
|
||||
@ -519,7 +612,8 @@ def train_classification_model(
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
dataflow_args = None,
|
||||
lr_adjust_strategy: str = "type1"
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series classification model
|
||||
@ -537,6 +631,7 @@ def train_classification_model(
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -567,7 +662,10 @@ def train_classification_model(
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
use_x_mark=use_x_mark,
|
||||
num_workers=config.get('num_workers', 4),
|
||||
pin_memory=config.get('pin_memory', True),
|
||||
persistent_workers=config.get('persistent_workers', True)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -582,8 +680,11 @@ def train_classification_model(
|
||||
weight_decay=config.get('weight_decay', 1e-4)
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
|
||||
# Create args object for learning rate adjustment
|
||||
lr_args = dotdict({
|
||||
'learning_rate': config.get('learning_rate', 1e-3),
|
||||
'lradj': lr_adjust_strategy
|
||||
})
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
@ -722,8 +823,8 @@ def train_classification_model(
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
# Adjust learning rate using utils.tools function
|
||||
adjust_learning_rate(optimizer, epoch, lr_args)
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
408
train/train_diffusion.py
Normal file
408
train/train_diffusion.py
Normal file
@ -0,0 +1,408 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
from dataflow import data_provider
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = float('inf')
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
if self.verbose:
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
"""Save model when validation loss decreases."""
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||
torch.save(model.state_dict(), self.path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
|
||||
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
|
||||
def __init__(self, original_dataset):
|
||||
self.original_dataset = original_dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
|
||||
return seq_x, seq_y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.original_dataset)
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if hasattr(self.original_dataset, 'inverse_transform'):
|
||||
return self.original_dataset.inverse_transform(data)
|
||||
return data
|
||||
|
||||
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""Create PyTorch DataLoaders using dataflow data_provider"""
|
||||
train_data, _ = data_provider(args, flag='train')
|
||||
val_data, _ = data_provider(args, flag='val')
|
||||
test_data, _ = data_provider(args, flag='test')
|
||||
|
||||
if not use_x_mark:
|
||||
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
|
||||
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
|
||||
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
|
||||
|
||||
train_shuffle = True
|
||||
val_shuffle = False
|
||||
test_shuffle = False
|
||||
|
||||
train_drop_last = True
|
||||
val_drop_last = True
|
||||
test_drop_last = True
|
||||
|
||||
batch_size = args.batch_size
|
||||
num_workers = args.num_workers
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_data, batch_size=batch_size, shuffle=train_shuffle,
|
||||
num_workers=num_workers, drop_last=train_drop_last
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_data, batch_size=batch_size, shuffle=val_shuffle,
|
||||
num_workers=num_workers, drop_last=val_drop_last
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_data, batch_size=batch_size, shuffle=test_shuffle,
|
||||
num_workers=num_workers, drop_last=test_drop_last
|
||||
)
|
||||
|
||||
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
Args:
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
"""
|
||||
# Load data from NPZ file
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
train_x = data['train_x']
|
||||
train_y = data['train_y']
|
||||
val_x = data['val_x']
|
||||
val_y = data['val_y']
|
||||
test_x = data['test_x']
|
||||
test_y = data['test_y']
|
||||
|
||||
# Load time features if available and needed
|
||||
if use_x_mark:
|
||||
train_x_mark = data.get('train_x_mark', None)
|
||||
train_y_mark = data.get('train_y_mark', None)
|
||||
val_x_mark = data.get('val_x_mark', None)
|
||||
val_y_mark = data.get('val_y_mark', None)
|
||||
test_x_mark = data.get('test_x_mark', None)
|
||||
test_y_mark = data.get('test_y_mark', None)
|
||||
else:
|
||||
train_x_mark = None
|
||||
train_y_mark = None
|
||||
val_x_mark = None
|
||||
val_y_mark = None
|
||||
test_x_mark = None
|
||||
test_y_mark = None
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
train_x = torch.FloatTensor(train_x)
|
||||
train_y = torch.FloatTensor(train_y)
|
||||
val_x = torch.FloatTensor(val_x)
|
||||
val_y = torch.FloatTensor(val_y)
|
||||
test_x = torch.FloatTensor(test_x)
|
||||
test_y = torch.FloatTensor(test_y)
|
||||
|
||||
# Create datasets based on whether time features are available
|
||||
if train_x_mark is not None:
|
||||
train_x_mark = torch.FloatTensor(train_x_mark)
|
||||
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||
test_x_mark = torch.FloatTensor(test_x_mark)
|
||||
test_y_mark = torch.FloatTensor(test_y_mark)
|
||||
|
||||
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
|
||||
else:
|
||||
train_dataset = TensorDataset(train_x, train_y)
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
test_dataset = TensorDataset(test_x, test_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
'val': val_loader,
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def train_diffusion_model(
|
||||
model_constructor: Callable,
|
||||
data_path: str,
|
||||
project_name: str,
|
||||
config: Dict[str, Any],
|
||||
device: Optional[str] = None,
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10,
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a Diffusion time series forecasting model using NPZ data loading
|
||||
"""
|
||||
# Setup device
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Initialize swanlab for experiment tracking
|
||||
swanlab_run = swanlab.init(
|
||||
project=project_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders using NPZ files (following other models' pattern)
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=False # DiffusionTimeSeries doesn't use time features
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
model = model_constructor()
|
||||
model = model.to(device)
|
||||
|
||||
print(f"Model created with {model.get_num_params():,} parameters")
|
||||
|
||||
# Define optimizer for diffusion training
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-4),
|
||||
weight_decay=config.get('weight_decay', 1e-4)
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode='min', patience=5, factor=0.5
|
||||
)
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
patience=early_stopping_patience,
|
||||
verbose=True,
|
||||
path=checkpoint_path
|
||||
)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
metrics = {}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{max_epochs}")
|
||||
print("-" * 50)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
train_samples = 0
|
||||
|
||||
interval_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Diffusion training: model returns loss directly when y is provided
|
||||
loss = model(seq_x, seq_y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
interval_loss += loss.item()
|
||||
train_samples += 1
|
||||
|
||||
# Log at intervals
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
elapsed_time = time.time() - start_time
|
||||
avg_interval_loss = interval_loss / log_interval
|
||||
|
||||
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
|
||||
f'Loss: {avg_interval_loss:.6f} '
|
||||
f'Time: {elapsed_time:.2f}s')
|
||||
|
||||
# Log to swanlab
|
||||
swanlab.log({
|
||||
'batch_loss': avg_interval_loss,
|
||||
'batch': epoch * len(dataloaders['train']) + batch_idx,
|
||||
'learning_rate': optimizer.param_groups[0]['lr']
|
||||
})
|
||||
|
||||
interval_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
avg_train_loss = train_loss / train_samples
|
||||
|
||||
# Validation phase - Use faster sampling for validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_samples = 0
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
print(" Validating...")
|
||||
with torch.no_grad():
|
||||
# Temporarily reduce diffusion steps for faster validation
|
||||
original_timesteps = model.diffusion.num_timesteps
|
||||
model.diffusion.num_timesteps = 200# Much faster validation
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
# Generate predictions (inference mode with reduced steps)
|
||||
pred = model(seq_x)
|
||||
|
||||
# Compute MSE loss for validation
|
||||
loss = criterion(pred, seq_y)
|
||||
val_loss += loss.item()
|
||||
val_samples += 1
|
||||
|
||||
# Print validation progress for first epoch
|
||||
if epoch == 0 and (batch_idx + 1) % 50 == 0:
|
||||
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
|
||||
|
||||
# Early break for very first epoch to speed up
|
||||
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
|
||||
break
|
||||
|
||||
# Restore original timesteps
|
||||
model.diffusion.num_timesteps = original_timesteps
|
||||
|
||||
avg_val_loss = val_loss / val_samples
|
||||
|
||||
# Learning rate scheduling
|
||||
scheduler.step(avg_val_loss)
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
print(f" Train Loss: {avg_train_loss:.6f}")
|
||||
print(f" Val Loss: {avg_val_loss:.6f}")
|
||||
print(f" Learning Rate: {current_lr:.2e}")
|
||||
|
||||
# Log to swanlab
|
||||
swanlab.log({
|
||||
'epoch': epoch + 1,
|
||||
'train_loss': avg_train_loss,
|
||||
'val_loss': avg_val_loss,
|
||||
'learning_rate': current_lr
|
||||
})
|
||||
|
||||
# Early stopping check
|
||||
early_stopping(avg_val_loss, model)
|
||||
|
||||
if early_stopping.early_stop:
|
||||
print(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Load best model
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
||||
|
||||
# Final evaluation on test set
|
||||
print("\nEvaluating on test set...")
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_samples = 0
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
with torch.no_grad():
|
||||
# Use reduced timesteps for faster testing
|
||||
original_timesteps = model.diffusion.num_timesteps
|
||||
model.diffusion.num_timesteps = 200 # Faster but still good quality
|
||||
|
||||
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
|
||||
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
|
||||
|
||||
pred = model(seq_x)
|
||||
|
||||
loss = criterion(pred, seq_y)
|
||||
test_loss += loss.item()
|
||||
test_samples += 1
|
||||
|
||||
all_preds.append(pred.cpu().numpy())
|
||||
all_targets.append(seq_y.cpu().numpy())
|
||||
|
||||
# Print progress every 50 batches
|
||||
if (batch_idx + 1) % 50 == 0:
|
||||
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
|
||||
|
||||
# Restore original timesteps
|
||||
model.diffusion.num_timesteps = original_timesteps
|
||||
|
||||
avg_test_loss = test_loss / test_samples
|
||||
|
||||
# Calculate additional metrics
|
||||
all_preds = np.concatenate(all_preds, axis=0)
|
||||
all_targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
mse = np.mean((all_preds - all_targets) ** 2)
|
||||
mae = np.mean(np.abs(all_preds - all_targets))
|
||||
rmse = np.sqrt(mse)
|
||||
|
||||
metrics = {
|
||||
'test_mse': mse,
|
||||
'test_mae': mae,
|
||||
'test_rmse': rmse,
|
||||
'test_loss': avg_test_loss
|
||||
}
|
||||
|
||||
print(f"Test Results:")
|
||||
print(f" MSE: {mse:.6f}")
|
||||
print(f" MAE: {mae:.6f}")
|
||||
print(f" RMSE: {rmse:.6f}")
|
||||
|
||||
# Log final results
|
||||
swanlab.log(metrics)
|
||||
|
||||
swanlab.finish()
|
||||
|
||||
return model, metrics
|
Reference in New Issue
Block a user