import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import swanlab from typing import Dict, Any, Optional, Callable, Union, Tuple from dataflow import data_provider class EarlyStopping: """Early stopping to stop training when validation performance doesn't improve.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = float('inf') self.delta = delta self.path = path def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): """Save model when validation loss decreases.""" if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') torch.save(model.state_dict(), self.path) self.val_loss_min = val_loss class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset): """Wrapper to remove time features from dataflow datasets when use_x_mark=False""" def __init__(self, original_dataset): self.original_dataset = original_dataset def __getitem__(self, index): seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index] return seq_x, seq_y def __len__(self): return len(self.original_dataset) def inverse_transform(self, data): if hasattr(self.original_dataset, 'inverse_transform'): return self.original_dataset.inverse_transform(data) return data def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]: """Create PyTorch DataLoaders using dataflow data_provider""" train_data, _ = data_provider(args, flag='train') val_data, _ = data_provider(args, flag='val') test_data, _ = data_provider(args, flag='test') if not use_x_mark: train_data = DatasetWrapperWithoutTimeFeatures(train_data) val_data = DatasetWrapperWithoutTimeFeatures(val_data) test_data = DatasetWrapperWithoutTimeFeatures(test_data) train_shuffle = True val_shuffle = False test_shuffle = False train_drop_last = True val_drop_last = True test_drop_last = True batch_size = args.batch_size num_workers = args.num_workers train_loader = DataLoader( train_data, batch_size=batch_size, shuffle=train_shuffle, num_workers=num_workers, drop_last=train_drop_last ) val_loader = DataLoader( val_data, batch_size=batch_size, shuffle=val_shuffle, num_workers=num_workers, drop_last=val_drop_last ) test_loader = DataLoader( test_data, batch_size=batch_size, shuffle=test_shuffle, num_workers=num_workers, drop_last=test_drop_last ) return {'train': train_loader, 'val': val_loader, 'test': test_loader} def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]: """ Create PyTorch DataLoaders from an NPZ file Args: data_path (str): Path to the NPZ file containing the data batch_size (int): Batch size for the DataLoaders use_x_mark (bool): Whether to use time features (x_mark) from the data file Returns: Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders """ # Load data from NPZ file data = np.load(data_path, allow_pickle=True) train_x = data['train_x'] train_y = data['train_y'] val_x = data['val_x'] val_y = data['val_y'] test_x = data['test_x'] test_y = data['test_y'] # Load time features if available and needed if use_x_mark: train_x_mark = data.get('train_x_mark', None) train_y_mark = data.get('train_y_mark', None) val_x_mark = data.get('val_x_mark', None) val_y_mark = data.get('val_y_mark', None) test_x_mark = data.get('test_x_mark', None) test_y_mark = data.get('test_y_mark', None) else: train_x_mark = None train_y_mark = None val_x_mark = None val_y_mark = None test_x_mark = None test_y_mark = None # Convert to PyTorch tensors train_x = torch.FloatTensor(train_x) train_y = torch.FloatTensor(train_y) val_x = torch.FloatTensor(val_x) val_y = torch.FloatTensor(val_y) test_x = torch.FloatTensor(test_x) test_y = torch.FloatTensor(test_y) # Create datasets based on whether time features are available if train_x_mark is not None: train_x_mark = torch.FloatTensor(train_x_mark) train_y_mark = torch.FloatTensor(train_y_mark) val_x_mark = torch.FloatTensor(val_x_mark) val_y_mark = torch.FloatTensor(val_y_mark) test_x_mark = torch.FloatTensor(test_x_mark) test_y_mark = torch.FloatTensor(test_y_mark) train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark) else: train_dataset = TensorDataset(train_x, train_y) val_dataset = TensorDataset(val_x, val_y) test_dataset = TensorDataset(test_x, test_y) # Create dataloaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return { 'train': train_loader, 'val': val_loader, 'test': test_loader } def train_diffusion_model( model_constructor: Callable, data_path: str, project_name: str, config: Dict[str, Any], device: Optional[str] = None, early_stopping_patience: int = 10, max_epochs: int = 100, checkpoint_dir: str = "./checkpoints", log_interval: int = 10, ) -> Tuple[nn.Module, Dict[str, float]]: """ Train a Diffusion time series forecasting model using NPZ data loading """ # Setup device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize swanlab for experiment tracking swanlab_run = swanlab.init( project=project_name, config=config, ) # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") # Create data loaders using NPZ files (following other models' pattern) dataloaders = create_data_loaders( data_path=data_path, batch_size=config.get('batch_size', 32), use_x_mark=False # DiffusionTimeSeries doesn't use time features ) # Construct the model model = model_constructor() model = model.to(device) print(f"Model created with {model.get_num_params():,} parameters") # Define optimizer for diffusion training optimizer = optim.Adam( model.parameters(), lr=config.get('learning_rate', 1e-4), weight_decay=config.get('weight_decay', 1e-4) ) # Learning rate scheduler scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=5, factor=0.5 ) # Initialize early stopping early_stopping = EarlyStopping( patience=early_stopping_patience, verbose=True, path=checkpoint_path ) # Training loop best_val_loss = float('inf') metrics = {} for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") print("-" * 50) # Training phase model.train() train_loss = 0.0 train_samples = 0 interval_loss = 0.0 start_time = time.time() for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['train']): seq_x, seq_y = seq_x.to(device), seq_y.to(device) optimizer.zero_grad() # Diffusion training: model returns loss directly when y is provided loss = model(seq_x, seq_y) loss.backward() optimizer.step() train_loss += loss.item() interval_loss += loss.item() train_samples += 1 # Log at intervals if (batch_idx + 1) % log_interval == 0: elapsed_time = time.time() - start_time avg_interval_loss = interval_loss / log_interval print(f' Batch [{batch_idx+1}/{len(dataloaders["train"])}] ' f'Loss: {avg_interval_loss:.6f} ' f'Time: {elapsed_time:.2f}s') # Log to swanlab swanlab.log({ 'batch_loss': avg_interval_loss, 'batch': epoch * len(dataloaders['train']) + batch_idx, 'learning_rate': optimizer.param_groups[0]['lr'] }) interval_loss = 0.0 start_time = time.time() avg_train_loss = train_loss / train_samples # Validation phase - Use faster sampling for validation model.eval() val_loss = 0.0 val_samples = 0 criterion = nn.MSELoss() print(" Validating...") with torch.no_grad(): # Temporarily reduce diffusion steps for faster validation original_timesteps = model.diffusion.num_timesteps model.diffusion.num_timesteps = 200# Much faster validation for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['val']): seq_x, seq_y = seq_x.to(device), seq_y.to(device) # Generate predictions (inference mode with reduced steps) pred = model(seq_x) # Compute MSE loss for validation loss = criterion(pred, seq_y) val_loss += loss.item() val_samples += 1 # Print validation progress for first epoch if epoch == 0 and (batch_idx + 1) % 50 == 0: print(f" Val Batch [{batch_idx+1}/{len(dataloaders['val'])}]") # Early break for very first epoch to speed up if epoch == 0 and batch_idx >= 100: # Only validate on first 100 batches for first epoch break # Restore original timesteps model.diffusion.num_timesteps = original_timesteps avg_val_loss = val_loss / val_samples # Learning rate scheduling scheduler.step(avg_val_loss) current_lr = optimizer.param_groups[0]['lr'] print(f" Train Loss: {avg_train_loss:.6f}") print(f" Val Loss: {avg_val_loss:.6f}") print(f" Learning Rate: {current_lr:.2e}") # Log to swanlab swanlab.log({ 'epoch': epoch + 1, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss, 'learning_rate': current_lr }) # Early stopping check early_stopping(avg_val_loss, model) if early_stopping.early_stop: print(f"Early stopping at epoch {epoch + 1}") break # Load best model model.load_state_dict(torch.load(checkpoint_path, map_location=device)) # Final evaluation on test set print("\nEvaluating on test set...") model.eval() test_loss = 0.0 test_samples = 0 all_preds = [] all_targets = [] with torch.no_grad(): # Use reduced timesteps for faster testing original_timesteps = model.diffusion.num_timesteps model.diffusion.num_timesteps = 200 # Faster but still good quality for batch_idx, (seq_x, seq_y) in enumerate(dataloaders['test']): seq_x, seq_y = seq_x.to(device), seq_y.to(device) pred = model(seq_x) loss = criterion(pred, seq_y) test_loss += loss.item() test_samples += 1 all_preds.append(pred.cpu().numpy()) all_targets.append(seq_y.cpu().numpy()) # Print progress every 50 batches if (batch_idx + 1) % 50 == 0: print(f" Test Batch [{batch_idx+1}/{len(dataloaders['test'])}]") # Restore original timesteps model.diffusion.num_timesteps = original_timesteps avg_test_loss = test_loss / test_samples # Calculate additional metrics all_preds = np.concatenate(all_preds, axis=0) all_targets = np.concatenate(all_targets, axis=0) mse = np.mean((all_preds - all_targets) ** 2) mae = np.mean(np.abs(all_preds - all_targets)) rmse = np.sqrt(mse) metrics = { 'test_mse': mse, 'test_mae': mae, 'test_rmse': rmse, 'test_loss': avg_test_loss } print(f"Test Results:") print(f" MSE: {mse:.6f}") print(f" MAE: {mae:.6f}") print(f" RMSE: {rmse:.6f}") # Log final results swanlab.log(metrics) swanlab.finish() return model, metrics