first timesnet try
This commit is contained in:
630
train/train.py
Normal file
630
train/train.py
Normal file
@ -0,0 +1,630 @@
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import swanlab
|
||||
from typing import Dict, Any, Optional, Callable, Union, Tuple
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stopping to stop training when validation performance doesn't improve."""
|
||||
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
|
||||
"""
|
||||
Args:
|
||||
patience (int): How long to wait after last improvement. Default: 7
|
||||
verbose (bool): If True, prints a message for each improvement. Default: False
|
||||
delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0
|
||||
path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt'
|
||||
"""
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = float('inf')
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
if self.verbose:
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
"""Save model when validation loss decreases."""
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
||||
torch.save(model.state_dict(), self.path)
|
||||
self.val_loss_min = val_loss
|
||||
|
||||
def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]:
|
||||
"""
|
||||
Create PyTorch DataLoaders from an NPZ file
|
||||
|
||||
Args:
|
||||
data_path (str): Path to the NPZ file containing the data
|
||||
batch_size (int): Batch size for the DataLoaders
|
||||
|
||||
Returns:
|
||||
Dict[str, DataLoader]: Dictionary with train and val DataLoaders
|
||||
"""
|
||||
# Load data from NPZ file
|
||||
data = np.load(data_path, allow_pickle=True)
|
||||
train_x = data['train_x']
|
||||
train_y = data['train_y']
|
||||
val_x = data['val_x']
|
||||
val_y = data['val_y']
|
||||
|
||||
# Load time features if available
|
||||
train_x_mark = data.get('train_x_mark', None)
|
||||
train_y_mark = data.get('train_y_mark', None)
|
||||
val_x_mark = data.get('val_x_mark', None)
|
||||
val_y_mark = data.get('val_y_mark', None)
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
train_x = torch.FloatTensor(train_x)
|
||||
train_y = torch.FloatTensor(train_y)
|
||||
val_x = torch.FloatTensor(val_x)
|
||||
val_y = torch.FloatTensor(val_y)
|
||||
|
||||
# Create datasets based on whether time features are available
|
||||
if train_x_mark is not None:
|
||||
train_x_mark = torch.FloatTensor(train_x_mark)
|
||||
train_y_mark = torch.FloatTensor(train_y_mark)
|
||||
val_x_mark = torch.FloatTensor(val_x_mark)
|
||||
val_y_mark = torch.FloatTensor(val_y_mark)
|
||||
|
||||
train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark)
|
||||
val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark)
|
||||
else:
|
||||
train_dataset = TensorDataset(train_x, train_y)
|
||||
val_dataset = TensorDataset(val_x, val_y)
|
||||
|
||||
# Create dataloaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
return {
|
||||
'train': train_loader,
|
||||
'val': val_loader
|
||||
}
|
||||
|
||||
def train_forecasting_model(
|
||||
model_constructor: Callable,
|
||||
data_path: str,
|
||||
project_name: str,
|
||||
config: Dict[str, Any],
|
||||
device: Optional[str] = None,
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series forecasting model
|
||||
|
||||
Args:
|
||||
model_constructor (Callable): Function that constructs and returns the model
|
||||
data_path (str): Path to the NPZ file containing the processed data
|
||||
project_name (str): Name of the project for swanlab tracking
|
||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||
early_stopping_patience (int): Number of epochs to wait before early stopping
|
||||
max_epochs (int): Maximum number of epochs to train for
|
||||
checkpoint_dir (str): Directory to save model checkpoints
|
||||
log_interval (int): How often to log metrics during training
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
"""
|
||||
# Setup device
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Initialize swanlab for experiment tracking
|
||||
swanlab_run = swanlab.init(
|
||||
project=project_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
model = model_constructor()
|
||||
model = model.to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-3),
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
patience=early_stopping_patience,
|
||||
verbose=True,
|
||||
path=checkpoint_path
|
||||
)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
metrics = {}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
print(f"Epoch {epoch+1}/{max_epochs}")
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
print("1\n")
|
||||
train_loss = 0.0
|
||||
|
||||
# 用于记录 log_interval 期间的损失
|
||||
interval_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, batch_data in enumerate(dataloaders['train']):
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
# Create decoder input (zeros for forecasting)
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
train_loss += loss.item()
|
||||
interval_loss += loss.item()
|
||||
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
# 计算这一个 interval 的平均损失并记录
|
||||
avg_interval_loss = interval_loss / log_interval
|
||||
swanlab_run.log({"batch_train_loss": avg_interval_loss})
|
||||
|
||||
# 重置 interval loss 以进行下一次计算
|
||||
interval_loss = 0.0
|
||||
|
||||
avg_train_loss = train_loss / len(dataloaders['train'])
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_mse = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
|
||||
|
||||
avg_val_loss = val_loss / len(dataloaders['val'])
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log metrics
|
||||
metrics_dict = {
|
||||
"train_loss": avg_train_loss,
|
||||
"val_loss": avg_val_loss,
|
||||
"learning_rate": current_lr,
|
||||
"epoch_time": epoch_time
|
||||
}
|
||||
|
||||
swanlab_run.log(metrics_dict)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, "
|
||||
f"Train Loss: {avg_train_loss:.4f}, "
|
||||
f"Val Loss: {avg_val_loss:.4f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
# Check if we should save the model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
metrics = metrics_dict
|
||||
|
||||
# Early stopping
|
||||
early_stopping(avg_val_loss, model)
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
# Final validation
|
||||
model.eval()
|
||||
final_val_loss = 0.0
|
||||
final_val_mse = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
final_val_loss += loss.item()
|
||||
|
||||
|
||||
final_val_loss /= len(dataloaders['val'])
|
||||
|
||||
print(f"Final validation loss: {final_val_loss:.4f}")
|
||||
|
||||
# Update metrics with final values
|
||||
metrics["final_val_loss"] = final_val_loss
|
||||
|
||||
# Finish the swanlab run
|
||||
swanlab_run.finish()
|
||||
|
||||
return model, metrics
|
||||
|
||||
def train_classification_model(
|
||||
model_constructor: Callable,
|
||||
data_path: str,
|
||||
project_name: str,
|
||||
config: Dict[str, Any],
|
||||
device: Optional[str] = None,
|
||||
early_stopping_patience: int = 10,
|
||||
max_epochs: int = 100,
|
||||
checkpoint_dir: str = "./checkpoints",
|
||||
log_interval: int = 10
|
||||
) -> Tuple[nn.Module, Dict[str, float]]:
|
||||
"""
|
||||
Train a time series classification model
|
||||
|
||||
Args:
|
||||
model_constructor (Callable): Function that constructs and returns the model
|
||||
data_path (str): Path to the NPZ file containing the processed data
|
||||
project_name (str): Name of the project for swanlab tracking
|
||||
config (Dict[str, Any]): Configuration dictionary for the experiment
|
||||
device (Optional[str]): Device to use for training ('cpu' or 'cuda')
|
||||
early_stopping_patience (int): Number of epochs to wait before early stopping
|
||||
max_epochs (int): Maximum number of epochs to train for
|
||||
checkpoint_dir (str): Directory to save model checkpoints
|
||||
log_interval (int): How often to log metrics during training
|
||||
|
||||
Returns:
|
||||
Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics
|
||||
"""
|
||||
# Setup device
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Initialize swanlab for experiment tracking
|
||||
swanlab_run = swanlab.init(
|
||||
project=project_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt")
|
||||
|
||||
# Create data loaders
|
||||
dataloaders = create_data_loaders(
|
||||
data_path=data_path,
|
||||
batch_size=config.get('batch_size', 32)
|
||||
)
|
||||
|
||||
# Construct the model
|
||||
model = model_constructor()
|
||||
model = model.to(device)
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.get('learning_rate', 1e-3),
|
||||
weight_decay=config.get('weight_decay', 1e-4)
|
||||
)
|
||||
|
||||
# Add learning rate scheduler to halve LR after each epoch
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
|
||||
|
||||
# Initialize early stopping
|
||||
early_stopping = EarlyStopping(
|
||||
patience=early_stopping_patience,
|
||||
verbose=True,
|
||||
path=checkpoint_path
|
||||
)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
metrics = {}
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
print(f"Epoch {epoch+1}/{max_epochs}")
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
train_correct = 0
|
||||
train_total = 0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, batch_data in enumerate(dataloaders['train']):
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
# Convert targets to long for classification
|
||||
targets = targets.long()
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Update statistics
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
train_total += targets.size(0)
|
||||
train_correct += predicted.eq(targets).sum().item()
|
||||
|
||||
if (batch_idx + 1) % log_interval == 0:
|
||||
print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}")
|
||||
|
||||
avg_train_loss = train_loss / len(dataloaders['train'])
|
||||
train_accuracy = 100. * train_correct / train_total
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_correct = 0
|
||||
val_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
targets = targets.long()
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
val_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
_, predicted = outputs.max(1)
|
||||
val_total += targets.size(0)
|
||||
val_correct += predicted.eq(targets).sum().item()
|
||||
|
||||
avg_val_loss = val_loss / len(dataloaders['val'])
|
||||
val_accuracy = 100. * val_correct / val_total
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log metrics
|
||||
metrics_dict = {
|
||||
"train_loss": avg_train_loss,
|
||||
"val_loss": avg_val_loss,
|
||||
"val_accuracy": val_accuracy,
|
||||
"learning_rate": current_lr,
|
||||
"epoch_time": epoch_time
|
||||
}
|
||||
|
||||
swanlab_run.log(metrics_dict)
|
||||
|
||||
print(f"Epoch {epoch+1}/{max_epochs}, "
|
||||
f"Train Loss: {avg_train_loss:.4f}, "
|
||||
f"Val Loss: {avg_val_loss:.4f}, "
|
||||
f"Val Accuracy: {val_accuracy:.2f}%, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
# Check if we should save the model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
metrics = metrics_dict
|
||||
|
||||
# Early stopping
|
||||
early_stopping(avg_val_loss, model)
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping triggered")
|
||||
break
|
||||
|
||||
# Step the learning rate scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Load the best model
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
# Final validation
|
||||
model.eval()
|
||||
final_val_loss = 0.0
|
||||
final_val_correct = 0
|
||||
final_val_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_data in dataloaders['val']:
|
||||
# Handle both cases: with and without time features
|
||||
if len(batch_data) == 4: # With time features
|
||||
inputs, targets, x_mark, y_mark = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = x_mark.to(device), y_mark.to(device)
|
||||
else: # Without time features
|
||||
inputs, targets = batch_data
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
x_mark, y_mark = None, None
|
||||
|
||||
targets = targets.long()
|
||||
|
||||
# Forward pass - handle both cases
|
||||
if x_mark is not None:
|
||||
# For TimesNet model with time features
|
||||
dec_inp = torch.zeros_like(targets).to(device)
|
||||
outputs = model(inputs, x_mark, dec_inp, y_mark)
|
||||
else:
|
||||
# For simple models without time features
|
||||
outputs = model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
loss = criterion(outputs, targets)
|
||||
final_val_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
_, predicted = outputs.max(1)
|
||||
final_val_total += targets.size(0)
|
||||
final_val_correct += predicted.eq(targets).sum().item()
|
||||
|
||||
final_val_loss /= len(dataloaders['val'])
|
||||
final_val_accuracy = 100. * final_val_correct / final_val_total
|
||||
|
||||
print(f"Final validation loss: {final_val_loss:.4f}")
|
||||
print(f"Final validation accuracy: {final_val_accuracy:.2f}%")
|
||||
|
||||
# Update metrics with final values
|
||||
metrics["final_val_loss"] = final_val_loss
|
||||
metrics["final_val_accuracy"] = final_val_accuracy
|
||||
|
||||
# Finish the swanlab run
|
||||
swanlab_run.finish()
|
||||
|
||||
return model, metrics
|
||||
|
||||
def main():
|
||||
# Example usage
|
||||
data_path = 'data/train_data.npz'
|
||||
project_name = 'TimeSeriesForecasting'
|
||||
config = {
|
||||
'learning_rate': 0.001,
|
||||
'batch_size': 32,
|
||||
'weight_decay': 1e-4
|
||||
}
|
||||
|
||||
model_constructor = lambda: nn.Sequential(
|
||||
nn.Linear(10, 50),
|
||||
nn.ReLU(),
|
||||
nn.Linear(50, 1)
|
||||
)
|
||||
|
||||
model, metrics = train_forecasting_model(
|
||||
model_constructor=model_constructor,
|
||||
data_path=data_path,
|
||||
project_name=project_name,
|
||||
config=config
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user