Files
tsmodel/layers/ps_loss.py

240 lines
9.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PSLoss(nn.Module):
"""
Patch-wise Structural (PS) Loss for time series forecasting.
Implements the loss function described in the paper that combines:
1. Fourier-based Adaptive Patching (FAP)
2. Patch-wise Structural Loss (correlation, variance, mean)
3. Gradient-based Dynamic Weighting (GDW)
"""
def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True):
super(PSLoss, self).__init__()
self.patch_len_threshold = patch_len_threshold
self.lambda_ps = lambda_ps
self.use_gdw = use_gdw
self.kl_loss = nn.KLDivLoss(reduction='none')
self.mse_loss = nn.MSELoss()
def create_patches(self, x, patch_len, stride):
"""
Create patches from time series data.
Args:
x: Input tensor of shape [B, L, C]
patch_len: Length of each patch
stride: Stride for patching
Returns:
patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length
"""
B, L, C = x.shape
num_patches = (L - patch_len) // stride + 1
patches = x.unfold(1, patch_len, stride) # [B, N, C, P]
patches = patches.permute(0, 2, 1, 3) # [B, C, N, P]
return patches
def fourier_based_adaptive_patching(self, true, pred):
"""
Fourier-based Adaptive Patching (FAP) to determine optimal patch length.
Args:
true: Ground truth tensor [B, L, C]
pred: Prediction tensor [B, L, C]
Returns:
true_patch: Patches from ground truth [B, C, N, P]
pred_patch: Patches from prediction [B, C, N, P]
"""
# Get dominant frequency from ground truth
true_fft = torch.fft.rfft(true, dim=1)
frequency_list = torch.abs(true_fft).mean(0).mean(-1)
frequency_list[:1] = 0.0 # Remove DC component
top_index = torch.argmax(frequency_list)
# Calculate period and patch length
period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold
patch_len = min(period // 2, self.patch_len_threshold)
patch_len = max(patch_len, 4) # Minimum patch length
stride = max(patch_len // 2, 1)
# Create patches
true_patch = self.create_patches(true, patch_len, stride)
pred_patch = self.create_patches(pred, patch_len, stride)
return true_patch, pred_patch
def patch_wise_structural_loss(self, true_patch, pred_patch):
"""
Calculate patch-wise structural loss components.
Args:
true_patch: Ground truth patches [B, C, N, P]
pred_patch: Prediction patches [B, C, N, P]
Returns:
corr_loss: Correlation loss
var_loss: Variance loss
mean_loss: Mean loss
"""
# Calculate patch statistics
true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1]
pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1]
true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False)
pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False)
true_patch_std = torch.sqrt(true_patch_var + 1e-8)
pred_patch_std = torch.sqrt(pred_patch_var + 1e-8)
# Calculate covariance
true_centered = true_patch - true_patch_mean
pred_centered = pred_patch - pred_patch_mean
covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True)
# 1. Correlation Loss (based on Pearson correlation coefficient)
correlation = covariance / (true_patch_std * pred_patch_std + 1e-8)
corr_loss = (1.0 - correlation).mean()
# 2. Variance Loss (using KL divergence of softmax distributions)
true_patch_softmax = F.softmax(true_patch, dim=-1)
pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1)
var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean()
# 3. Mean Loss (MAE between patch means)
mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean()
return corr_loss, var_loss, mean_loss
def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred):
"""
Gradient-based Dynamic Weighting (GDW) for balancing loss components.
Args:
model: The neural network model
corr_loss, var_loss, mean_loss: Loss components
true: Ground truth tensor [B, L, C]
pred: Prediction tensor [B, L, C]
Returns:
alpha, beta, gamma: Dynamic weights for the three loss components
"""
if not self.use_gdw or not self.training:
return 1.0, 1.0, 1.0
try:
# Get model parameters for gradient calculation
if hasattr(model, 'projector'):
params = list(model.projector.parameters())
elif hasattr(model, 'projection'):
params = list(model.projection.parameters())
else:
# Use output layer parameters
params = list(model.parameters())[-2:] # Last linear layer
if not params:
return 1.0, 1.0, 1.0
# Calculate gradients for each loss component
corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True)
var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True)
mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True)
# Filter out None gradients and calculate norms
corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None)
var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None)
mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None)
if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0:
return 1.0, 1.0, 1.0
# Calculate average gradient magnitude
grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0
# Calculate dynamic weights
alpha = grad_avg / corr_grad_norm
beta = grad_avg / var_grad_norm
gamma = grad_avg / mean_grad_norm
# Calculate scaling factors for gamma
true_flat = true.view(-1, true.shape[-1])
pred_flat = pred.view(-1, pred.shape[-1])
true_mean = torch.mean(true_flat, dim=0, keepdim=True)
pred_mean = torch.mean(pred_flat, dim=0, keepdim=True)
true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8
pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8
covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True)
correlation_sim = covariance / (true_std * pred_std)
correlation_sim = (1.0 + correlation_sim.mean()) * 0.5
variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2)
c = 0.5 * (1 + correlation_sim)
v = variance_sim
gamma = gamma * c * v
return alpha.item(), beta.item(), gamma.item()
except Exception:
# Fallback to equal weights if gradient calculation fails
return 1.0, 1.0, 1.0
def forward(self, pred, true, model=None):
"""
Forward pass of PS Loss.
Args:
pred: Predictions [B, L, C]
true: Ground truth [B, L, C]
model: Neural network model (for gradient-based weighting)
Returns:
total_loss: Combined MSE + PS loss
loss_dict: Dictionary with individual loss components
"""
# Standard MSE loss
mse_loss = self.mse_loss(pred, true)
# Fourier-based adaptive patching
true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred)
# Patch-wise structural loss
corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch)
# Gradient-based dynamic weighting
if model is not None and self.use_gdw:
alpha, beta, gamma = self.gradient_based_dynamic_weighting(
model, corr_loss, var_loss, mean_loss, true, pred
)
else:
alpha, beta, gamma = 1.0, 1.0, 1.0
# Combine PS loss components
ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss
# Total loss: MSE + λ * PS_loss
total_loss = mse_loss + self.lambda_ps * ps_loss
loss_dict = {
'mse_loss': mse_loss.item(),
'ps_loss': ps_loss.item(),
'corr_loss': corr_loss.item(),
'var_loss': var_loss.item(),
'mean_loss': mean_loss.item(),
'alpha': alpha,
'beta': beta,
'gamma': gamma,
'total_loss': total_loss.item()
}
return total_loss, loss_dict