feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
335
train_dc_patchtst.py
Normal file
335
train_dc_patchtst.py
Normal file
@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import os
|
||||
import argparse
|
||||
from models.DC_PatchTST import Model
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
||||
|
||||
# 检查并创建结果文件夹
|
||||
def ensure_dir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
class SineWaveDataset(Dataset):
|
||||
"""正弦波数据集"""
|
||||
def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.mode = mode
|
||||
|
||||
# 加载数据
|
||||
if mode == 'train':
|
||||
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
|
||||
elif mode == 'val':
|
||||
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
|
||||
else: # test
|
||||
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
|
||||
|
||||
# 提取特征列(除timestamp外)
|
||||
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
|
||||
|
||||
# 计算可用样本数量
|
||||
self.total_len = len(self.data)
|
||||
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
|
||||
|
||||
print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本")
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_num
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 输入序列
|
||||
s_begin = idx
|
||||
s_end = s_begin + self.seq_len
|
||||
|
||||
# 预测目标
|
||||
r_begin = s_end
|
||||
r_end = r_begin + self.pred_len
|
||||
|
||||
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
|
||||
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
|
||||
|
||||
# 时间标记(简单的位置编码)
|
||||
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
|
||||
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
|
||||
|
||||
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||
|
||||
class Config:
|
||||
"""配置类"""
|
||||
def __init__(self):
|
||||
# 基础配置
|
||||
self.task_name = 'long_term_forecast'
|
||||
self.model = 'DC_PatchTST'
|
||||
|
||||
# 数据配置
|
||||
self.seq_len = 96 # 输入序列长度
|
||||
self.pred_len = 24 # 预测序列长度
|
||||
self.label_len = 48 # 标签长度
|
||||
self.enc_in = 2 # 输入特征维度(双通道)
|
||||
self.dec_in = 2 # 解码器输入维度
|
||||
self.c_out = 2 # 输出维度
|
||||
|
||||
# 模型配置
|
||||
self.d_model = 128 # 模型维度
|
||||
self.n_heads = 8 # 注意力头数
|
||||
self.e_layers = 2 # 编码器层数
|
||||
self.d_layers = 1 # 解码器层数
|
||||
self.d_ff = 256 # 前向网络维度
|
||||
self.factor = 1 # 注意力因子
|
||||
self.dropout = 0 # Dropout率
|
||||
self.activation = 'gelu'
|
||||
|
||||
# 训练配置
|
||||
self.batch_size = 1024
|
||||
self.learning_rate = 0.001
|
||||
self.train_epochs = 50
|
||||
self.patience = 5
|
||||
|
||||
# 其他配置
|
||||
self.use_amp = False
|
||||
self.num_class = 0
|
||||
|
||||
# GPU配置
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False):
|
||||
"""训练一个epoch"""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
batch_count = 0
|
||||
|
||||
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
|
||||
batch_x = batch_x.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
batch_x_mark = batch_x_mark.to(device)
|
||||
batch_y_mark = batch_y_mark.to(device)
|
||||
|
||||
# 构造解码器输入
|
||||
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
loss = criterion(outputs, batch_y)
|
||||
|
||||
# 添加DC的ratio loss
|
||||
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
|
||||
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
|
||||
loss = loss + 0.0 * ratio_loss # ratio loss权重
|
||||
else:
|
||||
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
loss = criterion(outputs, batch_y)
|
||||
|
||||
# 添加DC的ratio loss
|
||||
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
|
||||
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
|
||||
loss = loss + 0.0 * ratio_loss # ratio loss权重
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
batch_count += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f'Batch {i}, Loss: {loss.item():.6f}')
|
||||
|
||||
return total_loss / batch_count
|
||||
|
||||
def validate(model, val_loader, criterion, device):
|
||||
"""验证模型"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
batch_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
|
||||
batch_x = batch_x.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
batch_x_mark = batch_x_mark.to(device)
|
||||
batch_y_mark = batch_y_mark.to(device)
|
||||
|
||||
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||
|
||||
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
loss = criterion(outputs, batch_y)
|
||||
|
||||
total_loss += loss.item()
|
||||
batch_count += 1
|
||||
|
||||
return total_loss / batch_count
|
||||
|
||||
def test_model(model, test_loader, device, save_path):
|
||||
"""测试模型并可视化结果"""
|
||||
model.eval()
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader:
|
||||
batch_x = batch_x.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
batch_x_mark = batch_x_mark.to(device)
|
||||
batch_y_mark = batch_y_mark.to(device)
|
||||
|
||||
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||
|
||||
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
|
||||
predictions.append(outputs.cpu().numpy())
|
||||
ground_truths.append(batch_y.cpu().numpy())
|
||||
|
||||
predictions = np.concatenate(predictions, axis=0)
|
||||
ground_truths = np.concatenate(ground_truths, axis=0)
|
||||
|
||||
# 计算指标
|
||||
mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1))
|
||||
mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1))
|
||||
|
||||
print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}")
|
||||
|
||||
# 可视化前几个样本
|
||||
plt.figure(figsize=(15, 10))
|
||||
for i in range(min(4, len(predictions))):
|
||||
for ch in range(2): # 双通道
|
||||
plt.subplot(4, 2, i*2 + ch + 1)
|
||||
plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue')
|
||||
plt.plot(predictions[i, :, ch], label='Prediction', color='red')
|
||||
plt.title(f'Sample {i+1}, Channel {ch+1}')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(save_path, 'predictions.png'))
|
||||
print(f"预测结果可视化保存到: {save_path}/predictions.png")
|
||||
|
||||
return mse, mae
|
||||
|
||||
def main():
|
||||
# 配置参数
|
||||
config = Config()
|
||||
|
||||
# 创建结果目录
|
||||
results_dir = './results/dc_patchtst_sine_wave'
|
||||
ensure_dir(results_dir)
|
||||
|
||||
print(f"使用设备: {config.device}")
|
||||
print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}")
|
||||
|
||||
# 加载数据
|
||||
data_path = './data/sine_wave/'
|
||||
train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train')
|
||||
val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val')
|
||||
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
|
||||
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
# 创建模型
|
||||
model = Model(config).to(config.device)
|
||||
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
# 优化器和损失函数
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
# 训练循环
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
|
||||
print("开始训练...")
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(config.train_epochs):
|
||||
epoch_start = time.time()
|
||||
|
||||
# 训练
|
||||
train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp)
|
||||
|
||||
# 验证
|
||||
val_loss = validate(model, val_loader, criterion, config.device)
|
||||
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
|
||||
print(f'Epoch {epoch+1:2d}/{config.train_epochs} | '
|
||||
f'Train Loss: {train_loss:.6f} | '
|
||||
f'Val Loss: {val_loss:.6f} | '
|
||||
f'Time: {epoch_time:.2f}s')
|
||||
|
||||
# 早停检查
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# 保存最佳模型
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'config': config,
|
||||
'epoch': epoch,
|
||||
'val_loss': val_loss
|
||||
}, os.path.join(results_dir, 'best_model.pth'))
|
||||
print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})')
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= config.patience:
|
||||
print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}')
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟')
|
||||
|
||||
# 加载最佳模型进行测试
|
||||
checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth'))
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
print("\n测试最佳模型...")
|
||||
test_mse, test_mae = test_model(model, test_loader, config.device, results_dir)
|
||||
|
||||
# 保存训练历史
|
||||
history = {
|
||||
'train_losses': train_losses,
|
||||
'val_losses': val_losses,
|
||||
'test_mse': test_mse,
|
||||
'test_mae': test_mae
|
||||
}
|
||||
|
||||
# 绘制训练曲线
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(train_losses, label='Train Loss')
|
||||
plt.plot(val_losses, label='Validation Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.title('Training History')
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.bar(['MSE', 'MAE'], [test_mse, test_mae])
|
||||
plt.title('Test Metrics')
|
||||
plt.ylabel('Error')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(results_dir, 'training_history.png'))
|
||||
|
||||
print(f"\n结果保存在: {results_dir}")
|
||||
print(f"最佳模型: {results_dir}/best_model.pth")
|
||||
print(f"训练历史: {results_dir}/training_history.png")
|
||||
print(f"预测可视化: {results_dir}/predictions.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user