313 lines
12 KiB
Python
313 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
import torch
|
|
import torch.nn as nn
|
|
import pandas as pd
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
from models.DC_PatchTST import Model
|
|
|
|
class Config:
|
|
"""Configuration class"""
|
|
def __init__(self):
|
|
# Basic configuration
|
|
self.task_name = 'long_term_forecast'
|
|
self.model = 'DC_PatchTST'
|
|
|
|
# Data configuration
|
|
self.seq_len = 96 # Input sequence length
|
|
self.pred_len = 24 # Prediction sequence length
|
|
self.label_len = 48 # Label length
|
|
self.enc_in = 2 # Input feature dimension (dual channel)
|
|
self.dec_in = 2 # Decoder input dimension
|
|
self.c_out = 2 # Output dimension
|
|
|
|
# Model configuration
|
|
self.d_model = 128 # Model dimension
|
|
self.n_heads = 8 # Number of attention heads
|
|
self.e_layers = 2 # Number of encoder layers
|
|
self.d_layers = 1 # Number of decoder layers
|
|
self.d_ff = 256 # Feed forward dimension
|
|
self.factor = 1 # Attention factor
|
|
self.dropout = 0.1 # Dropout rate
|
|
self.activation = 'gelu'
|
|
|
|
# Training configuration
|
|
self.batch_size = 32
|
|
self.learning_rate = 0.001
|
|
self.train_epochs = 50
|
|
self.patience = 5
|
|
|
|
# Other configuration
|
|
self.use_amp = False
|
|
self.num_class = 0
|
|
|
|
# GPU configuration
|
|
self.use_gpu = torch.cuda.is_available()
|
|
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
|
|
|
class SineWaveDataset(Dataset):
|
|
"""Sine wave dataset"""
|
|
def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'):
|
|
self.seq_len = seq_len
|
|
self.pred_len = pred_len
|
|
self.mode = mode
|
|
|
|
# Load data
|
|
if mode == 'train':
|
|
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
|
|
elif mode == 'val':
|
|
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
|
|
else: # test
|
|
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
|
|
|
|
# Extract feature columns (except timestamp)
|
|
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
|
|
|
|
# Calculate available sample count
|
|
self.total_len = len(self.data)
|
|
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
|
|
|
|
print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples")
|
|
|
|
def __len__(self):
|
|
return self.samples_num
|
|
|
|
def __getitem__(self, idx):
|
|
# Input sequence
|
|
s_begin = idx
|
|
s_end = s_begin + self.seq_len
|
|
|
|
# Prediction target
|
|
r_begin = s_end
|
|
r_end = r_begin + self.pred_len
|
|
|
|
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
|
|
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
|
|
|
|
# Time marks (simple positional encoding)
|
|
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
|
|
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
|
|
|
|
return seq_x, seq_y, seq_x_mark, seq_y_mark, idx
|
|
|
|
def load_model(model_path):
|
|
"""Load saved model"""
|
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
|
config = checkpoint['config']
|
|
|
|
# Create model
|
|
model = Model(config)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}")
|
|
print(f"Using device: {device}")
|
|
|
|
return model, config, device
|
|
|
|
def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5):
|
|
"""Predict and visualize results, marking chunk points"""
|
|
model.eval()
|
|
|
|
# Create save directory
|
|
vis_dir = os.path.join(save_path, 'visualizations')
|
|
os.makedirs(vis_dir, exist_ok=True)
|
|
|
|
sample_count = 0
|
|
|
|
with torch.no_grad():
|
|
for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader:
|
|
if sample_count >= num_samples:
|
|
break
|
|
|
|
batch_x = batch_x.to(device)
|
|
batch_y = batch_y.to(device)
|
|
batch_x_mark = batch_x_mark.to(device)
|
|
batch_y_mark = batch_y_mark.to(device)
|
|
|
|
# Construct decoder input
|
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
|
|
|
# Predict
|
|
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
|
|
|
# Convert to CPU
|
|
batch_x_cpu = batch_x.cpu().numpy()
|
|
batch_y_cpu = batch_y.cpu().numpy()
|
|
outputs_cpu = outputs.cpu().numpy()
|
|
|
|
# Process each sample in batch
|
|
for i in range(min(batch_x.size(0), num_samples - sample_count)):
|
|
input_seq = batch_x_cpu[i] # (seq_len, 2)
|
|
true_pred = batch_y_cpu[i] # (pred_len, 2)
|
|
pred_seq = outputs_cpu[i] # (pred_len, 2)
|
|
|
|
# Get first layer chunk information
|
|
boundary_mask_stage0 = None
|
|
if aux is not None and 'stage0' in aux:
|
|
# aux['stage0']['boundary_mask'] shape: (B*nvars, L)
|
|
# We need to reshape to (B, nvars, L) then take i-th sample
|
|
stage0_info = aux['stage0']
|
|
boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L)
|
|
|
|
B = batch_x.size(0)
|
|
nvars = batch_x.size(2) # should be 2
|
|
L = boundary_mask.shape[1]
|
|
|
|
# Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L)
|
|
boundary_mask = boundary_mask.reshape(B, nvars, L)
|
|
boundary_mask_stage0 = boundary_mask[i] # (nvars, L)
|
|
|
|
# Create visualization for each channel
|
|
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
|
|
|
|
for ch in range(2): # dual channel
|
|
ax = axes[ch]
|
|
|
|
# Time axis
|
|
input_time = np.arange(len(input_seq))
|
|
pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred))
|
|
|
|
# Plot input sequence
|
|
ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5)
|
|
|
|
# Plot ground truth prediction and model prediction
|
|
ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2)
|
|
ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2)
|
|
|
|
# Mark first layer chunk points
|
|
if boundary_mask_stage0 is not None:
|
|
chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel
|
|
for point in chunk_points:
|
|
if point < len(input_seq): # Only mark points within input sequence range
|
|
ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1)
|
|
|
|
# Add chunk points explanation in legend
|
|
if len(chunk_points) > 0:
|
|
ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7,
|
|
linewidth=1, label=f'Chunk Points (Stage 0)')
|
|
|
|
ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}')
|
|
ax.set_xlabel('Time Steps')
|
|
ax.set_ylabel('Value')
|
|
ax.legend()
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
# Add boundary line marking input and prediction boundary
|
|
ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1)
|
|
ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start',
|
|
rotation=90, verticalalignment='top', fontsize=8)
|
|
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
sample_filename = f'sample_{sample_count + 1}_with_chunks.png'
|
|
sample_path = os.path.join(vis_dir, sample_filename)
|
|
plt.savefig(sample_path, dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
print(f"Sample {sample_count + 1} visualization saved to: {sample_path}")
|
|
|
|
# Print chunk statistics
|
|
if boundary_mask_stage0 is not None:
|
|
for ch in range(2):
|
|
chunk_count = np.sum(boundary_mask_stage0[ch])
|
|
chunk_ratio = chunk_count / len(boundary_mask_stage0[ch])
|
|
print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)")
|
|
|
|
sample_count += 1
|
|
if sample_count >= num_samples:
|
|
break
|
|
|
|
if sample_count >= num_samples:
|
|
break
|
|
|
|
def evaluate_model(model, test_loader, device):
|
|
"""Evaluate model performance"""
|
|
model.eval()
|
|
total_mse = 0.0
|
|
total_mae = 0.0
|
|
batch_count = 0
|
|
|
|
all_predictions = []
|
|
all_ground_truths = []
|
|
|
|
with torch.no_grad():
|
|
for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader:
|
|
batch_x = batch_x.to(device)
|
|
batch_y = batch_y.to(device)
|
|
batch_x_mark = batch_x_mark.to(device)
|
|
batch_y_mark = batch_y_mark.to(device)
|
|
|
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
|
|
|
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
|
|
|
# Calculate loss
|
|
mse = torch.mean((outputs - batch_y) ** 2)
|
|
mae = torch.mean(torch.abs(outputs - batch_y))
|
|
|
|
total_mse += mse.item()
|
|
total_mae += mae.item()
|
|
batch_count += 1
|
|
|
|
# Collect for overall statistics
|
|
all_predictions.append(outputs.cpu().numpy())
|
|
all_ground_truths.append(batch_y.cpu().numpy())
|
|
|
|
avg_mse = total_mse / batch_count
|
|
avg_mae = total_mae / batch_count
|
|
|
|
# Calculate overall RMSE
|
|
all_predictions = np.concatenate(all_predictions, axis=0)
|
|
all_ground_truths = np.concatenate(all_ground_truths, axis=0)
|
|
rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2))
|
|
|
|
print(f"\nTest set evaluation results:")
|
|
print(f"MSE: {avg_mse:.6f}")
|
|
print(f"MAE: {avg_mae:.6f}")
|
|
print(f"RMSE: {rmse:.6f}")
|
|
|
|
return avg_mse, avg_mae, rmse
|
|
|
|
def main():
|
|
# Configuration parameters
|
|
model_path = './results/dc_patchtst_sine_wave/best_model.pth'
|
|
data_path = './data/sine_wave/'
|
|
save_path = './results/dc_patchtst_sine_wave/'
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Error: Model file not found {model_path}")
|
|
print("Please run the training script train_dc_patchtst.py first")
|
|
return
|
|
|
|
# Load model
|
|
model, config, device = load_model(model_path)
|
|
|
|
# Load test data
|
|
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
|
|
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization
|
|
|
|
print(f"\nConfiguration info:")
|
|
print(f"Sequence length: {config.seq_len}")
|
|
print(f"Prediction length: {config.pred_len}")
|
|
print(f"Feature dimension: {config.enc_in}")
|
|
|
|
# Evaluate model
|
|
mse, mae, rmse = evaluate_model(model, test_loader, device)
|
|
|
|
# Visualize prediction results and mark chunk points
|
|
print(f"\nGenerating visualization results...")
|
|
visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5)
|
|
|
|
print(f"\nAll results saved to: {save_path}")
|
|
print(f"Visualization files located at: {save_path}/visualizations/")
|
|
|
|
if __name__ == "__main__":
|
|
main() |