diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b0ac3ed --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.aider* diff --git a/dataflow/__init__.py b/dataflow/__init__.py new file mode 100644 index 0000000..aced101 --- /dev/null +++ b/dataflow/__init__.py @@ -0,0 +1,3 @@ +from .tsf import preprocess_time_series, load_and_split_time_series, process_and_save_time_series + +__all__ = ['preprocess_time_series', 'load_and_split_time_series', 'process_and_save_time_series'] diff --git a/dataflow/tsf.py b/dataflow/tsf.py new file mode 100644 index 0000000..0306ad0 --- /dev/null +++ b/dataflow/tsf.py @@ -0,0 +1,337 @@ +import pandas as pd +import numpy as np +from sklearn.preprocessing import StandardScaler +import joblib +from utils.timefeatures import time_features + + +def preprocess_time_series( + csv_data, + input_len, + pred_len, + slide_step, + train_ratio=0.6, + test_ratio=0.2, + val_ratio=0.2, + selected_columns=None, + date_column='date', + freq='T', +): + """ + Preprocess time series data from CSV for model training, testing and validation. + Applies global Z-score normalization using only training data statistics. + + Args: + csv_data (pd.DataFrame or str): CSV data as DataFrame or path to CSV file + input_len (int): Length of input sequence + pred_len (int): Length of prediction sequence + slide_step (int): Step size for sliding window + train_ratio (float): Ratio of data to use for training (default: 0.6) + test_ratio (float): Ratio of data to use for testing (default: 0.2) + val_ratio (float): Ratio of data to use for validation (default: 0.2) + selected_columns (list): List of column names to use (default: None, uses all) + date_column (str): Name of the date column (default: 'date') + freq (str): Frequency of the time series data (default: 'T' for minutely) + + Returns: + dict: Dictionary containing: + - train_x: Training input sequences + - train_y: Training target sequences + - train_x_mark: Training input time features + - train_y_mark: Training target time features + - test_x: Testing input sequences + - test_y: Testing target sequences + - test_x_mark: Testing input time features + - test_y_mark: Testing target time features + - val_x: Validation input sequences + - val_y: Validation target sequences + - val_x_mark: Validation input time features + - val_y_mark: Validation target time features + - scaler: Fitted StandardScaler object for inverse transformation + """ + # Load data if path to CSV is provided + if isinstance(csv_data, str): + try: + data = pd.read_csv(csv_data) + except FileNotFoundError: + raise FileNotFoundError(f"CSV file not found: {csv_data}") + except Exception as e: + raise Exception(f"Error loading CSV file: {e}") + else: + data = csv_data.copy() + + # Extract time features from date column + if date_column in data.columns: + date_index = pd.to_datetime(data[date_column]) + if isinstance(date_index, pd.Series): + date_index = pd.DatetimeIndex(date_index) + time_stamp = time_features(date_index, freq=freq) + time_stamp = time_stamp.transpose(1, 0) # Shape: (n_samples, n_time_features) + else: + raise ValueError(f"Date column '{date_column}' not found in data") + + # Select columns if specified (excluding date column) + if selected_columns is not None: + data = data[selected_columns] + else: + # Use all columns except the date column + feature_columns = [col for col in data.columns if col != date_column] + data = data[feature_columns] + + # Validate ratios sum to 1 + if abs(train_ratio + test_ratio + val_ratio - 1.0) > 1e-6: + raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + test_ratio + val_ratio}") + + # Calculate split points + total_len = len(data) + train_len = int(total_len * train_ratio) + test_len = int(total_len * test_ratio) + + # Split data into train, test and validation sets + train_data = data.iloc[:train_len].values + test_data = data.iloc[train_len:train_len + test_len].values + val_data = data.iloc[train_len + test_len:].values + + # Split time features correspondingly + train_time_stamp = time_stamp[:train_len] + test_time_stamp = time_stamp[train_len:train_len + test_len] + val_time_stamp = time_stamp[train_len + test_len:] + + # Global Z-Score normalization using only training data statistics + scaler = StandardScaler() + scaler.fit(train_data) # Fit only on training data to avoid data leakage + + # Apply normalization to all datasets using the same scaler + train_data_scaled = scaler.transform(train_data) + test_data_scaled = scaler.transform(test_data) if len(test_data) > 0 else test_data + val_data_scaled = scaler.transform(val_data) if len(val_data) > 0 else val_data + + # Create sliding windows for training data + train_x, train_y = create_sliding_windows( + train_data_scaled, input_len, pred_len, slide_step + ) + train_x_mark, train_y_mark = create_sliding_windows( + train_time_stamp, input_len, pred_len, slide_step + ) + + # Create sliding windows for testing data + if len(test_data) > 0: + test_x, test_y = create_sliding_windows( + test_data_scaled, input_len, pred_len, slide_step + ) + test_x_mark, test_y_mark = create_sliding_windows( + test_time_stamp, input_len, pred_len, slide_step + ) + else: + test_x, test_y = np.array([]), np.array([]) + test_x_mark, test_y_mark = np.array([]), np.array([]) + + # Create sliding windows for validation data + if len(val_data) > 0: + val_x, val_y = create_sliding_windows( + val_data_scaled, input_len, pred_len, slide_step + ) + val_x_mark, val_y_mark = create_sliding_windows( + val_time_stamp, input_len, pred_len, slide_step + ) + else: + val_x, val_y = np.array([]), np.array([]) + val_x_mark, val_y_mark = np.array([]), np.array([]) + + return { + 'train_x': train_x, + 'train_y': train_y, + 'train_x_mark': train_x_mark, + 'train_y_mark': train_y_mark, + 'test_x': test_x, + 'test_y': test_y, + 'test_x_mark': test_x_mark, + 'test_y_mark': test_y_mark, + 'val_x': val_x, + 'val_y': val_y, + 'val_x_mark': val_x_mark, + 'val_y_mark': val_y_mark, + 'scaler': scaler + } + + +def create_sliding_windows(data, input_len, pred_len, slide_step): + """ + Create sliding windows from time series data. + + Args: + data (np.ndarray): Time series data + input_len (int): Length of input sequence + pred_len (int): Length of prediction sequence + slide_step (int): Step size for sliding window + + Returns: + tuple: (X, y) where X is input sequences and y is target sequences + """ + total_len = input_len + pred_len + X, y = [], [] + + # Start indices for sliding windows + start_indices = range(0, len(data) - total_len + 1, slide_step) + + for start_idx in start_indices: + end_idx = start_idx + total_len + + # Skip if there's not enough data + if end_idx > len(data): + break + + # Get window + window = data[start_idx:end_idx] + + # Split window into input and target + x = window[:input_len] + target = window[input_len:end_idx] + + X.append(x) + y.append(target) + + # Convert to numpy arrays + X = np.array(X) + y = np.array(y) + + return X, y + + +def load_and_split_time_series( + csv_path, + input_len, + pred_len, + slide_step, + train_ratio=0.6, + test_ratio=0.2, + val_ratio=0.2, + selected_columns=None, + date_column='date', + freq='T', +): + """ + Convenience function to load CSV file and preprocess time series data. + + Args: + csv_path (str): Path to CSV file + input_len (int): Length of input sequence + pred_len (int): Length of prediction sequence + slide_step (int): Step size for sliding window + train_ratio (float): Ratio of data to use for training (default: 0.6) + test_ratio (float): Ratio of data to use for testing (default: 0.2) + val_ratio (float): Ratio of data to use for validation (default: 0.2) + selected_columns (list): List of column names to use (default: None, uses all) + date_column (str): Name of the date column (default: 'date') + freq (str): Frequency of the time series data (default: 'T' for minutely) + + Returns: + dict: Dictionary containing processed data including time features + """ + return preprocess_time_series( + csv_path, + input_len, + pred_len, + slide_step, + train_ratio, + test_ratio, + val_ratio, + selected_columns, + date_column, + freq + ) + + +def process_and_save_time_series( + csv_path, + output_file, + input_len, + pred_len, + slide_step, + train_ratio=0.6, + test_ratio=0.2, + val_ratio=0.2, + selected_columns=None, + date_column='date', + freq='T', +): + """ + Process time series data and save it as an NPZ file along with the fitted scaler. + + Args: + csv_path (str): Path to CSV file + output_file (str): Path to output NPZ file + input_len (int): Length of input sequence + pred_len (int): Length of prediction sequence + slide_step (int): Step size for sliding window + train_ratio (float): Ratio of data to use for training (default: 0.6) + test_ratio (float): Ratio of data to use for testing (default: 0.2) + val_ratio (float): Ratio of data to use for validation (default: 0.2) + selected_columns (list): List of column names to use (default: None, uses all) + date_column (str): Name of the date column (default: 'date') + freq (str): Frequency of the time series data (default: 'T' for minutely) + + Returns: + dict: Dictionary containing processed data including time features + """ + import os + import numpy as np + + # Create output directory if it doesn't exist + output_dir = os.path.dirname(os.path.abspath(output_file)) + os.makedirs(output_dir, exist_ok=True) + + # Load and preprocess the time series data + result = load_and_split_time_series( + csv_path=csv_path, + input_len=input_len, + pred_len=pred_len, + slide_step=slide_step, + train_ratio=train_ratio, + test_ratio=test_ratio, + val_ratio=val_ratio, + selected_columns=selected_columns, + date_column=date_column, + freq=freq + ) + + # Extract the processed data + train_x = result['train_x'] + train_y = result['train_y'] + train_x_mark = result['train_x_mark'] + train_y_mark = result['train_y_mark'] + test_x = result['test_x'] + test_y = result['test_y'] + test_x_mark = result['test_x_mark'] + test_y_mark = result['test_y_mark'] + val_x = result['val_x'] + val_y = result['val_y'] + val_x_mark = result['val_x_mark'] + val_y_mark = result['val_y_mark'] + scaler = result['scaler'] + + # Save the scaler object + scaler_file = output_file.replace('.npz', '_scaler.gz') + joblib.dump(scaler, scaler_file) + print(f"Saved scaler to {scaler_file}") + + # Save the processed data as .npz file + np.savez( + output_file, + train_x=train_x, + train_y=train_y, + train_x_mark=train_x_mark, + train_y_mark=train_y_mark, + test_x=test_x, + test_y=test_y, + test_x_mark=test_x_mark, + test_y_mark=test_y_mark, + val_x=val_x, + val_y=val_y, + val_x_mark=val_x_mark, + val_y_mark=val_y_mark + ) + + print(f"Saved processed data to {output_file}") + + return result diff --git a/layers/Conv_Blocks.py b/layers/Conv_Blocks.py new file mode 100644 index 0000000..8eddfa3 --- /dev/null +++ b/layers/Conv_Blocks.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + + +class Inception_Block_V1(nn.Module): + def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): + super(Inception_Block_V1, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + kernels = [] + for i in range(self.num_kernels): + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + res_list = [] + for i in range(self.num_kernels): + res_list.append(self.kernels[i](x)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res + + +class Inception_Block_V2(nn.Module): + def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): + super(Inception_Block_V2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_kernels = num_kernels + kernels = [] + for i in range(self.num_kernels // 2): + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) + kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) + self.kernels = nn.ModuleList(kernels) + if init_weight: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + res_list = [] + for i in range(self.num_kernels // 2 * 2 + 1): + res_list.append(self.kernels[i](x)) + res = torch.stack(res_list, dim=-1).mean(-1) + return res diff --git a/layers/Embed.py b/layers/Embed.py new file mode 100644 index 0000000..977e255 --- /dev/null +++ b/layers/Embed.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + minute_x = self.minute_embed(x[:, :, 4]) if hasattr( + self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h': 4, 't': 5, 's': 6, + 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = self.value_embedding( + x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + return self.dropout(x) + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_inverted, self).__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars diff --git a/models/TimeMixer++/Autoformer_EncDec.py b/models/TimeMixer++/Autoformer_EncDec.py new file mode 100644 index 0000000..6fce4bc --- /dev/null +++ b/models/TimeMixer++/Autoformer_EncDec.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class my_Layernorm(nn.Module): + """ + Special designed layernorm for the seasonal part + """ + + def __init__(self, channels): + super(my_Layernorm, self).__init__() + self.layernorm = nn.LayerNorm(channels) + + def forward(self, x): + x_hat = self.layernorm(x) + bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) + return x_hat - bias + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class series_decomp_multi(nn.Module): + """ + Multiple Series decomposition block from FEDformer + """ + + def __init__(self, kernel_size): + super(series_decomp_multi, self).__init__() + self.kernel_size = kernel_size + self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] + + def forward(self, x): + moving_mean = [] + res = [] + for func in self.series_decomp: + sea, moving_avg = func(x) + moving_mean.append(moving_avg) + res.append(sea) + + sea = sum(res) / len(res) + moving_mean = sum(moving_mean) / len(moving_mean) + return sea, moving_mean + + +class EncoderLayer(nn.Module): + """ + Autoformer encoder layer with the progressive decomposition architecture + """ + + def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + x, _ = self.decomp1(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + res, _ = self.decomp2(x + y) + return res, attn + + +class Encoder(nn.Module): + """ + Autoformer encoder + """ + + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + """ + Autoformer decoder layer with the progressive decomposition architecture + """ + + def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, + moving_avg=25, dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.decomp3 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, + padding_mode='circular', bias=False) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask + )[0]) + x, trend1 = self.decomp1(x) + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask + )[0]) + x, trend2 = self.decomp2(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + x, trend3 = self.decomp3(x + y) + + residual_trend = trend1 + trend2 + trend3 + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) + return x, residual_trend + + +class Decoder(nn.Module): + """ + Autoformer encoder + """ + + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend diff --git a/models/TimeMixer++/Embed.py b/models/TimeMixer++/Embed.py new file mode 100644 index 0000000..31150c8 --- /dev/null +++ b/models/TimeMixer++/Embed.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + minute_x = self.minute_embed(x[:, :, 4]) if hasattr( + self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h': 4, 't': 5, 's': 6, 'ms': 7, + 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + self.c_in = c_in + self.d_model = d_model + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + _, _, N = x.size() + if N == self.c_in: + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = self.value_embedding( + x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + elif N == self.d_model: + if x_mark is None: + x = x + self.position_embedding(x) + else: + x = x + self.temporal_embedding(x_mark) + self.position_embedding(x) + + return self.dropout(x) + + +class DataEmbedding_ms(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_ms, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=1, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + B, T, N = x.shape + x1 = self.value_embedding(x.reshape(0, 2, 1).reshape(B * N, T).unsqueeze(-1)).reshape(B, N, T, -1).permute(0, 2, + 1, 3) + if x_mark is None: + x = x1 + else: + x = x1 + self.temporal_embedding(x_mark) + return self.dropout(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + if x is None and x_mark is not None: + return self.temporal_embedding(x_mark) + if x_mark is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + + +class PatchEmbedding_crossformer(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout): + super(PatchEmbedding_crossformer, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars + + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, dropout): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = TokenEmbedding(patch_len, d_model) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars diff --git a/models/TimeMixer++/StandardNorm.py b/models/TimeMixer++/StandardNorm.py new file mode 100644 index 0000000..c1c9269 --- /dev/null +++ b/models/TimeMixer++/StandardNorm.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn + +class Normalize(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(Normalize, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + self.non_norm = non_norm + if self.affine: + self._init_params() + + def forward(self, x, mode: str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim - 1)) + if self.subtract_last: + self.last = x[:, -1, :].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.non_norm: + return x + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.non_norm: + return x + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x diff --git a/models/TimeMixer++/TimeMixer.py b/models/TimeMixer++/TimeMixer.py new file mode 100644 index 0000000..c127168 --- /dev/null +++ b/models/TimeMixer++/TimeMixer.py @@ -0,0 +1,527 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.Autoformer_EncDec import series_decomp +from layers.Embed import DataEmbedding_wo_pos +from layers.StandardNorm import Normalize + +class DFT_series_decomp(nn.Module): + """ + Series decomposition block + """ + + def __init__(self, top_k=5): + super(DFT_series_decomp, self).__init__() + self.top_k = top_k + + def forward(self, x): + xf = torch.fft.rfft(x) + freq = abs(xf) + freq[0] = 0 + top_k_freq, top_list = torch.topk(freq, self.top_k) + xf[freq <= top_k_freq.min()] = 0 + x_season = torch.fft.irfft(xf) + x_trend = x - x_season + return x_season, x_trend + + +class MultiScaleSeasonMixing(nn.Module): + """ + Bottom-up mixing season pattern + """ + + def __init__(self, configs): + super(MultiScaleSeasonMixing, self).__init__() + + self.down_sampling_layers = torch.nn.ModuleList( + [ + nn.Sequential( + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** i), + configs.seq_len // (configs.down_sampling_window ** (i + 1)), + ), + nn.GELU(), + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** (i + 1)), + configs.seq_len // (configs.down_sampling_window ** (i + 1)), + ), + + ) + for i in range(configs.down_sampling_layers) + ] + ) + + def forward(self, season_list): + + # mixing high->low + out_high = season_list[0] + out_low = season_list[1] + out_season_list = [out_high.permute(0, 2, 1)] + + for i in range(len(season_list) - 1): + out_low_res = self.down_sampling_layers[i](out_high) + out_low = out_low + out_low_res + out_high = out_low + if i + 2 <= len(season_list) - 1: + out_low = season_list[i + 2] + out_season_list.append(out_high.permute(0, 2, 1)) + + return out_season_list + + +class MultiScaleTrendMixing(nn.Module): + """ + Top-down mixing trend pattern + """ + + def __init__(self, configs): + super(MultiScaleTrendMixing, self).__init__() + + self.up_sampling_layers = torch.nn.ModuleList( + [ + nn.Sequential( + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** (i + 1)), + configs.seq_len // (configs.down_sampling_window ** i), + ), + nn.GELU(), + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** i), + configs.seq_len // (configs.down_sampling_window ** i), + ), + ) + for i in reversed(range(configs.down_sampling_layers)) + ]) + + def forward(self, trend_list): + + # mixing low->high + trend_list_reverse = trend_list.copy() + trend_list_reverse.reverse() + out_low = trend_list_reverse[0] + out_high = trend_list_reverse[1] + out_trend_list = [out_low.permute(0, 2, 1)] + + for i in range(len(trend_list_reverse) - 1): + out_high_res = self.up_sampling_layers[i](out_low) + out_high = out_high + out_high_res + out_low = out_high + if i + 2 <= len(trend_list_reverse) - 1: + out_high = trend_list_reverse[i + 2] + out_trend_list.append(out_low.permute(0, 2, 1)) + + out_trend_list.reverse() + return out_trend_list + + +class PastDecomposableMixing(nn.Module): + def __init__(self, configs): + super(PastDecomposableMixing, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.down_sampling_window = configs.down_sampling_window + + self.layer_norm = nn.LayerNorm(configs.d_model) + self.dropout = nn.Dropout(configs.dropout) + self.channel_independence = configs.channel_independence + + if configs.decomp_method == 'moving_avg': + self.decompsition = series_decomp(configs.moving_avg) + elif configs.decomp_method == "dft_decomp": + self.decompsition = DFT_series_decomp(configs.top_k) + else: + raise ValueError('decompsition is error') + + if configs.channel_independence == 0: + self.cross_layer = nn.Sequential( + nn.Linear(in_features=configs.d_model, out_features=configs.d_ff), + nn.GELU(), + nn.Linear(in_features=configs.d_ff, out_features=configs.d_model), + ) + + # Mixing season + self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs) + + # Mxing trend + self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs) + + self.out_cross_layer = nn.Sequential( + nn.Linear(in_features=configs.d_model, out_features=configs.d_ff), + nn.GELU(), + nn.Linear(in_features=configs.d_ff, out_features=configs.d_model), + ) + + def forward(self, x_list): + length_list = [] + for x in x_list: + _, T, _ = x.size() + length_list.append(T) + + # Decompose to obtain the season and trend + season_list = [] + trend_list = [] + for x in x_list: + season, trend = self.decompsition(x) + if self.channel_independence == 0: + season = self.cross_layer(season) + trend = self.cross_layer(trend) + season_list.append(season.permute(0, 2, 1)) + trend_list.append(trend.permute(0, 2, 1)) + + # bottom-up season mixing + out_season_list = self.mixing_multi_scale_season(season_list) + # top-down trend mixing + out_trend_list = self.mixing_multi_scale_trend(trend_list) + + out_list = [] + for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list, + length_list): + out = out_season + out_trend + if self.channel_independence: + out = ori + self.out_cross_layer(out) + out_list.append(out[:, :length, :]) + return out_list + + +class TimeMixer(nn.Module): + + def __init__(self, configs): + super(TimeMixer, self).__init__() + self.configs = configs + self.task_name = configs.task_name + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + self.down_sampling_window = configs.down_sampling_window + self.channel_independence = configs.channel_independence + self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs) + for _ in range(configs.e_layers)]) + + self.preprocess = series_decomp(configs.moving_avg) + self.enc_in = configs.enc_in + self.use_future_temporal_feature = configs.use_future_temporal_feature + + if self.channel_independence == 1: + self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq, + configs.dropout) + else: + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + self.layer = configs.e_layers + + self.normalize_layers = torch.nn.ModuleList( + [ + Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False) + for i in range(configs.down_sampling_layers + 1) + ] + ) + + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + self.predict_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** i), + configs.pred_len, + ) + for i in range(configs.down_sampling_layers + 1) + ] + ) + + if self.channel_independence == 1: + self.projection_layer = nn.Linear( + configs.d_model, 1, bias=True) + else: + self.projection_layer = nn.Linear( + configs.d_model, configs.c_out, bias=True) + + self.out_res_layers = torch.nn.ModuleList([ + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** i), + configs.seq_len // (configs.down_sampling_window ** i), + ) + for i in range(configs.down_sampling_layers + 1) + ]) + + self.regression_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + configs.seq_len // (configs.down_sampling_window ** i), + configs.pred_len, + ) + for i in range(configs.down_sampling_layers + 1) + ] + ) + if self.task_name == 'imputation' or self.task_name == 'anomaly_detection': + if self.channel_independence == 1: + self.projection_layer = nn.Linear( + configs.d_model, 1, bias=True) + else: + self.projection_layer = nn.Linear( + configs.d_model, configs.c_out, bias=True) + if self.task_name == 'classification': + self.act = F.gelu + self.dropout = nn.Dropout(configs.dropout) + self.projection = nn.Linear( + configs.d_model * configs.seq_len, configs.num_class) + + def out_projection(self, dec_out, i, out_res): + dec_out = self.projection_layer(dec_out) + out_res = out_res.permute(0, 2, 1) + out_res = self.out_res_layers[i](out_res) + out_res = self.regression_layers[i](out_res).permute(0, 2, 1) + dec_out = dec_out + out_res + return dec_out + + def pre_enc(self, x_list): + if self.channel_independence == 1: + return (x_list, None) + else: + out1_list = [] + out2_list = [] + for x in x_list: + x_1, x_2 = self.preprocess(x) + out1_list.append(x_1) + out2_list.append(x_2) + return (out1_list, out2_list) + + def __multi_scale_process_inputs(self, x_enc, x_mark_enc): + if self.configs.down_sampling_method == 'max': + down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False) + elif self.configs.down_sampling_method == 'avg': + down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window) + elif self.configs.down_sampling_method == 'conv': + padding = 1 if torch.__version__ >= '1.5.0' else 2 + down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in, + kernel_size=3, padding=padding, + stride=self.configs.down_sampling_window, + padding_mode='circular', + bias=False) + else: + return x_enc, x_mark_enc + # B,T,C -> B,C,T + x_enc = x_enc.permute(0, 2, 1) + + x_enc_ori = x_enc + x_mark_enc_mark_ori = x_mark_enc + + x_enc_sampling_list = [] + x_mark_sampling_list = [] + x_enc_sampling_list.append(x_enc.permute(0, 2, 1)) + x_mark_sampling_list.append(x_mark_enc) + + for i in range(self.configs.down_sampling_layers): + x_enc_sampling = down_pool(x_enc_ori) + + x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1)) + x_enc_ori = x_enc_sampling + + if x_mark_enc_mark_ori is not None: + x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]) + x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :] + + x_enc = x_enc_sampling_list + if x_mark_enc_mark_ori is not None: + x_mark_enc = x_mark_sampling_list + else: + x_mark_enc = x_mark_enc + + return x_enc, x_mark_enc + + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + + if self.use_future_temporal_feature: + if self.channel_independence == 1: + B, T, N = x_enc.size() + x_mark_dec = x_mark_dec.repeat(N, 1, 1) + self.x_mark_dec = self.enc_embedding(None, x_mark_dec) + else: + self.x_mark_dec = self.enc_embedding(None, x_mark_dec) + + x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) + + x_list = [] + x_mark_list = [] + if x_mark_enc is not None: + for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): + B, T, N = x.size() + x = self.normalize_layers[i](x, 'norm') + if self.channel_independence == 1: + x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) + x_mark = x_mark.repeat(N, 1, 1) + x_list.append(x) + x_mark_list.append(x_mark) + else: + for i, x in zip(range(len(x_enc)), x_enc, ): + B, T, N = x.size() + x = self.normalize_layers[i](x, 'norm') + if self.channel_independence == 1: + x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) + x_list.append(x) + + # embedding + enc_out_list = [] + x_list = self.pre_enc(x_list) + if x_mark_enc is not None: + for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list): + enc_out = self.enc_embedding(x, x_mark) # [B,T,C] + enc_out_list.append(enc_out) + else: + for i, x in zip(range(len(x_list[0])), x_list[0]): + enc_out = self.enc_embedding(x, None) # [B,T,C] + enc_out_list.append(enc_out) + + # Past Decomposable Mixing as encoder for past + for i in range(self.layer): + enc_out_list = self.pdm_blocks[i](enc_out_list) + + # Future Multipredictor Mixing as decoder for future + dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list) + + dec_out = torch.stack(dec_out_list, dim=-1).sum(-1) + dec_out = self.normalize_layers[0](dec_out, 'denorm') + return dec_out + + def future_multi_mixing(self, B, enc_out_list, x_list): + dec_out_list = [] + if self.channel_independence == 1: + x_list = x_list[0] + for i, enc_out in zip(range(len(x_list)), enc_out_list): + dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( + 0, 2, 1) # align temporal dimension + if self.use_future_temporal_feature: + dec_out = dec_out + self.x_mark_dec + dec_out = self.projection_layer(dec_out) + else: + dec_out = self.projection_layer(dec_out) + dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous() + dec_out_list.append(dec_out) + + else: + for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]): + dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( + 0, 2, 1) # align temporal dimension + dec_out = self.out_projection(dec_out, i, out_res) + dec_out_list.append(dec_out) + + return dec_out_list + + def classification(self, x_enc, x_mark_enc): + x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) + x_list = x_enc + + # embedding + enc_out_list = [] + for x in x_list: + enc_out = self.enc_embedding(x, None) # [B,T,C] + enc_out_list.append(enc_out) + + # MultiScale-CrissCrossAttention as encoder for past + for i in range(self.layer): + enc_out_list = self.pdm_blocks[i](enc_out_list) + + enc_out = enc_out_list[0] + # Output + # the output transformer encoder/decoder embeddings don't include non-linearity + output = self.act(enc_out) + output = self.dropout(output) + # zero-out padding embeddings + output = output * x_mark_enc.unsqueeze(-1) + # (batch_size, seq_length * d_model) + output = output.reshape(output.shape[0], -1) + output = self.projection(output) # (batch_size, num_classes) + return output + + def anomaly_detection(self, x_enc): + B, T, N = x_enc.size() + x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) + + x_list = [] + + for i, x in zip(range(len(x_enc)), x_enc, ): + B, T, N = x.size() + x = self.normalize_layers[i](x, 'norm') + if self.channel_independence == 1: + x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) + x_list.append(x) + + # embedding + enc_out_list = [] + for x in x_list: + enc_out = self.enc_embedding(x, None) # [B,T,C] + enc_out_list.append(enc_out) + + # MultiScale-CrissCrossAttention as encoder for past + for i in range(self.layer): + enc_out_list = self.pdm_blocks[i](enc_out_list) + + dec_out = self.projection_layer(enc_out_list[0]) + dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous() + + dec_out = self.normalize_layers[0](dec_out, 'denorm') + return dec_out + + def imputation(self, x_enc, x_mark_enc, mask): + means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) + means = means.unsqueeze(1).detach() + x_enc = x_enc - means + x_enc = x_enc.masked_fill(mask == 0, 0) + stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / + torch.sum(mask == 1, dim=1) + 1e-5) + stdev = stdev.unsqueeze(1).detach() + x_enc /= stdev + + B, T, N = x_enc.size() + x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) + + x_list = [] + x_mark_list = [] + if x_mark_enc is not None: + for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): + B, T, N = x.size() + if self.channel_independence == 1: + x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) + x_list.append(x) + x_mark = x_mark.repeat(N, 1, 1) + x_mark_list.append(x_mark) + else: + for i, x in zip(range(len(x_enc)), x_enc, ): + B, T, N = x.size() + if self.channel_independence == 1: + x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) + x_list.append(x) + + # embedding + enc_out_list = [] + for x in x_list: + enc_out = self.enc_embedding(x, None) # [B,T,C] + enc_out_list.append(enc_out) + + # MultiScale-CrissCrossAttention as encoder for past + for i in range(self.layer): + enc_out_list = self.pdm_blocks[i](enc_out_list) + + dec_out = self.projection_layer(enc_out_list[0]) + dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous() + + dec_out = dec_out * \ + (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) + dec_out = dec_out + \ + (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) + return dec_out + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out + if self.task_name == 'imputation': + dec_out = self.imputation(x_enc, x_mark_enc, mask) + return dec_out # [B, L, D] + if self.task_name == 'anomaly_detection': + dec_out = self.anomaly_detection(x_enc) + return dec_out # [B, L, D] + if self.task_name == 'classification': + dec_out = self.classification(x_enc, x_mark_enc) + return dec_out # [B, N] + else: + raise ValueError('Other tasks implemented yet') diff --git a/models/TimeMixer++/__init__.py b/models/TimeMixer++/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/TimesNet/TimesNet.py b/models/TimesNet/TimesNet.py new file mode 100644 index 0000000..ef849c5 --- /dev/null +++ b/models/TimesNet/TimesNet.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.fft +from layers.Embed import DataEmbedding +from layers.Conv_Blocks import Inception_Block_V1 + + +def FFT_for_Period(x, k=2): + # [B, T, C] + xf = torch.fft.rfft(x, dim=1) + # find period by amplitudes + frequency_list = abs(xf).mean(0).mean(-1) + frequency_list[0] = 0 + _, top_list = torch.topk(frequency_list, k) + top_list = top_list.detach().cpu().numpy() + period = x.shape[1] // top_list + return period, abs(xf).mean(-1)[:, top_list] + + +class TimesBlock(nn.Module): + def __init__(self, configs): + super(TimesBlock, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.k = configs.top_k + # parameter-efficient design + self.conv = nn.Sequential( + Inception_Block_V1(configs.d_model, configs.d_ff, + num_kernels=configs.num_kernels), + nn.GELU(), + Inception_Block_V1(configs.d_ff, configs.d_model, + num_kernels=configs.num_kernels) + ) + + def forward(self, x): + B, T, N = x.size() + period_list, period_weight = FFT_for_Period(x, self.k) + + res = [] + for i in range(self.k): + period = period_list[i] + # padding + if (self.seq_len + self.pred_len) % period != 0: + length = ( + ((self.seq_len + self.pred_len) // period) + 1) * period + padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device) + out = torch.cat([x, padding], dim=1) + else: + length = (self.seq_len + self.pred_len) + out = x + # reshape + out = out.reshape(B, length // period, period, + N).permute(0, 3, 1, 2).contiguous() + # 2D conv: from 1d Variation to 2d Variation + out = self.conv(out) + # reshape back + out = out.permute(0, 2, 3, 1).reshape(B, -1, N) + res.append(out[:, :(self.seq_len + self.pred_len), :]) + res = torch.stack(res, dim=-1) + # adaptive aggregation + period_weight = F.softmax(period_weight, dim=1) + period_weight = period_weight.unsqueeze( + 1).unsqueeze(1).repeat(1, T, N, 1) + res = torch.sum(res * period_weight, -1) + # residual connection + res = res + x + return res + + +class Model(nn.Module): + """ + Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.configs = configs + self.task_name = configs.task_name + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + self.model = nn.ModuleList([TimesBlock(configs) + for _ in range(configs.e_layers)]) + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.layer = configs.e_layers + self.layer_norm = nn.LayerNorm(configs.d_model) + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + self.predict_linear = nn.Linear( + self.seq_len, self.pred_len + self.seq_len) + self.projection = nn.Linear( + configs.d_model, configs.c_out, bias=True) + if self.task_name == 'imputation' or self.task_name == 'anomaly_detection': + self.projection = nn.Linear( + configs.d_model, configs.c_out, bias=True) + if self.task_name == 'classification': + self.act = F.gelu + self.dropout = nn.Dropout(configs.dropout) + self.projection = nn.Linear( + configs.d_model * configs.seq_len, configs.num_class) + + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc.sub(means) + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc = x_enc.div(stdev) + + # embedding + enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] + enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute( + 0, 2, 1) # align temporal dimension + + # TimesNet + for i in range(self.layer): + enc_out = self.layer_norm(self.model[i](enc_out)) + # project back + dec_out = self.projection(enc_out) + + # De-Normalization from Non-stationary Transformer + dec_out = dec_out.mul( + (stdev[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + dec_out = dec_out.add( + (means[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + return dec_out + + def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): + # Normalization from Non-stationary Transformer + means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) + means = means.unsqueeze(1).detach() + x_enc = x_enc.sub(means) + x_enc = x_enc.masked_fill(mask == 0, 0) + stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / + torch.sum(mask == 1, dim=1) + 1e-5) + stdev = stdev.unsqueeze(1).detach() + x_enc = x_enc.div(stdev) + + # embedding + enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] + # TimesNet + for i in range(self.layer): + enc_out = self.layer_norm(self.model[i](enc_out)) + # project back + dec_out = self.projection(enc_out) + + # De-Normalization from Non-stationary Transformer + dec_out = dec_out.mul( + (stdev[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + dec_out = dec_out.add( + (means[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + return dec_out + + def anomaly_detection(self, x_enc): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc.sub(means) + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc = x_enc.div(stdev) + + # embedding + enc_out = self.enc_embedding(x_enc, None) # [B,T,C] + # TimesNet + for i in range(self.layer): + enc_out = self.layer_norm(self.model[i](enc_out)) + # project back + dec_out = self.projection(enc_out) + + # De-Normalization from Non-stationary Transformer + dec_out = dec_out.mul( + (stdev[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + dec_out = dec_out.add( + (means[:, 0, :].unsqueeze(1).repeat( + 1, self.pred_len + self.seq_len, 1))) + return dec_out + + def classification(self, x_enc, x_mark_enc): + # embedding + enc_out = self.enc_embedding(x_enc, None) # [B,T,C] + # TimesNet + for i in range(self.layer): + enc_out = self.layer_norm(self.model[i](enc_out)) + + # Output + # the output transformer encoder/decoder embeddings don't include non-linearity + output = self.act(enc_out) + output = self.dropout(output) + # zero-out padding embeddings + output = output * x_mark_enc.unsqueeze(-1) + # (batch_size, seq_length * d_model) + output = output.reshape(output.shape[0], -1) + output = self.projection(output) # (batch_size, num_classes) + return output + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': + dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out[:, -self.pred_len:, :] # [B, L, D] + if self.task_name == 'imputation': + dec_out = self.imputation( + x_enc, x_mark_enc, x_dec, x_mark_dec, mask) + return dec_out # [B, L, D] + if self.task_name == 'anomaly_detection': + dec_out = self.anomaly_detection(x_enc) + return dec_out # [B, L, D] + if self.task_name == 'classification': + dec_out = self.classification(x_enc, x_mark_enc) + return dec_out # [B, N] + return None diff --git a/models/TimesNet/__init__.py b/models/TimesNet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test.py b/test.py new file mode 100644 index 0000000..7d8c7fb --- /dev/null +++ b/test.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Test script for processing ETT datasets with different prediction lengths. +Processes ETTm1.csv and ETTm2.csv with prediction lengths of 96, 192, 336, 720. +""" + +import os +import sys +from dataflow import process_and_save_time_series + +def main(): + # Configuration + datasets = ['ETTm1', 'ETTm2'] + input_len = 96 + pred_lengths = [96, 192, 336, 720] + slide_step = 1 + + # Split ratios (train:test:val = 6:2:2) + train_ratio = 0.6 + test_ratio = 0.2 + val_ratio = 0.2 + + # Base paths + data_dir = 'data/ETT-small' + output_dir = 'processed_data' + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + print("Starting ETT dataset processing...") + print(f"Input length: {input_len}") + print(f"Split ratios - Train: {train_ratio}, Test: {test_ratio}, Val: {val_ratio}") + print("-" * 60) + + # Process each dataset + for dataset in datasets: + csv_path = os.path.join(data_dir, f"{dataset}.csv") + + # Check if CSV file exists + if not os.path.exists(csv_path): + print(f"Warning: {csv_path} not found, skipping...") + continue + + print(f"\nProcessing {dataset}...") + + # Process each prediction length + for pred_len in pred_lengths: + output_file = os.path.join(output_dir, f"{dataset}_input{input_len}_pred{pred_len}.npz") + + print(f" - Prediction length {pred_len} -> {output_file}") + + try: + # Read CSV to get column names and exclude the date column + import pandas as pd + sample_data = pd.read_csv(csv_path) + + # Get all columns except the first one (date column) + feature_columns = sample_data.columns[1:].tolist() + print(f" Features: {feature_columns} (excluding date column)") + + result = process_and_save_time_series( + csv_path=csv_path, + output_file=output_file, + input_len=input_len, + pred_len=pred_len, + slide_step=slide_step, + train_ratio=train_ratio, + test_ratio=test_ratio, + val_ratio=val_ratio, + selected_columns=feature_columns, + date_column='date', + freq='h' + ) + + # Print dataset shapes for verification + print(f" Train: {result['train_x'].shape} -> {result['train_y'].shape}") + print(f" Test: {result['test_x'].shape} -> {result['test_y'].shape}") + print(f" Val: {result['val_x'].shape} -> {result['val_y'].shape}") + print(f" Train time marks: {result['train_x_mark'].shape} -> {result['train_y_mark'].shape}") + print(f" Test time marks: {result['test_x_mark'].shape} -> {result['test_y_mark'].shape}") + print(f" Val time marks: {result['val_x_mark'].shape} -> {result['val_y_mark'].shape}") + + except Exception as e: + print(f" Error processing {dataset} with pred_len {pred_len}: {e}") + continue + + print("\n" + "=" * 60) + print("Processing completed!") + print(f"Output files saved in: {os.path.abspath(output_dir)}") + +if __name__ == "__main__": + main() diff --git a/train/train.py b/train/train.py new file mode 100644 index 0000000..c5067e4 --- /dev/null +++ b/train/train.py @@ -0,0 +1,630 @@ +import os +import time +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +import swanlab +from typing import Dict, Any, Optional, Callable, Union, Tuple + +class EarlyStopping: + """Early stopping to stop training when validation performance doesn't improve.""" + def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): + """ + Args: + patience (int): How long to wait after last improvement. Default: 7 + verbose (bool): If True, prints a message for each improvement. Default: False + delta (float): Minimum change in monitored quantity to qualify as improvement. Default: 0 + path (str): Path for the checkpoint to be saved to. Default: 'checkpoint.pt' + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = float('inf') + self.delta = delta + self.path = path + + def __call__(self, val_loss, model): + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + if self.verbose: + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + """Save model when validation loss decreases.""" + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') + torch.save(model.state_dict(), self.path) + self.val_loss_min = val_loss + +def create_data_loaders(data_path: str, batch_size: int = 32) -> Dict[str, DataLoader]: + """ + Create PyTorch DataLoaders from an NPZ file + + Args: + data_path (str): Path to the NPZ file containing the data + batch_size (int): Batch size for the DataLoaders + + Returns: + Dict[str, DataLoader]: Dictionary with train and val DataLoaders + """ + # Load data from NPZ file + data = np.load(data_path, allow_pickle=True) + train_x = data['train_x'] + train_y = data['train_y'] + val_x = data['val_x'] + val_y = data['val_y'] + + # Load time features if available + train_x_mark = data.get('train_x_mark', None) + train_y_mark = data.get('train_y_mark', None) + val_x_mark = data.get('val_x_mark', None) + val_y_mark = data.get('val_y_mark', None) + + # Convert to PyTorch tensors + train_x = torch.FloatTensor(train_x) + train_y = torch.FloatTensor(train_y) + val_x = torch.FloatTensor(val_x) + val_y = torch.FloatTensor(val_y) + + # Create datasets based on whether time features are available + if train_x_mark is not None: + train_x_mark = torch.FloatTensor(train_x_mark) + train_y_mark = torch.FloatTensor(train_y_mark) + val_x_mark = torch.FloatTensor(val_x_mark) + val_y_mark = torch.FloatTensor(val_y_mark) + + train_dataset = TensorDataset(train_x, train_y, train_x_mark, train_y_mark) + val_dataset = TensorDataset(val_x, val_y, val_x_mark, val_y_mark) + else: + train_dataset = TensorDataset(train_x, train_y) + val_dataset = TensorDataset(val_x, val_y) + + # Create dataloaders + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + + return { + 'train': train_loader, + 'val': val_loader + } + +def train_forecasting_model( + model_constructor: Callable, + data_path: str, + project_name: str, + config: Dict[str, Any], + device: Optional[str] = None, + early_stopping_patience: int = 10, + max_epochs: int = 100, + checkpoint_dir: str = "./checkpoints", + log_interval: int = 10 +) -> Tuple[nn.Module, Dict[str, float]]: + """ + Train a time series forecasting model + + Args: + model_constructor (Callable): Function that constructs and returns the model + data_path (str): Path to the NPZ file containing the processed data + project_name (str): Name of the project for swanlab tracking + config (Dict[str, Any]): Configuration dictionary for the experiment + device (Optional[str]): Device to use for training ('cpu' or 'cuda') + early_stopping_patience (int): Number of epochs to wait before early stopping + max_epochs (int): Maximum number of epochs to train for + checkpoint_dir (str): Directory to save model checkpoints + log_interval (int): How often to log metrics during training + + Returns: + Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics + """ + # Setup device + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Initialize swanlab for experiment tracking + swanlab_run = swanlab.init( + project=project_name, + config=config, + ) + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") + + # Create data loaders + dataloaders = create_data_loaders( + data_path=data_path, + batch_size=config.get('batch_size', 32) + ) + + # Construct the model + model = model_constructor() + model = model.to(device) + + # Define loss function and optimizer + criterion = nn.MSELoss() + optimizer = optim.Adam( + model.parameters(), + lr=config.get('learning_rate', 1e-3), + ) + + # Add learning rate scheduler to halve LR after each epoch + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) + + # Initialize early stopping + early_stopping = EarlyStopping( + patience=early_stopping_patience, + verbose=True, + path=checkpoint_path + ) + + # Training loop + best_val_loss = float('inf') + metrics = {} + + for epoch in range(max_epochs): + print(f"Epoch {epoch+1}/{max_epochs}") + + # Training phase + model.train() + print("1\n") + train_loss = 0.0 + + # 用于记录 log_interval 期间的损失 + interval_loss = 0.0 + start_time = time.time() + + for batch_idx, batch_data in enumerate(dataloaders['train']): + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + # Create decoder input (zeros for forecasting) + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark) + else: + # For simple models without time features + outputs = model(inputs) + + loss = criterion(outputs, targets) + + # Backward pass and optimize + loss.backward() + optimizer.step() + + # Update statistics + train_loss += loss.item() + interval_loss += loss.item() + + if (batch_idx + 1) % log_interval == 0: + print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") + # 计算这一个 interval 的平均损失并记录 + avg_interval_loss = interval_loss / log_interval + swanlab_run.log({"batch_train_loss": avg_interval_loss}) + + # 重置 interval loss 以进行下一次计算 + interval_loss = 0.0 + + avg_train_loss = train_loss / len(dataloaders['train']) + epoch_time = time.time() - start_time + + # Validation phase + model.eval() + val_loss = 0.0 + val_mse = 0.0 + + with torch.no_grad(): + for batch_data in dataloaders['val']: + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark) + else: + # For simple models without time features + outputs = model(inputs) + + # Calculate loss + loss = criterion(outputs, targets) + val_loss += loss.item() + + + avg_val_loss = val_loss / len(dataloaders['val']) + current_lr = optimizer.param_groups[0]['lr'] + + # Log metrics + metrics_dict = { + "train_loss": avg_train_loss, + "val_loss": avg_val_loss, + "learning_rate": current_lr, + "epoch_time": epoch_time + } + + swanlab_run.log(metrics_dict) + + print(f"Epoch {epoch+1}/{max_epochs}, " + f"Train Loss: {avg_train_loss:.4f}, " + f"Val Loss: {avg_val_loss:.4f}, " + f"LR: {current_lr:.6f}, " + f"Time: {epoch_time:.2f}s") + + # Check if we should save the model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + metrics = metrics_dict + + # Early stopping + early_stopping(avg_val_loss, model) + if early_stopping.early_stop: + print("Early stopping triggered") + break + + # Step the learning rate scheduler + scheduler.step() + + # Load the best model + model.load_state_dict(torch.load(checkpoint_path)) + + # Final validation + model.eval() + final_val_loss = 0.0 + final_val_mse = 0.0 + + with torch.no_grad(): + for batch_data in dataloaders['val']: + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark, dec_inp, y_mark) + else: + # For simple models without time features + outputs = model(inputs) + + # Calculate loss + loss = criterion(outputs, targets) + final_val_loss += loss.item() + + + final_val_loss /= len(dataloaders['val']) + + print(f"Final validation loss: {final_val_loss:.4f}") + + # Update metrics with final values + metrics["final_val_loss"] = final_val_loss + + # Finish the swanlab run + swanlab_run.finish() + + return model, metrics + +def train_classification_model( + model_constructor: Callable, + data_path: str, + project_name: str, + config: Dict[str, Any], + device: Optional[str] = None, + early_stopping_patience: int = 10, + max_epochs: int = 100, + checkpoint_dir: str = "./checkpoints", + log_interval: int = 10 +) -> Tuple[nn.Module, Dict[str, float]]: + """ + Train a time series classification model + + Args: + model_constructor (Callable): Function that constructs and returns the model + data_path (str): Path to the NPZ file containing the processed data + project_name (str): Name of the project for swanlab tracking + config (Dict[str, Any]): Configuration dictionary for the experiment + device (Optional[str]): Device to use for training ('cpu' or 'cuda') + early_stopping_patience (int): Number of epochs to wait before early stopping + max_epochs (int): Maximum number of epochs to train for + checkpoint_dir (str): Directory to save model checkpoints + log_interval (int): How often to log metrics during training + + Returns: + Tuple[nn.Module, Dict[str, float]]: Trained model and dictionary of evaluation metrics + """ + # Setup device + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Initialize swanlab for experiment tracking + swanlab_run = swanlab.init( + project=project_name, + config=config, + ) + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, f"{project_name}.pt") + + # Create data loaders + dataloaders = create_data_loaders( + data_path=data_path, + batch_size=config.get('batch_size', 32) + ) + + # Construct the model + model = model_constructor() + model = model.to(device) + + # Define loss function and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam( + model.parameters(), + lr=config.get('learning_rate', 1e-3), + weight_decay=config.get('weight_decay', 1e-4) + ) + + # Add learning rate scheduler to halve LR after each epoch + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) + + # Initialize early stopping + early_stopping = EarlyStopping( + patience=early_stopping_patience, + verbose=True, + path=checkpoint_path + ) + + # Training loop + best_val_loss = float('inf') + metrics = {} + + for epoch in range(max_epochs): + print(f"Epoch {epoch+1}/{max_epochs}") + + # Training phase + model.train() + train_loss = 0.0 + train_correct = 0 + train_total = 0 + start_time = time.time() + + for batch_idx, batch_data in enumerate(dataloaders['train']): + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + # Convert targets to long for classification + targets = targets.long() + + # Zero the parameter gradients + optimizer.zero_grad() + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark, dec_inp, y_mark) + else: + # For simple models without time features + outputs = model(inputs) + + loss = criterion(outputs, targets) + + # Backward pass and optimize + loss.backward() + optimizer.step() + + # Update statistics + train_loss += loss.item() + _, predicted = outputs.max(1) + train_total += targets.size(0) + train_correct += predicted.eq(targets).sum().item() + + if (batch_idx + 1) % log_interval == 0: + print(f"Batch {batch_idx+1}/{len(dataloaders['train'])}, Loss: {loss.item():.4f}") + + avg_train_loss = train_loss / len(dataloaders['train']) + train_accuracy = 100. * train_correct / train_total + epoch_time = time.time() - start_time + + # Validation phase + model.eval() + val_loss = 0.0 + val_correct = 0 + val_total = 0 + + with torch.no_grad(): + for batch_data in dataloaders['val']: + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + targets = targets.long() + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark, dec_inp, y_mark) + else: + # For simple models without time features + outputs = model(inputs) + + # Calculate loss + loss = criterion(outputs, targets) + val_loss += loss.item() + + # Calculate accuracy + _, predicted = outputs.max(1) + val_total += targets.size(0) + val_correct += predicted.eq(targets).sum().item() + + avg_val_loss = val_loss / len(dataloaders['val']) + val_accuracy = 100. * val_correct / val_total + current_lr = optimizer.param_groups[0]['lr'] + + # Log metrics + metrics_dict = { + "train_loss": avg_train_loss, + "val_loss": avg_val_loss, + "val_accuracy": val_accuracy, + "learning_rate": current_lr, + "epoch_time": epoch_time + } + + swanlab_run.log(metrics_dict) + + print(f"Epoch {epoch+1}/{max_epochs}, " + f"Train Loss: {avg_train_loss:.4f}, " + f"Val Loss: {avg_val_loss:.4f}, " + f"Val Accuracy: {val_accuracy:.2f}%, " + f"LR: {current_lr:.6f}, " + f"Time: {epoch_time:.2f}s") + + # Check if we should save the model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + metrics = metrics_dict + + # Early stopping + early_stopping(avg_val_loss, model) + if early_stopping.early_stop: + print("Early stopping triggered") + break + + # Step the learning rate scheduler + scheduler.step() + + # Load the best model + model.load_state_dict(torch.load(checkpoint_path)) + + # Final validation + model.eval() + final_val_loss = 0.0 + final_val_correct = 0 + final_val_total = 0 + + with torch.no_grad(): + for batch_data in dataloaders['val']: + # Handle both cases: with and without time features + if len(batch_data) == 4: # With time features + inputs, targets, x_mark, y_mark = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = x_mark.to(device), y_mark.to(device) + else: # Without time features + inputs, targets = batch_data + inputs, targets = inputs.to(device), targets.to(device) + x_mark, y_mark = None, None + + targets = targets.long() + + # Forward pass - handle both cases + if x_mark is not None: + # For TimesNet model with time features + dec_inp = torch.zeros_like(targets).to(device) + outputs = model(inputs, x_mark, dec_inp, y_mark) + else: + # For simple models without time features + outputs = model(inputs) + + # Calculate loss + loss = criterion(outputs, targets) + final_val_loss += loss.item() + + # Calculate accuracy + _, predicted = outputs.max(1) + final_val_total += targets.size(0) + final_val_correct += predicted.eq(targets).sum().item() + + final_val_loss /= len(dataloaders['val']) + final_val_accuracy = 100. * final_val_correct / final_val_total + + print(f"Final validation loss: {final_val_loss:.4f}") + print(f"Final validation accuracy: {final_val_accuracy:.2f}%") + + # Update metrics with final values + metrics["final_val_loss"] = final_val_loss + metrics["final_val_accuracy"] = final_val_accuracy + + # Finish the swanlab run + swanlab_run.finish() + + return model, metrics + +def main(): + # Example usage + data_path = 'data/train_data.npz' + project_name = 'TimeSeriesForecasting' + config = { + 'learning_rate': 0.001, + 'batch_size': 32, + 'weight_decay': 1e-4 + } + + model_constructor = lambda: nn.Sequential( + nn.Linear(10, 50), + nn.ReLU(), + nn.Linear(50, 1) + ) + + model, metrics = train_forecasting_model( + model_constructor=model_constructor, + data_path=data_path, + project_name=project_name, + config=config + ) + +if __name__ == "__main__": + main() diff --git a/train_test.py b/train_test.py new file mode 100644 index 0000000..0bb57f5 --- /dev/null +++ b/train_test.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Training script for TimesNet model on ETT datasets. +""" + +import os +import math +import argparse +import torch +import torch.nn as nn +from train.train import train_forecasting_model +from models.TimesNet.TimesNet import Model as TimesNet + +class Args: + """Configuration class for TimesNet model parameters.""" + def __init__(self, seq_len, pred_len, enc_in, c_out): + # Model architecture parameters + self.task_name = 'long_term_forecast' + self.seq_len = seq_len + self.label_len = seq_len // 2 # Half of seq_len as label length + self.pred_len = pred_len + self.enc_in = enc_in + self.c_out = c_out + + # TimesNet specific parameters + self.top_k = 5 # k parameter as specified + self.e_layers = 2 # Number of layers as specified + self.d_min = 32 # dmin as specified + self.d_max = 512 # dmax as specified + + # Calculate d_model based on the formula: min{max{2*⌈log C⌉, dmin}, dmax} + log_c = math.ceil(math.log2(enc_in)) if enc_in > 1 else 1 + # self.d_model = min(max(2 * log_c, self.d_min), self.d_max) + self.d_model = 64 + + # Other model parameters + self.d_ff = 64 # Standard transformer ratio + self.num_kernels = 6 # For Inception blocks + self.embed = 'timeF' # Time feature embedding type + self.freq = 'h' # Frequency for time features (minutely for ETT) + self.dropout = 0.1 + + print(f"Model configuration:") + print(f" - Input channels (C): {enc_in}") + print(f" - d_model: {self.d_model} (calculated from 2*⌈log₂({enc_in})⌉ = {2 * log_c})") + print(f" - Sequence length: {seq_len}") + print(f" - Prediction length: {pred_len}") + print(f" - Top-k: {self.top_k}") + print(f" - Layers: {self.e_layers}") + +def create_timesnet_model(args): + """Create TimesNet model with given configuration.""" + def model_constructor(): + return TimesNet(args) + return model_constructor + +def train_single_dataset(data_path, dataset_name, pred_len, args): + """Train TimesNet on a single dataset configuration.""" + + # Update args for current prediction length + args.pred_len = pred_len + + # Create model constructor + model_constructor = create_timesnet_model(args) + + # Training configuration + config = { + 'learning_rate': 1e-5, # LR = 10^-4 as specified + 'batch_size': 32, # BatchSize 32 as specified + 'weight_decay': 1e-4, + 'dataset': dataset_name, + 'pred_len': pred_len, + 'seq_len': args.seq_len, + 'd_model': args.d_model, + 'top_k': args.top_k, + 'e_layers': args.e_layers + } + + # Project name for tracking + project_name = f"TimesNet_{dataset_name}_pred{pred_len}" + + print(f"\n{'='*60}") + print(f"Training {dataset_name} with prediction length {pred_len}") + print(f"Data path: {data_path}") + print(f"{'='*60}") + + # Train the model + try: + model, metrics = train_forecasting_model( + model_constructor=model_constructor, + data_path=data_path, + project_name=project_name, + config=config, + early_stopping_patience=10, + max_epochs=10, # epochs 10 as specified + checkpoint_dir="./checkpoints", + log_interval=50 + ) + + print(f"Training completed for {project_name}") + print(f"Final validation MSE: {metrics.get('final_val_loss', 'N/A'):.6f}") + + return model, metrics + + except Exception as e: + print(f"Error training {project_name}: {e}") + return None, None + +def main(): + parser = argparse.ArgumentParser(description='Train TimesNet on ETT datasets') + parser.add_argument('--data_dir', type=str, default='processed_data', + help='Directory containing processed NPZ files') + parser.add_argument('--datasets', nargs='+', default=['ETTm1', 'ETTm2'], + help='List of datasets to train on') + parser.add_argument('--pred_lengths', nargs='+', type=int, default=[96, 192, 336, 720], + help='List of prediction lengths to train on') + parser.add_argument('--seq_len', type=int, default=96, + help='Input sequence length') + parser.add_argument('--device', type=str, default=None, + help='Device to use for training (cuda/cpu)') + + args = parser.parse_args() + + print("TimesNet Training Script") + print("=" * 50) + print(f"Datasets: {args.datasets}") + print(f"Prediction lengths: {args.pred_lengths}") + print(f"Input sequence length: {args.seq_len}") + print(f"Data directory: {args.data_dir}") + + # Check if data directory exists + if not os.path.exists(args.data_dir): + print(f"Error: Data directory '{args.data_dir}' not found!") + return + + # Set device + if args.device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = args.device + print(f"Using device: {device}") + + # Training results storage + all_results = {} + + # Train on each dataset and prediction length combination + for dataset in args.datasets: + all_results[dataset] = {} + + for pred_len in args.pred_lengths: + # Construct data file path + data_file = f"{dataset}_input{args.seq_len}_pred{pred_len}.npz" + data_path = os.path.join(args.data_dir, data_file) + + # Check if data file exists + if not os.path.exists(data_path): + print(f"Warning: Data file '{data_path}' not found, skipping...") + continue + + # Load data to get input dimensions + import numpy as np + data = np.load(data_path, allow_pickle=True) + enc_in = data['train_x'].shape[-1] # Number of features/channels + print("输入数据通道数:", enc_in) + c_out = enc_in # Output same number of channels + + # Create model configuration + model_args = Args( + seq_len=args.seq_len, + pred_len=pred_len, + enc_in=enc_in, + c_out=c_out + ) + + # Train the model + model, metrics = train_single_dataset( + data_path=data_path, + dataset_name=dataset, + pred_len=pred_len, + args=model_args + ) + + # Store results + all_results[dataset][pred_len] = { + 'model': model, + 'metrics': metrics, + 'data_path': data_path + } + + # Print summary + print("\n" + "=" * 80) + print("TRAINING SUMMARY") + print("=" * 80) + + for dataset in all_results: + print(f"\n{dataset}:") + for pred_len in all_results[dataset]: + result = all_results[dataset][pred_len] + if result['metrics'] is not None: + mse = result['metrics'].get('final_val_mse', 'N/A') + print(f" Pred Length {pred_len}: MSE = {mse}") + else: + print(f" Pred Length {pred_len}: Training failed") + + print(f"\nAll models saved in: ./checkpoints/") + print("Training completed!") + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..377e324 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +# Utils package for time series modeling \ No newline at end of file diff --git a/utils/timefeatures.py b/utils/timefeatures.py new file mode 100644 index 0000000..7604d68 --- /dev/null +++ b/utils/timefeatures.py @@ -0,0 +1,148 @@ +# From: gluonts/src/gluonts/time_feature/_base.py +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Returns a list of time features that will be appropriate for the given frequency string. + Parameters + ---------- + freq_str + Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + + +def time_features(dates, freq='h'): + return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) \ No newline at end of file