240 lines
9.5 KiB
Python
240 lines
9.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
class PSLoss(nn.Module):
|
|
"""
|
|
Patch-wise Structural (PS) Loss for time series forecasting.
|
|
|
|
Implements the loss function described in the paper that combines:
|
|
1. Fourier-based Adaptive Patching (FAP)
|
|
2. Patch-wise Structural Loss (correlation, variance, mean)
|
|
3. Gradient-based Dynamic Weighting (GDW)
|
|
"""
|
|
|
|
def __init__(self, patch_len_threshold=64, lambda_ps=5.0, use_gdw=True):
|
|
super(PSLoss, self).__init__()
|
|
self.patch_len_threshold = patch_len_threshold
|
|
self.lambda_ps = lambda_ps
|
|
self.use_gdw = use_gdw
|
|
self.kl_loss = nn.KLDivLoss(reduction='none')
|
|
self.mse_loss = nn.MSELoss()
|
|
|
|
def create_patches(self, x, patch_len, stride):
|
|
"""
|
|
Create patches from time series data.
|
|
|
|
Args:
|
|
x: Input tensor of shape [B, L, C]
|
|
patch_len: Length of each patch
|
|
stride: Stride for patching
|
|
|
|
Returns:
|
|
patches: Tensor of shape [B, C, N, P] where N is number of patches, P is patch length
|
|
"""
|
|
B, L, C = x.shape
|
|
|
|
num_patches = (L - patch_len) // stride + 1
|
|
patches = x.unfold(1, patch_len, stride) # [B, N, C, P]
|
|
patches = patches.permute(0, 2, 1, 3) # [B, C, N, P]
|
|
|
|
return patches
|
|
|
|
def fourier_based_adaptive_patching(self, true, pred):
|
|
"""
|
|
Fourier-based Adaptive Patching (FAP) to determine optimal patch length.
|
|
|
|
Args:
|
|
true: Ground truth tensor [B, L, C]
|
|
pred: Prediction tensor [B, L, C]
|
|
|
|
Returns:
|
|
true_patch: Patches from ground truth [B, C, N, P]
|
|
pred_patch: Patches from prediction [B, C, N, P]
|
|
"""
|
|
# Get dominant frequency from ground truth
|
|
true_fft = torch.fft.rfft(true, dim=1)
|
|
frequency_list = torch.abs(true_fft).mean(0).mean(-1)
|
|
frequency_list[:1] = 0.0 # Remove DC component
|
|
top_index = torch.argmax(frequency_list)
|
|
|
|
# Calculate period and patch length
|
|
period = true.shape[1] // top_index if top_index > 0 else self.patch_len_threshold
|
|
patch_len = min(period // 2, self.patch_len_threshold)
|
|
patch_len = max(patch_len, 4) # Minimum patch length
|
|
stride = max(patch_len // 2, 1)
|
|
|
|
# Create patches
|
|
true_patch = self.create_patches(true, patch_len, stride)
|
|
pred_patch = self.create_patches(pred, patch_len, stride)
|
|
|
|
return true_patch, pred_patch
|
|
|
|
def patch_wise_structural_loss(self, true_patch, pred_patch):
|
|
"""
|
|
Calculate patch-wise structural loss components.
|
|
|
|
Args:
|
|
true_patch: Ground truth patches [B, C, N, P]
|
|
pred_patch: Prediction patches [B, C, N, P]
|
|
|
|
Returns:
|
|
corr_loss: Correlation loss
|
|
var_loss: Variance loss
|
|
mean_loss: Mean loss
|
|
"""
|
|
# Calculate patch statistics
|
|
true_patch_mean = torch.mean(true_patch, dim=-1, keepdim=True) # [B, C, N, 1]
|
|
pred_patch_mean = torch.mean(pred_patch, dim=-1, keepdim=True) # [B, C, N, 1]
|
|
|
|
true_patch_var = torch.var(true_patch, dim=-1, keepdim=True, unbiased=False)
|
|
pred_patch_var = torch.var(pred_patch, dim=-1, keepdim=True, unbiased=False)
|
|
true_patch_std = torch.sqrt(true_patch_var + 1e-8)
|
|
pred_patch_std = torch.sqrt(pred_patch_var + 1e-8)
|
|
|
|
# Calculate covariance
|
|
true_centered = true_patch - true_patch_mean
|
|
pred_centered = pred_patch - pred_patch_mean
|
|
covariance = torch.mean(true_centered * pred_centered, dim=-1, keepdim=True)
|
|
|
|
# 1. Correlation Loss (based on Pearson correlation coefficient)
|
|
correlation = covariance / (true_patch_std * pred_patch_std + 1e-8)
|
|
corr_loss = (1.0 - correlation).mean()
|
|
|
|
# 2. Variance Loss (using KL divergence of softmax distributions)
|
|
true_patch_softmax = F.softmax(true_patch, dim=-1)
|
|
pred_patch_log_softmax = F.log_softmax(pred_patch, dim=-1)
|
|
var_loss = self.kl_loss(pred_patch_log_softmax, true_patch_softmax).sum(dim=-1).mean()
|
|
|
|
# 3. Mean Loss (MAE between patch means)
|
|
mean_loss = torch.abs(true_patch_mean - pred_patch_mean).mean()
|
|
|
|
return corr_loss, var_loss, mean_loss
|
|
|
|
def gradient_based_dynamic_weighting(self, model, corr_loss, var_loss, mean_loss, true, pred):
|
|
"""
|
|
Gradient-based Dynamic Weighting (GDW) for balancing loss components.
|
|
|
|
Args:
|
|
model: The neural network model
|
|
corr_loss, var_loss, mean_loss: Loss components
|
|
true: Ground truth tensor [B, L, C]
|
|
pred: Prediction tensor [B, L, C]
|
|
|
|
Returns:
|
|
alpha, beta, gamma: Dynamic weights for the three loss components
|
|
"""
|
|
if not self.use_gdw or not self.training:
|
|
return 1.0, 1.0, 1.0
|
|
|
|
try:
|
|
# Get model parameters for gradient calculation
|
|
if hasattr(model, 'projector'):
|
|
params = list(model.projector.parameters())
|
|
elif hasattr(model, 'projection'):
|
|
params = list(model.projection.parameters())
|
|
else:
|
|
# Use output layer parameters
|
|
params = list(model.parameters())[-2:] # Last linear layer
|
|
|
|
if not params:
|
|
return 1.0, 1.0, 1.0
|
|
|
|
# Calculate gradients for each loss component
|
|
corr_grad = torch.autograd.grad(corr_loss, params, retain_graph=True, allow_unused=True)
|
|
var_grad = torch.autograd.grad(var_loss, params, retain_graph=True, allow_unused=True)
|
|
mean_grad = torch.autograd.grad(mean_loss, params, retain_graph=True, allow_unused=True)
|
|
|
|
# Filter out None gradients and calculate norms
|
|
corr_grad_norm = sum(g.norm().item() for g in corr_grad if g is not None)
|
|
var_grad_norm = sum(g.norm().item() for g in var_grad if g is not None)
|
|
mean_grad_norm = sum(g.norm().item() for g in mean_grad if g is not None)
|
|
|
|
if corr_grad_norm == 0 or var_grad_norm == 0 or mean_grad_norm == 0:
|
|
return 1.0, 1.0, 1.0
|
|
|
|
# Calculate average gradient magnitude
|
|
grad_avg = (corr_grad_norm + var_grad_norm + mean_grad_norm) / 3.0
|
|
|
|
# Calculate dynamic weights
|
|
alpha = grad_avg / corr_grad_norm
|
|
beta = grad_avg / var_grad_norm
|
|
gamma = grad_avg / mean_grad_norm
|
|
|
|
# Calculate scaling factors for gamma
|
|
true_flat = true.view(-1, true.shape[-1])
|
|
pred_flat = pred.view(-1, pred.shape[-1])
|
|
|
|
true_mean = torch.mean(true_flat, dim=0, keepdim=True)
|
|
pred_mean = torch.mean(pred_flat, dim=0, keepdim=True)
|
|
true_std = torch.std(true_flat, dim=0, keepdim=True) + 1e-8
|
|
pred_std = torch.std(pred_flat, dim=0, keepdim=True) + 1e-8
|
|
|
|
covariance = torch.mean((true_flat - true_mean) * (pred_flat - pred_mean), dim=0, keepdim=True)
|
|
correlation_sim = covariance / (true_std * pred_std)
|
|
correlation_sim = (1.0 + correlation_sim.mean()) * 0.5
|
|
|
|
variance_sim = (2 * true_std.mean() * pred_std.mean()) / (true_std.mean()**2 + pred_std.mean()**2)
|
|
|
|
c = 0.5 * (1 + correlation_sim)
|
|
v = variance_sim
|
|
gamma = gamma * c * v
|
|
|
|
return alpha.item(), beta.item(), gamma.item()
|
|
|
|
except Exception:
|
|
# Fallback to equal weights if gradient calculation fails
|
|
return 1.0, 1.0, 1.0
|
|
|
|
def forward(self, pred, true, model=None):
|
|
"""
|
|
Forward pass of PS Loss.
|
|
|
|
Args:
|
|
pred: Predictions [B, L, C]
|
|
true: Ground truth [B, L, C]
|
|
model: Neural network model (for gradient-based weighting)
|
|
|
|
Returns:
|
|
total_loss: Combined MSE + PS loss
|
|
loss_dict: Dictionary with individual loss components
|
|
"""
|
|
# Standard MSE loss
|
|
mse_loss = self.mse_loss(pred, true)
|
|
|
|
# Fourier-based adaptive patching
|
|
true_patch, pred_patch = self.fourier_based_adaptive_patching(true, pred)
|
|
|
|
# Patch-wise structural loss
|
|
corr_loss, var_loss, mean_loss = self.patch_wise_structural_loss(true_patch, pred_patch)
|
|
|
|
# Gradient-based dynamic weighting
|
|
if model is not None and self.use_gdw:
|
|
alpha, beta, gamma = self.gradient_based_dynamic_weighting(
|
|
model, corr_loss, var_loss, mean_loss, true, pred
|
|
)
|
|
else:
|
|
alpha, beta, gamma = 1.0, 1.0, 1.0
|
|
|
|
# Combine PS loss components
|
|
ps_loss = alpha * corr_loss + beta * var_loss + gamma * mean_loss
|
|
|
|
# Total loss: MSE + λ * PS_loss
|
|
total_loss = mse_loss + self.lambda_ps * ps_loss
|
|
|
|
loss_dict = {
|
|
'mse_loss': mse_loss.item(),
|
|
'ps_loss': ps_loss.item(),
|
|
'corr_loss': corr_loss.item(),
|
|
'var_loss': var_loss.item(),
|
|
'mean_loss': mean_loss.item(),
|
|
'alpha': alpha,
|
|
'beta': beta,
|
|
'gamma': gamma,
|
|
'total_loss': total_loss.item()
|
|
}
|
|
|
|
return total_loss, loss_dict
|