Files
tsmodel/test.py
2025-07-30 21:18:46 +08:00

93 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""
Test script for processing ETT datasets with different prediction lengths.
Processes ETTm1.csv and ETTm2.csv with prediction lengths of 96, 192, 336, 720.
"""
import os
import sys
from dataflow import process_and_save_time_series
def main():
# Configuration
datasets = ['ETTm1', 'ETTm2']
input_len = 96
pred_lengths = [96, 192, 336, 720]
slide_step = 1
# Split ratios (train:test:val = 6:2:2)
train_ratio = 0.6
test_ratio = 0.2
val_ratio = 0.2
# Base paths
data_dir = 'data/ETT-small'
output_dir = 'processed_data'
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
print("Starting ETT dataset processing...")
print(f"Input length: {input_len}")
print(f"Split ratios - Train: {train_ratio}, Test: {test_ratio}, Val: {val_ratio}")
print("-" * 60)
# Process each dataset
for dataset in datasets:
csv_path = os.path.join(data_dir, f"{dataset}.csv")
# Check if CSV file exists
if not os.path.exists(csv_path):
print(f"Warning: {csv_path} not found, skipping...")
continue
print(f"\nProcessing {dataset}...")
# Process each prediction length
for pred_len in pred_lengths:
output_file = os.path.join(output_dir, f"{dataset}_input{input_len}_pred{pred_len}.npz")
print(f" - Prediction length {pred_len} -> {output_file}")
try:
# Read CSV to get column names and exclude the date column
import pandas as pd
sample_data = pd.read_csv(csv_path)
# Get all columns except the first one (date column)
feature_columns = sample_data.columns[1:].tolist()
print(f" Features: {feature_columns} (excluding date column)")
result = process_and_save_time_series(
csv_path=csv_path,
output_file=output_file,
input_len=input_len,
pred_len=pred_len,
slide_step=slide_step,
train_ratio=train_ratio,
test_ratio=test_ratio,
val_ratio=val_ratio,
selected_columns=feature_columns,
date_column='date',
freq='h'
)
# Print dataset shapes for verification
print(f" Train: {result['train_x'].shape} -> {result['train_y'].shape}")
print(f" Test: {result['test_x'].shape} -> {result['test_y'].shape}")
print(f" Val: {result['val_x'].shape} -> {result['val_y'].shape}")
print(f" Train time marks: {result['train_x_mark'].shape} -> {result['train_y_mark'].shape}")
print(f" Test time marks: {result['test_x_mark'].shape} -> {result['test_y_mark'].shape}")
print(f" Val time marks: {result['val_x_mark'].shape} -> {result['val_y_mark'].shape}")
except Exception as e:
print(f" Error processing {dataset} with pred_len {pred_len}: {e}")
continue
print("\n" + "=" * 60)
print("Processing completed!")
print(f"Output files saved in: {os.path.abspath(output_dir)}")
if __name__ == "__main__":
main()