Files
TSlib/test_dc_patchtst.py

313 lines
12 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
from models.DC_PatchTST import Model
class Config:
"""Configuration class"""
def __init__(self):
# Basic configuration
self.task_name = 'long_term_forecast'
self.model = 'DC_PatchTST'
# Data configuration
self.seq_len = 96 # Input sequence length
self.pred_len = 24 # Prediction sequence length
self.label_len = 48 # Label length
self.enc_in = 2 # Input feature dimension (dual channel)
self.dec_in = 2 # Decoder input dimension
self.c_out = 2 # Output dimension
# Model configuration
self.d_model = 128 # Model dimension
self.n_heads = 8 # Number of attention heads
self.e_layers = 2 # Number of encoder layers
self.d_layers = 1 # Number of decoder layers
self.d_ff = 256 # Feed forward dimension
self.factor = 1 # Attention factor
self.dropout = 0.1 # Dropout rate
self.activation = 'gelu'
# Training configuration
self.batch_size = 32
self.learning_rate = 0.001
self.train_epochs = 50
self.patience = 5
# Other configuration
self.use_amp = False
self.num_class = 0
# GPU configuration
self.use_gpu = torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
class SineWaveDataset(Dataset):
"""Sine wave dataset"""
def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'):
self.seq_len = seq_len
self.pred_len = pred_len
self.mode = mode
# Load data
if mode == 'train':
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
elif mode == 'val':
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
else: # test
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
# Extract feature columns (except timestamp)
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
# Calculate available sample count
self.total_len = len(self.data)
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples")
def __len__(self):
return self.samples_num
def __getitem__(self, idx):
# Input sequence
s_begin = idx
s_end = s_begin + self.seq_len
# Prediction target
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
# Time marks (simple positional encoding)
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
return seq_x, seq_y, seq_x_mark, seq_y_mark, idx
def load_model(model_path):
"""Load saved model"""
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
config = checkpoint['config']
# Create model
model = Model(config)
model.load_state_dict(checkpoint['model_state_dict'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}")
print(f"Using device: {device}")
return model, config, device
def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5):
"""Predict and visualize results, marking chunk points"""
model.eval()
# Create save directory
vis_dir = os.path.join(save_path, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
sample_count = 0
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader:
if sample_count >= num_samples:
break
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
# Construct decoder input
dec_inp = torch.zeros_like(batch_y).to(device)
# Predict
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# Convert to CPU
batch_x_cpu = batch_x.cpu().numpy()
batch_y_cpu = batch_y.cpu().numpy()
outputs_cpu = outputs.cpu().numpy()
# Process each sample in batch
for i in range(min(batch_x.size(0), num_samples - sample_count)):
input_seq = batch_x_cpu[i] # (seq_len, 2)
true_pred = batch_y_cpu[i] # (pred_len, 2)
pred_seq = outputs_cpu[i] # (pred_len, 2)
# Get first layer chunk information
boundary_mask_stage0 = None
if aux is not None and 'stage0' in aux:
# aux['stage0']['boundary_mask'] shape: (B*nvars, L)
# We need to reshape to (B, nvars, L) then take i-th sample
stage0_info = aux['stage0']
boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L)
B = batch_x.size(0)
nvars = batch_x.size(2) # should be 2
L = boundary_mask.shape[1]
# Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L)
boundary_mask = boundary_mask.reshape(B, nvars, L)
boundary_mask_stage0 = boundary_mask[i] # (nvars, L)
# Create visualization for each channel
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
for ch in range(2): # dual channel
ax = axes[ch]
# Time axis
input_time = np.arange(len(input_seq))
pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred))
# Plot input sequence
ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5)
# Plot ground truth prediction and model prediction
ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2)
ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2)
# Mark first layer chunk points
if boundary_mask_stage0 is not None:
chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel
for point in chunk_points:
if point < len(input_seq): # Only mark points within input sequence range
ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1)
# Add chunk points explanation in legend
if len(chunk_points) > 0:
ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7,
linewidth=1, label=f'Chunk Points (Stage 0)')
ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}')
ax.set_xlabel('Time Steps')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True, alpha=0.3)
# Add boundary line marking input and prediction boundary
ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1)
ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start',
rotation=90, verticalalignment='top', fontsize=8)
plt.tight_layout()
# Save figure
sample_filename = f'sample_{sample_count + 1}_with_chunks.png'
sample_path = os.path.join(vis_dir, sample_filename)
plt.savefig(sample_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Sample {sample_count + 1} visualization saved to: {sample_path}")
# Print chunk statistics
if boundary_mask_stage0 is not None:
for ch in range(2):
chunk_count = np.sum(boundary_mask_stage0[ch])
chunk_ratio = chunk_count / len(boundary_mask_stage0[ch])
print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)")
sample_count += 1
if sample_count >= num_samples:
break
if sample_count >= num_samples:
break
def evaluate_model(model, test_loader, device):
"""Evaluate model performance"""
model.eval()
total_mse = 0.0
total_mae = 0.0
batch_count = 0
all_predictions = []
all_ground_truths = []
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# Calculate loss
mse = torch.mean((outputs - batch_y) ** 2)
mae = torch.mean(torch.abs(outputs - batch_y))
total_mse += mse.item()
total_mae += mae.item()
batch_count += 1
# Collect for overall statistics
all_predictions.append(outputs.cpu().numpy())
all_ground_truths.append(batch_y.cpu().numpy())
avg_mse = total_mse / batch_count
avg_mae = total_mae / batch_count
# Calculate overall RMSE
all_predictions = np.concatenate(all_predictions, axis=0)
all_ground_truths = np.concatenate(all_ground_truths, axis=0)
rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2))
print(f"\nTest set evaluation results:")
print(f"MSE: {avg_mse:.6f}")
print(f"MAE: {avg_mae:.6f}")
print(f"RMSE: {rmse:.6f}")
return avg_mse, avg_mae, rmse
def main():
# Configuration parameters
model_path = './results/dc_patchtst_sine_wave/best_model.pth'
data_path = './data/sine_wave/'
save_path = './results/dc_patchtst_sine_wave/'
if not os.path.exists(model_path):
print(f"Error: Model file not found {model_path}")
print("Please run the training script train_dc_patchtst.py first")
return
# Load model
model, config, device = load_model(model_path)
# Load test data
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization
print(f"\nConfiguration info:")
print(f"Sequence length: {config.seq_len}")
print(f"Prediction length: {config.pred_len}")
print(f"Feature dimension: {config.enc_in}")
# Evaluate model
mse, mae, rmse = evaluate_model(model, test_loader, device)
# Visualize prediction results and mark chunk points
print(f"\nGenerating visualization results...")
visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5)
print(f"\nAll results saved to: {save_path}")
print(f"Visualization files located at: {save_path}/visualizations/")
if __name__ == "__main__":
main()