feat(data): add flexible data loading with dataflow integration and test set

This commit is contained in:
game-loader
2025-08-06 18:36:43 +08:00
parent 6ee5c769c4
commit f977abeea7
3 changed files with 229 additions and 44 deletions

7
.gitignore vendored
View File

@ -1 +1,8 @@
.aider* .aider*
__pycache__/
*.pyc
*.pyo
*.pyd
*.npz
*.gz

View File

@ -7,6 +7,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
import swanlab import swanlab
from typing import Dict, Any, Optional, Callable, Union, Tuple from typing import Dict, Any, Optional, Callable, Union, Tuple
from dataflow import data_provider
class EarlyStopping: class EarlyStopping:
"""Early stopping to stop training when validation performance doesn't improve.""" """Early stopping to stop training when validation performance doesn't improve."""
@ -51,16 +52,103 @@ class EarlyStopping:
torch.save(model.state_dict(), self.path) torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss self.val_loss_min = val_loss
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]: class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
def __init__(self, original_dataset):
self.original_dataset = original_dataset
def __getitem__(self, index):
# Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark)
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
# Return only seq_x and seq_y (remove time features)
return seq_x, seq_y
def __len__(self):
return len(self.original_dataset)
def inverse_transform(self, data):
if hasattr(self.original_dataset, 'inverse_transform'):
return self.original_dataset.inverse_transform(data)
return data
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
"""
Create PyTorch DataLoaders using dataflow data_provider
Args:
args: Arguments object containing dataset configuration
Required attributes: data, root_path, data_path, seq_len, label_len, pred_len,
features, target, embed, freq, batch_size, num_workers, train_only
use_x_mark (bool): Whether to use time features (x_mark and y_mark)
Returns:
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
"""
# Create datasets and dataloaders for each split
train_data, _ = data_provider(args, flag='train')
val_data, _ = data_provider(args, flag='val')
test_data, _ = data_provider(args, flag='test')
# Wrap datasets to respect use_x_mark parameter
if not use_x_mark:
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
# Determine batch size and other parameters based on flag
train_shuffle = True
val_shuffle = False
test_shuffle = False
train_drop_last = True
val_drop_last = True
test_drop_last = True
batch_size = args.batch_size
num_workers = args.num_workers
# Create new dataloaders with wrapped datasets
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=train_shuffle,
num_workers=num_workers,
drop_last=train_drop_last
)
val_loader = DataLoader(
val_data,
batch_size=batch_size,
shuffle=val_shuffle,
num_workers=num_workers,
drop_last=val_drop_last
)
test_loader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=test_shuffle,
num_workers=num_workers,
drop_last=test_drop_last
)
return {
'train': train_loader,
'val': val_loader,
'test': test_loader
}
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
""" """
Create PyTorch DataLoaders from an NPZ file Create PyTorch DataLoaders from an NPZ file
Args: Args:
data_path (str): Path to the NPZ file containing the data data_path (str): Path to the NPZ file containing the data
batch_size (int): Batch size for the DataLoaders batch_size (int): Batch size for the DataLoaders
use_x_mark (bool): Whether to use time features (x_mark) from the data file
Returns: Returns:
Dict[str, DataLoader]: Dictionary with train and val DataLoaders Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
""" """
# Load data from NPZ file # Load data from NPZ file
data = np.load(data_path, allow_pickle=True) data = np.load(data_path, allow_pickle=True)
@ -68,18 +156,32 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
train_y = data['train_y'] train_y = data['train_y']
val_x = data['val_x'] val_x = data['val_x']
val_y = data['val_y'] val_y = data['val_y']
test_x = data['test_x']
test_y = data['test_y']
# Load time features if available # Load time features if available and needed
train_x_mark = data.get('train_x_mark', None) if use_x_mark:
train_y_mark = data.get('train_y_mark', None) train_x_mark = data.get('train_x_mark', None)
val_x_mark = data.get('val_x_mark', None) train_y_mark = data.get('train_y_mark', None)
val_y_mark = data.get('val_y_mark', None) val_x_mark = data.get('val_x_mark', None)
val_y_mark = data.get('val_y_mark', None)
test_x_mark = data.get('test_x_mark', None)
test_y_mark = data.get('test_y_mark', None)
else:
train_x_mark = None
train_y_mark = None
val_x_mark = None
val_y_mark = None
test_x_mark = None
test_y_mark = None
# Convert to PyTorch tensors # Convert to PyTorch tensors
train_x = torch.FloatTensor(train_x) train_x = torch.FloatTensor(train_x)
train_y = torch.FloatTensor(train_y) train_y = torch.FloatTensor(train_y)
val_x = torch.FloatTensor(val_x) val_x = torch.FloatTensor(val_x)
val_y = torch.FloatTensor(val_y) val_y = torch.FloatTensor(val_y)
test_x = torch.FloatTensor(test_x)
test_y = torch.FloatTensor(test_y)
# Create datasets based on whether time features are available # Create datasets based on whether time features are available
if train_x_mark is not None: if train_x_mark is not None:
@ -87,20 +189,26 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
train_y_mark = torch.FloatTensor(train_y_mark) train_y_mark = torch.FloatTensor(train_y_mark)
val_x_mark = torch.FloatTensor(val_x_mark) val_x_mark = torch.FloatTensor(val_x_mark)
val_y_mark = torch.FloatTensor(val_y_mark) val_y_mark = torch.FloatTensor(val_y_mark)
test_x_mark = torch.FloatTensor(test_x_mark)
test_y_mark = torch.FloatTensor(test_y_mark)
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
else: else:
train_dataset = TensorDataset(train_x, train_y) train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y) val_dataset = TensorDataset(val_x, val_y)
test_dataset = TensorDataset(test_x, test_y)
# Create dataloaders # Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return { return {
'train': train_loader, 'train': train_loader,
'val': val_loader 'val': val_loader,
'test': test_loader
} }
def train_forecasting_model( def train_forecasting_model(
@ -112,14 +220,17 @@ def train_forecasting_model(
early_stopping_patience: int = 10, early_stopping_patience: int = 10,
max_epochs: int = 100, max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints", checkpoint_dir: str = "./checkpoints",
log_interval: int = 10 log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
) -> Tuple[nn.Module, Dict[str, float]]: ) -> Tuple[nn.Module, Dict[str, float]]:
""" """
Train a time series forecasting model Train a time series forecasting model
Args: Args:
model_constructor (Callable): Function that constructs and returns the model model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda') device (Optional[str]): Device to use for training ('cpu' or 'cuda')
@ -127,6 +238,9 @@ def train_forecasting_model(
max_epochs (int): Maximum number of epochs to train for max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
Returns: Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -145,11 +259,20 @@ def train_forecasting_model(
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders # Create data loaders based on dataset_mode
dataloaders = create_data_loaders( if dataset_mode == "dataflow":
data_path=data_path, if dataflow_args is None:
batch_size=config.get('batch_size', 32) raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
) dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
)
# Construct the model # Construct the model
model = model_constructor() model = model_constructor()
@ -206,7 +329,6 @@ def train_forecasting_model(
if x_mark is not None: if x_mark is not None:
# For TimesNet model with time features # For TimesNet model with time features
# Create decoder input (zeros for forecasting) # Create decoder input (zeros for forecasting)
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark) outputs = model(inputs, x_mark)
else: else:
# For simple models without time features # For simple models without time features
@ -244,17 +366,16 @@ def train_forecasting_model(
# Handle both cases: with and without time features # Handle both cases: with and without time features
if len(batch_data) == 4: # With time features if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features else: # Without time features
inputs, targets = batch_data inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None x_mark, y_mark = None, None
# Forward pass - handle both cases # Forward pass - handle both cases
if x_mark is not None: if x_mark is not None:
# For TimesNet model with time features # For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device)
outputs = model(inputs, x_mark) outputs = model(inputs, x_mark)
else: else:
# For simple models without time features # For simple models without time features
@ -301,7 +422,42 @@ def train_forecasting_model(
# Load the best model # Load the best model
model.load_state_dict(torch.load(checkpoint_path)) model.load_state_dict(torch.load(checkpoint_path))
# Final validation # Test evaluation on the best model
model.eval()
test_loss = 0.0
test_mse = 0.0
print("Evaluating on test set...")
with torch.no_grad():
for batch_data in dataloaders['test']:
# Handle both cases: with and without time features
if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features
inputs, targets = batch_data
inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None
# Forward pass - handle both cases
if x_mark is not None:
# For TimesNet model with time features
outputs = model(inputs, x_mark)
else:
# For simple models without time features
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
test_loss += loss.item()
test_loss /= len(dataloaders['test'])
print(f"Test evaluation completed!")
print(f"Test Loss (MSE): {test_loss:.6f}")
# Final validation for consistency
model.eval() model.eval()
final_val_loss = 0.0 final_val_loss = 0.0
final_val_mse = 0.0 final_val_mse = 0.0
@ -311,18 +467,17 @@ def train_forecasting_model(
# Handle both cases: with and without time features # Handle both cases: with and without time features
if len(batch_data) == 4: # With time features if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features else: # Without time features
inputs, targets = batch_data inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None x_mark, y_mark = None, None
# Forward pass - handle both cases # Forward pass - handle both cases
if x_mark is not None: if x_mark is not None:
# For TimesNet model with time features # For TimesNet model with time features
dec_inp = torch.zeros_like(targets).to(device) outputs = model(inputs, x_mark)
outputs = model(inputs, x_mark, dec_inp, y_mark)
else: else:
# For simple models without time features # For simple models without time features
outputs = model(inputs) outputs = model(inputs)
@ -334,10 +489,18 @@ def train_forecasting_model(
final_val_loss /= len(dataloaders['val']) final_val_loss /= len(dataloaders['val'])
print(f"Final validation loss: {final_val_loss:.4f}") print(f"Final validation loss: {final_val_loss:.6f}")
# Log final test results to swanlab
final_metrics = {
"final_test_loss": test_loss,
"final_val_loss": final_val_loss
}
swanlab_run.log(final_metrics)
# Update metrics with final values # Update metrics with final values
metrics["final_val_loss"] = final_val_loss metrics["final_val_loss"] = final_val_loss
metrics["final_test_loss"] = test_loss
# Finish the swanlab run # Finish the swanlab run
swanlab_run.finish() swanlab_run.finish()
@ -353,14 +516,17 @@ def train_classification_model(
early_stopping_patience: int = 10, early_stopping_patience: int = 10,
max_epochs: int = 100, max_epochs: int = 100,
checkpoint_dir: str = "./checkpoints", checkpoint_dir: str = "./checkpoints",
log_interval: int = 10 log_interval: int = 10,
use_x_mark: bool = True,
dataset_mode: str = "npz",
dataflow_args = None
) -> Tuple[nn.Module, Dict[str, float]]: ) -> Tuple[nn.Module, Dict[str, float]]:
""" """
Train a time series classification model Train a time series classification model
Args: Args:
model_constructor (Callable): Function that constructs and returns the model model_constructor (Callable): Function that constructs and returns the model
data_path (str): Path to the NPZ file containing the processed data data_path (str): Path to the NPZ file containing the processed data (for npz mode)
project_name (str): Name of the project for swanlab tracking project_name (str): Name of the project for swanlab tracking
config (Dict[str, Any]): Configuration dictionary for the experiment config (Dict[str, Any]): Configuration dictionary for the experiment
device (Optional[str]): Device to use for training ('cpu' or 'cuda') device (Optional[str]): Device to use for training ('cpu' or 'cuda')
@ -368,6 +534,9 @@ def train_classification_model(
max_epochs (int): Maximum number of epochs to train for max_epochs (int): Maximum number of epochs to train for
checkpoint_dir (str): Directory to save model checkpoints checkpoint_dir (str): Directory to save model checkpoints
log_interval (int): How often to log metrics during training log_interval (int): How often to log metrics during training
use_x_mark (bool): Whether to use time features (x_mark) from the data file
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
Returns: Returns:
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
@ -386,11 +555,20 @@ def train_classification_model(
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
# Create data loaders # Create data loaders based on dataset_mode
dataloaders = create_data_loaders( if dataset_mode == "dataflow":
data_path=data_path, if dataflow_args is None:
batch_size=config.get('batch_size', 32) raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
) dataloaders = create_data_loaders_from_dataflow(
dataflow_args,
use_x_mark=use_x_mark
)
else: # Default to "npz" mode
dataloaders = create_data_loaders(
data_path=data_path,
batch_size=config.get('batch_size', 32),
use_x_mark=use_x_mark
)
# Construct the model # Construct the model
model = model_constructor() model = model_constructor()
@ -432,11 +610,11 @@ def train_classification_model(
# Handle both cases: with and without time features # Handle both cases: with and without time features
if len(batch_data) == 4: # With time features if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features else: # Without time features
inputs, targets = batch_data inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None x_mark, y_mark = None, None
# Convert targets to long for classification # Convert targets to long for classification
@ -484,11 +662,11 @@ def train_classification_model(
# Handle both cases: with and without time features # Handle both cases: with and without time features
if len(batch_data) == 4: # With time features if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features else: # Without time features
inputs, targets = batch_data inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None x_mark, y_mark = None, None
targets = targets.long() targets = targets.long()
@ -561,11 +739,11 @@ def train_classification_model(
# Handle both cases: with and without time features # Handle both cases: with and without time features
if len(batch_data) == 4: # With time features if len(batch_data) == 4: # With time features
inputs, targets, x_mark, y_mark = batch_data inputs, targets, x_mark, y_mark = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = x_mark.to(device), y_mark.to(device) x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
else: # Without time features else: # Without time features
inputs, targets = batch_data inputs, targets = batch_data
inputs, targets = inputs.to(device), targets.to(device) inputs, targets = inputs.float().to(device), targets.float().to(device)
x_mark, y_mark = None, None x_mark, y_mark = None, None
targets = targets.long() targets = targets.long()

View File

@ -145,4 +145,4 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
def time_features(dates, freq='h'): def time_features(dates, freq='h'):
return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])