From ef307a57e9b7a525b1f6882145b0dd2db9fa279b Mon Sep 17 00:00:00 2001 From: gameloader Date: Thu, 4 Sep 2025 01:32:13 +0000 Subject: [PATCH] feat: add mamba and dynamic chunking related code and test code --- generate_sine_data.py | 102 ++++ layers/DynamicChunking.py | 440 +++++++++++++++ layers/SelfAttention_Family.py | 169 +++--- models/DC_PatchTST.py | 528 ++++++++++++++++++ models/DC_hnet.py | 339 +++++++++++ models/vanillaMamba-Copy1.py | 138 +++++ models/vanillaMamba.py | 203 +++++++ run.py | 6 +- scripts/classification/DC_PatchTST.sh | 142 +++++ .../vanillaMamba_classification.sh | 259 +++++++++ .../xPatch_SparseChannel-Copy1.sh | 145 +++++ .../classification/xPatch_SparseChannel.sh | 249 ++++++++- .../long_term_forecast/vanillaMamba_all.sh | 251 +++++++++ .../xPatch_SparseChannel_PEMS.sh | 150 +++++ .../xPatch_SparseChannel_all-Copy1.sh | 251 +++++++++ .../xPatch_SparseChannel_all.sh | 77 +++ .../short_term_forecast/vanillaMamba_M4.sh | 165 ++++++ .../xPatch_SparseChannel_M4.sh | 165 ++++++ test_DC_hnet.py | 209 +++++++ test_dc_patchtst.py | 313 +++++++++++ train_dc_patchtst.py | 335 +++++++++++ 21 files changed, 4550 insertions(+), 86 deletions(-) create mode 100644 generate_sine_data.py create mode 100644 layers/DynamicChunking.py create mode 100644 models/DC_PatchTST.py create mode 100644 models/DC_hnet.py create mode 100644 models/vanillaMamba-Copy1.py create mode 100644 models/vanillaMamba.py create mode 100755 scripts/classification/DC_PatchTST.sh create mode 100644 scripts/classification/vanillaMamba_classification.sh create mode 100644 scripts/classification/xPatch_SparseChannel-Copy1.sh create mode 100644 scripts/long_term_forecast/vanillaMamba_all.sh create mode 100644 scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh create mode 100644 scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh create mode 100644 scripts/long_term_forecast/xPatch_SparseChannel_all.sh create mode 100644 scripts/short_term_forecast/vanillaMamba_M4.sh create mode 100644 scripts/short_term_forecast/xPatch_SparseChannel_M4.sh create mode 100644 test_DC_hnet.py create mode 100644 test_dc_patchtst.py create mode 100644 train_dc_patchtst.py diff --git a/generate_sine_data.py b/generate_sine_data.py new file mode 100644 index 0000000..5d8a015 --- /dev/null +++ b/generate_sine_data.py @@ -0,0 +1,102 @@ +import numpy as np +import pandas as pd +import os + +def generate_sine_wave_data(n_samples=10000, seq_len=200, n_channels=2, save_path='./data/sine_wave/'): + """ + 生成双通道正弦波时序数据 + + Args: + n_samples: 总样本数 + seq_len: 每个序列长度 + n_channels: 通道数 (固定为2) + save_path: 保存路径 + """ + + if not os.path.exists(save_path): + os.makedirs(save_path) + + # 生成时间轴 + t = np.linspace(0, 4*np.pi, seq_len) + + all_data = [] + + for i in range(n_samples): + # 为每个样本生成不同周期和相位的正弦波 + # 通道1: 随机周期和相位 + freq1 = np.random.uniform(0.5, 3.0) # 频率范围 + phase1 = np.random.uniform(0, 2*np.pi) # 相位 + amplitude1 = np.random.uniform(0.5, 2.0) # 幅度 + + # 通道2: 不同的随机周期和相位 + freq2 = np.random.uniform(0.3, 2.5) + phase2 = np.random.uniform(0, 2*np.pi) + amplitude2 = np.random.uniform(0.8, 1.8) + + # 生成正弦波数据 + channel1 = amplitude1 * np.sin(freq1 * t + phase1) + channel2 = amplitude2 * np.sin(freq2 * t + phase2) + + # 添加少量噪声 + # noise1 = np.random.normal(0, 0.1, seq_len) + # noise2 = np.random.normal(0, 0.1, seq_len) + # + # channel1 += noise1 + # channel2 += noise2 + + # 组合数据: [timestamp, channel1, channel2] + timestamp = np.arange(seq_len) + sample_data = np.column_stack([timestamp, channel1, channel2]) + all_data.append(sample_data) + + # 转换为连续的时间序列格式 + continuous_data = [] + current_time = 0 + + for sample in all_data: + sample[:, 0] = current_time + sample[:, 0] # 调整时间戳 + continuous_data.append(sample) + current_time += seq_len + + # 合并所有数据 + full_data = np.vstack(continuous_data) + + # 创建DataFrame + df = pd.DataFrame(full_data, columns=['timestamp', 'channel1', 'channel2']) + + # 按 8:1:1 比例分割训练、验证、测试集 + total_len = len(df) + train_end = int(0.8 * total_len) + val_end = int(0.9 * total_len) + + train_df = df[:train_end] + val_df = df[train_end:val_end] + test_df = df[val_end:] + + # 保存数据 + train_df.to_csv(os.path.join(save_path, 'train.csv'), index=False) + val_df.to_csv(os.path.join(save_path, 'val.csv'), index=False) + test_df.to_csv(os.path.join(save_path, 'test.csv'), index=False) + + # 保存完整数据 + df.to_csv(os.path.join(save_path, 'sine_wave.csv'), index=False) + + print(f"数据已生成并保存到 {save_path}") + print(f"训练集: {len(train_df)} 条记录") + print(f"验证集: {len(val_df)} 条记录") + print(f"测试集: {len(test_df)} 条记录") + print(f"总计: {len(df)} 条记录,{n_channels} 个通道") + + return df + +if __name__ == "__main__": + # 生成数据 + data = generate_sine_wave_data( + n_samples=200, # 2000个不同的正弦波样本 + seq_len=200, # 每个样本200个时间点 + n_channels=2, # 双通道 + save_path='./data/sine_wave/' + ) + + print("\n数据统计信息:") + print(data.describe()) diff --git a/layers/DynamicChunking.py b/layers/DynamicChunking.py new file mode 100644 index 0000000..8824521 --- /dev/null +++ b/layers/DynamicChunking.py @@ -0,0 +1,440 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import repeat, rearrange + +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + + + + + +@dataclass +class RoutingModuleOutput: + # 路由模块的前向输出: + # - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2) + # - boundary_mask: 基于最大概率得到的硬选择(True=分隔点),形状与输入序列相同的前两维 + # - selected_probs: 对应硬选择类别的概率,形状(..., 1) + boundary_prob: torch.Tensor + boundary_mask: torch.Tensor + selected_probs: torch.Tensor + + +@dataclass +class RoutingModuleState: + """ + 路由模块的推理状态(用于增量/流式step) + + 包含: + - has_seen_tokens: (batch_size,) 是否已经见过任意token(用于首token强制边界) + - last_hidden_state: (batch_size, d_model) 上一次的隐藏状态(用于与当前token做相邻相似度) + """ + + has_seen_tokens: torch.Tensor # (batch_size,) + last_hidden_state: torch.Tensor # (batch_size, d_model) + + +@dataclass +class DeChunkState: + """ + DeChunk 的推理状态(EMA的记忆值) + + 包含: + - last_value: (batch_size, d_model) EMA反聚合的上一时刻值 + """ + + last_value: torch.Tensor # (batch_size, d_model) + + +def get_seq_idx(cu_seqlens, device=None): + seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device) + seq_idx[cu_seqlens[:-1]] = 1 + seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int() + + return seq_idx + +class RoutingModule(nn.Module): + """ + 路由模块: + 用相邻token的余弦相似度构造“成为分隔点”的概率: + p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1) + 并强制首位置为边界(概率=1)。 + 支持: + - 常规batch掩码 mask 模式 + - packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接) + - 流式推理 step()(维护状态) + """ + + def __init__(self, d_model, device=None, dtype=None): + self.d_model = d_model + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + # 相邻相似度计算前的线性投影(初始化为恒等) + self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs) + self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs) + with torch.no_grad(): + self.q_proj_layer.weight.copy_(torch.eye(d_model)) + self.k_proj_layer.weight.copy_(torch.eye(d_model)) + # 防止外部权重再初始化 + self.q_proj_layer.weight._no_reinit = True + self.k_proj_layer.weight._no_reinit = True + + def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): + # 分配推理cache(用于step) + return RoutingModuleState( + has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool), + last_hidden_state=torch.zeros( + batch_size, self.d_model, device=device, dtype=dtype + ), + ) + + def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None): + """ + hidden_states: + - 若 cu_seqlens is None: (B, L, D) + - 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D) + cu_seqlens: packed模式下每条序列的前缀和下标,形如 [0, len1, len1+len2, ...] + mask: (B, L) bool,True=有效(非packed) + inference_params: RoutingModuleState,用于prefill时的校验与状态维护 + """ + assert (mask is not None) or ( + cu_seqlens is not None + ), "Either mask or cu_seqlens must be provided" + + if inference_params is not None: + # prefill阶段必须提供mask,且不允许之前已经见过token + assert ( + mask is not None + ), "Mask must be provided if inference_params is provided" + assert ( + ~inference_params.has_seen_tokens + ).all(), "Cannot have seen tokens when inference_params is not provided" + + if cu_seqlens is not None: + # packed 模式:把 (T, D) 临时变为 (1, T, D) + hidden_states = hidden_states.unsqueeze(0) + + # 计算相邻余弦相似度 cos(h_{t-1}, h_t) + cos_sim = torch.einsum( + "b l d, b l d -> b l", + F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1), + F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1), + ) + # p = ((1 - cos) / 2) ∈ [0,1] + boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) + + # 强制首位置为边界:首位概率=1,补充后长度和输入序列长度想等 + PAD_PROB = 1.0 + boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB) + + if cu_seqlens is not None: + # packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1] + boundary_prob = boundary_prob.squeeze(0) + boundary_prob[cu_seqlens[:-1]] = PAD_PROB + + # 组装为二分类概率 [非边界, 边界] + boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) + + # 取最大概率类别 + selected_idx = torch.argmax(boundary_prob, dim=-1) + + # 硬边界掩码 + boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致 + if mask is not None: + # 不允许选择到无效token + boundary_mask = boundary_mask & mask + + if inference_params is not None: + # 维护路由状态:是否见过token、最后一个有效token的隐藏状态 + has_mask = mask.any(dim=-1) + inference_params.has_seen_tokens.copy_( + has_mask | inference_params.has_seen_tokens + ) + last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0) + inference_params.last_hidden_state.copy_( + torch.where( + has_mask, + hidden_states[ + torch.arange( + hidden_states.shape[0], device=hidden_states.device + ), + last_mask, + ], + inference_params.last_hidden_state, + ) + ) + + # 取硬选择对应的概率(便于可视化/正则) + selected_probs = boundary_prob.gather( + dim=-1, index=selected_idx.unsqueeze(-1) + ) # (..., 1) + + return RoutingModuleOutput( + boundary_prob=boundary_prob, # (..., 2) + boundary_mask=boundary_mask, # (...) + selected_probs=selected_probs, # (..., 1) + ) + + def step(self, hidden_states, inference_params): + """ + 流式单步: + hidden_states: (B, 1, D) + 使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度,得到当前步的边界概率 + """ + # (B, D) + hidden_states = hidden_states.squeeze(1) + cos_sim = torch.einsum( + "b d, b d -> b", + F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1), + F.normalize(self.k_proj_layer(hidden_states), dim=-1), + ) + boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) + # 更新最后隐藏状态 + inference_params.last_hidden_state.copy_(hidden_states) + # 首个token前,强制边界 + boundary_prob = torch.where( + inference_params.has_seen_tokens, + boundary_prob, + torch.ones_like(boundary_prob), + ) + boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) + + # 标记为已见token + inference_params.has_seen_tokens.copy_( + torch.ones_like(inference_params.has_seen_tokens) + ) + return RoutingModuleOutput( + boundary_prob=boundary_prob, # (B, 2) + boundary_mask=boundary_prob[..., 1] > 0.5, # (B,) + selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1) + ) + + +class ChunkLayer(nn.Module): + """ + Chunk层:根据 boundary_mask 将被选中的“边界token”抽取出来,形成下一层序列。 + 支持两种模式: + - packed(cu_seqlens 非 None):直接在拼接后的序列上索引 + - 非packed(mask 非 None):通过排序 trick 把True位置排到前面,并生成 next_mask + 返回: + - next_hidden_states: 选中的token序列(packed: shape=(#selected, D);非packed: (B, M, D)) + - next_cu_seqlens: packed模式下新序列的cu_seqlens;否则None + - next_max_seqlen: packed模式下选中的最大长度;非packed模式返回None + - next_mask: 非packed模式下的右侧pad掩码;packed模式下None + """ + + def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None): + assert (mask is not None) or ( + cu_seqlens is not None + ), "Either mask or cu_seqlens must be provided" + + if cu_seqlens is not None: + # packed:直接选择True的行,得到拼接后的 selected + next_hidden_states = hidden_states[boundary_mask] + # 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数,再前置0 + next_cu_seqlens = F.pad( + boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0) + ) + # 新序列的最大段长(仅用于内核/优化) + next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max()) + next_mask = None + else: + # 非packed:对每个batch内,把True位置排到前面(False放到靠后) + next_cu_seqlens = None + num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量 + next_max_seqlen = int(num_tokens.max()) + + device = hidden_states.device + L = hidden_states.shape[1] + # trick:用 (~boundary_mask)*L 把False加大,从而 argsort 后 True 的下标排在前面 + token_idx = ( + torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L + ) + seq_sorted_indices = torch.argsort(token_idx, dim=1) + + # 收集前 next_max_seqlen 个(不足的样本右侧pad) + next_hidden_states = torch.gather( + hidden_states, + dim=1, + index=seq_sorted_indices[:, :next_max_seqlen, None].expand( + -1, -1, hidden_states.shape[-1] + ), + ) + + # 下游的有效mask(右侧pad无效) + next_mask = ( + torch.arange(next_max_seqlen, device=device)[None, :] + < num_tokens[:, None] + ) + # 非packed模式下,不再需要 max_seqlen(返回None) + next_max_seqlen = None + + return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask + + def step(self, hidden_states, boundary_mask): + # 流式step:仅返回当前步被选中的token(用于下一层) + return hidden_states[boundary_mask] + + +class DeChunkLayer(nn.Module): + """ + DeChunk层:把“被选中的边界token序列”反聚合(EMA)回原始等长序列。 + 实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined: + - 将 d_model 切分为 nheads * headdim + - 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描 + - 最终把扫描输出根据分段索引映射回原位置(plug back) + 支持: + - packed 模式(cu_seqlens) + - 非packed(batch+右侧pad) + - 流式 step(EMA递推) + """ + + def __init__( + self, + d_model, + dtype=torch.bfloat16, + block_size=256, + headdim=32, + ): + super().__init__() + self.d_model = d_model + + # 仅为内核要求:使用 bfloat16,块大小与头维拆分 + self.dtype = dtype + self.block_size = block_size + self.headdim = headdim + assert d_model % self.headdim == 0 + self.nheads = d_model // self.headdim + + def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): + # 分配EMA的last_value缓存 + return DeChunkState( + last_value=torch.zeros( + batch_size, self.d_model, device=device, dtype=dtype + ), + ) + + def forward( + self, + hidden_states, # 被选中的token序列(packed: (M, D);非packed: (B, M, D)) + boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L)) + boundary_prob, # 原序列上的二分类概率((..., 2)) + cu_seqlens=None, + inference_params=None, + mask=None, + ): + """ + 核心思路: + 1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4) + 2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt + 3) 用 mamba_chunk_scan_combined 扫描:A=-1, b=p, c=1,对 (B, M, H, P) 进行块扫描 + 4) 将结果根据 cumulative boundary index 回填到原序列位置 + """ + if inference_params is not None: + # prefill时必须有mask,且首token必须是边界(保证EMA初始化) + assert ( + mask is not None + ), "Mask must be provided if inference_params is provided" + assert boundary_mask[ + :, 0 + ].all(), "First token must be a boundary if running prefill" + + # 取边界概率的“边界类”概率 p,并限制在(1e-4, 1-1e-4)内,避免数值不稳 + p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4)) + + if cu_seqlens is not None: + # packed:从原序列p中取出被选中的位置对应的概率,形状(B=1, M) + p = p[boundary_mask].unsqueeze(0) + # 为triton核准备packed序列的索引映射 + seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device) + else: + B, L = boundary_mask.shape + seq_idx = None + # 与ChunkLayer一致的排序 trick,得到选中的顺序(True在前) + token_idx = ( + torch.arange(L, device=hidden_states.device)[None, :] + + (~boundary_mask).long() * L + ) + seq_sorted_indices = torch.argsort(token_idx, dim=1) + + # 取出与 hidden_states 对应长度的 p((B, M)) + p = torch.gather( + p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]] + ) # (B, M) + + original_dtype = hidden_states.dtype + # 构造 EMA 扫描所需变量 + dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M) + x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1) + + # A, b, c 分别对应一阶状态空间/EMA的参数 + A = -torch.ones( + (self.nheads,), device=hidden_states.device, dtype=torch.float32 + ) + b = p.to(self.dtype) + c = torch.ones_like(b) + + # 调用triton核进行块扫描 + out = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P) + repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H) + A, # (H,) + rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1) + rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1) + chunk_size=self.block_size, + seq_idx=seq_idx, # packed时提供 + ) + out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D) + + # 将扫描结果回填(plug back)到原序列位置 + if cu_seqlens is not None: + out = out.squeeze(0) # (M, D) + plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,) + out = torch.gather( + out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model) + ) # (T, D) + else: + plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L) + out = torch.gather( + out, + dim=1, + index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model), + ) # (B, L, D) + + # 更新流式缓存 + if inference_params is not None: + inference_params.last_value.copy_(out[:, -1]) + + return out.to(original_dtype) + + def step(self, hidden_states, boundary_mask, boundary_prob, inference_params): + """ + 流式单步 EMA 反聚合: + hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量(boundary_mask.sum()) + boundary_mask: (B,) 当前batch哪些位置被选中为边界 + boundary_prob: (B, 2) 当前batch各位置的边界概率 + 输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值 + """ + B = boundary_mask.shape[0] + D = hidden_states.shape[-1] + + # 构造当前步每个位置的 p(未被选中的位置 p=0) + p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype) + p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp( + min=1e-4, max=1 - (1e-4) + ) + + # 构造当前被选中的隐藏状态(未选中为0) + current_hidden_states = torch.zeros( + B, D, device=hidden_states.device, dtype=hidden_states.dtype + ) + current_hidden_states[boundary_mask] = hidden_states.squeeze(1) + + # EMA:result = p * x + (1 - p) * last + result = p * current_hidden_states + (1 - p) * inference_params.last_value + inference_params.last_value.copy_(result) + + return result.unsqueeze(1) diff --git a/layers/SelfAttention_Family.py b/layers/SelfAttention_Family.py index b151bff..240e27b 100644 --- a/layers/SelfAttention_Family.py +++ b/layers/SelfAttention_Family.py @@ -17,25 +17,30 @@ class DSAttention(nn.Module): self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None): + """ + key_padding_mask: (B, S) bool, True=valid, False=pad(可选,忽略或由上层应用) + """ B, L, H, E = queries.shape _, S, _, D = values.shape scale = self.scale or 1. / sqrt(E) - tau = 1.0 if tau is None else tau.unsqueeze( - 1).unsqueeze(1) # B x 1 x 1 x 1 - delta = 0.0 if delta is None else delta.unsqueeze( - 1).unsqueeze(1) # B x 1 x 1 x S + tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 + delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S - # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors - scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta + scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta # (B,H,L,S) if self.mask_flag: if attn_mask is None: attn_mask = TriangularCausalMask(B, L, device=queries.device) - scores.masked_fill_(attn_mask.mask, -np.inf) + # 可选:基于key_padding_mask的无效键屏蔽(不改变原行为,默认None) + if key_padding_mask is not None: + # key_padding_mask: True 表示有效,False为padding + invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S) + scores = scores.masked_fill(invalid_k, -np.inf) + A = self.dropout(torch.softmax(scale * scores, dim=-1)) V = torch.einsum("bhls,bshd->blhd", A, values) @@ -46,6 +51,12 @@ class DSAttention(nn.Module): class FullAttention(nn.Module): + """ + 修正点: + - 新增 key_padding_mask 支持,用于屏蔽批内右侧pad的键向量(与DC变长对齐) + - key_padding_mask 约定:shape=(B, S),True=有效,False=padding + - 其余行为与原实现保持一致 + """ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): super(FullAttention, self).__init__() self.scale = scale @@ -53,21 +64,33 @@ class FullAttention(nn.Module): self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None): + """ + queries: (B, L, H, E) + keys: (B, S, H, E) + values: (B, S, H, D) + attn_mask: TriangularCausalMask 或 None + key_padding_mask: (B, S) bool,True=有效,False=padding(可选) + """ B, L, H, E = queries.shape _, S, _, D = values.shape scale = self.scale or 1. / sqrt(E) - scores = torch.einsum("blhe,bshe->bhls", queries, keys) + scores = torch.einsum("blhe,bshe->bhls", queries, keys) # (B,H,L,S) if self.mask_flag: if attn_mask is None: attn_mask = TriangularCausalMask(B, L, device=queries.device) - scores.masked_fill_(attn_mask.mask, -np.inf) - A = self.dropout(torch.softmax(scale * scores, dim=-1)) - V = torch.einsum("bhls,bshd->blhd", A, values) + # 基于key_padding_mask屏蔽无效键(padding位置不参与注意力) + if key_padding_mask is not None: + # key_padding_mask: True=有效,False=padding + invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S) + scores = scores.masked_fill(invalid_k, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) # (B,H,L,S) + V = torch.einsum("bhls,bshd->blhd", A, values) # (B,L,H,D) if self.output_attention: return V.contiguous(), A @@ -85,100 +108,86 @@ class ProbAttention(nn.Module): self.dropout = nn.Dropout(attention_dropout) def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) - # Q [B, H, L, D] + # Q [B, H, L_q, D], K [B, H, L_k, D] B, H, L_K, E = K.shape _, _, L_Q, _ = Q.shape - # calculate the sampled Q_K K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) - # real U = U_part(factor*ln(L_k))*L_q - index_sample = torch.randint(L_K, (L_Q, sample_k)) - K_sample = K_expand[:, :, torch.arange( - L_Q).unsqueeze(1), index_sample, :] - Q_K_sample = torch.matmul( - Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + index_sample = torch.randint(L_K, (L_Q, sample_k), device=Q.device) + K_sample = K_expand[:, :, torch.arange(L_Q, device=Q.device).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # (B,H,L_Q,sample_k) - # find the Top_k query with sparisty measurement - M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) - M_top = M.topk(n_top, sorted=False)[1] + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) # (B,H,L_Q) + M_top = M.topk(n_top, sorted=False)[1] # indices - # use the reduced Q to calculate Q_K Q_reduce = Q[torch.arange(B)[:, None, None], - torch.arange(H)[None, :, None], - M_top, :] # factor*ln(L_q) - Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k - + torch.arange(H)[None, :, None], + M_top, :] # (B,H,n_top,D) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # (B,H,n_top,L_K) return Q_K, M_top def _get_initial_context(self, V, L_Q): B, H, L_V, D = V.shape if not self.mask_flag: - # V_sum = V.sum(dim=-2) - V_sum = V.mean(dim=-2) - contex = V_sum.unsqueeze(-2).expand(B, H, - L_Q, V_sum.shape[-1]).clone() - else: # use mask - # requires that L_Q == L_V, i.e. for self-attention only - assert (L_Q == L_V) - contex = V.cumsum(dim=-2) - return contex + V_mean = V.mean(dim=-2) # (B,H,D) + context = V_mean.unsqueeze(-2).expand(B, H, L_Q, D).clone() + else: + assert L_Q == L_V + context = V.cumsum(dim=-2) + return context def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): B, H, L_V, D = V.shape - if self.mask_flag: attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) scores.masked_fill_(attn_mask.mask, -np.inf) - - attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + attn = torch.softmax(scores, dim=-1) context_in[torch.arange(B)[:, None, None], - torch.arange(H)[None, :, None], - index, :] = torch.matmul(attn, V).type_as(context_in) + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) if self.output_attention: - attns = (torch.ones([B, H, L_V, L_V]) / - L_V).type_as(attn).to(attn.device) - attns[torch.arange(B)[:, None, None], torch.arange(H)[ - None, :, None], index, :] = attn + attns = (torch.ones([B, H, L_V, L_V], device=attn.device, dtype=attn.dtype) / L_V) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn return context_in, attns else: return context_in, None - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None): + """ + key_padding_mask 目前未集成到 ProbAttention(如需,可在scores处对无效键置 -inf) + """ B, L_Q, H, D = queries.shape _, L_K, _, _ = keys.shape - queries = queries.transpose(2, 1) - keys = keys.transpose(2, 1) - values = values.transpose(2, 1) + queries = queries.transpose(2, 1) # (B,H,L_Q,D) + keys = keys.transpose(2, 1) # (B,H,L_K,D) + values = values.transpose(2, 1) # (B,H,L_K,D) - U_part = self.factor * \ - np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) - u = self.factor * \ - np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + U_part = self.factor * int(np.ceil(np.log(L_K))) + u = self.factor * int(np.ceil(np.log(L_Q))) - U_part = U_part if U_part < L_K else L_K - u = u if u < L_Q else L_Q + U_part = min(U_part, L_K) + u = min(u, L_Q) - scores_top, index = self._prob_QK( - queries, keys, sample_k=U_part, n_top=u) + scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) - # add scale factor scale = self.scale or 1. / sqrt(D) - if scale is not None: - scores_top = scores_top * scale - # get the context + scores_top = scores_top * scale + context = self._get_initial_context(values, L_Q) - # update the context with selected top_k queries - context, attn = self._update_context( - context, values, scores_top, index, L_Q, attn_mask) + context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) return context.contiguous(), attn class AttentionLayer(nn.Module): - def __init__(self, attention, d_model, n_heads, d_keys=None, - d_values=None): + """ + 修正点: + - forward 新增 key_padding_mask 参数,并向 inner_attention 透传 + - 保持与旧调用兼容(不传时默认None) + """ + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): super(AttentionLayer, self).__init__() d_keys = d_keys or (d_model // n_heads) @@ -191,7 +200,10 @@ class AttentionLayer(nn.Module): self.out_projection = nn.Linear(d_values * n_heads, d_model) self.n_heads = n_heads - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None): + """ + key_padding_mask: (B, S) bool, True=有效,False=padding + """ B, L, _ = queries.shape _, S, _ = keys.shape H = self.n_heads @@ -206,10 +218,10 @@ class AttentionLayer(nn.Module): values, attn_mask, tau=tau, - delta=delta + delta=delta, + key_padding_mask=key_padding_mask, ) out = out.view(B, L, -1) - return self.out_projection(out), attn @@ -232,12 +244,11 @@ class ReformerLayer(nn.Module): if N % (self.bucket_size * 2) == 0: return queries else: - # fill the time series fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) - def forward(self, queries, keys, values, attn_mask, tau, delta): - # in Reformer: defalut queries=keys + def forward(self, queries, keys, values, attn_mask, tau, delta, key_padding_mask=None): + # queries=keys in Reformer B, N, C = queries.shape queries = self.attn(self.fit_length(queries))[:, :N, :] return queries, None @@ -275,23 +286,23 @@ class TwoStageAttentionLayer(nn.Module): nn.GELU(), nn.Linear(d_ff, d_model)) - def forward(self, x, attn_mask=None, tau=None, delta=None): + def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None): # Cross Time Stage: Directly apply MSA to each dimension batch = x.shape[0] time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') time_enc, attn = self.time_attention( - time_in, time_in, time_in, attn_mask=None, tau=None, delta=None + time_in, time_in, time_in, attn_mask=None, tau=None, delta=None, key_padding_mask=key_padding_mask ) dim_in = time_in + self.dropout(time_enc) dim_in = self.norm1(dim_in) dim_in = dim_in + self.dropout(self.MLP1(dim_in)) dim_in = self.norm2(dim_in) - # Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection + # Cross Dimension Stage dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch) batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch) - dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) - dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) + dim_buffer, _ = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) + dim_receive, _ = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) dim_enc = dim_send + self.dropout(dim_receive) dim_enc = self.norm3(dim_enc) dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) diff --git a/models/DC_PatchTST.py b/models/DC_PatchTST.py new file mode 100644 index 0000000..6330312 --- /dev/null +++ b/models/DC_PatchTST.py @@ -0,0 +1,528 @@ +import torch +from torch import nn +import torch.nn.functional as F +from layers.SelfAttention_Family import FullAttention, AttentionLayer + +# 需要 Mamba2 作为外层编码器 +from mamba_ssm.modules.mamba2 import Mamba2 + + + +# -------------------- Routing(余弦路由,和论文一致) -------------------- +class RoutingModule(nn.Module): + def __init__(self, d_model): + super().__init__() + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + with torch.no_grad(): + nn.init.eye_(self.q_proj.weight) + nn.init.eye_(self.k_proj.weight) + self.q_proj.weight._no_reinit = True + self.k_proj.weight._no_reinit = True + + def forward(self, x, mask=None): + """ + x: (B, L, D) + mask: (B, L) bool, True=有效 + 返回: + boundary_prob: (B, L, 2) + boundary_mask: (B, L) bool + selected_probs: (B, L, 1) + """ + B, L, D = x.shape + q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D) + k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D) + cos_sim = (q * k).sum(dim=-1) # (B, L-1) + p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1) + p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界 + + if mask is not None: + p = p * mask.float() + p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0]) + + boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2) + selected_idx = boundary_prob.argmax(dim=-1) + boundary_mask = (selected_idx == 1) + if mask is not None: + boundary_mask = boundary_mask & mask + selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1) + return boundary_prob, boundary_mask, selected_probs + + +# -------------------- 选择并右侧零pad(不丢弃、不重复填充) -------------------- +def select_and_right_pad(x, boundary_mask): + """ + 内存优化版本:减少临时tensor创建 + x: (B, L, D), boundary_mask: (B, L) bool + 返回: + x_pad: (B, T_max, D) + key_padding_mask: (B, T_max) bool, True=有效 + lengths: (B,) + """ + B, L, D = x.shape + device = x.device + lengths = boundary_mask.sum(dim=1) # (B,) + T_max = int(lengths.max().item()) if lengths.max() > 0 else 1 + + x_pad = x.new_zeros(B, T_max, D) + key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device) + + # 预创建默认索引tensor避免重复创建 + default_idx = torch.tensor([0], device=device) + + for b in range(B): + mask_b = boundary_mask[b] + if mask_b.any(): + idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero + t = idx.numel() + x_pad[b, :t] = x[b, idx] + key_padding_mask[b, :t] = True + else: + # 使用预创建的tensor + x_pad[b, 0] = x[b, default_idx] + key_padding_mask[b, 0] = True + + return x_pad, key_padding_mask, lengths + + +# -------------------- Mamba2 堆叠(外层编码器) -------------------- +class Mamba2Encoder(nn.Module): + def __init__(self, d_model, depth=4, dropout=0.0): + super().__init__() + self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)]) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.norm(x) + x = self.dropout(x) + return x + + +# -------------------- 两层Encoder + DC(变长,不丢信息;以比率约束压缩) -------------------- +class DCEmbedding2StageVarLen(nn.Module): + """ + - Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1 + - Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择 + 输出: + enc_out: (B*nvars, T_max, D1) + key_padding_mask: (B*nvars, T_max) + n_vars: int + aux: dict(含两层ratio loss与边界信息) + """ + def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0, + target_ratio0=0.25, target_ratio1=0.5): + super().__init__() + assert d_model_out >= d_model_stage0, "要求 D0 <= D1" + self.d0 = d_model_stage0 + self.d1 = d_model_out + + # 标量 -> D0 + self.input_proj = nn.Linear(1, self.d0) + + # Stage 0 + self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout) + self.router0 = RoutingModule(self.d0) + delta = self.d1 - self.d0 + self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None + self.target_ratio0 = target_ratio0 + + # Stage 1 + self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout) + self.router1 = RoutingModule(self.d1) + self.target_ratio1 = target_ratio1 + + def _expand_width(self, x): + if self.pad_vec is None: + return x + B, L, _ = x.shape + return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1) + + @staticmethod + def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor: + eps = 1e-6 + F_act = boundary_mask.float().mean(dim=1) # (B,) + G_prob = boundary_prob[..., 1].mean(dim=1) # (B,) + N = 1.0 / max(target_ratio, eps) + loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob)) + return loss.mean() + + def forward(self, x): + """ + x: (B, nvars, L) + 内存优化版本:及时删除中间tensor + """ + B, nvars, L = x.shape + x = x.reshape(B * nvars, L, 1) + x = self.input_proj(x) # (B*nvars, L, D0) + + # Stage 0 + h0 = self.enc0(x) # (B*nvars, L, D0) + p0, bm0, _ = self.router0(h0) + h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0) + + # 及时删除不需要的tensor + del h0 + + # h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1) + + # Stage 1 + #h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1) + #p1, bm1, _ = self.router1(h1) + #bm1 = bm1 & mask0 + #h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1) + + # 及时删除中间tensor + #del h1, h0_sel + + # 计算ratio loss时使用detach避免保存计算图 + ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0) + # ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1) + + # 简化aux字典,只保存必要信息 + aux = { + "stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()}, + # "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()}, + "ratio_loss0": ratio_loss0, + # "ratio_loss1": ratio_loss1, + } + + return h0_sel, mask0, nvars, aux + + +# -------------------- Encoder/EncoderLayer(带 key_padding_mask 透传) -------------------- +class EncoderLayerWithMask(nn.Module): + """ + 与原EncoderLayer结构一致,但 forward 增加 key_padding_mask,并传入 AttentionLayer。 + FFN 用简单的 MLP(与常规Transformer一致)。 + """ + def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + if activation == "relu": + act = nn.ReLU() + elif activation == "gelu": + act = nn.GELU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + self.ffn = nn.Sequential( + nn.Linear(d_model, d_ff), + act, + nn.Dropout(dropout), + nn.Linear(d_ff, d_model), + ) + + def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None): + # Multi-head attention with key padding mask + attn_out, attn = self.attention( + x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask + ) + x = x + self.dropout(attn_out) + x = self.norm1(x) + + # FFN + y = self.ffn(x) + x = x + self.dropout(y) + x = self.norm2(x) + return x, attn + + +class EncoderWithMask(nn.Module): + """ + 与原Encoder类似,但 forward 支持 key_padding_mask,并传递给每一层的注意力。 + """ + def __init__(self, attn_layers, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.norm = norm_layer + + def forward(self, x, attn_mask=None, key_padding_mask=None): + attns = [] + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + attns.append(attn) + if self.norm is not None: + x = self.norm(x) + return x, attns + + +# -------------------- 门控注意力聚合 + 任务头(不依赖token数;保留信息) -------------------- +def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + logits: (..., T) + mask: (..., T) bool, True=有效 + """ + neg_inf = torch.finfo(logits.dtype).min + logits = logits.masked_fill(~mask, neg_inf) + return torch.softmax(logits, dim=dim) + +class GatedAttnAggregator(nn.Module): + """ + 门控注意力聚合器(可学习查询 + mask softmax + 值端门控) + 输入: x: (B*, T, D), mask: (B*, T) bool + 输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置) + """ + def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1): + super().__init__() + self.d_model = d_model + self.R = num_slots + self.d_att = d_att or d_model + + # 可学习查询(R个) + self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5)) + + # 线性投影 + self.key_proj = nn.Linear(d_model, self.d_att) + self.val_proj = nn.Linear(d_model, d_model) + + # 值端门控(逐token标量门) + self.gate = nn.Sequential( + nn.Linear(d_model, d_model // 2), + nn.GELU(), + nn.Linear(d_model // 2, 1), + nn.Sigmoid() + ) + + self.dropout = nn.Dropout(dropout) + + def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor: + """ + x_bt_t_d: (B*, T, D) + mask_bt: (B*, T) bool + return: slots (B*, R, D) + """ + BStar, T, D = x_bt_t_d.shape + K = self.key_proj(x_bt_t_d) # (B*, T, d_att) + V = self.val_proj(x_bt_t_d) # (B*, T, D) + g = self.gate(x_bt_t_d) # (B*, T, 1) + Vg = V * g # 门控后的值 + + Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att) + + logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T) + attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T) + attn = masked_softmax(logits, attn_mask, dim=-1) + attn = self.dropout(attn) + + slots = torch.matmul(attn, Vg) # (B*, R, D) + return slots + +class AttnPoolHeadForecast(nn.Module): + """ + 预测任务头:门控注意力聚合到 R 个slots,再映射到 target_window(pred_len) + 输出:(B, pred_len, nvars) + """ + def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1): + super().__init__() + self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) + self.proj = nn.Sequential( + nn.LayerNorm(num_slots * d_model), + nn.Linear(num_slots * d_model, target_window), + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): + slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) + slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D) + out = self.proj(self.dropout(slots)) # (B, nvars, pred_len) + return out.permute(0, 2, 1) # (B, pred_len, nvars) + +class AttnPoolHeadSeq(nn.Module): + """ + 序列重建头:门控注意力聚合后映射到 seq_len + 输出:(B, seq_len, nvars) + """ + def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1): + super().__init__() + self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) + self.proj = nn.Sequential( + nn.LayerNorm(num_slots * d_model), + nn.Linear(num_slots * d_model, target_window), + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): + slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) + slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D) + out = self.proj(self.dropout(slots)) # (B, nvars, seq_len) + return out.permute(0, 2, 1) # (B, seq_len, nvars) + +class AttnPoolHeadCls(nn.Module): + """ + 分类头:每变量先门控注意力聚合到 R 个slots,拼接所有变量后线性分类。 + 输出:(B, num_class) + """ + def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1): + super().__init__() + self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout) + self.dropout = nn.Dropout(dropout) + self.proj = nn.Sequential( + nn.LayerNorm(n_vars * num_slots * d_model), + nn.Linear(n_vars * num_slots * d_model, num_class), + ) + self.n_vars = n_vars + self.num_slots = num_slots + self.d_model = d_model + + def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int): + slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D) + slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D) + flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D) + return self.proj(flat) + + +# -------------------- 主模型:两层DC(比率控制) + 带mask的Encoder + 门控聚合头 -------------------- +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims) + +class Model(nn.Module): + """ + PatchTST with DC and masked attention + gated heads: + - 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding + - DC 使用 ratio loss(target_ratio0/1)控制压缩强度;随层级加深,序列变短,d_model 变大(D0->D1) + - 注意力传入 key_padding_mask 屏蔽pad + - 头部使用门控注意力聚合(不依赖token数,信息保留更充分) + """ + + def __init__( + self, configs, + d_model_stage0=None, # D0,默认= d_model // 2 + depth_enc0=1, depth_enc1=1, + target_ratio0=0.25, # 约等于 1/N0 + target_ratio1=0.5, # 约等于 1/N1 + agg_slots=4, # 门控聚合的slot数 + ): + super().__init__() + self.task_name = configs.task_name + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in + + # DC 嵌入 + D1 = configs.d_model + D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2) + assert D1 >= D0, "要求 D0 <= D1" + self.dc_embedding = DCEmbedding2StageVarLen( + d_model_out=D1, + d_model_stage0=D0, + depth_enc0=depth_enc0, + depth_enc1=depth_enc1, + dropout=configs.dropout, + target_ratio0=target_ratio0, + target_ratio1=target_ratio1, + ) + + # 带mask的Encoder + attn_layers = [ + EncoderLayerWithMask( + AttentionLayer( + FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), + D1, configs.n_heads + ), + d_model=D1, + d_ff=configs.d_ff, + dropout=configs.dropout, + activation=configs.activation + ) for _ in range(configs.e_layers) + ] + self.encoder = EncoderWithMask( + attn_layers, + norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2)) + ) + + # 门控聚合头(与token数无关) + if self.task_name in ('long_term_forecast', 'short_term_forecast'): + self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout) + elif self.task_name in ('imputation', 'anomaly_detection'): + self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout) + elif self.task_name == 'classification': + self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout) + + # --------- 归一化/反归一化 --------- + def _pre_norm(self, x): + means = x.mean(1, keepdim=True).detach() + x = x - means + stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5) + x = x / stdev + return x, means, stdev + + def _denorm(self, y, means, stdev, length): + return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \ + (means[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + + # --------- DC + Transformer Encoder(携带 key_padding_mask) ---------- + def _embed_and_encode(self, x_enc): + """ + x_enc: (B, L, C) + 返回: + enc_out: (B*nvars, T_max, D1) + n_vars: int + key_padding_mask: (B*nvars, T_max) + aux: dict + """ + B, L, C = x_enc.shape + x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L) + enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars) + enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask) + return enc_out, n_vars, key_padding_mask, B, aux + + # --------- 各任务前向 --------- + def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + x_enc, means, stdev = self._pre_norm(x_enc) + enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) + dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars) + dec_out = self._denorm(dec_out, means, stdev, self.pred_len) + return dec_out, aux + + def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): + means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) + means = means.unsqueeze(1).detach() + x = x_enc - means + x = x.masked_fill(mask == 0, 0) + stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5) + stdev = stdev.unsqueeze(1).detach() + x = x / stdev + + enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x) + dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars) + dec_out = self._denorm(dec_out, means, stdev, self.seq_len) + return dec_out, aux + + def anomaly_detection(self, x_enc): + x_enc, means, stdev = self._pre_norm(x_enc) + enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) + dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars) + dec_out = self._denorm(dec_out, means, stdev, self.seq_len) + return dec_out, aux + + def classification(self, x_enc, x_mark_enc): + x_enc, _, _ = self._pre_norm(x_enc) + enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc) + logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class) + return logits, aux + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name in ('long_term_forecast', 'short_term_forecast'): + dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses + if self.task_name == 'imputation': + dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) + return dec_out, aux + if self.task_name == 'anomaly_detection': + dec_out, aux = self.anomaly_detection(x_enc) + return dec_out, aux + if self.task_name == 'classification': + logits, aux = self.classification(x_enc, x_mark_enc) + return logits, aux + return None, None diff --git a/models/DC_hnet.py b/models/DC_hnet.py new file mode 100644 index 0000000..6526f7a --- /dev/null +++ b/models/DC_hnet.py @@ -0,0 +1,339 @@ +from dataclasses import dataclass +from typing import Optional, Literal, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 来自你的代码库(可直接使用) +from hnet.modules.dc import RoutingModule, ChunkLayer +from hnet.modules.isotropic import Isotropic +from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig + +# -------------------- 辅助 -------------------- +def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None): + """创建简化的Isotropic编码器""" + factory_kwargs = {"device": device, "dtype": dtype} + + # 创建HNetConfig,确保list字段有足够的元素 + config = HNetConfig( + arch_layout=[f"{arch}{height}"], + d_model=[d_model], + d_intermediate=[d_model * 2], + ssm_cfg=SSMConfig( + d_conv=4, + expand=2, + d_state=128, + chunk_size=256 + ), + attn_cfg=AttnConfig( + num_heads=[8], # 确保有至少一个元素 + rotary_emb_dim=[0], # 确保有至少一个元素 + window_size=[-1] # 确保有至少一个元素 + ) + ) + + return Isotropic( + config=config, + pos_idx=0, + stage_idx=0, + **factory_kwargs + ) + +def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor: + F_act = boundary_mask.float().mean(dim=1) # (B,) + G_prob = boundary_prob[..., 1].mean(dim=1) # (B,) + N = float(target_N) + loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob)) + return loss.mean() + +def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask_f = mask.float().unsqueeze(-1) # (B, L, 1) + s = (x * mask_f).sum(dim=1) # (B, D) + denom = mask_f.sum(dim=1).clamp_min(1.0) + return s / denom + +# -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 -------------------- +class PyramidEncoders_NoDechunk(nn.Module): + """ + 层级结构(仅编码器逐层压缩;主网络只在最终一层): + 输入 x0: (B, L0, 1) + - 线性升维 -> D0 + For s = 0..S-1: + Es(Mamba2, D_s) -> h_s (B, L_s, D_s) + 路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1} + 维度扩展 D_s -> D_{s+1}(拼接共享向量) + 最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba) + 跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main + """ + def __init__( + self, + d_models: List[int], # [D0, D1, ..., D_S] 单调非降 + encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M') + main_cfg: dict, # 单一主网络配置(在最压缩序列上工作) + fusion_dropout: float = 0.1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + + assert len(d_models) >= 1 + S = len(d_models) - 1 + assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数" + for i in range(S): + assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)" + assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2" + + self.S = S + self.d_models = d_models + + # 输入升维到 D0 + self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs) + + # 每层编码器 + 路由 + 下采样 + 扩宽参数 + self.encoders = nn.ModuleList() + self.routers = nn.ModuleList() + self.chunks = nn.ModuleList() + self.pad_vectors = nn.ParameterList() + for s in range(S): + self.encoders.append( + create_isotropic_encoder( + d_model=d_models[s], + **{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"}, + **factory_kwargs + ) + ) + self.routers.append(RoutingModule(d_models[s], **factory_kwargs)) + self.chunks.append(ChunkLayer()) + delta = d_models[s+1] - d_models[s] + self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs))) + + # 最终唯一的主网络:在 D_S & L_S 上运行 + self.main_network = create_isotropic_encoder( + d_model=d_models[-1], + **{k: v for k, v in main_cfg.items() if k != "d_model"}, + **factory_kwargs + ) + + # 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S + self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs) + self.fusion_head = nn.Sequential( + nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs), + nn.GELU(), + nn.Dropout(fusion_dropout), + nn.Linear(d_models[-1], d_models[-1], **factory_kwargs), + ) + + def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor: + if pad_vec.numel() == 0: + return x + early = x.shape[:-1] + return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1) + + def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False): + """ + x_scalar: (B, L) 或 (B, L, 1) + mask: (B, L) bool + 返回: + fused_vec: (B, D_S) + debug: 可选 + aux: 包含各层路由信息(供ratio loss) + """ + if x_scalar.dim() == 2: + x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1) + B, L, _ = x_scalar.shape + device = x_scalar.device + if mask is None: + mask = torch.ones(B, L, dtype=torch.bool, device=device) + + # 初始升维到 D0 + x = self.input_proj(x_scalar) # (B, L0, D0) + cur_mask = mask + + pooled_enc0 = None + aux_per_stage = [] + seq_debug = [] if return_seq else None + + # 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D + for s in range(self.S): + d_in = self.d_models[s] + # 细粒度编码(未压缩序列) + h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s) + + if s == 0: + pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0) + + # 路由 + 下采样(得到更短序列) + bpred = self.routers[s](h_enc, mask=cur_mask) + x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s) + + # 扩展宽度 D_s -> D_{s+1} + x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1}) + + # 推进到下一层 + x, cur_mask = x_next, mask_next + + aux_per_stage.append({ + "boundary_mask": bpred.boundary_mask, + "boundary_prob": bpred.boundary_prob, + "selected_probs": bpred.selected_probs, + }) + if return_seq: + seq_debug.append({"stage": s, "seq": x, "mask": cur_mask}) + + # 现在 x: (B, L_S, D_S), cur_mask: (B, L_S) + # 最终单一主网络在最压缩序列上 + h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S) + + # 主网络池化 + if cur_mask is None: + pooled_main = h_main.mean(dim=1) # (B, D_S) + else: + pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \ + cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0) + + # 跨尺度融合:E^0 全局池化 与 主网络池化 + pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S) + fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S) + fused = self.fusion_head(fused) # (B, D_S) + + aux = {"per_stage": aux_per_stage} + if return_seq: + return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux + else: + return fused, None, aux + +# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) -------------------- +@dataclass +class HierEncodersSingleMainConfig: + num_channels: int + d_models: List[int] # [D0, D1, ..., D_S] 单调非降 + num_classes: int + encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4) + main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S + target_compression_N_per_stage: List[int] + share_channel: bool = True + fusion_across_channels: Literal["mean", "concat"] = "mean" + dropout: float = 0.1 + +class HierEncodersSingleMainClassifier(nn.Module): + def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None): + super().__init__() + self.cfg = cfg + factory_kwargs = {"dtype": dtype, "device": device} + + S = len(cfg.d_models) - 1 + assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致" + + if cfg.share_channel: + self.channel_encoder = PyramidEncoders_NoDechunk( + d_models=cfg.d_models, + encoder_cfg_per_stage=cfg.encoder_cfg_per_stage, + main_cfg=cfg.main_cfg, + **factory_kwargs, + ) + else: + self.channel_encoder = nn.ModuleList([ + PyramidEncoders_NoDechunk( + d_models=cfg.d_models, + encoder_cfg_per_stage=cfg.encoder_cfg_per_stage, + main_cfg=cfg.main_cfg, + **factory_kwargs, + ) + for _ in range(cfg.num_channels) + ]) + + fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \ + else cfg.d_models[-1] + self.dropout = nn.Dropout(cfg.dropout) + self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False): + """ + x: (B, L, N) 多通道输入 + mask: (B, L) 时序mask + """ + B, L, N = x.shape + assert N == self.cfg.num_channels + + channel_vecs: List[torch.Tensor] = [] + ratio_losses = [] + seq_dbg_all = [] if return_seq else None + + for c in range(N): + x_c = x[..., c] # (B, L) + if self.cfg.share_channel: + vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq) + else: + vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq) + + # ratio loss 累加(每个encoder stage一项) + total_rl = 0.0 + for s, aux_s in enumerate(aux["per_stage"]): + rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s]) + total_rl = total_rl + rl + ratio_losses.append(total_rl) + + channel_vecs.append(vec) + if return_seq: + seq_dbg_all.append(seq_dbg) + + if self.cfg.fusion_across_channels == "concat": + fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S) + else: + fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S) + + fused = self.dropout(fused) + logits = self.head(fused) + + aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()} + if return_seq: + return logits, seq_dbg_all, aux_all + else: + return logits, None, aux_all + +# -------------------- 使用示例 -------------------- +if __name__ == "__main__": + """ + 符合要求: + - 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个 + - 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接) + - 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化 + """ + B, L, N = 8, 1024, 6 + num_classes = 7 + d_models = [128, 256, 512] # D0 <= D1 <= D2 + + encoder_cfg_per_stage = [ + dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2) + dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2) + ] + main_cfg = dict( + arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重) + ) + target_compression_N_per_stage = [4, 4] + + cfg = HierEncodersSingleMainConfig( + num_channels=N, + d_models=d_models, + num_classes=num_classes, + encoder_cfg_per_stage=encoder_cfg_per_stage, + main_cfg=main_cfg, + target_compression_N_per_stage=target_compression_N_per_stage, + share_channel=True, + fusion_across_channels="mean", + dropout=0.1, + ) + + model = HierEncodersSingleMainClassifier(cfg).cuda().train() + x = torch.randn(B, L, N, device="cuda") + mask = torch.ones(B, L, dtype=torch.bool, device="cuda") + + logits, _, aux = model(x, mask=mask, return_seq=False) + y = torch.randint(0, num_classes, (B,), device="cuda") + cls_loss = F.cross_entropy(logits, y) + ratio_reg = 0.03 * aux["ratio_loss"] + loss = cls_loss + ratio_reg + loss.backward() + print("logits:", logits.shape, "loss:", float(loss)) diff --git a/models/vanillaMamba-Copy1.py b/models/vanillaMamba-Copy1.py new file mode 100644 index 0000000..66f9272 --- /dev/null +++ b/models/vanillaMamba-Copy1.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mamba_ssm import Mamba2 + + +class ValueEmbedding(nn.Module): + """ + 对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。 + 不包含 temporal embedding 和 positional embedding。 + """ + def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True): + super().__init__() + self.proj = nn.Linear(in_dim, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, L, 1] -> [B, L, d_model] + return self.dropout(self.proj(x)) + + +class ChannelMambaBlock(nn.Module): + """ + 针对单个通道的两层 Mamba-2 处理块: + - 输入: [B, L, 1],先做投影到 d_model + - 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接 + - 每层后接 LayerNorm + - 输出: [B, L, d_model] + """ + def __init__(self, d_model: int, dropout: float, m2_kwargs: dict): + super().__init__() + self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True) + + # 两层 Mamba-2 + self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs) + self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs) + + # 每层后接的归一化 + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + + def forward(self, x_ch: torch.Tensor) -> torch.Tensor: + # x_ch: [B, L, 1] + x = self.embed(x_ch) # [B, L, d_model] + + # 第一层 + 残差 + y1 = self.mamba1(x) # [B, L, d_model] + y1 = self.ln1(x + y1) # 残差1 + LN + + # 第二层 + 残差 + y2 = self.mamba2(y1) # [B, L, d_model] + y2 = self.ln2(y1 + y2) # 残差2 + LN + + return y2 # [B, L, d_model] + + +class Model(nn.Module): + """ + 按通道独立处理的 Mamba-2 分类模型: + - 将输入的每个通道拆开,分别使用独立的两层 Mamba2(含两处残差) + - 每个通道得到 [B, L, d_model] 输出 + - 取各通道最后时间步的表示拼接,接分类头 + 输入: + - x_enc: [B, L, D] 多变量时间序列 + 输出: + - logits: [B, num_class] + """ + def __init__(self, configs): + super().__init__() + self.task_name = getattr(configs, 'task_name', 'classification') + assert self.task_name == 'classification', "当前模型仅实现 classification 任务" + + # 基本配置 + self.enc_in = configs.enc_in # 通道数 D + self.d_model = configs.d_model # 每通道的模型维度 + self.num_class = configs.num_class + self.dropout = getattr(configs, 'dropout', 0.1) + + # Mamba-2 超参数(按需从 configs 读取) + # 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求 + m2_kwargs = dict( + d_state=getattr(configs, 'd_state', 64), + d_conv=getattr(configs, 'd_conv', 4), + expand=getattr(configs, 'expand', 2), + headdim=getattr(configs, 'headdim', 64), + d_ssm=getattr(configs, 'd_ssm', None), + ngroups=getattr(configs, 'ngroups', 1), + A_init_range=getattr(configs, 'A_init_range', (1, 16)), + D_has_hdim=getattr(configs, 'D_has_hdim', False), + rmsnorm=getattr(configs, 'rmsnorm', True), + norm_before_gate=getattr(configs, 'norm_before_gate', False), + dt_min=getattr(configs, 'dt_min', 0.001), + dt_max=getattr(configs, 'dt_max', 0.1), + dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4), + dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))), + bias=getattr(configs, 'bias', False), + conv_bias=getattr(configs, 'conv_bias', True), + chunk_size=getattr(configs, 'chunk_size', 256), + use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True), + ) + + # 为每个通道构建独立的两层 Mamba2 处理块 + self.channel_blocks = nn.ModuleList([ + ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs) + for _ in range(self.enc_in) + ]) + + # 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear + self.act = nn.GELU() + self.head = nn.Sequential( + nn.Dropout(self.dropout), + nn.Linear(self.d_model * self.enc_in, self.num_class) + ) + + def classification(self, x_enc: torch.Tensor) -> torch.Tensor: + # x_enc: [B, L, D] + B, L, D = x_enc.shape + assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致" + + per_channel_last = [] + for c in range(D): + # 取出单通道序列 [B, L] -> [B, L, 1] + x_ch = x_enc[:, :, c].unsqueeze(-1) + y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model] + per_channel_last.append(y_ch[:, -1, :]) # [B, d_model] + + # 拼接各通道最后时刻的表示 -> [B, D * d_model] + h_last = torch.cat(per_channel_last, dim=-1) + + # 分类头 + h_last = self.act(h_last) + logits = self.head(h_last) # [B, num_class] + return logits + + # 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + return self.classification(x_enc) diff --git a/models/vanillaMamba.py b/models/vanillaMamba.py new file mode 100644 index 0000000..02414d6 --- /dev/null +++ b/models/vanillaMamba.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mamba_ssm import Mamba2 + + +class ValueEmbedding(nn.Module): + """ + 对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。 + 不包含 temporal embedding 和 positional embedding。 + """ + def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True): + super().__init__() + self.proj = nn.Linear(in_dim, d_model, bias=bias) + self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, L, 1] -> [B, L, d_model] + return self.dropout(self.proj(x)) + + +class ChannelMambaBlock(nn.Module): + """ + 针对单个通道的两层 Mamba-2 处理块: + - 输入: [B, L, 1],先做投影到 d_model + - 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接 + - 每层后接 LayerNorm + - 输出: [B, L, d_model] + """ + def __init__(self, d_model: int, dropout: float, m2_kwargs: dict): + super().__init__() + self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True) + + # 两层 Mamba-2 + self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs) + self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs) + + # 每层后接的归一化 + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + + def forward(self, x_ch: torch.Tensor) -> torch.Tensor: + # x_ch: [B, L, 1] + x = self.embed(x_ch) # [B, L, d_model] + + # 第一层 + 残差 + y1 = self.mamba1(x) # [B, L, d_model] + y1 = self.ln1(x + y1) # 残差1 + LN + + # 第二层 + 残差 + y2 = self.mamba2(y1) # [B, L, d_model] + y2 = self.ln2(y1 + y2) # 残差2 + LN + + return y2 # [B, L, d_model] + + +class Model(nn.Module): + """ + 按通道独立处理的 Mamba-2 模型,支持: + - 分类:各通道独立提取,取最后时刻拼接 -> 分类头 + - 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接 + 注意:预测输出通道数与输入通道数严格相同(逐通道预测)。 + + 输入: + - x_enc: [B, L, D] 多变量时间序列 + - x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注) + + 输出: + - classification: logits [B, num_class] + - forecast: [B, pred_len, D] + """ + def __init__(self, configs): + super().__init__() + # 任务类型 + self.task_name = getattr(configs, 'task_name', 'classification') + assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \ + "只支持 classification / long_term_forecast / short_term_forecast" + + # 基本配置 + self.enc_in = configs.enc_in # 通道数 D + self.d_model = configs.d_model # 每通道的模型维度 + self.num_class = getattr(configs, 'num_class', None) + self.dropout = getattr(configs, 'dropout', 0.1) + + # 预测相关 + self.seq_len = getattr(configs, 'seq_len', None) + self.pred_len = getattr(configs, 'pred_len', None) + if self.task_name in ['long_term_forecast', 'short_term_forecast']: + assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len" + # 输出通道必须与输入通道一致 + self.c_out = getattr(configs, 'c_out', self.enc_in) + assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致" + + # Mamba-2 超参数 + m2_kwargs = dict( + d_state=getattr(configs, 'd_state', 64), + d_conv=getattr(configs, 'd_conv', 4), + expand=getattr(configs, 'expand', 2), + headdim=getattr(configs, 'headdim', 64), + d_ssm=getattr(configs, 'd_ssm', None), + ngroups=getattr(configs, 'ngroups', 1), + A_init_range=getattr(configs, 'A_init_range', (1, 16)), + D_has_hdim=getattr(configs, 'D_has_hdim', False), + rmsnorm=getattr(configs, 'rmsnorm', True), + norm_before_gate=getattr(configs, 'norm_before_gate', False), + dt_min=getattr(configs, 'dt_min', 0.001), + dt_max=getattr(configs, 'dt_max', 0.1), + dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4), + dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))), + bias=getattr(configs, 'bias', False), + conv_bias=getattr(configs, 'conv_bias', True), + chunk_size=getattr(configs, 'chunk_size', 256), + use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True), + ) + + # 为每个通道构建独立的两层 Mamba2 处理块 + self.channel_blocks = nn.ModuleList([ + ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs) + for _ in range(self.enc_in) + ]) + + # 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear + if self.task_name == 'classification': + assert self.num_class is not None, "classification 需要提供 num_class" + self.act = nn.GELU() + self.head = nn.Sequential( + nn.Dropout(self.dropout), + nn.Linear(self.d_model * self.enc_in, self.num_class) + ) + + # 预测头: + # - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model] + # - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1] + if self.task_name in ['long_term_forecast', 'short_term_forecast']: + self.predict_linear = nn.Linear(self.seq_len, self.pred_len) + self.projection = nn.Linear(self.d_model, 1, bias=True) + + def classification(self, x_enc: torch.Tensor) -> torch.Tensor: + # x_enc: [B, L, D] + B, L, D = x_enc.shape + assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致" + + per_channel_last = [] + for c in range(D): + # 取出单通道序列 [B, L] -> [B, L, 1] + x_ch = x_enc[:, :, c].unsqueeze(-1) + y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model] + per_channel_last.append(y_ch[:, -1, :]) # [B, d_model] + + # 拼接各通道最后时刻的表示 -> [B, D * d_model] + h_last = torch.cat(per_channel_last, dim=-1) + + # 分类头 + logits = self.head(self.act(h_last)) # [B, num_class] + return logits + + def forecast(self, x_enc: torch.Tensor) -> torch.Tensor: + """ + 逐通道预测: + - 归一化(时间维),按通道独立提取 + - 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量 + - 反归一化 + 返回: + dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段 + """ + B, L, D = x_enc.shape + assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致" + assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致" + + # Normalization (per Non-stationary Transformer) + means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D] + x = x_enc - means + stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D] + x = x / stdev + + per_channel_seq = [] + for c in range(D): + x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1] + h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model] + # 时间维映射到 L + pred_len + h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model] + # 投影回单通道 + y_ch = self.projection(h_ch) # [B, L+pred_len, 1] + per_channel_seq.append(y_ch) + + # 拼接通道 + dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D] + + # De-normalization + dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1) + + return dec_out + + # 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): + if self.task_name in ['long_term_forecast', 'short_term_forecast']: + dec_out = self.forecast(x_enc) # [B, L+pred_len, D] + return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D] + elif self.task_name == 'classification': + return self.classification(x_enc) + else: + raise NotImplementedError(f"Unsupported task: {self.task_name}") diff --git a/run.py b/run.py index d3d4ca1..ee47d42 100644 --- a/run.py +++ b/run.py @@ -7,6 +7,7 @@ from exp.exp_imputation import Exp_Imputation from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast from exp.exp_anomaly_detection import Exp_Anomaly_Detection from exp.exp_classification import Exp_Classification +from exp.exp_dc_patchtst_classification import Exp_DC_PatchTST_Classification from utils.print_args import print_args import random import numpy as np @@ -191,7 +192,10 @@ if __name__ == '__main__': elif args.task_name == 'anomaly_detection': Exp = Exp_Anomaly_Detection elif args.task_name == 'classification': - Exp = Exp_Classification + if args.model == 'DC_PatchTST': + Exp = Exp_DC_PatchTST_Classification + else: + Exp = Exp_Classification else: Exp = Exp_Long_Term_Forecast diff --git a/scripts/classification/DC_PatchTST.sh b/scripts/classification/DC_PatchTST.sh new file mode 100755 index 0000000..78d580f --- /dev/null +++ b/scripts/classification/DC_PatchTST.sh @@ -0,0 +1,142 @@ +export CUDA_VISIBLE_DEVICES=0 + +model_name=DC_PatchTST + + + +# DC_PatchTST specific parameters +d_model_stage0=64 # Stage 0 dimension (D0) +depth_enc0=1 # Stage 0 Mamba2 encoder depth +depth_enc1=1 # Stage 1 Mamba2 encoder depth +target_ratio0=0.25 # Target compression ratio for stage 0 +target_ratio1=0.25 # Target compression ratio for stage 1 + +# EthanolConcentration dataset +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/EthanolConcentration/ \ + --model_id EthanolConcentration \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 8 \ + --dropout 0.1 \ + --activation gelu \ + --des 'DC_PatchTST_Exp' \ + --itr 1 \ + --learning_rate 0.0002 \ + --train_epochs 100 \ + --patience 10 \ + --d_model_stage0 $d_model_stage0 \ + --depth_enc0 $depth_enc0 \ + --depth_enc1 $depth_enc1 \ + --target_ratio0 $target_ratio0 \ + --target_ratio1 $target_ratio1 + +# FaceDetection dataset +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/FaceDetection/ \ + --model_id FaceDetection \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 8 \ + --dropout 0.1 \ + --activation gelu \ + --des 'DC_PatchTST_Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --d_model_stage0 $d_model_stage0 \ + --depth_enc0 $depth_enc0 \ + --depth_enc1 $depth_enc1 \ + --target_ratio0 $target_ratio0 \ + --target_ratio1 $target_ratio1 + +# Handwriting dataset +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Handwriting/ \ + --model_id Handwriting \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 8 \ + --dropout 0.1 \ + --activation gelu \ + --des 'DC_PatchTST_Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --d_model_stage0 $d_model_stage0 \ + --depth_enc0 $depth_enc0 \ + --depth_enc1 $depth_enc1 \ + --target_ratio0 $target_ratio0 \ + --target_ratio1 $target_ratio1 + +# Heartbeat dataset +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Heartbeat/ \ + --model_id Heartbeat \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 8 \ + --dropout 0.1 \ + --activation gelu \ + --des 'DC_PatchTST_Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --d_model_stage0 $d_model_stage0 \ + --depth_enc0 $depth_enc0 \ + --depth_enc1 $depth_enc1 \ + --target_ratio0 $target_ratio0 \ + --target_ratio1 $target_ratio1 + +# JapaneseVowels dataset +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/JapaneseVowels/ \ + --model_id JapaneseVowels \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 8 \ + --dropout 0.1 \ + --activation gelu \ + --des 'DC_PatchTST_Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --d_model_stage0 $d_model_stage0 \ + --depth_enc0 $depth_enc0 \ + --depth_enc1 $depth_enc1 \ + --target_ratio0 $target_ratio0 \ + --target_ratio1 $target_ratio1 \ No newline at end of file diff --git a/scripts/classification/vanillaMamba_classification.sh b/scripts/classification/vanillaMamba_classification.sh new file mode 100644 index 0000000..c2f61b5 --- /dev/null +++ b/scripts/classification/vanillaMamba_classification.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +# vanillaMamba Classification Training Script for Multiple Datasets +export CUDA_VISIBLE_DEVICES=0 + +model_name=vanillaMamba + +# Create results directory if it doesn't exist +mkdir -p ./results + +# UWaveGestureLibrary dataset (seq_len=315, enc_in=3) - use Copy1 config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/UWaveGestureLibrary/ \ + --model_id UWaveGestureLibrary \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 315 \ + --enc_in 3 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 128 \ + --dropout 0.1 \ + --des 'vanillaMamba_UWaveGestureLibrary' \ + --itr 1 \ + --learning_rate 0.002 \ + --train_epochs 150 \ + --patience 30 \ + --revin 0 | tee ./results/vanillaMamba_UWaveGestureLibrary.log + +# EthanolConcentration dataset (seq_len=1751, enc_in=3) - use Copy1 config +python -u run.py \ + --task_name classification \ + --is_training 3 \ + --root_path ./dataset/EthanolConcentration/ \ + --model_id EthanolConcentration \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 1751 \ + --enc_in 4 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_EthanolConcentration' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 200 \ + --patience 30 \ + --revin 0 | tee ./results/vanillaMamba_EthanolConcentration.log + +# Handwriting dataset (seq_len=152, enc_in=3) - use Copy1 config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Handwriting/ \ + --model_id Handwriting \ + --model $model_name \ + --data UEA \ + --e_layers 4 \ + --batch_size 64 \ + --seq_len 152 \ + --enc_in 3 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_Handwriting' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 200 \ + --patience 30 \ + --revin 0 | tee ./results/vanillaMamba_Handwriting.log + +# JapaneseVowels dataset (seq_len=29, enc_in=12) - use Copy1 config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/JapaneseVowels/ \ + --model_id JapaneseVowels \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 29 \ + --enc_in 12 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_JapaneseVowels' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 | tee ./results/vanillaMamba_JapaneseVowels.log + +# PEMS-SF dataset (seq_len=144, enc_in=963) - use Copy1 config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/PEMS-SF/ \ + --model_id PEMS-SF \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 16 \ + --seq_len 144 \ + --enc_in 963 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_PEMS-SF' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 150 \ + --patience 30 \ + --revin 0 | tee ./results/vanillaMamba_PEMS-SF.log + +# Heartbeat dataset (seq_len=405, enc_in=61) - use original config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Heartbeat/ \ + --model_id Heartbeat \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 405 \ + --enc_in 61 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_Heartbeat' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 150 \ + --patience 10 \ + --revin 0 | tee ./results/vanillaMamba_Heartbeat.log + +# FaceDetection dataset (seq_len=62, enc_in=144) - use original config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/FaceDetection/ \ + --model_id FaceDetection \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 62 \ + --enc_in 144 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_FaceDetection' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --revin 0 | tee ./results/vanillaMamba_FaceDetection.log + +# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6) - use original config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SelfRegulationSCP1/ \ + --model_id SelfRegulationSCP1 \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 896 \ + --enc_in 6 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_SelfRegulationSCP1' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP1.log + +# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7) - use original config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SelfRegulationSCP2/ \ + --model_id SelfRegulationSCP2 \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 1152 \ + --enc_in 7 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_SelfRegulationSCP2' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP2.log + +# SpokenArabicDigits dataset (seq_len=93, enc_in=13) - use original config +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SpokenArabicDigits/ \ + --model_id SpokenArabicDigits \ + --model $model_name \ + --data UEA \ + --e_layers 3 \ + --batch_size 64 \ + --seq_len 93 \ + --enc_in 13 \ + --d_model 128 \ + --d_state 64 \ + --d_conv 4 \ + --expand 2 \ + --headdim 64 \ + --dropout 0.1 \ + --des 'vanillaMamba_SpokenArabicDigits' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 10 \ + --revin 0 | tee ./results/vanillaMamba_SpokenArabicDigits.log \ No newline at end of file diff --git a/scripts/classification/xPatch_SparseChannel-Copy1.sh b/scripts/classification/xPatch_SparseChannel-Copy1.sh new file mode 100644 index 0000000..3605560 --- /dev/null +++ b/scripts/classification/xPatch_SparseChannel-Copy1.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# xPatch_SparseChannel Classification Training Script for Multiple Datasets +export CUDA_VISIBLE_DEVICES=0 + +model_name=xPatch_SparseChannel + +# Create results directory if it doesn't exist +mkdir -p ./results + + + +# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/UWaveGestureLibrary/ \ + --model_id UWaveGestureLibrary \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 315 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_UWaveGestureLibrary' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log + + + +# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/EthanolConcentration/ \ + --model_id EthanolConcentration \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 1751 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_EthanolConcentration' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log + +# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Handwriting/ \ + --model_id Handwriting \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 152 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_Handwriting' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log + +# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/JapaneseVowels/ \ + --model_id JapaneseVowels \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 29 \ + --enc_in 12 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_JapaneseVowels' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log + + + +# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/PEMS-SF/ \ + --model_id PEMS-SF \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 16 \ + --seq_len 144 \ + --enc_in 963 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_PEMS-SF' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log diff --git a/scripts/classification/xPatch_SparseChannel.sh b/scripts/classification/xPatch_SparseChannel.sh index 544a4a5..e7d4006 100644 --- a/scripts/classification/xPatch_SparseChannel.sh +++ b/scripts/classification/xPatch_SparseChannel.sh @@ -1,10 +1,66 @@ #!/bin/bash -# xPatch_SparseChannel Classification Training Script for FaceDetection Dataset +# xPatch_SparseChannel Classification Training Script for Multiple Datasets export CUDA_VISIBLE_DEVICES=0 model_name=xPatch_SparseChannel +# Create results directory if it doesn't exist +mkdir -p ./results + +# Heartbeat dataset (seq_len=405, enc_in=61, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Heartbeat/ \ + --model_id Heartbeat \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 405 \ + --enc_in 61 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_Heartbeat' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 5 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_Heartbeat.log + +# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/UWaveGestureLibrary/ \ + --model_id UWaveGestureLibrary \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 315 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_UWaveGestureLibrary' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log + +# FaceDetection dataset (seq_len=62, enc_in=144, k_graph=8) python -u run.py \ --task_name classification \ --is_training 1 \ @@ -12,21 +68,202 @@ python -u run.py \ --model_id FaceDetection \ --model $model_name \ --data UEA \ - --e_layers 3 \ + --e_layers 2 \ --batch_size 64 \ --seq_len 62 \ --enc_in 144 \ --d_model 128 \ --d_ff 256 \ - --n_heads 8 \ + --n_heads 16 \ --patch_len 16 \ --stride 8 \ - --moving_avg 25 \ --dropout 0.1 \ --des 'xPatch_SparseChannel_FaceDetection' \ --itr 1 \ --learning_rate 0.0005 \ --train_epochs 100 \ --patience 5 \ - --revin 1 \ - --k_graph 8 \ No newline at end of file + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_FaceDetection.log + +# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/EthanolConcentration/ \ + --model_id EthanolConcentration \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 1751 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_EthanolConcentration' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log + +# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/Handwriting/ \ + --model_id Handwriting \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 152 \ + --enc_in 3 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_Handwriting' \ + --itr 1 \ + --learning_rate 0.001 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log + +# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/JapaneseVowels/ \ + --model_id JapaneseVowels \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 29 \ + --enc_in 12 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_JapaneseVowels' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log + +# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6, k_graph=6) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SelfRegulationSCP1/ \ + --model_id SelfRegulationSCP1 \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 896 \ + --enc_in 6 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_SelfRegulationSCP1' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 5 \ + --revin 0 \ + --k_graph 6 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP1.log + +# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7, k_graph=7) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SelfRegulationSCP2/ \ + --model_id SelfRegulationSCP2 \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 1152 \ + --enc_in 7 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_SelfRegulationSCP2' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 5 \ + --revin 0 \ + --k_graph 7 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP2.log + +# SpokenArabicDigits dataset (seq_len=93, enc_in=13, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/SpokenArabicDigits/ \ + --model_id SpokenArabicDigits \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 64 \ + --seq_len 93 \ + --enc_in 13 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_SpokenArabicDigits' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 5 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_SpokenArabicDigits.log + +# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8) +python -u run.py \ + --task_name classification \ + --is_training 1 \ + --root_path ./dataset/PEMS-SF/ \ + --model_id PEMS-SF \ + --model $model_name \ + --data UEA \ + --e_layers 2 \ + --batch_size 16 \ + --seq_len 144 \ + --enc_in 963 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --dropout 0.1 \ + --des 'xPatch_SparseChannel_PEMS-SF' \ + --itr 1 \ + --learning_rate 0.0005 \ + --train_epochs 100 \ + --patience 30 \ + --revin 0 \ + --k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log diff --git a/scripts/long_term_forecast/vanillaMamba_all.sh b/scripts/long_term_forecast/vanillaMamba_all.sh new file mode 100644 index 0000000..d07e474 --- /dev/null +++ b/scripts/long_term_forecast/vanillaMamba_all.sh @@ -0,0 +1,251 @@ +#!/bin/bash + +model_name=vanillaMamba + +# ETTm1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm1.csv \ + --model_id ETTm1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTm2 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm2.csv \ + --model_id ETTm2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm2 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTh1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh1.csv \ + --model_id ETTh1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTh2 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh2.csv \ + --model_id ETTh2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh2 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# Weather dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/weather/ \ + --data_path weather.csv \ + --model_id weather_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 21 \ + --c_out 21 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# ECL dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/electricity/ \ + --data_path electricity.csv \ + --model_id ECL_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 321 \ + --c_out 321 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# Traffic dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/traffic/ \ + --data_path traffic.csv \ + --model_id traffic_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 862 \ + --c_out 862 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done + +# Exchange dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/exchange_rate/ \ + --data_path exchange_rate.csv \ + --model_id Exchange_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 8 \ + --c_out 8 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 +done \ No newline at end of file diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh b/scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh new file mode 100644 index 0000000..942f123 --- /dev/null +++ b/scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +#export CUDA_VISIBLE_DEVICES=0 + +model_name=xPatch_SparseChannel + +seq_len=96 +pred_len=12 +learning_rate=0.003 +d_model=128 +d_ff=256 +batch_size=128 +train_epochs=10 +patience=10 + +# PEMS03 dataset +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/PEMS/ \ + --data_path PEMS03.npz \ + --model_id PEMS03 \ + --model $model_name \ + --data PEMS \ + --features M \ + --seq_len $seq_len \ + --label_len 0 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 358 \ + --dec_in 358 \ + --c_out 358 \ + --lradj 'sigmoid' \ + --d_model $d_model \ + --d_ff $d_ff \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --batch_size $batch_size \ + --learning_rate $learning_rate \ + --train_epochs $train_epochs \ + --patience $patience \ + --des 'Exp' \ + --itr 1 + +# PEMS04 dataset +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/PEMS/ \ + --data_path PEMS04.npz \ + --model_id PEMS04 \ + --model $model_name \ + --data PEMS \ + --features M \ + --seq_len $seq_len \ + --label_len 0 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 307 \ + --dec_in 307 \ + --c_out 307 \ + --lradj 'sigmoid' \ + --d_model $d_model \ + --d_ff $d_ff \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --batch_size $batch_size \ + --learning_rate $learning_rate \ + --train_epochs $train_epochs \ + --patience $patience \ + --des 'Exp' \ + --itr 1 + +# PEMS07 dataset +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/PEMS/ \ + --data_path PEMS07.npz \ + --model_id PEMS07 \ + --model $model_name \ + --data PEMS \ + --features M \ + --seq_len $seq_len \ + --label_len 0 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 883 \ + --dec_in 883 \ + --c_out 883 \ + --lradj 'sigmoid' \ + --d_model $d_model \ + --d_ff $d_ff \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --batch_size $batch_size \ + --learning_rate $learning_rate \ + --train_epochs $train_epochs \ + --patience $patience \ + --des 'Exp' \ + --itr 1 + +# PEMS08 dataset +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/PEMS/ \ + --data_path PEMS08.npz \ + --model_id PEMS08 \ + --model $model_name \ + --data PEMS \ + --features M \ + --seq_len $seq_len \ + --label_len 0 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 170 \ + --dec_in 170 \ + --c_out 170 \ + --lradj 'sigmoid' \ + --d_model $d_model \ + --d_ff $d_ff \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --batch_size $batch_size \ + --learning_rate $learning_rate \ + --train_epochs $train_epochs \ + --patience $patience \ + --des 'Exp' \ + --itr 1 \ No newline at end of file diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh new file mode 100644 index 0000000..ae2c04d --- /dev/null +++ b/scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh @@ -0,0 +1,251 @@ +#!/bin/bash + +model_name=xPatch_SparseChannel + +# ETTm1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm1.csv \ + --model_id ETTm1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 7 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTm2 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm2.csv \ + --model_id ETTm2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm2 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 7 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTh1 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh1.csv \ + --model_id ETTh1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh1 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 7 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# ETTh2 dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh2.csv \ + --model_id ETTh2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh2 \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --c_out 7 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 7 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# Weather dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/weather/ \ + --data_path weather.csv \ + --model_id weather_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 21 \ + --c_out 21 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# ECL dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/electricity/ \ + --data_path electricity.csv \ + --model_id ECL_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 321 \ + --c_out 321 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# Traffic dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/traffic/ \ + --data_path traffic.csv \ + --model_id traffic_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 862 \ + --c_out 862 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# Exchange dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/exchange_rate/ \ + --data_path exchange_rate.csv \ + --model_id Exchange_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 8 \ + --c_out 8 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done \ No newline at end of file diff --git a/scripts/long_term_forecast/xPatch_SparseChannel_all.sh b/scripts/long_term_forecast/xPatch_SparseChannel_all.sh new file mode 100644 index 0000000..f870afe --- /dev/null +++ b/scripts/long_term_forecast/xPatch_SparseChannel_all.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +model_name=xPatch_SparseChannel + +# ECL dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/electricity/ \ + --data_path electricity.csv \ + --model_id ECL_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 321 \ + --batch_size 512 \ + --learning_rate 0.001 \ + --use_multi_gpu True \ + --lradj 'sigmoid' \ + --train_epochs 50 \ + --patience 5 \ + --c_out 321 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done + +# Traffic dataset +for pred_len in 96 192 336 720 +do +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/traffic/ \ + --data_path traffic.csv \ + --model_id traffic_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len 96 \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 862 \ + --batch_size 256 \ + --learning_rate 0.0005 \ + --use_multi_gpu True \ + --lradj 'sigmoid' \ + --train_epochs 50 \ + --patience 5 \ + --c_out 862 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 8 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 +done diff --git a/scripts/short_term_forecast/vanillaMamba_M4.sh b/scripts/short_term_forecast/vanillaMamba_M4.sh new file mode 100644 index 0000000..a8e8406 --- /dev/null +++ b/scripts/short_term_forecast/vanillaMamba_M4.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +model_name=vanillaMamba + +# M4 Monthly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Monthly' \ + --model_id m4_Monthly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Yearly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Yearly' \ + --model_id m4_Yearly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Quarterly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Quarterly' \ + --model_id m4_Quarterly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Weekly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Weekly' \ + --model_id m4_Weekly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Daily +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Daily' \ + --model_id m4_Daily \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Hourly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Hourly' \ + --model_id m4_Hourly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --expand 2 \ + --d_conv 4 \ + --d_state 64 \ + --headdim 64 \ + --ngroups 1 \ + --chunk_size 256 \ + --dropout 0.1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' \ No newline at end of file diff --git a/scripts/short_term_forecast/xPatch_SparseChannel_M4.sh b/scripts/short_term_forecast/xPatch_SparseChannel_M4.sh new file mode 100644 index 0000000..59b778a --- /dev/null +++ b/scripts/short_term_forecast/xPatch_SparseChannel_M4.sh @@ -0,0 +1,165 @@ +#!/bin/bash + +model_name=xPatch_SparseChannel + +# M4 Monthly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Monthly' \ + --model_id m4_Monthly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Yearly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Yearly' \ + --model_id m4_Yearly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Quarterly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Quarterly' \ + --model_id m4_Quarterly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Weekly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Weekly' \ + --model_id m4_Weekly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Daily +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Daily' \ + --model_id m4_Daily \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +# M4 Hourly +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Hourly' \ + --model_id m4_Hourly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --d_ff 256 \ + --n_heads 16 \ + --patch_len 16 \ + --stride 8 \ + --k_graph 1 \ + --dropout 0.1 \ + --revin 1 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' \ No newline at end of file diff --git a/test_DC_hnet.py b/test_DC_hnet.py new file mode 100644 index 0000000..25ecfa0 --- /dev/null +++ b/test_DC_hnet.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +测试DC_hnet模型的脚本 +用于验证时间序列分类模型能否正常运行并得到期望的输出形状 +""" + +import torch +import torch.nn.functional as F +import sys +import os + +# 添加当前目录到Python路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from models.DC_hnet import HierEncodersSingleMainConfig, HierEncodersSingleMainClassifier + +def test_dc_hnet_model(): + """测试DC_hnet时间序列分类模型""" + print("=" * 60) + print("测试DC_hnet时间序列分类模型") + print("=" * 60) + + # 设置设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + + # 模型参数配置 + B, L, N = 8, 512, 6 # batch_size, seq_length, num_channels + num_classes = 10 # 分类数 + d_models = [64, 128, 256] # 各层维度,单调递增 + + print(f"输入形状: (B={B}, L={L}, N={N})") + print(f"分类数: {num_classes}") + print(f"模型维度: {d_models}") + + # 编码器配置(每层都是Mamba) + encoder_cfg_per_stage = [ + dict(arch="m", height=2), # stage 0: Mamba2, 2层 + dict(arch="m", height=3), # stage 1: Mamba2, 3层 + ] + + # 主网络配置(使用Transformer) + main_cfg = dict( + arch="T", height=4 # Transformer, 4层 + ) + + # 压缩目标 + target_compression_N_per_stage = [2, 3] # 每层压缩比例 + + # 创建配置 + cfg = HierEncodersSingleMainConfig( + num_channels=N, + d_models=d_models, + num_classes=num_classes, + encoder_cfg_per_stage=encoder_cfg_per_stage, + main_cfg=main_cfg, + target_compression_N_per_stage=target_compression_N_per_stage, + share_channel=True, # 通道间共享编码器 + fusion_across_channels="mean", # 通道融合方式 + dropout=0.1, + ) + + print("配置创建完成") + + try: + # 创建模型 - 设置正确的dtype以兼容flash attention + dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + model = HierEncodersSingleMainClassifier(cfg, device=device, dtype=dtype) + model = model.to(device) + print(f"模型创建成功,参数量: {sum(p.numel() for p in model.parameters()):,}") + + # 创建随机输入数据 - 使用bfloat16以兼容flash attention + dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + x = torch.randn(B, L, N, device=device, dtype=dtype) + mask = torch.ones(B, L, dtype=torch.bool, device=device) + + print(f"输入数据形状: {x.shape}, 数据类型: {x.dtype}") + + # 前向传播测试 + print("\n开始前向传播测试...") + model.eval() + with torch.no_grad(): + logits, seq_debug, aux = model(x, mask=mask, return_seq=False) + + print(f"✅ 前向传播成功!") + print(f"输出logits形状: {logits.shape}") # 应该是 (B, num_classes) + print(f"ratio_loss: {aux['ratio_loss']:.4f}") + + # 验证输出形状 + expected_shape = (B, num_classes) + if logits.shape == expected_shape: + print(f"✅ 输出形状正确: {logits.shape}") + else: + print(f"❌ 输出形状错误: 期望 {expected_shape}, 实际 {logits.shape}") + return False + + # 测试训练模式 + print("\n开始训练模式测试...") + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + # 创建目标标签 + y = torch.randint(0, num_classes, (B,), device=device) + + # 前向传播 + logits, _, aux = model(x, mask=mask, return_seq=False) + + # 计算损失 + cls_loss = F.cross_entropy(logits, y) + ratio_reg = 0.01 * aux["ratio_loss"] # ratio loss正则化 + total_loss = cls_loss + ratio_reg + + print(f"分类损失: {cls_loss:.4f}") + print(f"ratio损失: {ratio_reg:.4f}") + print(f"总损失: {total_loss:.4f}") + + # 反向传播 + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + print("✅ 训练步骤成功!") + + # 测试序列返回功能 + print("\n测试序列调试信息返回...") + with torch.no_grad(): + logits, seq_debug, aux = model(x, mask=mask, return_seq=True) + + if seq_debug is not None: + print(f"✅ 序列调试信息获取成功,包含 {len(seq_debug)} 个通道的信息") + else: + print("❌ 序列调试信息获取失败") + + print("\n" + "=" * 60) + print("🎉 DC_hnet模型测试全部通过!") + print("模型可以正常进行时间序列分类任务") + print("=" * 60) + + return True + + except Exception as e: + print(f"❌ 测试失败: {str(e)}") + import traceback + traceback.print_exc() + return False + +def test_different_configurations(): + """测试不同的模型配置""" + print("\n测试不同配置...") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 测试配置1: 不共享通道 + cfg1 = HierEncodersSingleMainConfig( + num_channels=3, + d_models=[32, 64], + num_classes=5, + encoder_cfg_per_stage=[dict(arch="m", height=2)], + main_cfg=dict(arch="m", height=3), + target_compression_N_per_stage=[2], + share_channel=False, # 不共享通道 + fusion_across_channels="concat", # 连接融合 + dropout=0.1, + ) + + try: + dtype1 = torch.bfloat16 if device.type == "cuda" else torch.float32 + model1 = HierEncodersSingleMainClassifier(cfg1, device=device, dtype=dtype1) + x1 = torch.randn(4, 256, 3, device=device, dtype=dtype1) + logits1, _, _ = model1(x1) + print(f"✅ 配置1 (不共享通道, concat融合): 输出形状 {logits1.shape}") + except Exception as e: + print(f"❌ 配置1测试失败: {str(e)}") + + # 测试配置2: 单层模型 + cfg2 = HierEncodersSingleMainConfig( + num_channels=2, + d_models=[128], # 只有一层,没有编码器阶段 + num_classes=3, + encoder_cfg_per_stage=[], # 空的编码器阶段 + main_cfg=dict(arch="T", height=2), + target_compression_N_per_stage=[], + share_channel=True, + fusion_across_channels="mean", + dropout=0.1, + ) + + try: + dtype2 = torch.bfloat16 if device.type == "cuda" else torch.float32 + model2 = HierEncodersSingleMainClassifier(cfg2, device=device, dtype=dtype2) + x2 = torch.randn(2, 128, 2, device=device, dtype=dtype2) + logits2, _, _ = model2(x2) + print(f"✅ 配置2 (单层模型): 输出形状 {logits2.shape}") + except Exception as e: + print(f"❌ 配置2测试失败: {str(e)}") + +if __name__ == "__main__": + # 主测试 + success = test_dc_hnet_model() + + # 额外配置测试 + test_different_configurations() + + if success: + print("\n🎊 所有测试完成!") + sys.exit(0) + else: + print("\n💥 测试失败!") + sys.exit(1) \ No newline at end of file diff --git a/test_dc_patchtst.py b/test_dc_patchtst.py new file mode 100644 index 0000000..c86548b --- /dev/null +++ b/test_dc_patchtst.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import pandas as pd +import numpy as np +from torch.utils.data import Dataset, DataLoader +import os +import matplotlib.pyplot as plt +from models.DC_PatchTST import Model + +class Config: + """Configuration class""" + def __init__(self): + # Basic configuration + self.task_name = 'long_term_forecast' + self.model = 'DC_PatchTST' + + # Data configuration + self.seq_len = 96 # Input sequence length + self.pred_len = 24 # Prediction sequence length + self.label_len = 48 # Label length + self.enc_in = 2 # Input feature dimension (dual channel) + self.dec_in = 2 # Decoder input dimension + self.c_out = 2 # Output dimension + + # Model configuration + self.d_model = 128 # Model dimension + self.n_heads = 8 # Number of attention heads + self.e_layers = 2 # Number of encoder layers + self.d_layers = 1 # Number of decoder layers + self.d_ff = 256 # Feed forward dimension + self.factor = 1 # Attention factor + self.dropout = 0.1 # Dropout rate + self.activation = 'gelu' + + # Training configuration + self.batch_size = 32 + self.learning_rate = 0.001 + self.train_epochs = 50 + self.patience = 5 + + # Other configuration + self.use_amp = False + self.num_class = 0 + + # GPU configuration + self.use_gpu = torch.cuda.is_available() + self.device = torch.device('cuda' if self.use_gpu else 'cpu') + +class SineWaveDataset(Dataset): + """Sine wave dataset""" + def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'): + self.seq_len = seq_len + self.pred_len = pred_len + self.mode = mode + + # Load data + if mode == 'train': + df = pd.read_csv(os.path.join(data_path, 'train.csv')) + elif mode == 'val': + df = pd.read_csv(os.path.join(data_path, 'val.csv')) + else: # test + df = pd.read_csv(os.path.join(data_path, 'test.csv')) + + # Extract feature columns (except timestamp) + self.data = df[['channel1', 'channel2']].values.astype(np.float32) + + # Calculate available sample count + self.total_len = len(self.data) + self.samples_num = max(0, self.total_len - seq_len - pred_len + 1) + + print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples") + + def __len__(self): + return self.samples_num + + def __getitem__(self, idx): + # Input sequence + s_begin = idx + s_end = s_begin + self.seq_len + + # Prediction target + r_begin = s_end + r_end = r_begin + self.pred_len + + seq_x = self.data[s_begin:s_end] # (seq_len, n_vars) + seq_y = self.data[r_begin:r_end] # (pred_len, n_vars) + + # Time marks (simple positional encoding) + seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32) + seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32) + + return seq_x, seq_y, seq_x_mark, seq_y_mark, idx + +def load_model(model_path): + """Load saved model""" + checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) + config = checkpoint['config'] + + # Create model + model = Model(config) + model.load_state_dict(checkpoint['model_state_dict']) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model.eval() + + print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}") + print(f"Using device: {device}") + + return model, config, device + +def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5): + """Predict and visualize results, marking chunk points""" + model.eval() + + # Create save directory + vis_dir = os.path.join(save_path, 'visualizations') + os.makedirs(vis_dir, exist_ok=True) + + sample_count = 0 + + with torch.no_grad(): + for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader: + if sample_count >= num_samples: + break + + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + batch_x_mark = batch_x_mark.to(device) + batch_y_mark = batch_y_mark.to(device) + + # Construct decoder input + dec_inp = torch.zeros_like(batch_y).to(device) + + # Predict + outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + + # Convert to CPU + batch_x_cpu = batch_x.cpu().numpy() + batch_y_cpu = batch_y.cpu().numpy() + outputs_cpu = outputs.cpu().numpy() + + # Process each sample in batch + for i in range(min(batch_x.size(0), num_samples - sample_count)): + input_seq = batch_x_cpu[i] # (seq_len, 2) + true_pred = batch_y_cpu[i] # (pred_len, 2) + pred_seq = outputs_cpu[i] # (pred_len, 2) + + # Get first layer chunk information + boundary_mask_stage0 = None + if aux is not None and 'stage0' in aux: + # aux['stage0']['boundary_mask'] shape: (B*nvars, L) + # We need to reshape to (B, nvars, L) then take i-th sample + stage0_info = aux['stage0'] + boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L) + + B = batch_x.size(0) + nvars = batch_x.size(2) # should be 2 + L = boundary_mask.shape[1] + + # Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L) + boundary_mask = boundary_mask.reshape(B, nvars, L) + boundary_mask_stage0 = boundary_mask[i] # (nvars, L) + + # Create visualization for each channel + fig, axes = plt.subplots(2, 1, figsize=(15, 10)) + + for ch in range(2): # dual channel + ax = axes[ch] + + # Time axis + input_time = np.arange(len(input_seq)) + pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred)) + + # Plot input sequence + ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5) + + # Plot ground truth prediction and model prediction + ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2) + ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2) + + # Mark first layer chunk points + if boundary_mask_stage0 is not None: + chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel + for point in chunk_points: + if point < len(input_seq): # Only mark points within input sequence range + ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1) + + # Add chunk points explanation in legend + if len(chunk_points) > 0: + ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7, + linewidth=1, label=f'Chunk Points (Stage 0)') + + ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}') + ax.set_xlabel('Time Steps') + ax.set_ylabel('Value') + ax.legend() + ax.grid(True, alpha=0.3) + + # Add boundary line marking input and prediction boundary + ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1) + ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start', + rotation=90, verticalalignment='top', fontsize=8) + + plt.tight_layout() + + # Save figure + sample_filename = f'sample_{sample_count + 1}_with_chunks.png' + sample_path = os.path.join(vis_dir, sample_filename) + plt.savefig(sample_path, dpi=300, bbox_inches='tight') + plt.close() + + print(f"Sample {sample_count + 1} visualization saved to: {sample_path}") + + # Print chunk statistics + if boundary_mask_stage0 is not None: + for ch in range(2): + chunk_count = np.sum(boundary_mask_stage0[ch]) + chunk_ratio = chunk_count / len(boundary_mask_stage0[ch]) + print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)") + + sample_count += 1 + if sample_count >= num_samples: + break + + if sample_count >= num_samples: + break + +def evaluate_model(model, test_loader, device): + """Evaluate model performance""" + model.eval() + total_mse = 0.0 + total_mae = 0.0 + batch_count = 0 + + all_predictions = [] + all_ground_truths = [] + + with torch.no_grad(): + for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader: + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + batch_x_mark = batch_x_mark.to(device) + batch_y_mark = batch_y_mark.to(device) + + dec_inp = torch.zeros_like(batch_y).to(device) + + outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + + # Calculate loss + mse = torch.mean((outputs - batch_y) ** 2) + mae = torch.mean(torch.abs(outputs - batch_y)) + + total_mse += mse.item() + total_mae += mae.item() + batch_count += 1 + + # Collect for overall statistics + all_predictions.append(outputs.cpu().numpy()) + all_ground_truths.append(batch_y.cpu().numpy()) + + avg_mse = total_mse / batch_count + avg_mae = total_mae / batch_count + + # Calculate overall RMSE + all_predictions = np.concatenate(all_predictions, axis=0) + all_ground_truths = np.concatenate(all_ground_truths, axis=0) + rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2)) + + print(f"\nTest set evaluation results:") + print(f"MSE: {avg_mse:.6f}") + print(f"MAE: {avg_mae:.6f}") + print(f"RMSE: {rmse:.6f}") + + return avg_mse, avg_mae, rmse + +def main(): + # Configuration parameters + model_path = './results/dc_patchtst_sine_wave/best_model.pth' + data_path = './data/sine_wave/' + save_path = './results/dc_patchtst_sine_wave/' + + if not os.path.exists(model_path): + print(f"Error: Model file not found {model_path}") + print("Please run the training script train_dc_patchtst.py first") + return + + # Load model + model, config, device = load_model(model_path) + + # Load test data + test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test') + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization + + print(f"\nConfiguration info:") + print(f"Sequence length: {config.seq_len}") + print(f"Prediction length: {config.pred_len}") + print(f"Feature dimension: {config.enc_in}") + + # Evaluate model + mse, mae, rmse = evaluate_model(model, test_loader, device) + + # Visualize prediction results and mark chunk points + print(f"\nGenerating visualization results...") + visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5) + + print(f"\nAll results saved to: {save_path}") + print(f"Visualization files located at: {save_path}/visualizations/") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_dc_patchtst.py b/train_dc_patchtst.py new file mode 100644 index 0000000..11ed4bf --- /dev/null +++ b/train_dc_patchtst.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +import torch +import torch.nn as nn +import pandas as pd +import numpy as np +from torch.utils.data import Dataset, DataLoader +import os +import argparse +from models.DC_PatchTST import Model +import time +import matplotlib.pyplot as plt +from sklearn.metrics import mean_squared_error, mean_absolute_error + +# 检查并创建结果文件夹 +def ensure_dir(path): + if not os.path.exists(path): + os.makedirs(path) + +class SineWaveDataset(Dataset): + """正弦波数据集""" + def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'): + self.seq_len = seq_len + self.pred_len = pred_len + self.mode = mode + + # 加载数据 + if mode == 'train': + df = pd.read_csv(os.path.join(data_path, 'train.csv')) + elif mode == 'val': + df = pd.read_csv(os.path.join(data_path, 'val.csv')) + else: # test + df = pd.read_csv(os.path.join(data_path, 'test.csv')) + + # 提取特征列(除timestamp外) + self.data = df[['channel1', 'channel2']].values.astype(np.float32) + + # 计算可用样本数量 + self.total_len = len(self.data) + self.samples_num = max(0, self.total_len - seq_len - pred_len + 1) + + print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本") + + def __len__(self): + return self.samples_num + + def __getitem__(self, idx): + # 输入序列 + s_begin = idx + s_end = s_begin + self.seq_len + + # 预测目标 + r_begin = s_end + r_end = r_begin + self.pred_len + + seq_x = self.data[s_begin:s_end] # (seq_len, n_vars) + seq_y = self.data[r_begin:r_end] # (pred_len, n_vars) + + # 时间标记(简单的位置编码) + seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32) + seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32) + + return seq_x, seq_y, seq_x_mark, seq_y_mark + +class Config: + """配置类""" + def __init__(self): + # 基础配置 + self.task_name = 'long_term_forecast' + self.model = 'DC_PatchTST' + + # 数据配置 + self.seq_len = 96 # 输入序列长度 + self.pred_len = 24 # 预测序列长度 + self.label_len = 48 # 标签长度 + self.enc_in = 2 # 输入特征维度(双通道) + self.dec_in = 2 # 解码器输入维度 + self.c_out = 2 # 输出维度 + + # 模型配置 + self.d_model = 128 # 模型维度 + self.n_heads = 8 # 注意力头数 + self.e_layers = 2 # 编码器层数 + self.d_layers = 1 # 解码器层数 + self.d_ff = 256 # 前向网络维度 + self.factor = 1 # 注意力因子 + self.dropout = 0 # Dropout率 + self.activation = 'gelu' + + # 训练配置 + self.batch_size = 1024 + self.learning_rate = 0.001 + self.train_epochs = 50 + self.patience = 5 + + # 其他配置 + self.use_amp = False + self.num_class = 0 + + # GPU配置 + self.use_gpu = torch.cuda.is_available() + self.device = torch.device('cuda' if self.use_gpu else 'cpu') + +def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False): + """训练一个epoch""" + model.train() + total_loss = 0.0 + batch_count = 0 + + for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + batch_x_mark = batch_x_mark.to(device) + batch_y_mark = batch_y_mark.to(device) + + # 构造解码器输入 + dec_inp = torch.zeros_like(batch_y).to(device) + + optimizer.zero_grad() + + if use_amp: + with torch.cuda.amp.autocast(): + outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + loss = criterion(outputs, batch_y) + + # 添加DC的ratio loss + if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux: + ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1'] + loss = loss + 0.0 * ratio_loss # ratio loss权重 + else: + outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + loss = criterion(outputs, batch_y) + + # 添加DC的ratio loss + if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux: + ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1'] + loss = loss + 0.0 * ratio_loss # ratio loss权重 + + loss.backward() + optimizer.step() + + total_loss += loss.item() + batch_count += 1 + + if i % 100 == 0: + print(f'Batch {i}, Loss: {loss.item():.6f}') + + return total_loss / batch_count + +def validate(model, val_loader, criterion, device): + """验证模型""" + model.eval() + total_loss = 0.0 + batch_count = 0 + + with torch.no_grad(): + for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader: + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + batch_x_mark = batch_x_mark.to(device) + batch_y_mark = batch_y_mark.to(device) + + dec_inp = torch.zeros_like(batch_y).to(device) + + outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + loss = criterion(outputs, batch_y) + + total_loss += loss.item() + batch_count += 1 + + return total_loss / batch_count + +def test_model(model, test_loader, device, save_path): + """测试模型并可视化结果""" + model.eval() + predictions = [] + ground_truths = [] + + with torch.no_grad(): + for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader: + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + batch_x_mark = batch_x_mark.to(device) + batch_y_mark = batch_y_mark.to(device) + + dec_inp = torch.zeros_like(batch_y).to(device) + + outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + + predictions.append(outputs.cpu().numpy()) + ground_truths.append(batch_y.cpu().numpy()) + + predictions = np.concatenate(predictions, axis=0) + ground_truths = np.concatenate(ground_truths, axis=0) + + # 计算指标 + mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1)) + mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1)) + + print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}") + + # 可视化前几个样本 + plt.figure(figsize=(15, 10)) + for i in range(min(4, len(predictions))): + for ch in range(2): # 双通道 + plt.subplot(4, 2, i*2 + ch + 1) + plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue') + plt.plot(predictions[i, :, ch], label='Prediction', color='red') + plt.title(f'Sample {i+1}, Channel {ch+1}') + plt.legend() + + plt.tight_layout() + plt.savefig(os.path.join(save_path, 'predictions.png')) + print(f"预测结果可视化保存到: {save_path}/predictions.png") + + return mse, mae + +def main(): + # 配置参数 + config = Config() + + # 创建结果目录 + results_dir = './results/dc_patchtst_sine_wave' + ensure_dir(results_dir) + + print(f"使用设备: {config.device}") + print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}") + + # 加载数据 + data_path = './data/sine_wave/' + train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train') + val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val') + test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test') + + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4) + test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4) + + # 创建模型 + model = Model(config).to(config.device) + print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}") + + # 优化器和损失函数 + optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4) + criterion = nn.MSELoss() + + # 训练循环 + best_val_loss = float('inf') + patience_counter = 0 + train_losses = [] + val_losses = [] + + print("开始训练...") + start_time = time.time() + + for epoch in range(config.train_epochs): + epoch_start = time.time() + + # 训练 + train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp) + + # 验证 + val_loss = validate(model, val_loader, criterion, config.device) + + train_losses.append(train_loss) + val_losses.append(val_loss) + + epoch_time = time.time() - epoch_start + + print(f'Epoch {epoch+1:2d}/{config.train_epochs} | ' + f'Train Loss: {train_loss:.6f} | ' + f'Val Loss: {val_loss:.6f} | ' + f'Time: {epoch_time:.2f}s') + + # 早停检查 + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + # 保存最佳模型 + torch.save({ + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + 'val_loss': val_loss + }, os.path.join(results_dir, 'best_model.pth')) + print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})') + else: + patience_counter += 1 + if patience_counter >= config.patience: + print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}') + break + + total_time = time.time() - start_time + print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟') + + # 加载最佳模型进行测试 + checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth')) + model.load_state_dict(checkpoint['model_state_dict']) + + print("\n测试最佳模型...") + test_mse, test_mae = test_model(model, test_loader, config.device, results_dir) + + # 保存训练历史 + history = { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'test_mse': test_mse, + 'test_mae': test_mae + } + + # 绘制训练曲线 + plt.figure(figsize=(12, 4)) + plt.subplot(1, 2, 1) + plt.plot(train_losses, label='Train Loss') + plt.plot(val_losses, label='Validation Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + plt.title('Training History') + + plt.subplot(1, 2, 2) + plt.bar(['MSE', 'MAE'], [test_mse, test_mae]) + plt.title('Test Metrics') + plt.ylabel('Error') + + plt.tight_layout() + plt.savefig(os.path.join(results_dir, 'training_history.png')) + + print(f"\n结果保存在: {results_dir}") + print(f"最佳模型: {results_dir}/best_model.pth") + print(f"训练历史: {results_dir}/training_history.png") + print(f"预测可视化: {results_dir}/predictions.png") + +if __name__ == "__main__": + main() \ No newline at end of file