feat(data): add flexible data loading with dataflow integration and test set
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@ -1 +1,8 @@
|
||||
.aider*
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.npz
|
||||
*.gz
|
||||
|
||||
|
244
train/train.py
244
train/train.py
@ -7,6 +7,7 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
from dataflow import data_provider
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
@ -51,16 +52,103 @@ class EarlyStopping:
|
||||
torch.save(model.state_dict(), self.path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]:
|
||||
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
|
||||
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
|
||||
def __init__(self, original_dataset):
|
||||
self.original_dataset = original_dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark)
|
||||
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
|
||||
# Return only seq_x and seq_y (remove time features)
|
||||
return seq_x, seq_y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.original_dataset)
|
||||
|
||||
def inverse_transform(self, data):
|
||||
if hasattr(self.original_dataset, 'inverse_transform'):
|
||||
return self.original_dataset.inverse_transform(data)
|
||||
return data
|
||||
|
||||
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders using dataflow data_provider
|
||||
|
||||
Args:
|
||||
args: Arguments object containing dataset configuration
|
||||
Required attributes: data, root_path, data_path, seq_len, label_len, pred_len,
|
||||
features, target, embed, freq, batch_size, num_workers, train_only
|
||||
use_x_mark (bool): Whether to use time features (x_mark and y_mark)
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
"""
|
||||
# Create datasets and dataloaders for each split
|
||||
train_data, _ = data_provider(args, flag='train')
|
||||
val_data, _ = data_provider(args, flag='val')
|
||||
test_data, _ = data_provider(args, flag='test')
|
||||
|
||||
# Wrap datasets to respect use_x_mark parameter
|
||||
if not use_x_mark:
|
||||
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
|
||||
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
|
||||
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
|
||||
|
||||
# Determine batch size and other parameters based on flag
|
||||
train_shuffle = True
|
||||
val_shuffle = False
|
||||
test_shuffle = False
|
||||
|
||||
train_drop_last = True
|
||||
val_drop_last = True
|
||||
test_drop_last = True
|
||||
|
||||
batch_size = args.batch_size
|
||||
num_workers = args.num_workers
|
||||
|
||||
# Create new dataloaders with wrapped datasets
|
||||
train_loader = DataLoader(
|
||||
train_data,
|
||||
batch_size=batch_size,
|
||||
shuffle=train_shuffle,
|
||||
num_workers=num_workers,
|
||||
drop_last=train_drop_last
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_data,
|
||||
batch_size=batch_size,
|
||||
shuffle=val_shuffle,
|
||||
num_workers=num_workers,
|
||||
drop_last=val_drop_last
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_data,
|
||||
batch_size=batch_size,
|
||||
shuffle=test_shuffle,
|
||||
num_workers=num_workers,
|
||||
drop_last=test_drop_last
|
||||
)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
'val': val_loader,
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
Args:
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train and val DataLoaders
|
||||
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||
"""
|
||||
# Load data from NPZ file
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
@ -68,18 +156,32 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
|
||||
train_y = data['train_y']
|
||||
val_x = data['val_x']
|
||||
val_y = data['val_y']
|
||||
test_x = data['test_x']
|
||||
test_y = data['test_y']
|
||||
|
||||
# Load time features if available
|
||||
# Load time features if available and needed
|
||||
if use_x_mark:
|
||||
train_x_mark = data.get('train_x_mark', None)
|
||||
train_y_mark = data.get('train_y_mark', None)
|
||||
val_x_mark = data.get('val_x_mark', None)
|
||||
val_y_mark = data.get('val_y_mark', None)
|
||||
test_x_mark = data.get('test_x_mark', None)
|
||||
test_y_mark = data.get('test_y_mark', None)
|
||||
else:
|
||||
train_x_mark = None
|
||||
train_y_mark = None
|
||||
val_x_mark = None
|
||||
val_y_mark = None
|
||||
test_x_mark = None
|
||||
test_y_mark = None
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
train_x = torch.FloatTensor(train_x)
|
||||
train_y = torch.FloatTensor(train_y)
|
||||
val_x = torch.FloatTensor(val_x)
|
||||
val_y = torch.FloatTensor(val_y)
|
||||
test_x = torch.FloatTensor(test_x)
|
||||
test_y = torch.FloatTensor(test_y)
|
||||
|
||||
# Create datasets based on whether time features are available
|
||||
if train_x_mark is not None:
|
||||
@ -87,20 +189,26 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
|
||||
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||
test_x_mark = torch.FloatTensor(test_x_mark)
|
||||
test_y_mark = torch.FloatTensor(test_y_mark)
|
||||
|
||||
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
|
||||
else:
|
||||
train_dataset = TensorDataset(train_x, train_y)
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
test_dataset = TensorDataset(test_x, test_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
'val': val_loader
|
||||
'val': val_loader,
|
||||
'test': test_loader
|
||||
}
|
||||
|
||||
def train_forecasting_model(
|
||||
@ -112,14 +220,17 @@ def train_forecasting_model(
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series forecasting model
|
||||
|
||||
Args:
|
||||
model_constructor (Callable): Function that constructs and returns the model
|
||||
data_path (str): Path to the NPZ file containing the processed data
|
||||
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
|
||||
project_name (str): Name of the project for swanlab tracking
|
||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||
@ -127,6 +238,9 @@ def train_forecasting_model(
|
||||
max_epochs (int): Maximum number of epochs to train for
|
||||
checkpoint_dir (str): Directory to save model checkpoints
|
||||
log_interval (int): How often to log metrics during training
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -145,10 +259,19 @@ def train_forecasting_model(
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders
|
||||
# Create data loaders based on dataset_mode
|
||||
if dataset_mode == "dataflow":
|
||||
if dataflow_args is None:
|
||||
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
|
||||
dataloaders = create_data_loaders_from_dataflow(
|
||||
dataflow_args,
|
||||
use_x_mark=use_x_mark
|
||||
)
|
||||
else: # Default to "npz" mode
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32)
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -206,7 +329,6 @@ def train_forecasting_model(
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
# Create decoder input (zeros for forecasting)
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
@ -244,17 +366,16 @@ def train_forecasting_model(
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
@ -301,7 +422,42 @@ def train_forecasting_model(
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
# Final validation
|
||||
# Test evaluation on the best model
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
test_mse = 0.0
|
||||
|
||||
print("Evaluating on test set...")
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['test']:
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
|
||||
test_loss /= len(dataloaders['test'])
|
||||
|
||||
print(f"Test evaluation completed!")
|
||||
print(f"Test Loss (MSE): {test_loss:.6f}")
|
||||
|
||||
# Final validation for consistency
|
||||
model.eval()
|
||||
final_val_loss = 0.0
|
||||
final_val_mse = 0.0
|
||||
@ -311,18 +467,17 @@ def train_forecasting_model(
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
@ -334,10 +489,18 @@ def train_forecasting_model(
|
||||
|
||||
final_val_loss /= len(dataloaders['val'])
|
||||
|
||||
print(f"Final validation loss: {final_val_loss:.4f}")
|
||||
print(f"Final validation loss: {final_val_loss:.6f}")
|
||||
|
||||
# Log final test results to swanlab
|
||||
final_metrics = {
|
||||
"final_test_loss": test_loss,
|
||||
"final_val_loss": final_val_loss
|
||||
}
|
||||
swanlab_run.log(final_metrics)
|
||||
|
||||
# Update metrics with final values
|
||||
metrics["final_val_loss"] = final_val_loss
|
||||
metrics["final_test_loss"] = test_loss
|
||||
|
||||
# Finish the swanlab run
|
||||
swanlab_run.finish()
|
||||
@ -353,14 +516,17 @@ def train_classification_model(
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10
|
||||
log_interval: int = 10,
|
||||
use_x_mark: bool = True,
|
||||
dataset_mode: str = "npz",
|
||||
dataflow_args = None
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series classification model
|
||||
|
||||
Args:
|
||||
model_constructor (Callable): Function that constructs and returns the model
|
||||
data_path (str): Path to the NPZ file containing the processed data
|
||||
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
|
||||
project_name (str): Name of the project for swanlab tracking
|
||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||
@ -368,6 +534,9 @@ def train_classification_model(
|
||||
max_epochs (int): Maximum number of epochs to train for
|
||||
checkpoint_dir (str): Directory to save model checkpoints
|
||||
log_interval (int): How often to log metrics during training
|
||||
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
@ -386,10 +555,19 @@ def train_classification_model(
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders
|
||||
# Create data loaders based on dataset_mode
|
||||
if dataset_mode == "dataflow":
|
||||
if dataflow_args is None:
|
||||
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
|
||||
dataloaders = create_data_loaders_from_dataflow(
|
||||
dataflow_args,
|
||||
use_x_mark=use_x_mark
|
||||
)
|
||||
else: # Default to "npz" mode
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32)
|
||||
batch_size=config.get('batch_size', 32),
|
||||
use_x_mark=use_x_mark
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
@ -432,11 +610,11 @@ def train_classification_model(
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Convert targets to long for classification
|
||||
@ -484,11 +662,11 @@ def train_classification_model(
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
targets = targets.long()
|
||||
@ -561,11 +739,11 @@ def train_classification_model(
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
targets = targets.long()
|
||||
|
Reference in New Issue
Block a user