feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
313
test_dc_patchtst.py
Normal file
313
test_dc_patchtst.py
Normal file
@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from models.DC_PatchTST import Model
|
||||
|
||||
class Config:
|
||||
"""Configuration class"""
|
||||
def __init__(self):
|
||||
# Basic configuration
|
||||
self.task_name = 'long_term_forecast'
|
||||
self.model = 'DC_PatchTST'
|
||||
|
||||
# Data configuration
|
||||
self.seq_len = 96 # Input sequence length
|
||||
self.pred_len = 24 # Prediction sequence length
|
||||
self.label_len = 48 # Label length
|
||||
self.enc_in = 2 # Input feature dimension (dual channel)
|
||||
self.dec_in = 2 # Decoder input dimension
|
||||
self.c_out = 2 # Output dimension
|
||||
|
||||
# Model configuration
|
||||
self.d_model = 128 # Model dimension
|
||||
self.n_heads = 8 # Number of attention heads
|
||||
self.e_layers = 2 # Number of encoder layers
|
||||
self.d_layers = 1 # Number of decoder layers
|
||||
self.d_ff = 256 # Feed forward dimension
|
||||
self.factor = 1 # Attention factor
|
||||
self.dropout = 0.1 # Dropout rate
|
||||
self.activation = 'gelu'
|
||||
|
||||
# Training configuration
|
||||
self.batch_size = 32
|
||||
self.learning_rate = 0.001
|
||||
self.train_epochs = 50
|
||||
self.patience = 5
|
||||
|
||||
# Other configuration
|
||||
self.use_amp = False
|
||||
self.num_class = 0
|
||||
|
||||
# GPU configuration
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
||||
|
||||
class SineWaveDataset(Dataset):
|
||||
"""Sine wave dataset"""
|
||||
def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.mode = mode
|
||||
|
||||
# Load data
|
||||
if mode == 'train':
|
||||
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
|
||||
elif mode == 'val':
|
||||
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
|
||||
else: # test
|
||||
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
|
||||
|
||||
# Extract feature columns (except timestamp)
|
||||
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
|
||||
|
||||
# Calculate available sample count
|
||||
self.total_len = len(self.data)
|
||||
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
|
||||
|
||||
print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples")
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_num
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Input sequence
|
||||
s_begin = idx
|
||||
s_end = s_begin + self.seq_len
|
||||
|
||||
# Prediction target
|
||||
r_begin = s_end
|
||||
r_end = r_begin + self.pred_len
|
||||
|
||||
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
|
||||
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
|
||||
|
||||
# Time marks (simple positional encoding)
|
||||
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
|
||||
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
|
||||
|
||||
return seq_x, seq_y, seq_x_mark, seq_y_mark, idx
|
||||
|
||||
def load_model(model_path):
|
||||
"""Load saved model"""
|
||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
config = checkpoint['config']
|
||||
|
||||
# Create model
|
||||
model = Model(config)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
return model, config, device
|
||||
|
||||
def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5):
|
||||
"""Predict and visualize results, marking chunk points"""
|
||||
model.eval()
|
||||
|
||||
# Create save directory
|
||||
vis_dir = os.path.join(save_path, 'visualizations')
|
||||
os.makedirs(vis_dir, exist_ok=True)
|
||||
|
||||
sample_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader:
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
batch_x = batch_x.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
batch_x_mark = batch_x_mark.to(device)
|
||||
batch_y_mark = batch_y_mark.to(device)
|
||||
|
||||
# Construct decoder input
|
||||
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||
|
||||
# Predict
|
||||
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
|
||||
# Convert to CPU
|
||||
batch_x_cpu = batch_x.cpu().numpy()
|
||||
batch_y_cpu = batch_y.cpu().numpy()
|
||||
outputs_cpu = outputs.cpu().numpy()
|
||||
|
||||
# Process each sample in batch
|
||||
for i in range(min(batch_x.size(0), num_samples - sample_count)):
|
||||
input_seq = batch_x_cpu[i] # (seq_len, 2)
|
||||
true_pred = batch_y_cpu[i] # (pred_len, 2)
|
||||
pred_seq = outputs_cpu[i] # (pred_len, 2)
|
||||
|
||||
# Get first layer chunk information
|
||||
boundary_mask_stage0 = None
|
||||
if aux is not None and 'stage0' in aux:
|
||||
# aux['stage0']['boundary_mask'] shape: (B*nvars, L)
|
||||
# We need to reshape to (B, nvars, L) then take i-th sample
|
||||
stage0_info = aux['stage0']
|
||||
boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L)
|
||||
|
||||
B = batch_x.size(0)
|
||||
nvars = batch_x.size(2) # should be 2
|
||||
L = boundary_mask.shape[1]
|
||||
|
||||
# Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L)
|
||||
boundary_mask = boundary_mask.reshape(B, nvars, L)
|
||||
boundary_mask_stage0 = boundary_mask[i] # (nvars, L)
|
||||
|
||||
# Create visualization for each channel
|
||||
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
|
||||
|
||||
for ch in range(2): # dual channel
|
||||
ax = axes[ch]
|
||||
|
||||
# Time axis
|
||||
input_time = np.arange(len(input_seq))
|
||||
pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred))
|
||||
|
||||
# Plot input sequence
|
||||
ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5)
|
||||
|
||||
# Plot ground truth prediction and model prediction
|
||||
ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2)
|
||||
ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2)
|
||||
|
||||
# Mark first layer chunk points
|
||||
if boundary_mask_stage0 is not None:
|
||||
chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel
|
||||
for point in chunk_points:
|
||||
if point < len(input_seq): # Only mark points within input sequence range
|
||||
ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1)
|
||||
|
||||
# Add chunk points explanation in legend
|
||||
if len(chunk_points) > 0:
|
||||
ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7,
|
||||
linewidth=1, label=f'Chunk Points (Stage 0)')
|
||||
|
||||
ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}')
|
||||
ax.set_xlabel('Time Steps')
|
||||
ax.set_ylabel('Value')
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Add boundary line marking input and prediction boundary
|
||||
ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1)
|
||||
ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start',
|
||||
rotation=90, verticalalignment='top', fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
sample_filename = f'sample_{sample_count + 1}_with_chunks.png'
|
||||
sample_path = os.path.join(vis_dir, sample_filename)
|
||||
plt.savefig(sample_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"Sample {sample_count + 1} visualization saved to: {sample_path}")
|
||||
|
||||
# Print chunk statistics
|
||||
if boundary_mask_stage0 is not None:
|
||||
for ch in range(2):
|
||||
chunk_count = np.sum(boundary_mask_stage0[ch])
|
||||
chunk_ratio = chunk_count / len(boundary_mask_stage0[ch])
|
||||
print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)")
|
||||
|
||||
sample_count += 1
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
if sample_count >= num_samples:
|
||||
break
|
||||
|
||||
def evaluate_model(model, test_loader, device):
|
||||
"""Evaluate model performance"""
|
||||
model.eval()
|
||||
total_mse = 0.0
|
||||
total_mae = 0.0
|
||||
batch_count = 0
|
||||
|
||||
all_predictions = []
|
||||
all_ground_truths = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader:
|
||||
batch_x = batch_x.to(device)
|
||||
batch_y = batch_y.to(device)
|
||||
batch_x_mark = batch_x_mark.to(device)
|
||||
batch_y_mark = batch_y_mark.to(device)
|
||||
|
||||
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||
|
||||
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
|
||||
# Calculate loss
|
||||
mse = torch.mean((outputs - batch_y) ** 2)
|
||||
mae = torch.mean(torch.abs(outputs - batch_y))
|
||||
|
||||
total_mse += mse.item()
|
||||
total_mae += mae.item()
|
||||
batch_count += 1
|
||||
|
||||
# Collect for overall statistics
|
||||
all_predictions.append(outputs.cpu().numpy())
|
||||
all_ground_truths.append(batch_y.cpu().numpy())
|
||||
|
||||
avg_mse = total_mse / batch_count
|
||||
avg_mae = total_mae / batch_count
|
||||
|
||||
# Calculate overall RMSE
|
||||
all_predictions = np.concatenate(all_predictions, axis=0)
|
||||
all_ground_truths = np.concatenate(all_ground_truths, axis=0)
|
||||
rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2))
|
||||
|
||||
print(f"\nTest set evaluation results:")
|
||||
print(f"MSE: {avg_mse:.6f}")
|
||||
print(f"MAE: {avg_mae:.6f}")
|
||||
print(f"RMSE: {rmse:.6f}")
|
||||
|
||||
return avg_mse, avg_mae, rmse
|
||||
|
||||
def main():
|
||||
# Configuration parameters
|
||||
model_path = './results/dc_patchtst_sine_wave/best_model.pth'
|
||||
data_path = './data/sine_wave/'
|
||||
save_path = './results/dc_patchtst_sine_wave/'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Error: Model file not found {model_path}")
|
||||
print("Please run the training script train_dc_patchtst.py first")
|
||||
return
|
||||
|
||||
# Load model
|
||||
model, config, device = load_model(model_path)
|
||||
|
||||
# Load test data
|
||||
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
|
||||
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization
|
||||
|
||||
print(f"\nConfiguration info:")
|
||||
print(f"Sequence length: {config.seq_len}")
|
||||
print(f"Prediction length: {config.pred_len}")
|
||||
print(f"Feature dimension: {config.enc_in}")
|
||||
|
||||
# Evaluate model
|
||||
mse, mae, rmse = evaluate_model(model, test_loader, device)
|
||||
|
||||
# Visualize prediction results and mark chunk points
|
||||
print(f"\nGenerating visualization results...")
|
||||
visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5)
|
||||
|
||||
print(f"\nAll results saved to: {save_path}")
|
||||
print(f"Visualization files located at: {save_path}/visualizations/")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user