feat: add mamba and dynamic chunking related code and test code

This commit is contained in:
gameloader
2025-09-04 01:32:13 +00:00
parent 12cb7652cf
commit ef307a57e9
21 changed files with 4550 additions and 86 deletions

335
train_dc_patchtst.py Normal file
View File

@ -0,0 +1,335 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import argparse
from models.DC_PatchTST import Model
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
# 检查并创建结果文件夹
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
class SineWaveDataset(Dataset):
"""正弦波数据集"""
def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'):
self.seq_len = seq_len
self.pred_len = pred_len
self.mode = mode
# 加载数据
if mode == 'train':
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
elif mode == 'val':
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
else: # test
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
# 提取特征列除timestamp外
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
# 计算可用样本数量
self.total_len = len(self.data)
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本")
def __len__(self):
return self.samples_num
def __getitem__(self, idx):
# 输入序列
s_begin = idx
s_end = s_begin + self.seq_len
# 预测目标
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
# 时间标记(简单的位置编码)
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
return seq_x, seq_y, seq_x_mark, seq_y_mark
class Config:
"""配置类"""
def __init__(self):
# 基础配置
self.task_name = 'long_term_forecast'
self.model = 'DC_PatchTST'
# 数据配置
self.seq_len = 96 # 输入序列长度
self.pred_len = 24 # 预测序列长度
self.label_len = 48 # 标签长度
self.enc_in = 2 # 输入特征维度(双通道)
self.dec_in = 2 # 解码器输入维度
self.c_out = 2 # 输出维度
# 模型配置
self.d_model = 128 # 模型维度
self.n_heads = 8 # 注意力头数
self.e_layers = 2 # 编码器层数
self.d_layers = 1 # 解码器层数
self.d_ff = 256 # 前向网络维度
self.factor = 1 # 注意力因子
self.dropout = 0 # Dropout率
self.activation = 'gelu'
# 训练配置
self.batch_size = 1024
self.learning_rate = 0.001
self.train_epochs = 50
self.patience = 5
# 其他配置
self.use_amp = False
self.num_class = 0
# GPU配置
self.use_gpu = torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False):
"""训练一个epoch"""
model.train()
total_loss = 0.0
batch_count = 0
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
# 构造解码器输入
dec_inp = torch.zeros_like(batch_y).to(device)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
else:
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
loss.backward()
optimizer.step()
total_loss += loss.item()
batch_count += 1
if i % 100 == 0:
print(f'Batch {i}, Loss: {loss.item():.6f}')
return total_loss / batch_count
def validate(model, val_loader, criterion, device):
"""验证模型"""
model.eval()
total_loss = 0.0
batch_count = 0
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
total_loss += loss.item()
batch_count += 1
return total_loss / batch_count
def test_model(model, test_loader, device, save_path):
"""测试模型并可视化结果"""
model.eval()
predictions = []
ground_truths = []
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
predictions.append(outputs.cpu().numpy())
ground_truths.append(batch_y.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
ground_truths = np.concatenate(ground_truths, axis=0)
# 计算指标
mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1))
mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1))
print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}")
# 可视化前几个样本
plt.figure(figsize=(15, 10))
for i in range(min(4, len(predictions))):
for ch in range(2): # 双通道
plt.subplot(4, 2, i*2 + ch + 1)
plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue')
plt.plot(predictions[i, :, ch], label='Prediction', color='red')
plt.title(f'Sample {i+1}, Channel {ch+1}')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'predictions.png'))
print(f"预测结果可视化保存到: {save_path}/predictions.png")
return mse, mae
def main():
# 配置参数
config = Config()
# 创建结果目录
results_dir = './results/dc_patchtst_sine_wave'
ensure_dir(results_dir)
print(f"使用设备: {config.device}")
print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}")
# 加载数据
data_path = './data/sine_wave/'
train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train')
val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val')
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
# 创建模型
model = Model(config).to(config.device)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
# 优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
criterion = nn.MSELoss()
# 训练循环
best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []
print("开始训练...")
start_time = time.time()
for epoch in range(config.train_epochs):
epoch_start = time.time()
# 训练
train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp)
# 验证
val_loss = validate(model, val_loader, criterion, config.device)
train_losses.append(train_loss)
val_losses.append(val_loss)
epoch_time = time.time() - epoch_start
print(f'Epoch {epoch+1:2d}/{config.train_epochs} | '
f'Train Loss: {train_loss:.6f} | '
f'Val Loss: {val_loss:.6f} | '
f'Time: {epoch_time:.2f}s')
# 早停检查
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# 保存最佳模型
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'epoch': epoch,
'val_loss': val_loss
}, os.path.join(results_dir, 'best_model.pth'))
print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})')
else:
patience_counter += 1
if patience_counter >= config.patience:
print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}')
break
total_time = time.time() - start_time
print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟')
# 加载最佳模型进行测试
checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print("\n测试最佳模型...")
test_mse, test_mae = test_model(model, test_loader, config.device, results_dir)
# 保存训练历史
history = {
'train_losses': train_losses,
'val_losses': val_losses,
'test_mse': test_mse,
'test_mae': test_mae
}
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training History')
plt.subplot(1, 2, 2)
plt.bar(['MSE', 'MAE'], [test_mse, test_mae])
plt.title('Test Metrics')
plt.ylabel('Error')
plt.tight_layout()
plt.savefig(os.path.join(results_dir, 'training_history.png'))
print(f"\n结果保存在: {results_dir}")
print(f"最佳模型: {results_dir}/best_model.pth")
print(f"训练历史: {results_dir}/training_history.png")
print(f"预测可视化: {results_dir}/predictions.png")
if __name__ == "__main__":
main()