feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel

This commit is contained in:
game-loader
2025-08-26 20:53:35 +08:00
parent 44bd5c8f29
commit c3713f5c0b
11 changed files with 1528 additions and 41 deletions

View File

@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
from layers.ps_loss import PSLoss
from utils.tools import adjust_learning_rate, dotdict
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
num_workers: int = 4, pin_memory: bool = True,
persistent_workers: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
num_workers (int): Number of worker processes for data loading
pin_memory (bool): Whether to pin memory for faster GPU transfer
persistent_workers (bool): Whether to keep workers alive between epochs
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Create dataloaders with performance optimizations
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=True # Drop incomplete batches for training
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
return {
'train': train_loader,
@ -223,7 +254,12 @@ def train_forecasting_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
use_ps_loss: bool = False,
ps_lambda: float = 5.0,
patch_len_threshold: int = 64,
use_gdw: bool = True,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
@ -241,6 +277,11 @@ def train_forecasting_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
ps_lambda (float): Weight for PS loss component when combined with MSE
patch_len_threshold (int): Maximum patch length for adaptive patching
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -271,7 +312,10 @@ def train_forecasting_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -279,14 +323,24 @@ def train_forecasting_model(
model = model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss()
if use_ps_loss:
criterion = PSLoss(
patch_len_threshold=patch_len_threshold,
lambda_ps=ps_lambda,
use_gdw=use_gdw
)
else:
criterion = nn.MSELoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -334,7 +388,11 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Calculate loss
if use_ps_loss:
loss, loss_dict = criterion(outputs, targets, model)
else:
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
@ -345,10 +403,26 @@ def train_forecasting_model(
interval_loss += loss.item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
# 计算这一个 interval 的平均损失并记录
avg_interval_loss = interval_loss / log_interval
swanlab_run.log({"batch_train_loss": avg_interval_loss})
if use_ps_loss and 'loss_dict' in locals():
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
f"Total Loss: {loss.item():.4f}, "
f"MSE: {loss_dict['mse_loss']:.4f}, "
f"PS: {loss_dict['ps_loss']:.4f}")
# Log detailed loss components
swanlab_run.log({
"batch_total_loss": loss.item(),
"batch_mse_loss": loss_dict['mse_loss'],
"batch_ps_loss": loss_dict['ps_loss'],
"batch_corr_loss": loss_dict['corr_loss'],
"batch_var_loss": loss_dict['var_loss'],
"batch_mean_loss": loss_dict['mean_loss'],
"alpha": loss_dict['alpha'],
"beta": loss_dict['beta'],
"gamma": loss_dict['gamma']
})
else:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
swanlab_run.log({"batch_train_loss": loss.item()})
# 重置 interval loss 以进行下一次计算
interval_loss = 0.0
@ -360,6 +434,7 @@ def train_forecasting_model(
model.eval()
val_loss = 0.0
val_mse = 0.0
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
with torch.no_grad():
for batch_data in dataloaders['val']:
@ -381,18 +456,28 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate training loss (PS or MSE)
if use_ps_loss:
loss, _ = criterion(outputs, targets, model)
val_loss += loss.item()
else:
loss = criterion(outputs, targets)
val_loss += loss.item()
# Always calculate MSE for validation metrics
mse_loss = val_mse_criterion(outputs, targets)
val_mse += mse_loss.item()
avg_val_loss = val_loss / len(dataloaders['val'])
avg_val_mse = val_mse / len(dataloaders['val'])
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_mse": avg_val_mse,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
@ -402,6 +487,7 @@ def train_forecasting_model(
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val MSE: {avg_val_mse:.4f}, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
@ -416,16 +502,17 @@ def train_forecasting_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Test evaluation on the best model
# Test evaluation on the best model - Always use MSE for final evaluation
model.eval()
test_loss = 0.0
test_mse = 0.0
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
print("Evaluating on test set...")
with torch.no_grad():
@ -448,16 +535,16 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
test_loss += loss.item()
# Always calculate MSE for test evaluation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
test_loss += mse_loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency
# Final validation for consistency - Always use MSE for final metrics
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
@ -482,25 +569,31 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
# Always calculate MSE for final validation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
final_val_loss += mse_loss.item()
final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.6f}")
print(f"Final validation MSE: {final_val_loss:.6f}")
print(f"Final test MSE: {test_loss:.6f}")
if use_ps_loss:
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
# Log final test results to swanlab
final_metrics = {
"final_test_loss": test_loss,
"final_val_loss": final_val_loss
"final_test_mse": test_loss,
"final_val_mse": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values
# Update metrics with final values (always MSE for comparison)
metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
# Finish the swanlab run
swanlab_run.finish()
@ -519,7 +612,8 @@ def train_classification_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
@ -537,6 +631,7 @@ def train_classification_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -567,7 +662,10 @@ def train_classification_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -582,8 +680,11 @@ def train_classification_model(
weight_decay=config.get('weight_decay', 1e-4)
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -722,8 +823,8 @@ def train_classification_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))

408
train/train_diffusion.py Normal file
View File

@ -0,0 +1,408 @@
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""Create PyTorch DataLoaders using dataflow data_provider"""
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
train_loader = DataLoader(
train_data, batch_size=batch_size, shuffle=train_shuffle,
num_workers=num_workers, drop_last=train_drop_last
)
val_loader = DataLoader(
val_data, batch_size=batch_size, shuffle=val_shuffle,
num_workers=num_workers, drop_last=val_drop_last
)
test_loader = DataLoader(
test_data, batch_size=batch_size, shuffle=test_shuffle,
num_workers=num_workers, drop_last=test_drop_last
)
return {'train': train_loader, 'val': val_loader, 'test': test_loader}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
train_x = data['train_x']
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available and needed
if use_x_mark:
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
train_x_mark = torch.FloatTensor(train_x_mark)
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def train_diffusion_model(
model_constructor: Callable,
data_path: str,
project_name: str,
config: Dict[str, Any],
device: Optional[str] = None,
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10,
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a Diffusion time series forecasting model using NPZ data loading
"""
# Setup device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize swanlab for experiment tracking
swanlab_run = swanlab.init(
project=project_name,
config=config,
)
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders using NPZ files (following other models' pattern)
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=False # DiffusionTimeSeries doesn't use time features
)
# Construct the model
model = model_constructor()
model = model.to(device)
print(f"Model created with {model.get_num_params():,} parameters")
# Define optimizer for diffusion training
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-4),
weight_decay=config.get('weight_decay', 1e-4)
)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=5, factor=0.5
)
# Initialize early stopping
early_stopping = EarlyStopping(
patience=early_stopping_patience,
verbose=True,
path=checkpoint_path
)
# Training loop
best_val_loss = float('inf')
metrics = {}
for epoch in range(max_epochs):
print(f"\nEpoch {epoch+1}/{max_epochs}")
print("-" * 50)
# Training phase
model.train()
train_loss = 0.0
train_samples = 0
interval_loss = 0.0
start_time = time.time()
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
optimizer.zero_grad()
# Diffusion training: model returns loss directly when y is provided
loss = model(seq_x, seq_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
interval_loss += loss.item()
train_samples += 1
# Log at intervals
if (batch_idx + 1) % log_interval == 0:
elapsed_time = time.time() - start_time
avg_interval_loss = interval_loss / log_interval
print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] '
f'Loss: {avg_interval_loss:.6f} '
f'Time: {elapsed_time:.2f}s')
# Log to swanlab
swanlab.log({
'batch_loss': avg_interval_loss,
'batch': epoch * len(dataloaders['train']) + batch_idx,
'learning_rate': optimizer.param_groups[0]['lr']
})
interval_loss = 0.0
start_time = time.time()
avg_train_loss = train_loss / train_samples
# Validation phase - Use faster sampling for validation
model.eval()
val_loss = 0.0
val_samples = 0
criterion = nn.MSELoss()
print(" Validating...")
with torch.no_grad():
# Temporarily reduce diffusion steps for faster validation
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200# Much faster validation
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
# Generate predictions (inference mode with reduced steps)
pred = model(seq_x)
# Compute MSE loss for validation
loss = criterion(pred, seq_y)
val_loss += loss.item()
val_samples += 1
# Print validation progress for first epoch
if epoch == 0 and (batch_idx + 1) % 50 == 0:
print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]")
# Early break for very first epoch to speed up
if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch
break
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_val_loss = val_loss / val_samples
# Learning rate scheduling
scheduler.step(avg_val_loss)
current_lr = optimizer.param_groups[0]['lr']
print(f" Train Loss: {avg_train_loss:.6f}")
print(f" Val Loss: {avg_val_loss:.6f}")
print(f" Learning Rate: {current_lr:.2e}")
# Log to swanlab
swanlab.log({
'epoch': epoch + 1,
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
'learning_rate': current_lr
})
# Early stopping check
early_stopping(avg_val_loss, model)
if early_stopping.early_stop:
print(f"Early stopping at epoch {epoch + 1}")
break
# Load best model
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# Final evaluation on test set
print("\nEvaluating on test set...")
model.eval()
test_loss = 0.0
test_samples = 0
all_preds = []
all_targets = []
with torch.no_grad():
# Use reduced timesteps for faster testing
original_timesteps = model.diffusion.num_timesteps
model.diffusion.num_timesteps = 200 # Faster but still good quality
for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']):
seq_x, seq_y = seq_x.to(device), seq_y.to(device)
pred = model(seq_x)
loss = criterion(pred, seq_y)
test_loss += loss.item()
test_samples += 1
all_preds.append(pred.cpu().numpy())
all_targets.append(seq_y.cpu().numpy())
# Print progress every 50 batches
if (batch_idx + 1) % 50 == 0:
print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]")
# Restore original timesteps
model.diffusion.num_timesteps = original_timesteps
avg_test_loss = test_loss / test_samples
# Calculate additional metrics
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
mse = np.mean((all_preds - all_targets) ** 2)
mae = np.mean(np.abs(all_preds - all_targets))
rmse = np.sqrt(mse)
metrics = {
'test_mse': mse,
'test_mae': mae,
'test_rmse': rmse,
'test_loss': avg_test_loss
}
print(f"Test Results:")
print(f" MSE: {mse:.6f}")
print(f" MAE: {mae:.6f}")
print(f" RMSE: {rmse:.6f}")
# Log final results
swanlab.log(metrics)
swanlab.finish()
return model, metrics