From 4129832f987c398782c0b850b6133cd6ab7030b0 Mon Sep 17 00:00:00 2001 From: game-loader Date: Wed, 27 Aug 2025 15:58:04 +0800 Subject: [PATCH] fix(dataflow): correct target sequence definition in data loader --- dataflow/data_loader.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/dataflow/data_loader.py b/dataflow/data_loader.py index 215d986..eb14059 100644 --- a/dataflow/data_loader.py +++ b/dataflow/data_loader.py @@ -79,15 +79,21 @@ class Dataset_ETT_hour(Dataset): self.data_stamp = data_stamp def __getitem__(self, index): + # 1. 定义输入序列 seq_x 的起止位置 s_begin = index s_end = s_begin + self.seq_len - r_begin = s_end - self.label_len - r_end = r_begin + self.label_len + self.pred_len - + # 2. 定义目标序列 seq_y 的起止位置 + # seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end) + r_begin = s_end + # seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len) + r_end = r_begin + self.pred_len + # 3. 根据起止位置切片数据 seq_x = self.data_x[s_begin:s_end] seq_y = self.data_y[r_begin:r_end] seq_x_mark = self.data_stamp[s_begin:s_end] seq_y_mark = self.data_stamp[r_begin:r_end] + seq_x = seq_x.astype('float32') + seq_y = seq_y.astype('float32') return seq_x, seq_y, seq_x_mark, seq_y_mark @@ -169,15 +175,22 @@ class Dataset_ETT_minute(Dataset): self.data_stamp = data_stamp def __getitem__(self, index): + + # 1. 定义输入序列 seq_x 的起止位置 s_begin = index s_end = s_begin + self.seq_len - r_begin = s_end - self.label_len - r_end = r_begin + self.label_len + self.pred_len - + # 2. 定义目标序列 seq_y 的起止位置 + # seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end) + r_begin = s_end + # seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len) + r_end = r_begin + self.pred_len + # 3. 根据起止位置切片数据 seq_x = self.data_x[s_begin:s_end] seq_y = self.data_y[r_begin:r_end] seq_x_mark = self.data_stamp[s_begin:s_end] seq_y_mark = self.data_stamp[r_begin:r_end] + seq_x = seq_x.astype('float32') + seq_y = seq_y.astype('float32') return seq_x, seq_y, seq_x_mark, seq_y_mark