93 lines
3.4 KiB
Python
93 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for processing ETT datasets with different prediction lengths.
|
|
Processes ETTm1.csv and ETTm2.csv with prediction lengths of 96, 192, 336, 720.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from dataflow import process_and_save_time_series
|
|
|
|
def main():
|
|
# Configuration
|
|
datasets = ['ETTm1', 'ETTm2']
|
|
input_len = 96
|
|
pred_lengths = [96, 192, 336, 720]
|
|
slide_step = 1
|
|
|
|
# Split ratios (train:test:val = 6:2:2)
|
|
train_ratio = 0.6
|
|
test_ratio = 0.2
|
|
val_ratio = 0.2
|
|
|
|
# Base paths
|
|
data_dir = 'data/ETT-small'
|
|
output_dir = 'processed_data'
|
|
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
print("Starting ETT dataset processing...")
|
|
print(f"Input length: {input_len}")
|
|
print(f"Split ratios - Train: {train_ratio}, Test: {test_ratio}, Val: {val_ratio}")
|
|
print("-" * 60)
|
|
|
|
# Process each dataset
|
|
for dataset in datasets:
|
|
csv_path = os.path.join(data_dir, f"{dataset}.csv")
|
|
|
|
# Check if CSV file exists
|
|
if not os.path.exists(csv_path):
|
|
print(f"Warning: {csv_path} not found, skipping...")
|
|
continue
|
|
|
|
print(f"\nProcessing {dataset}...")
|
|
|
|
# Process each prediction length
|
|
for pred_len in pred_lengths:
|
|
output_file = os.path.join(output_dir, f"{dataset}_input{input_len}_pred{pred_len}.npz")
|
|
|
|
print(f" - Prediction length {pred_len} -> {output_file}")
|
|
|
|
try:
|
|
# Read CSV to get column names and exclude the date column
|
|
import pandas as pd
|
|
sample_data = pd.read_csv(csv_path)
|
|
|
|
# Get all columns except the first one (date column)
|
|
feature_columns = sample_data.columns[1:].tolist()
|
|
print(f" Features: {feature_columns} (excluding date column)")
|
|
|
|
result = process_and_save_time_series(
|
|
csv_path=csv_path,
|
|
output_file=output_file,
|
|
input_len=input_len,
|
|
pred_len=pred_len,
|
|
slide_step=slide_step,
|
|
train_ratio=train_ratio,
|
|
test_ratio=test_ratio,
|
|
val_ratio=val_ratio,
|
|
selected_columns=feature_columns,
|
|
date_column='date',
|
|
freq='h'
|
|
)
|
|
|
|
# Print dataset shapes for verification
|
|
print(f" Train: {result['train_x'].shape} -> {result['train_y'].shape}")
|
|
print(f" Test: {result['test_x'].shape} -> {result['test_y'].shape}")
|
|
print(f" Val: {result['val_x'].shape} -> {result['val_y'].shape}")
|
|
print(f" Train time marks: {result['train_x_mark'].shape} -> {result['train_y_mark'].shape}")
|
|
print(f" Test time marks: {result['test_x_mark'].shape} -> {result['test_y_mark'].shape}")
|
|
print(f" Val time marks: {result['val_x_mark'].shape} -> {result['val_y_mark'].shape}")
|
|
|
|
except Exception as e:
|
|
print(f" Error processing {dataset} with pred_len {pred_len}: {e}")
|
|
continue
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Processing completed!")
|
|
print(f"Output files saved in: {os.path.abspath(output_dir)}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|