feat: add DiffusionTimeSeries and iTransformer models, introduce xPatch_SparseChannel

This commit is contained in:
game-loader
2025-08-26 20:53:35 +08:00
parent 44bd5c8f29
commit c3713f5c0b
11 changed files with 1528 additions and 41 deletions

View File

@ -8,6 +8,8 @@ from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
from layers.ps_loss import PSLoss
from utils.tools import adjust_learning_rate, dotdict
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
@ -138,7 +140,9 @@ def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True,
num_workers: int = 4, pin_memory: bool = True,
persistent_workers: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
@ -146,6 +150,9 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
num_workers (int): Number of worker processes for data loading
pin_memory (bool): Whether to pin memory for faster GPU transfer
persistent_workers (bool): Whether to keep workers alive between epochs
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
@ -200,10 +207,34 @@ def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool =
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Create dataloaders with performance optimizations
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=True # Drop incomplete batches for training
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
drop_last=False
)
return {
'train': train_loader,
@ -223,7 +254,12 @@ def train_forecasting_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
use_ps_loss: bool = False,
ps_lambda: float = 5.0,
patch_len_threshold: int = 64,
use_gdw: bool = True,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
@ -241,6 +277,11 @@ def train_forecasting_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
use_ps_loss (bool): Whether to use Patch-wise Structural (PS) loss instead of MSE
ps_lambda (float): Weight for PS loss component when combined with MSE
patch_len_threshold (int): Maximum patch length for adaptive patching
use_gdw (bool): Whether to use Gradient-based Dynamic Weighting
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -271,7 +312,10 @@ def train_forecasting_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -279,14 +323,24 @@ def train_forecasting_model(
model = model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss()
if use_ps_loss:
criterion = PSLoss(
patch_len_threshold=patch_len_threshold,
lambda_ps=ps_lambda,
use_gdw=use_gdw
)
else:
criterion = nn.MSELoss()
optimizer = optim.Adam(
model.parameters(),
lr=config.get('learning_rate', 1e-3),
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -334,7 +388,11 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
loss = criterion(outputs, targets)
# Calculate loss
if use_ps_loss:
loss, loss_dict = criterion(outputs, targets, model)
else:
loss = criterion(outputs, targets)
# Backward pass and optimize
loss.backward()
@ -345,10 +403,26 @@ def train_forecasting_model(
interval_loss += loss.item()
if (batch_idx + 1) % log_interval == 0:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
# 计算这一个 interval 的平均损失并记录
avg_interval_loss = interval_loss / log_interval
swanlab_run.log({"batch_train_loss": avg_interval_loss})
if use_ps_loss and 'loss_dict' in locals():
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, "
f"Total Loss: {loss.item():.4f}, "
f"MSE: {loss_dict['mse_loss']:.4f}, "
f"PS: {loss_dict['ps_loss']:.4f}")
# Log detailed loss components
swanlab_run.log({
"batch_total_loss": loss.item(),
"batch_mse_loss": loss_dict['mse_loss'],
"batch_ps_loss": loss_dict['ps_loss'],
"batch_corr_loss": loss_dict['corr_loss'],
"batch_var_loss": loss_dict['var_loss'],
"batch_mean_loss": loss_dict['mean_loss'],
"alpha": loss_dict['alpha'],
"beta": loss_dict['beta'],
"gamma": loss_dict['gamma']
})
else:
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
swanlab_run.log({"batch_train_loss": loss.item()})
# 重置 interval loss 以进行下一次计算
interval_loss = 0.0
@ -360,6 +434,7 @@ def train_forecasting_model(
model.eval()
val_loss = 0.0
val_mse = 0.0
val_mse_criterion = nn.MSELoss() # Always use MSE for validation metrics
with torch.no_grad():
for batch_data in dataloaders['val']:
@ -381,18 +456,28 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate training loss (PS or MSE)
if use_ps_loss:
loss, _ = criterion(outputs, targets, model)
val_loss += loss.item()
else:
loss = criterion(outputs, targets)
val_loss += loss.item()
# Always calculate MSE for validation metrics
mse_loss = val_mse_criterion(outputs, targets)
val_mse += mse_loss.item()
avg_val_loss = val_loss / len(dataloaders['val'])
avg_val_mse = val_mse / len(dataloaders['val'])
current_lr = optimizer.param_groups[0]['lr']
# Log metrics
metrics_dict = {
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
"val_mse": avg_val_mse,
"learning_rate": current_lr,
"epoch_time": epoch_time
}
@ -402,6 +487,7 @@ def train_forecasting_model(
print(f"Epoch {epoch+1}/{max_epochs}, "
f"Train Loss: {avg_train_loss:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, "
f"Val MSE: {avg_val_mse:.4f}, "
f"LR: {current_lr:.6f}, "
f"Time: {epoch_time:.2f}s")
@ -416,16 +502,17 @@ def train_forecasting_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Test evaluation on the best model
# Test evaluation on the best model - Always use MSE for final evaluation
model.eval()
test_loss = 0.0
test_mse = 0.0
mse_criterion = nn.MSELoss() # Always use MSE for test evaluation
print("Evaluating on test set...")
with torch.no_grad():
@ -448,16 +535,16 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
test_loss += loss.item()
# Always calculate MSE for test evaluation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
test_loss += mse_loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency
# Final validation for consistency - Always use MSE for final metrics
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
@ -482,25 +569,31 @@ def train_forecasting_model(
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
final_val_loss += loss.item()
# Always calculate MSE for final validation (for fair comparison)
mse_loss = mse_criterion(outputs, targets)
final_val_loss += mse_loss.item()
final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.6f}")
print(f"Final validation MSE: {final_val_loss:.6f}")
print(f"Final test MSE: {test_loss:.6f}")
if use_ps_loss:
print("Note: Model was trained with PS Loss but evaluated with MSE for fair comparison")
# Log final test results to swanlab
final_metrics = {
"final_test_loss": test_loss,
"final_val_loss": final_val_loss
"final_test_mse": test_loss,
"final_val_mse": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values
# Update metrics with final values (always MSE for comparison)
metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
metrics["final_val_mse"] = final_val_loss # Same as final_val_loss since we use MSE
metrics["final_test_mse"] = test_loss # Same as final_test_loss since we use MSE
# Finish the swanlab run
swanlab_run.finish()
@ -519,7 +612,8 @@ def train_classification_model(
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
dataflow_args = None,
lr_adjust_strategy: str = "type1"
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
@ -537,6 +631,7 @@ def train_classification_model(
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
lr_adjust_strategy (str): Learning rate adjustment strategy - 'type1', 'type2', 'type3', 'sigmoid', 'constant', '3', '4', '5', '6'
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -567,7 +662,10 @@ def train_classification_model(
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
use_x_mark=use_x_mark,
num_workers=config.get('num_workers', 4),
pin_memory=config.get('pin_memory', True),
persistent_workers=config.get('persistent_workers', True)
)
# Construct the model
@ -582,8 +680,11 @@ def train_classification_model(
weight_decay=config.get('weight_decay', 1e-4)
)
# Add learning rate scheduler to halve LR after each epoch
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
# Create args object for learning rate adjustment
lr_args = dotdict({
'learning_rate': config.get('learning_rate', 1e-3),
'lradj': lr_adjust_strategy
})
# Initialize early stopping
early_stopping = EarlyStopping(
@ -722,8 +823,8 @@ def train_classification_model(
print("Early stopping triggered")
break
# Step the learning rate scheduler
scheduler.step()
# Adjust learning rate using utils.tools function
adjust_learning_rate(optimizer, epoch, lr_args)
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))