first timesnet try
This commit is contained in:
92
test.py
Normal file
92
test.py
Normal file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for processing ETT datasets with different prediction lengths.
|
||||
Processes ETTm1.csv and ETTm2.csv with prediction lengths of 96, 192, 336, 720.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataflow import process_and_save_time_series
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
datasets = ['ETTm1', 'ETTm2']
|
||||
input_len = 96
|
||||
pred_lengths = [96, 192, 336, 720]
|
||||
slide_step = 1
|
||||
|
||||
# Split ratios (train:test:val = 6:2:2)
|
||||
train_ratio = 0.6
|
||||
test_ratio = 0.2
|
||||
val_ratio = 0.2
|
||||
|
||||
# Base paths
|
||||
data_dir = 'data/ETT-small'
|
||||
output_dir = 'processed_data'
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print("Starting ETT dataset processing...")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Split ratios - Train: {train_ratio}, Test: {test_ratio}, Val: {val_ratio}")
|
||||
print("-" * 60)
|
||||
|
||||
# Process each dataset
|
||||
for dataset in datasets:
|
||||
csv_path = os.path.join(data_dir, f"{dataset}.csv")
|
||||
|
||||
# Check if CSV file exists
|
||||
if not os.path.exists(csv_path):
|
||||
print(f"Warning: {csv_path} not found, skipping...")
|
||||
continue
|
||||
|
||||
print(f"\nProcessing {dataset}...")
|
||||
|
||||
# Process each prediction length
|
||||
for pred_len in pred_lengths:
|
||||
output_file = os.path.join(output_dir, f"{dataset}_input{input_len}_pred{pred_len}.npz")
|
||||
|
||||
print(f" - Prediction length {pred_len} -> {output_file}")
|
||||
|
||||
try:
|
||||
# Read CSV to get column names and exclude the date column
|
||||
import pandas as pd
|
||||
sample_data = pd.read_csv(csv_path)
|
||||
|
||||
# Get all columns except the first one (date column)
|
||||
feature_columns = sample_data.columns[1:].tolist()
|
||||
print(f" Features: {feature_columns} (excluding date column)")
|
||||
|
||||
result = process_and_save_time_series(
|
||||
csv_path=csv_path,
|
||||
output_file=output_file,
|
||||
input_len=input_len,
|
||||
pred_len=pred_len,
|
||||
slide_step=slide_step,
|
||||
train_ratio=train_ratio,
|
||||
test_ratio=test_ratio,
|
||||
val_ratio=val_ratio,
|
||||
selected_columns=feature_columns,
|
||||
date_column='date',
|
||||
freq='h'
|
||||
)
|
||||
|
||||
# Print dataset shapes for verification
|
||||
print(f" Train: {result['train_x'].shape} -> {result['train_y'].shape}")
|
||||
print(f" Test: {result['test_x'].shape} -> {result['test_y'].shape}")
|
||||
print(f" Val: {result['val_x'].shape} -> {result['val_y'].shape}")
|
||||
print(f" Train time marks: {result['train_x_mark'].shape} -> {result['train_y_mark'].shape}")
|
||||
print(f" Test time marks: {result['test_x_mark'].shape} -> {result['test_y_mark'].shape}")
|
||||
print(f" Val time marks: {result['val_x_mark'].shape} -> {result['val_y_mark'].shape}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Error processing {dataset} with pred_len {pred_len}: {e}")
|
||||
continue
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Processing completed!")
|
||||
print(f"Output files saved in: {os.path.abspath(output_dir)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user