Files
tsmodel/dataflow/tsf.py

329 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import joblib
from utils.timefeatures import time_features
import os
def get_ett_dataset_borders(dataset_name, data_len, input_len):
"""
ETT系列数据集的特定边界处理函数
Args:
dataset_name (str): 数据集名称(如 'ETTm1', 'ETTh1'
data_len (int): 数据总长度
input_len (int): 输入序列长度
Returns:
tuple: (border1s, border2s) 边界点列表
"""
if dataset_name.startswith('ETTm'):
# ETTm1, ETTm2: 15分钟间隔每天96个点
border1s = [0, 12 * 30 * 96 - input_len, 12 * 30 * 96 + 4 * 30 * 96 - input_len]
border2s = [12 * 30 * 96, 12 * 30 * 96 + 4 * 30 * 96, 12 * 30 * 96 + 8 * 30 * 96]
elif dataset_name.startswith('ETTh'):
# ETTh1, ETTh2: 小时间隔每天24个点
border1s = [0, 12 * 30 * 24 - input_len, 12 * 30 * 24 + 4 * 30 * 24 - input_len]
border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
else:
raise ValueError(f"Unknown ETT dataset: {dataset_name}")
return border1s, border2s
# 示例:可以添加其他特定数据集的处理函数
# def get_weather_dataset_borders(dataset_name, data_len, input_len):
# """
# Weather数据集的特定边界处理函数
# """
# # 假设weather数据集使用不同的分割策略
# # 比如前80%训练中间10%验证后10%测试
# train_end = int(data_len * 0.8)
# val_end = int(data_len * 0.9)
#
# border1s = [0, train_end - input_len, val_end - input_len]
# border2s = [train_end, val_end, data_len]
#
# return border1s, border2s
# 数据集处理函数映射表
DATASET_HANDLERS = {
'ETTm1': get_ett_dataset_borders,
'ETTm2': get_ett_dataset_borders,
'ETTh1': get_ett_dataset_borders,
'ETTh2': get_ett_dataset_borders,
# 可以在这里添加更多数据集的处理函数
# 'weather': get_weather_dataset_borders,
}
def preprocess_time_series(
csv_data,
input_len,
pred_len,
slide_step,
dataset_name=None, # 新增:数据集名称参数
data_path_name='ETTm1.csv', # 保留向后兼容但优先使用dataset_name
selected_columns=None,
date_column='date',
freq='h', # 按照分析,原文 ETTm1/ETTh1 实验均使用 'h'
split_method='auto', # 'auto', 'specific', 'ratio'
train_ratio=0.7,
val_ratio=0.1,
test_ratio=0.2,
has_time_column=True, # 新增:是否包含时间列
):
"""
修改版:根据 TimesNet 原文逻辑预处理时序数据。
1. 支持三种分割方法auto自动选择、specific特定数据集、ratio比例分割
2. 支持基于数据集名称的特定处理函数调用
3. 滑动窗口的目标 y 长度为 pred_len (按用户要求)。
4. 支持无时间列的数据集处理
Args:
csv_data (pd.DataFrame or str): CSV data as DataFrame or path to CSV file
input_len (int): Length of input sequence (seq_len in original paper)
pred_len (int): Length of prediction sequence
slide_step (int): Step size for sliding window
dataset_name (str): 数据集名称(如 'ETTm1', 'weather'),优先使用此参数
data_path_name (str): 数据文件名(如 'ETTm1.csv'),向后兼容用
selected_columns (list): List of column names to use (default: None, uses all).
date_column (str): Name of the date column (default: 'date').
freq (str): Frequency for time features ('h' for hourly, 't' for minutely).
split_method (str): Data split method - 'auto', 'specific', or 'ratio'
- 'auto': automatically choose based on dataset_name
- 'specific': use dataset-specific split function
- 'ratio': use ratio-based split
train_ratio (float): Training set ratio (only used when split_method='ratio')
val_ratio (float): Validation set ratio (only used when split_method='ratio')
test_ratio (float): Test set ratio (only used when split_method='ratio')
has_time_column (bool): Whether the dataset has a time column
Returns:
dict: Dictionary containing processed data.
"""
# 1. 加载数据
if isinstance(csv_data, str):
try:
data = pd.read_csv(csv_data)
except FileNotFoundError:
raise FileNotFoundError(f"CSV file not found: {csv_data}")
except Exception as e:
raise Exception(f"Error loading CSV file: {e}")
else:
data = csv_data.copy()
# 2. 提取时间特征(仅在有时间列时)
if has_time_column and date_column in data.columns:
date_index = pd.to_datetime(data[date_column])
if isinstance(date_index, pd.Series):
date_index = pd.DatetimeIndex(date_index)
time_stamp = time_features(date_index, freq=freq)
time_stamp = time_stamp.transpose(1, 0) # Shape: (n_samples, n_time_features)
elif has_time_column:
raise ValueError(f"Date column '{date_column}' not found in data")
else:
# 没有时间列,创建空的时间戳数组
time_stamp = None
# 3. 选择数据列
if selected_columns is not None:
data = data[selected_columns]
else:
if has_time_column:
feature_columns = [col for col in data.columns if col != date_column]
else:
feature_columns = list(data.columns)
data = data[feature_columns]
# 4. 【核心修改】根据split_method选择数据集分割方式
# 确定使用的数据集名称
if dataset_name is None:
# 向后兼容:从文件路径提取数据集名称
dataset_name = os.path.splitext(data_path_name)[0]
if split_method == 'auto':
# 自动选择特定数据集用specific其他用ratio
if dataset_name in DATASET_HANDLERS:
split_method = 'specific'
else:
split_method = 'ratio'
if split_method == 'specific':
# 使用特定数据集的处理函数
if dataset_name in DATASET_HANDLERS:
handler_func = DATASET_HANDLERS[dataset_name]
border1s, border2s = handler_func(dataset_name, len(data), input_len)
print(f"Using specific split for dataset '{dataset_name}'")
else:
print(f"Warning: No specific handler for dataset '{dataset_name}'. Falling back to ratio split.")
split_method = 'ratio'
if split_method == 'ratio':
# 使用比例分割数据集
# 验证比例和为1
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
total_len = len(data)
num_train = int(total_len * train_ratio)
num_val = int(total_len * val_ratio)
num_test = total_len - num_train - num_val # 确保所有数据都被使用
border1s = [0, num_train - input_len, num_train + num_val - input_len]
border2s = [num_train, num_train + num_val, total_len]
print(f"Using ratio split for dataset '{dataset_name}': train={train_ratio:.1%}, val={val_ratio:.1%}, test={test_ratio:.1%}")
print(f"Data points: train={num_train}, val={num_val}, test={num_test}")
train_data = data.iloc[border1s[0]:border2s[0]].values
val_data = data.iloc[border1s[1]:border2s[1]].values
test_data = data.iloc[border1s[2]:border2s[2]].values
# 处理时间戳(仅在有时间列时)
if time_stamp is not None:
train_time_stamp = time_stamp[border1s[0]:border2s[0]]
val_time_stamp = time_stamp[border1s[1]:border2s[1]]
test_time_stamp = time_stamp[border1s[2]:border2s[2]]
else:
train_time_stamp = None
val_time_stamp = None
test_time_stamp = None
# 5. 归一化 (Fit on training data only)
scaler = StandardScaler()
scaler.fit(train_data)
train_data_scaled = scaler.transform(train_data)
val_data_scaled = scaler.transform(val_data)
test_data_scaled = scaler.transform(test_data)
# 6. 【核心修改】使用您的滑窗逻辑创建样本
train_x, train_y = create_sliding_windows(train_data_scaled, input_len, pred_len, slide_step)
val_x, val_y = create_sliding_windows(val_data_scaled, input_len, pred_len, slide_step)
test_x, test_y = create_sliding_windows(test_data_scaled, input_len, pred_len, slide_step)
# 处理时间标记(仅在有时间列时)
if train_time_stamp is not None:
train_x_mark, train_y_mark = create_sliding_windows(train_time_stamp, input_len, pred_len, slide_step)
val_x_mark, val_y_mark = create_sliding_windows(val_time_stamp, input_len, pred_len, slide_step)
test_x_mark, test_y_mark = create_sliding_windows(test_time_stamp, input_len, pred_len, slide_step)
else:
train_x_mark, train_y_mark = None, None
val_x_mark, val_y_mark = None, None
test_x_mark, test_y_mark = None, None
return {
'train_x': train_x, 'train_y': train_y,
'train_x_mark': train_x_mark, 'train_y_mark': train_y_mark,
'val_x': val_x, 'val_y': val_y,
'val_x_mark': val_x_mark, 'val_y_mark': val_y_mark,
'test_x': test_x, 'test_y': test_y,
'test_x_mark': test_x_mark, 'test_y_mark': test_y_mark,
'scaler': scaler
}
def create_sliding_windows(data, input_len, pred_len, slide_step):
"""
Create sliding windows from time series data.
Target `y` has length `pred_len`.
Args:
data (np.ndarray): Time series data (features or time marks)
input_len (int): Length of input sequence
pred_len (int): Length of prediction sequence
slide_step (int): Step size for sliding window
Returns:
tuple: (X, y) where X is input sequences and y is target sequences
"""
total_window_len = input_len + pred_len
X, y = [], []
n_samples = len(data)
for start_idx in range(0, n_samples, slide_step):
end_idx = start_idx + total_window_len
# Skip if there's not enough data for a full window
if end_idx > n_samples:
break
# Split window into input and target
input_window = data[start_idx : start_idx + input_len]
target_window = data[start_idx + input_len : end_idx]
X.append(input_window)
y.append(target_window)
return np.array(X), np.array(y)
def process_and_save_time_series(
csv_path,
output_file,
input_len,
pred_len,
slide_step,
dataset_name=None, # 新增:数据集名称参数
selected_columns=None,
date_column='date',
freq='h',
split_method='auto',
train_ratio=0.7,
val_ratio=0.1,
test_ratio=0.2,
has_time_column=True, # 新增:是否包含时间列
):
"""
Process time series data and save it as an NPZ file along with the fitted scaler.
This function now calls the modified preprocess_time_series with flexible split methods.
Args:
dataset_name (str): 数据集名称(如 'ETTm1', 'weather'),优先使用此参数
split_method (str): Data split method - 'auto', 'specific', or 'ratio'
train_ratio (float): Training set ratio (only used when split_method='ratio')
val_ratio (float): Validation set ratio (only used when split_method='ratio')
test_ratio (float): Test set ratio (only used when split_method='ratio')
has_time_column (bool): Whether the dataset has a time column
"""
# Create output directory if it doesn't exist
output_dir = os.path.dirname(os.path.abspath(output_file))
os.makedirs(output_dir, exist_ok=True)
# Extract data file name from path
data_path_name = os.path.basename(csv_path)
# Load and preprocess the time series data using the new logic
result = preprocess_time_series(
csv_data=csv_path,
input_len=input_len,
pred_len=pred_len,
slide_step=slide_step,
dataset_name=dataset_name,
data_path_name=data_path_name,
selected_columns=selected_columns,
date_column=date_column,
freq=freq,
split_method=split_method,
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
has_time_column=has_time_column
)
# Extract the processed data
scaler = result.pop('scaler') # Pop scaler to not save it in the npz
# Save the scaler object separately
scaler_file = output_file.replace('.npz', '_scaler.gz')
joblib.dump(scaler, scaler_file)
print(f"Saved scaler to {scaler_file}")
# Save the processed data arrays as .npz file
np.savez(output_file, **result)
print(f"Saved processed data to {output_file}")
return result