feat(data): add flexible data loading with dataflow integration and test set

This commit is contained in:
game-loader
2025-08-06 18:36:43 +08:00
parent 6ee5c769c4
commit f977abeea7
3 changed files with 229 additions and 44 deletions

7
.gitignore vendored
View File

@ -1 +1,8 @@
.aider*
__pycache__/
*.pyc
*.pyo
*.pyd
*.npz
*.gz

View File

@ -7,6 +7,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve."""
@ -51,16 +52,103 @@ class EarlyStopping:
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]:
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
# Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark)
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
# Return only seq_x and seq_y (remove time features)
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders using dataflow data_provider
Args:
args: Arguments object containing dataset configuration
Required attributes: data, root_path, data_path, seq_len, label_len, pred_len,
features, target, embed, freq, batch_size, num_workers, train_only
use_x_mark (bool): Whether to use time features (x_mark and y_mark)
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Create datasets and dataloaders for each split
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
# Wrap datasets to respect use_x_mark parameter
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
# Determine batch size and other parameters based on flag
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
# Create new dataloaders with wrapped datasets
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=train_shuffle,
num_workers=num_workers,
drop_last=train_drop_last
)
val_loader = DataLoader(
val_data,
batch_size=batch_size,
shuffle=val_shuffle,
num_workers=num_workers,
drop_last=val_drop_last
)
test_loader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=test_shuffle,
num_workers=num_workers,
drop_last=test_drop_last
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders from an NPZ file
Args:
data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
Returns:
Dict[str, DataLoader]: Dictionary with train and val DataLoaders
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Load data from NPZ file
data = np.load(data_path, allow_pickle=True)
@ -68,18 +156,32 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
train_y = data['train_y']
val_x = data['val_x']
val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available
# Load time features if available and needed
if use_x_mark:
train_x_mark = data.get('train_x_mark', None)
train_y_mark = data.get('train_y_mark', None)
val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available
if train_x_mark is not None:
@ -87,20 +189,26 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else:
train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return {
'train': train_loader,
'val': val_loader
'val': val_loader,
'test': test_loader
}
def train_forecasting_model(
@ -112,14 +220,17 @@ def train_forecasting_model(
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series forecasting model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
@ -127,6 +238,9 @@ def train_forecasting_model(
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -145,10 +259,19 @@ def train_forecasting_model(
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders
# Create data loaders based on dataset_mode
if dataset_mode == "dataflow":
if dataflow_args is None:
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32)
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
)
# Construct the model
@ -206,7 +329,6 @@ def train_forecasting_model(
if x_mark is not None:
# For TimesNet model with time features
# Create decoder input (zeros for forecasting)
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
@ -244,17 +366,16 @@ def train_forecasting_model(
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
@ -301,7 +422,42 @@ def train_forecasting_model(
# Load the best model
model.load_state_dict(torch.load(checkpoint_path))
# Final validation
# Test evaluation on the best model
model.eval()
test_loss = 0.0
test_mse = 0.0
print("Evaluating on test set...")
with torch.no_grad():
for batch_data in dataloaders['test']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
test_loss += loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency
model.eval()
final_val_loss = 0.0
final_val_mse = 0.0
@ -311,18 +467,17 @@ def train_forecasting_model(
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark, dec_inp, y_mark)
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
@ -334,10 +489,18 @@ def train_forecasting_model(
final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.4f}")
print(f"Final validation loss: {final_val_loss:.6f}")
# Log final test results to swanlab
final_metrics = {
"final_test_loss": test_loss,
"final_val_loss": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values
metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
# Finish the swanlab run
swanlab_run.finish()
@ -353,14 +516,17 @@ def train_classification_model(
early_stopping_patience: int = 10,
max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints",
log_interval: int = 10
log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
) -> Tuple[nn.Module, Dict[str, float]]:
"""
Train a time series classification model
Args:
model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
@ -368,6 +534,9 @@ def train_classification_model(
max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -386,10 +555,19 @@ def train_classification_model(
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders
# Create data loaders based on dataset_mode
if dataset_mode == "dataflow":
if dataflow_args is None:
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32)
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
)
# Construct the model
@ -432,11 +610,11 @@ def train_classification_model(
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Convert targets to long for classification
@ -484,11 +662,11 @@ def train_classification_model(
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
targets = targets.long()
@ -561,11 +739,11 @@ def train_classification_model(
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device)
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
targets = targets.long()