#!/usr/bin/env python3 import torch import torch.nn as nn import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader import os import argparse from models.DC_PatchTST import Model import time import matplotlib.pyplot as plt from sklearn.metrics import mean_squared_error, mean_absolute_error # 检查并创建结果文件夹 def ensure_dir(path): if not os.path.exists(path): os.makedirs(path) class SineWaveDataset(Dataset): """正弦波数据集""" def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'): self.seq_len = seq_len self.pred_len = pred_len self.mode = mode # 加载数据 if mode == 'train': df = pd.read_csv(os.path.join(data_path, 'train.csv')) elif mode == 'val': df = pd.read_csv(os.path.join(data_path, 'val.csv')) else: # test df = pd.read_csv(os.path.join(data_path, 'test.csv')) # 提取特征列(除timestamp外) self.data = df[['channel1', 'channel2']].values.astype(np.float32) # 计算可用样本数量 self.total_len = len(self.data) self.samples_num = max(0, self.total_len - seq_len - pred_len + 1) print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本") def __len__(self): return self.samples_num def __getitem__(self, idx): # 输入序列 s_begin = idx s_end = s_begin + self.seq_len # 预测目标 r_begin = s_end r_end = r_begin + self.pred_len seq_x = self.data[s_begin:s_end] # (seq_len, n_vars) seq_y = self.data[r_begin:r_end] # (pred_len, n_vars) # 时间标记(简单的位置编码) seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32) seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32) return seq_x, seq_y, seq_x_mark, seq_y_mark class Config: """配置类""" def __init__(self): # 基础配置 self.task_name = 'long_term_forecast' self.model = 'DC_PatchTST' # 数据配置 self.seq_len = 96 # 输入序列长度 self.pred_len = 24 # 预测序列长度 self.label_len = 48 # 标签长度 self.enc_in = 2 # 输入特征维度(双通道) self.dec_in = 2 # 解码器输入维度 self.c_out = 2 # 输出维度 # 模型配置 self.d_model = 128 # 模型维度 self.n_heads = 8 # 注意力头数 self.e_layers = 2 # 编码器层数 self.d_layers = 1 # 解码器层数 self.d_ff = 256 # 前向网络维度 self.factor = 1 # 注意力因子 self.dropout = 0 # Dropout率 self.activation = 'gelu' # 训练配置 self.batch_size = 1024 self.learning_rate = 0.001 self.train_epochs = 50 self.patience = 5 # 其他配置 self.use_amp = False self.num_class = 0 # GPU配置 self.use_gpu = torch.cuda.is_available() self.device = torch.device('cuda' if self.use_gpu else 'cpu') def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False): """训练一个epoch""" model.train() total_loss = 0.0 batch_count = 0 for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): batch_x = batch_x.to(device) batch_y = batch_y.to(device) batch_x_mark = batch_x_mark.to(device) batch_y_mark = batch_y_mark.to(device) # 构造解码器输入 dec_inp = torch.zeros_like(batch_y).to(device) optimizer.zero_grad() if use_amp: with torch.cuda.amp.autocast(): outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) loss = criterion(outputs, batch_y) # 添加DC的ratio loss if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux: ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1'] loss = loss + 0.0 * ratio_loss # ratio loss权重 else: outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) loss = criterion(outputs, batch_y) # 添加DC的ratio loss if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux: ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1'] loss = loss + 0.0 * ratio_loss # ratio loss权重 loss.backward() optimizer.step() total_loss += loss.item() batch_count += 1 if i % 100 == 0: print(f'Batch {i}, Loss: {loss.item():.6f}') return total_loss / batch_count def validate(model, val_loader, criterion, device): """验证模型""" model.eval() total_loss = 0.0 batch_count = 0 with torch.no_grad(): for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader: batch_x = batch_x.to(device) batch_y = batch_y.to(device) batch_x_mark = batch_x_mark.to(device) batch_y_mark = batch_y_mark.to(device) dec_inp = torch.zeros_like(batch_y).to(device) outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) loss = criterion(outputs, batch_y) total_loss += loss.item() batch_count += 1 return total_loss / batch_count def test_model(model, test_loader, device, save_path): """测试模型并可视化结果""" model.eval() predictions = [] ground_truths = [] with torch.no_grad(): for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader: batch_x = batch_x.to(device) batch_y = batch_y.to(device) batch_x_mark = batch_x_mark.to(device) batch_y_mark = batch_y_mark.to(device) dec_inp = torch.zeros_like(batch_y).to(device) outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) predictions.append(outputs.cpu().numpy()) ground_truths.append(batch_y.cpu().numpy()) predictions = np.concatenate(predictions, axis=0) ground_truths = np.concatenate(ground_truths, axis=0) # 计算指标 mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1)) mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1)) print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}") # 可视化前几个样本 plt.figure(figsize=(15, 10)) for i in range(min(4, len(predictions))): for ch in range(2): # 双通道 plt.subplot(4, 2, i*2 + ch + 1) plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue') plt.plot(predictions[i, :, ch], label='Prediction', color='red') plt.title(f'Sample {i+1}, Channel {ch+1}') plt.legend() plt.tight_layout() plt.savefig(os.path.join(save_path, 'predictions.png')) print(f"预测结果可视化保存到: {save_path}/predictions.png") return mse, mae def main(): # 配置参数 config = Config() # 创建结果目录 results_dir = './results/dc_patchtst_sine_wave' ensure_dir(results_dir) print(f"使用设备: {config.device}") print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}") # 加载数据 data_path = './data/sine_wave/' train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train') val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val') test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test') train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4) # 创建模型 model = Model(config).to(config.device) print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}") # 优化器和损失函数 optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4) criterion = nn.MSELoss() # 训练循环 best_val_loss = float('inf') patience_counter = 0 train_losses = [] val_losses = [] print("开始训练...") start_time = time.time() for epoch in range(config.train_epochs): epoch_start = time.time() # 训练 train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp) # 验证 val_loss = validate(model, val_loader, criterion, config.device) train_losses.append(train_loss) val_losses.append(val_loss) epoch_time = time.time() - epoch_start print(f'Epoch {epoch+1:2d}/{config.train_epochs} | ' f'Train Loss: {train_loss:.6f} | ' f'Val Loss: {val_loss:.6f} | ' f'Time: {epoch_time:.2f}s') # 早停检查 if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 # 保存最佳模型 torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'config': config, 'epoch': epoch, 'val_loss': val_loss }, os.path.join(results_dir, 'best_model.pth')) print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})') else: patience_counter += 1 if patience_counter >= config.patience: print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}') break total_time = time.time() - start_time print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟') # 加载最佳模型进行测试 checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth')) model.load_state_dict(checkpoint['model_state_dict']) print("\n测试最佳模型...") test_mse, test_mae = test_model(model, test_loader, config.device, results_dir) # 保存训练历史 history = { 'train_losses': train_losses, 'val_losses': val_losses, 'test_mse': test_mse, 'test_mae': test_mae } # 绘制训练曲线 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='Train Loss') plt.plot(val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('Training History') plt.subplot(1, 2, 2) plt.bar(['MSE', 'MAE'], [test_mse, test_mae]) plt.title('Test Metrics') plt.ylabel('Error') plt.tight_layout() plt.savefig(os.path.join(results_dir, 'training_history.png')) print(f"\n结果保存在: {results_dir}") print(f"最佳模型: {results_dir}/best_model.pth") print(f"训练历史: {results_dir}/training_history.png") print(f"预测可视化: {results_dir}/predictions.png") if __name__ == "__main__": main()