feat(data): add flexible data loading with dataflow integration and test set
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@ -1 +1,8 @@
|
|||||||
.aider*
|
.aider*
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.npz
|
||||||
|
*.gz
|
||||||
|
|
||||||
|
244
train/train.py
244
train/train.py
@ -7,6 +7,7 @@ import torch.optim as optim
|
|||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
import swanlab
|
import swanlab
|
||||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||||
|
from dataflow import data_provider
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||||
@ -51,16 +52,103 @@ class EarlyStopping:
|
|||||||
torch.save(model.state_dict(), self.path)
|
torch.save(model.state_dict(), self.path)
|
||||||
self.val_loss_min = val_loss
|
self.val_loss_min = val_loss
|
||||||
|
|
||||||
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]:
|
class DatasetWrapperWithoutTimeFeatures(torch.utils.data.Dataset):
|
||||||
|
"""Wrapper to remove time features from dataflow datasets when use_x_mark=False"""
|
||||||
|
def __init__(self, original_dataset):
|
||||||
|
self.original_dataset = original_dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# Get original data (seq_x, seq_y, seq_x_mark, seq_y_mark)
|
||||||
|
seq_x, seq_y, seq_x_mark, seq_y_mark = self.original_dataset[index]
|
||||||
|
# Return only seq_x and seq_y (remove time features)
|
||||||
|
return seq_x, seq_y
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.original_dataset)
|
||||||
|
|
||||||
|
def inverse_transform(self, data):
|
||||||
|
if hasattr(self.original_dataset, 'inverse_transform'):
|
||||||
|
return self.original_dataset.inverse_transform(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def create_data_loaders_from_dataflow(args, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create PyTorch DataLoaders using dataflow data_provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments object containing dataset configuration
|
||||||
|
Required attributes: data, root_path, data_path, seq_len, label_len, pred_len,
|
||||||
|
features, target, embed, freq, batch_size, num_workers, train_only
|
||||||
|
use_x_mark (bool): Whether to use time features (x_mark and y_mark)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||||
|
"""
|
||||||
|
# Create datasets and dataloaders for each split
|
||||||
|
train_data, _ = data_provider(args, flag='train')
|
||||||
|
val_data, _ = data_provider(args, flag='val')
|
||||||
|
test_data, _ = data_provider(args, flag='test')
|
||||||
|
|
||||||
|
# Wrap datasets to respect use_x_mark parameter
|
||||||
|
if not use_x_mark:
|
||||||
|
train_data = DatasetWrapperWithoutTimeFeatures(train_data)
|
||||||
|
val_data = DatasetWrapperWithoutTimeFeatures(val_data)
|
||||||
|
test_data = DatasetWrapperWithoutTimeFeatures(test_data)
|
||||||
|
|
||||||
|
# Determine batch size and other parameters based on flag
|
||||||
|
train_shuffle = True
|
||||||
|
val_shuffle = False
|
||||||
|
test_shuffle = False
|
||||||
|
|
||||||
|
train_drop_last = True
|
||||||
|
val_drop_last = True
|
||||||
|
test_drop_last = True
|
||||||
|
|
||||||
|
batch_size = args.batch_size
|
||||||
|
num_workers = args.num_workers
|
||||||
|
|
||||||
|
# Create new dataloaders with wrapped datasets
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=train_shuffle,
|
||||||
|
num_workers=num_workers,
|
||||||
|
drop_last=train_drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=val_shuffle,
|
||||||
|
num_workers=num_workers,
|
||||||
|
drop_last=val_drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=test_shuffle,
|
||||||
|
num_workers=num_workers,
|
||||||
|
drop_last=test_drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'train': train_loader,
|
||||||
|
'val': val_loader,
|
||||||
|
'test': test_loader
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_data_loaders(data_path: str, batch_size: int = 32, use_x_mark: bool = True) -> Dict[str, DataLoader]:
|
||||||
"""
|
"""
|
||||||
Create PyTorch DataLoaders from an NPZ file
|
Create PyTorch DataLoaders from an NPZ file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_path (str): Path to the NPZ file containing the data
|
data_path (str): Path to the NPZ file containing the data
|
||||||
batch_size (int): Batch size for the DataLoaders
|
batch_size (int): Batch size for the DataLoaders
|
||||||
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, DataLoader]: Dictionary with train and val DataLoaders
|
Dict[str, DataLoader]: Dictionary with train, val, and test DataLoaders
|
||||||
"""
|
"""
|
||||||
# Load data from NPZ file
|
# Load data from NPZ file
|
||||||
data = np.load(data_path, allow_pickle=True)
|
data = np.load(data_path, allow_pickle=True)
|
||||||
@ -68,18 +156,32 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
|
|||||||
train_y = data['train_y']
|
train_y = data['train_y']
|
||||||
val_x = data['val_x']
|
val_x = data['val_x']
|
||||||
val_y = data['val_y']
|
val_y = data['val_y']
|
||||||
|
test_x = data['test_x']
|
||||||
|
test_y = data['test_y']
|
||||||
|
|
||||||
# Load time features if available
|
# Load time features if available and needed
|
||||||
|
if use_x_mark:
|
||||||
train_x_mark = data.get('train_x_mark', None)
|
train_x_mark = data.get('train_x_mark', None)
|
||||||
train_y_mark = data.get('train_y_mark', None)
|
train_y_mark = data.get('train_y_mark', None)
|
||||||
val_x_mark = data.get('val_x_mark', None)
|
val_x_mark = data.get('val_x_mark', None)
|
||||||
val_y_mark = data.get('val_y_mark', None)
|
val_y_mark = data.get('val_y_mark', None)
|
||||||
|
test_x_mark = data.get('test_x_mark', None)
|
||||||
|
test_y_mark = data.get('test_y_mark', None)
|
||||||
|
else:
|
||||||
|
train_x_mark = None
|
||||||
|
train_y_mark = None
|
||||||
|
val_x_mark = None
|
||||||
|
val_y_mark = None
|
||||||
|
test_x_mark = None
|
||||||
|
test_y_mark = None
|
||||||
|
|
||||||
# Convert to PyTorch tensors
|
# Convert to PyTorch tensors
|
||||||
train_x = torch.FloatTensor(train_x)
|
train_x = torch.FloatTensor(train_x)
|
||||||
train_y = torch.FloatTensor(train_y)
|
train_y = torch.FloatTensor(train_y)
|
||||||
val_x = torch.FloatTensor(val_x)
|
val_x = torch.FloatTensor(val_x)
|
||||||
val_y = torch.FloatTensor(val_y)
|
val_y = torch.FloatTensor(val_y)
|
||||||
|
test_x = torch.FloatTensor(test_x)
|
||||||
|
test_y = torch.FloatTensor(test_y)
|
||||||
|
|
||||||
# Create datasets based on whether time features are available
|
# Create datasets based on whether time features are available
|
||||||
if train_x_mark is not None:
|
if train_x_mark is not None:
|
||||||
@ -87,20 +189,26 @@ def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataL
|
|||||||
train_y_mark = torch.FloatTensor(train_y_mark)
|
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||||
val_x_mark = torch.FloatTensor(val_x_mark)
|
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||||
val_y_mark = torch.FloatTensor(val_y_mark)
|
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||||
|
test_x_mark = torch.FloatTensor(test_x_mark)
|
||||||
|
test_y_mark = torch.FloatTensor(test_y_mark)
|
||||||
|
|
||||||
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||||
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||||
|
test_dataset = TensorDataset(test_x, test_y, test_x_mark, test_y_mark)
|
||||||
else:
|
else:
|
||||||
train_dataset = TensorDataset(train_x, train_y)
|
train_dataset = TensorDataset(train_x, train_y)
|
||||||
val_dataset = TensorDataset(val_x, val_y)
|
val_dataset = TensorDataset(val_x, val_y)
|
||||||
|
test_dataset = TensorDataset(test_x, test_y)
|
||||||
|
|
||||||
# Create dataloaders
|
# Create dataloaders
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'train': train_loader,
|
'train': train_loader,
|
||||||
'val': val_loader
|
'val': val_loader,
|
||||||
|
'test': test_loader
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_forecasting_model(
|
def train_forecasting_model(
|
||||||
@ -112,14 +220,17 @@ def train_forecasting_model(
|
|||||||
early_stopping_patience: int = 10,
|
early_stopping_patience: int = 10,
|
||||||
max_epochs: int = 100,
|
max_epochs: int = 100,
|
||||||
checkpoint_dir: str = "./checkpoints",
|
checkpoint_dir: str = "./checkpoints",
|
||||||
log_interval: int = 10
|
log_interval: int = 10,
|
||||||
|
use_x_mark: bool = True,
|
||||||
|
dataset_mode: str = "npz",
|
||||||
|
dataflow_args = None
|
||||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Train a time series forecasting model
|
Train a time series forecasting model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_constructor (Callable): Function that constructs and returns the model
|
model_constructor (Callable): Function that constructs and returns the model
|
||||||
data_path (str): Path to the NPZ file containing the processed data
|
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
|
||||||
project_name (str): Name of the project for swanlab tracking
|
project_name (str): Name of the project for swanlab tracking
|
||||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||||
@ -127,6 +238,9 @@ def train_forecasting_model(
|
|||||||
max_epochs (int): Maximum number of epochs to train for
|
max_epochs (int): Maximum number of epochs to train for
|
||||||
checkpoint_dir (str): Directory to save model checkpoints
|
checkpoint_dir (str): Directory to save model checkpoints
|
||||||
log_interval (int): How often to log metrics during training
|
log_interval (int): How often to log metrics during training
|
||||||
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
|
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||||
|
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||||
@ -145,10 +259,19 @@ def train_forecasting_model(
|
|||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||||
|
|
||||||
# Create data loaders
|
# Create data loaders based on dataset_mode
|
||||||
|
if dataset_mode == "dataflow":
|
||||||
|
if dataflow_args is None:
|
||||||
|
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
|
||||||
|
dataloaders = create_data_loaders_from_dataflow(
|
||||||
|
dataflow_args,
|
||||||
|
use_x_mark=use_x_mark
|
||||||
|
)
|
||||||
|
else: # Default to "npz" mode
|
||||||
dataloaders = create_data_loaders(
|
dataloaders = create_data_loaders(
|
||||||
data_path=data_path,
|
data_path=data_path,
|
||||||
batch_size=config.get('batch_size', 32)
|
batch_size=config.get('batch_size', 32),
|
||||||
|
use_x_mark=use_x_mark
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
@ -206,7 +329,6 @@ def train_forecasting_model(
|
|||||||
if x_mark is not None:
|
if x_mark is not None:
|
||||||
# For TimesNet model with time features
|
# For TimesNet model with time features
|
||||||
# Create decoder input (zeros for forecasting)
|
# Create decoder input (zeros for forecasting)
|
||||||
dec_inp = torch.zeros_like(targets).to(device)
|
|
||||||
outputs = model(inputs, x_mark)
|
outputs = model(inputs, x_mark)
|
||||||
else:
|
else:
|
||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
@ -244,17 +366,16 @@ def train_forecasting_model(
|
|||||||
# Handle both cases: with and without time features
|
# Handle both cases: with and without time features
|
||||||
if len(batch_data) == 4: # With time features
|
if len(batch_data) == 4: # With time features
|
||||||
inputs, targets, x_mark, y_mark = batch_data
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
else: # Without time features
|
else: # Without time features
|
||||||
inputs, targets = batch_data
|
inputs, targets = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = None, None
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
# Forward pass - handle both cases
|
# Forward pass - handle both cases
|
||||||
if x_mark is not None:
|
if x_mark is not None:
|
||||||
# For TimesNet model with time features
|
# For TimesNet model with time features
|
||||||
dec_inp = torch.zeros_like(targets).to(device)
|
|
||||||
outputs = model(inputs, x_mark)
|
outputs = model(inputs, x_mark)
|
||||||
else:
|
else:
|
||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
@ -301,7 +422,42 @@ def train_forecasting_model(
|
|||||||
# Load the best model
|
# Load the best model
|
||||||
model.load_state_dict(torch.load(checkpoint_path))
|
model.load_state_dict(torch.load(checkpoint_path))
|
||||||
|
|
||||||
# Final validation
|
# Test evaluation on the best model
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0.0
|
||||||
|
test_mse = 0.0
|
||||||
|
|
||||||
|
print("Evaluating on test set...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_data in dataloaders['test']:
|
||||||
|
# Handle both cases: with and without time features
|
||||||
|
if len(batch_data) == 4: # With time features
|
||||||
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
|
else: # Without time features
|
||||||
|
inputs, targets = batch_data
|
||||||
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
|
# Forward pass - handle both cases
|
||||||
|
if x_mark is not None:
|
||||||
|
# For TimesNet model with time features
|
||||||
|
outputs = model(inputs, x_mark)
|
||||||
|
else:
|
||||||
|
# For simple models without time features
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
test_loss += loss.item()
|
||||||
|
|
||||||
|
test_loss /= len(dataloaders['test'])
|
||||||
|
|
||||||
|
print(f"Test evaluation completed!")
|
||||||
|
print(f"Test Loss (MSE): {test_loss:.6f}")
|
||||||
|
|
||||||
|
# Final validation for consistency
|
||||||
model.eval()
|
model.eval()
|
||||||
final_val_loss = 0.0
|
final_val_loss = 0.0
|
||||||
final_val_mse = 0.0
|
final_val_mse = 0.0
|
||||||
@ -311,18 +467,17 @@ def train_forecasting_model(
|
|||||||
# Handle both cases: with and without time features
|
# Handle both cases: with and without time features
|
||||||
if len(batch_data) == 4: # With time features
|
if len(batch_data) == 4: # With time features
|
||||||
inputs, targets, x_mark, y_mark = batch_data
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
else: # Without time features
|
else: # Without time features
|
||||||
inputs, targets = batch_data
|
inputs, targets = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = None, None
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
# Forward pass - handle both cases
|
# Forward pass - handle both cases
|
||||||
if x_mark is not None:
|
if x_mark is not None:
|
||||||
# For TimesNet model with time features
|
# For TimesNet model with time features
|
||||||
dec_inp = torch.zeros_like(targets).to(device)
|
outputs = model(inputs, x_mark)
|
||||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
|
||||||
else:
|
else:
|
||||||
# For simple models without time features
|
# For simple models without time features
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
@ -334,10 +489,18 @@ def train_forecasting_model(
|
|||||||
|
|
||||||
final_val_loss /= len(dataloaders['val'])
|
final_val_loss /= len(dataloaders['val'])
|
||||||
|
|
||||||
print(f"Final validation loss: {final_val_loss:.4f}")
|
print(f"Final validation loss: {final_val_loss:.6f}")
|
||||||
|
|
||||||
|
# Log final test results to swanlab
|
||||||
|
final_metrics = {
|
||||||
|
"final_test_loss": test_loss,
|
||||||
|
"final_val_loss": final_val_loss
|
||||||
|
}
|
||||||
|
swanlab_run.log(final_metrics)
|
||||||
|
|
||||||
# Update metrics with final values
|
# Update metrics with final values
|
||||||
metrics["final_val_loss"] = final_val_loss
|
metrics["final_val_loss"] = final_val_loss
|
||||||
|
metrics["final_test_loss"] = test_loss
|
||||||
|
|
||||||
# Finish the swanlab run
|
# Finish the swanlab run
|
||||||
swanlab_run.finish()
|
swanlab_run.finish()
|
||||||
@ -353,14 +516,17 @@ def train_classification_model(
|
|||||||
early_stopping_patience: int = 10,
|
early_stopping_patience: int = 10,
|
||||||
max_epochs: int = 100,
|
max_epochs: int = 100,
|
||||||
checkpoint_dir: str = "./checkpoints",
|
checkpoint_dir: str = "./checkpoints",
|
||||||
log_interval: int = 10
|
log_interval: int = 10,
|
||||||
|
use_x_mark: bool = True,
|
||||||
|
dataset_mode: str = "npz",
|
||||||
|
dataflow_args = None
|
||||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Train a time series classification model
|
Train a time series classification model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_constructor (Callable): Function that constructs and returns the model
|
model_constructor (Callable): Function that constructs and returns the model
|
||||||
data_path (str): Path to the NPZ file containing the processed data
|
data_path (str): Path to the NPZ file containing the processed data (for npz mode)
|
||||||
project_name (str): Name of the project for swanlab tracking
|
project_name (str): Name of the project for swanlab tracking
|
||||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||||
@ -368,6 +534,9 @@ def train_classification_model(
|
|||||||
max_epochs (int): Maximum number of epochs to train for
|
max_epochs (int): Maximum number of epochs to train for
|
||||||
checkpoint_dir (str): Directory to save model checkpoints
|
checkpoint_dir (str): Directory to save model checkpoints
|
||||||
log_interval (int): How often to log metrics during training
|
log_interval (int): How often to log metrics during training
|
||||||
|
use_x_mark (bool): Whether to use time features (x_mark) from the data file
|
||||||
|
dataset_mode (str): Dataset construction mode - "npz" or "dataflow"
|
||||||
|
dataflow_args: Arguments object for dataflow mode (required if dataset_mode="dataflow")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||||
@ -386,10 +555,19 @@ def train_classification_model(
|
|||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||||
|
|
||||||
# Create data loaders
|
# Create data loaders based on dataset_mode
|
||||||
|
if dataset_mode == "dataflow":
|
||||||
|
if dataflow_args is None:
|
||||||
|
raise ValueError("dataflow_args is required when dataset_mode='dataflow'")
|
||||||
|
dataloaders = create_data_loaders_from_dataflow(
|
||||||
|
dataflow_args,
|
||||||
|
use_x_mark=use_x_mark
|
||||||
|
)
|
||||||
|
else: # Default to "npz" mode
|
||||||
dataloaders = create_data_loaders(
|
dataloaders = create_data_loaders(
|
||||||
data_path=data_path,
|
data_path=data_path,
|
||||||
batch_size=config.get('batch_size', 32)
|
batch_size=config.get('batch_size', 32),
|
||||||
|
use_x_mark=use_x_mark
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
@ -432,11 +610,11 @@ def train_classification_model(
|
|||||||
# Handle both cases: with and without time features
|
# Handle both cases: with and without time features
|
||||||
if len(batch_data) == 4: # With time features
|
if len(batch_data) == 4: # With time features
|
||||||
inputs, targets, x_mark, y_mark = batch_data
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
else: # Without time features
|
else: # Without time features
|
||||||
inputs, targets = batch_data
|
inputs, targets = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = None, None
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
# Convert targets to long for classification
|
# Convert targets to long for classification
|
||||||
@ -484,11 +662,11 @@ def train_classification_model(
|
|||||||
# Handle both cases: with and without time features
|
# Handle both cases: with and without time features
|
||||||
if len(batch_data) == 4: # With time features
|
if len(batch_data) == 4: # With time features
|
||||||
inputs, targets, x_mark, y_mark = batch_data
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
else: # Without time features
|
else: # Without time features
|
||||||
inputs, targets = batch_data
|
inputs, targets = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = None, None
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
targets = targets.long()
|
targets = targets.long()
|
||||||
@ -561,11 +739,11 @@ def train_classification_model(
|
|||||||
# Handle both cases: with and without time features
|
# Handle both cases: with and without time features
|
||||||
if len(batch_data) == 4: # With time features
|
if len(batch_data) == 4: # With time features
|
||||||
inputs, targets, x_mark, y_mark = batch_data
|
inputs, targets, x_mark, y_mark = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
x_mark, y_mark = x_mark.float().to(device), y_mark.float().to(device)
|
||||||
else: # Without time features
|
else: # Without time features
|
||||||
inputs, targets = batch_data
|
inputs, targets = batch_data
|
||||||
inputs, targets = inputs.to(device), targets.to(device)
|
inputs, targets = inputs.float().to(device), targets.float().to(device)
|
||||||
x_mark, y_mark = None, None
|
x_mark, y_mark = None, None
|
||||||
|
|
||||||
targets = targets.long()
|
targets = targets.long()
|
||||||
|
Reference in New Issue
Block a user