Files
TSlib/train_dc_patchtst.py

335 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import argparse
from models.DC_PatchTST import Model
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
# 检查并创建结果文件夹
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
class SineWaveDataset(Dataset):
"""正弦波数据集"""
def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'):
self.seq_len = seq_len
self.pred_len = pred_len
self.mode = mode
# 加载数据
if mode == 'train':
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
elif mode == 'val':
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
else: # test
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
# 提取特征列除timestamp外
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
# 计算可用样本数量
self.total_len = len(self.data)
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本")
def __len__(self):
return self.samples_num
def __getitem__(self, idx):
# 输入序列
s_begin = idx
s_end = s_begin + self.seq_len
# 预测目标
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
# 时间标记(简单的位置编码)
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
return seq_x, seq_y, seq_x_mark, seq_y_mark
class Config:
"""配置类"""
def __init__(self):
# 基础配置
self.task_name = 'long_term_forecast'
self.model = 'DC_PatchTST'
# 数据配置
self.seq_len = 96 # 输入序列长度
self.pred_len = 24 # 预测序列长度
self.label_len = 48 # 标签长度
self.enc_in = 2 # 输入特征维度(双通道)
self.dec_in = 2 # 解码器输入维度
self.c_out = 2 # 输出维度
# 模型配置
self.d_model = 128 # 模型维度
self.n_heads = 8 # 注意力头数
self.e_layers = 2 # 编码器层数
self.d_layers = 1 # 解码器层数
self.d_ff = 256 # 前向网络维度
self.factor = 1 # 注意力因子
self.dropout = 0 # Dropout率
self.activation = 'gelu'
# 训练配置
self.batch_size = 1024
self.learning_rate = 0.001
self.train_epochs = 50
self.patience = 5
# 其他配置
self.use_amp = False
self.num_class = 0
# GPU配置
self.use_gpu = torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False):
"""训练一个epoch"""
model.train()
total_loss = 0.0
batch_count = 0
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
# 构造解码器输入
dec_inp = torch.zeros_like(batch_y).to(device)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
else:
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
loss.backward()
optimizer.step()
total_loss += loss.item()
batch_count += 1
if i % 100 == 0:
print(f'Batch {i}, Loss: {loss.item():.6f}')
return total_loss / batch_count
def validate(model, val_loader, criterion, device):
"""验证模型"""
model.eval()
total_loss = 0.0
batch_count = 0
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
total_loss += loss.item()
batch_count += 1
return total_loss / batch_count
def test_model(model, test_loader, device, save_path):
"""测试模型并可视化结果"""
model.eval()
predictions = []
ground_truths = []
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
predictions.append(outputs.cpu().numpy())
ground_truths.append(batch_y.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
ground_truths = np.concatenate(ground_truths, axis=0)
# 计算指标
mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1))
mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1))
print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}")
# 可视化前几个样本
plt.figure(figsize=(15, 10))
for i in range(min(4, len(predictions))):
for ch in range(2): # 双通道
plt.subplot(4, 2, i*2 + ch + 1)
plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue')
plt.plot(predictions[i, :, ch], label='Prediction', color='red')
plt.title(f'Sample {i+1}, Channel {ch+1}')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'predictions.png'))
print(f"预测结果可视化保存到: {save_path}/predictions.png")
return mse, mae
def main():
# 配置参数
config = Config()
# 创建结果目录
results_dir = './results/dc_patchtst_sine_wave'
ensure_dir(results_dir)
print(f"使用设备: {config.device}")
print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}")
# 加载数据
data_path = './data/sine_wave/'
train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train')
val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val')
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
# 创建模型
model = Model(config).to(config.device)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
# 优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
criterion = nn.MSELoss()
# 训练循环
best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []
print("开始训练...")
start_time = time.time()
for epoch in range(config.train_epochs):
epoch_start = time.time()
# 训练
train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp)
# 验证
val_loss = validate(model, val_loader, criterion, config.device)
train_losses.append(train_loss)
val_losses.append(val_loss)
epoch_time = time.time() - epoch_start
print(f'Epoch {epoch+1:2d}/{config.train_epochs} | '
f'Train Loss: {train_loss:.6f} | '
f'Val Loss: {val_loss:.6f} | '
f'Time: {epoch_time:.2f}s')
# 早停检查
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# 保存最佳模型
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'epoch': epoch,
'val_loss': val_loss
}, os.path.join(results_dir, 'best_model.pth'))
print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})')
else:
patience_counter += 1
if patience_counter >= config.patience:
print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}')
break
total_time = time.time() - start_time
print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟')
# 加载最佳模型进行测试
checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print("\n测试最佳模型...")
test_mse, test_mae = test_model(model, test_loader, config.device, results_dir)
# 保存训练历史
history = {
'train_losses': train_losses,
'val_losses': val_losses,
'test_mse': test_mse,
'test_mae': test_mae
}
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training History')
plt.subplot(1, 2, 2)
plt.bar(['MSE', 'MAE'], [test_mse, test_mae])
plt.title('Test Metrics')
plt.ylabel('Error')
plt.tight_layout()
plt.savefig(os.path.join(results_dir, 'training_history.png'))
print(f"\n结果保存在: {results_dir}")
print(f"最佳模型: {results_dir}/best_model.pth")
print(f"训练历史: {results_dir}/training_history.png")
print(f"预测可视化: {results_dir}/predictions.png")
if __name__ == "__main__":
main()