first timesnet try

This commit is contained in:
game-loader
2025-07-30 21:18:46 +08:00
parent dc8c9f1f09
commit 6ee5c769c4
17 changed files with 2918 additions and 0 deletions

92
test.py Normal file
View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
Test script for processing ETT datasets with different prediction lengths.
Processes ETTm1.csv and ETTm2.csv with prediction lengths of 96, 192, 336, 720.
"""
import os
import sys
from dataflow import process_and_save_time_series
def main():
# Configuration
datasets = ['ETTm1', 'ETTm2']
input_len = 96
pred_lengths = [96, 192, 336, 720]
slide_step = 1
# Split ratios (train:test:val = 6:2:2)
train_ratio = 0.6
test_ratio = 0.2
val_ratio = 0.2
# Base paths
data_dir = 'data/ETT-small'
output_dir = 'processed_data'
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
print("Starting ETT dataset processing...")
print(f"Input length: {input_len}")
print(f"Split ratios - Train: {train_ratio}, Test: {test_ratio}, Val: {val_ratio}")
print("-" * 60)
# Process each dataset
for dataset in datasets:
csv_path = os.path.join(data_dir, f"{dataset}.csv")
# Check if CSV file exists
if not os.path.exists(csv_path):
print(f"Warning: {csv_path} not found, skipping...")
continue
print(f"\nProcessing {dataset}...")
# Process each prediction length
for pred_len in pred_lengths:
output_file = os.path.join(output_dir, f"{dataset}_input{input_len}_pred{pred_len}.npz")
print(f" - Prediction length {pred_len} -> {output_file}")
try:
# Read CSV to get column names and exclude the date column
import pandas as pd
sample_data = pd.read_csv(csv_path)
# Get all columns except the first one (date column)
feature_columns = sample_data.columns[1:].tolist()
print(f" Features: {feature_columns} (excluding date column)")
result = process_and_save_time_series(
csv_path=csv_path,
output_file=output_file,
input_len=input_len,
pred_len=pred_len,
slide_step=slide_step,
train_ratio=train_ratio,
test_ratio=test_ratio,
val_ratio=val_ratio,
selected_columns=feature_columns,
date_column='date',
freq='h'
)
# Print dataset shapes for verification
print(f" Train: {result['train_x'].shape} -> {result['train_y'].shape}")
print(f" Test: {result['test_x'].shape} -> {result['test_y'].shape}")
print(f" Val: {result['val_x'].shape} -> {result['val_y'].shape}")
print(f" Train time marks: {result['train_x_mark'].shape} -> {result['train_y_mark'].shape}")
print(f" Test time marks: {result['test_x_mark'].shape} -> {result['test_y_mark'].shape}")
print(f" Val time marks: {result['val_x_mark'].shape} -> {result['val_y_mark'].shape}")
except Exception as e:
print(f" Error processing {dataset} with pred_len {pred_len}: {e}")
continue
print("\n" + "=" * 60)
print("Processing completed!")
print(f"Output files saved in: {os.path.abspath(output_dir)}")
if __name__ == "__main__":
main()