import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class PSLoss(nn.Module): """ Patch-wise Structural (PS) Loss for time series forecasting. Implements the loss function described in the paper that combines: 1. Fourier-based Adaptive Patching (FAP) 2. Patch-wise Structural Loss (correlation, variance, mean) 3. Gradient-based Dynamic Weighting (GDW) """ def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True): super(PSLoss, self).__init__() self.patch_len_threshold = patch_len_threshold self.lambda_ps = lambda_ps self.use_gdw = use_gdw self.kl_loss = nn.KLDivLoss(reduction='none') self.mse_loss = nn.MSELoss() def create_patches(self, x, patch_len, stride): """ Create patches from time series data. Args: x: Input tensor of shape [B, L, C] patch_len: Length of each patch stride: Stride for patching Returns: patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length """ B, L, C = x.shape num_patches = (L - patch_len) // stride + 1 patches = x.unfold(1, patch_len, stride) # [B, N, C, P] patches = patches.permute(0, 2, 1, 3) # [B, C, N, P] return patches def fourier_based_adaptive_patching(self, true, pred): """ Fourier-based Adaptive Patching (FAP) to determine optimal patch length. Args: true: Ground truth tensor [B, L, C] pred: Prediction tensor [B, L, C] Returns: true_patch: Patches from ground truth [B, C, N, P] pred_patch: Patches from prediction [B, C, N, P] """ # Get dominant frequency from ground truth true_fft = torch.fft.rfft(true, dim=1) frequency_list = torch.abs(true_fft).mean(0).mean(-1) frequency_list[:1] = 0.0 # Remove DC component top_index = torch.argmax(frequency_list) # Calculate period and patch length period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold patch_len = min(period // 2, self.patch_len_threshold) patch_len = max(patch_len, 4) # Minimum patch length stride = max(patch_len // 2, 1) # Create patches true_patch = self.create_patches(true, patch_len, stride) pred_patch = self.create_patches(pred, patch_len, stride) return true_patch, pred_patch def patch_wise_structural_loss(self, true_patch, pred_patch): """ Calculate patch-wise structural loss components. Args: true_patch: Ground truth patches [B, C, N, P] pred_patch: Prediction patches [B, C, N, P] Returns: corr_loss: Correlation loss var_loss: Variance loss mean_loss: Mean loss """ # Calculate patch statistics true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1] pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1] true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False) pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False) true_patch_std = torch.sqrt(true_patch_var + 1e-8) pred_patch_std = torch.sqrt(pred_patch_var + 1e-8) # Calculate covariance true_centered = true_patch - true_patch_mean pred_centered = pred_patch - pred_patch_mean covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True) # 1. Correlation Loss (based on Pearson correlation coefficient) correlation = covariance / (true_patch_std * pred_patch_std + 1e-8) corr_loss = (1.0 - correlation).mean() # 2. Variance Loss (using KL divergence of softmax distributions) true_patch_softmax = F.softmax(true_patch, dim=-1) pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1) var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean() # 3. Mean Loss (MAE between patch means) mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean() return corr_loss, var_loss, mean_loss def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred): """ Gradient-based Dynamic Weighting (GDW) for balancing loss components. Args: model: The neural network model corr_loss, var_loss, mean_loss: Loss components true: Ground truth tensor [B, L, C] pred: Prediction tensor [B, L, C] Returns: alpha, beta, gamma: Dynamic weights for the three loss components """ if not self.use_gdw or not self.training: return 1.0, 1.0, 1.0 try: # Get model parameters for gradient calculation if hasattr(model, 'projector'): params = list(model.projector.parameters()) elif hasattr(model, 'projection'): params = list(model.projection.parameters()) else: # Use output layer parameters params = list(model.parameters())[-2:] # Last linear layer if not params: return 1.0, 1.0, 1.0 # Calculate gradients for each loss component corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True) var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True) mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True) # Filter out None gradients and calculate norms corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None) var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None) mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None) if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0: return 1.0, 1.0, 1.0 # Calculate average gradient magnitude grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0 # Calculate dynamic weights alpha = grad_avg / corr_grad_norm beta = grad_avg / var_grad_norm gamma = grad_avg / mean_grad_norm # Calculate scaling factors for gamma true_flat = true.view(-1, true.shape[-1]) pred_flat = pred.view(-1, pred.shape[-1]) true_mean = torch.mean(true_flat, dim=0, keepdim=True) pred_mean = torch.mean(pred_flat, dim=0, keepdim=True) true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8 pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8 covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True) correlation_sim = covariance / (true_std * pred_std) correlation_sim = (1.0 + correlation_sim.mean()) * 0.5 variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2) c = 0.5 * (1 + correlation_sim) v = variance_sim gamma = gamma * c * v return alpha.item(), beta.item(), gamma.item() except Exception: # Fallback to equal weights if gradient calculation fails return 1.0, 1.0, 1.0 def forward(self, pred, true, model=None): """ Forward pass of PS Loss. Args: pred: Predictions [B, L, C] true: Ground truth [B, L, C] model: Neural network model (for gradient-based weighting) Returns: total_loss: Combined MSE + PS loss loss_dict: Dictionary with individual loss components """ # Standard MSE loss mse_loss = self.mse_loss(pred, true) # Fourier-based adaptive patching true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred) # Patch-wise structural loss corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch) # Gradient-based dynamic weighting if model is not None and self.use_gdw: alpha, beta, gamma = self.gradient_based_dynamic_weighting( model, corr_loss, var_loss, mean_loss, true, pred ) else: alpha, beta, gamma = 1.0, 1.0, 1.0 # Combine PS loss components ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss # Total loss: MSE + λ * PS_loss total_loss = mse_loss + self.lambda_ps * ps_loss loss_dict = { 'mse_loss': mse_loss.item(), 'ps_loss': ps_loss.item(), 'corr_loss': corr_loss.item(), 'var_loss': var_loss.item(), 'mean_loss': mean_loss.item(), 'alpha': alpha, 'beta': beta, 'gamma': gamma, 'total_loss': total_loss.item() } return total_loss, loss_dict