fix(dataflow): correct target sequence definition in data loader
This commit is contained in:
@ -79,15 +79,21 @@ class Dataset_ETT_hour(Dataset):
|
|||||||
self.data_stamp = data_stamp
|
self.data_stamp = data_stamp
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
# 1. 定义输入序列 seq_x 的起止位置
|
||||||
s_begin = index
|
s_begin = index
|
||||||
s_end = s_begin + self.seq_len
|
s_end = s_begin + self.seq_len
|
||||||
r_begin = s_end - self.label_len
|
# 2. 定义目标序列 seq_y 的起止位置
|
||||||
r_end = r_begin + self.label_len + self.pred_len
|
# seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end)
|
||||||
|
r_begin = s_end
|
||||||
|
# seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len)
|
||||||
|
r_end = r_begin + self.pred_len
|
||||||
|
# 3. 根据起止位置切片数据
|
||||||
seq_x = self.data_x[s_begin:s_end]
|
seq_x = self.data_x[s_begin:s_end]
|
||||||
seq_y = self.data_y[r_begin:r_end]
|
seq_y = self.data_y[r_begin:r_end]
|
||||||
seq_x_mark = self.data_stamp[s_begin:s_end]
|
seq_x_mark = self.data_stamp[s_begin:s_end]
|
||||||
seq_y_mark = self.data_stamp[r_begin:r_end]
|
seq_y_mark = self.data_stamp[r_begin:r_end]
|
||||||
|
seq_x = seq_x.astype('float32')
|
||||||
|
seq_y = seq_y.astype('float32')
|
||||||
|
|
||||||
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||||
|
|
||||||
@ -169,15 +175,22 @@ class Dataset_ETT_minute(Dataset):
|
|||||||
self.data_stamp = data_stamp
|
self.data_stamp = data_stamp
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
# 1. 定义输入序列 seq_x 的起止位置
|
||||||
s_begin = index
|
s_begin = index
|
||||||
s_end = s_begin + self.seq_len
|
s_end = s_begin + self.seq_len
|
||||||
r_begin = s_end - self.label_len
|
# 2. 定义目标序列 seq_y 的起止位置
|
||||||
r_end = r_begin + self.label_len + self.pred_len
|
# seq_y 的开始 (r_begin) 就是 seq_x 的结束 (s_end)
|
||||||
|
r_begin = s_end
|
||||||
|
# seq_y 的结束 (r_end) 是其开始位置加上预测长度 (pred_len)
|
||||||
|
r_end = r_begin + self.pred_len
|
||||||
|
# 3. 根据起止位置切片数据
|
||||||
seq_x = self.data_x[s_begin:s_end]
|
seq_x = self.data_x[s_begin:s_end]
|
||||||
seq_y = self.data_y[r_begin:r_end]
|
seq_y = self.data_y[r_begin:r_end]
|
||||||
seq_x_mark = self.data_stamp[s_begin:s_end]
|
seq_x_mark = self.data_stamp[s_begin:s_end]
|
||||||
seq_y_mark = self.data_stamp[r_begin:r_end]
|
seq_y_mark = self.data_stamp[r_begin:r_end]
|
||||||
|
seq_x = seq_x.astype('float32')
|
||||||
|
seq_y = seq_y.astype('float32')
|
||||||
|
|
||||||
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user