Compare commits

14 Commits

Author SHA1 Message Date
93f14077da feat(model): introduce dynamic training flag for model forward pass 2025-09-13 00:04:01 +08:00
172328a4e6 feat: implement dynamic threshold scheduling for GraphMixer 2025-09-12 17:02:42 +08:00
6a1f9d30f3 refactor(scripts): update xPatch_SparseChannel forecast scripts 2025-09-11 17:11:17 +08:00
204d17086a feat(graph-mixer): implement L0 sparsity with Hard-Concrete gate for channel selection 2025-09-11 16:50:58 +08:00
5fc0da4239 refactor(mamba): add residual connection to Mamba2Encoder layers 2025-09-10 21:23:17 +08:00
598fdaadbc feat(mamba): extract last time step from Mamba2Encoder output 2025-09-10 21:03:35 +08:00
b139f711bc refactor(mamba): adjust Mamba2Encoder layer configuration 2025-09-10 21:00:31 +08:00
9787badd25 feat(mambaseries): allow stacking multiple Mamba2 layers 2025-09-10 20:56:24 +08:00
ff987da4c6 refactor(xPatch): remove redundant encoder argument from season_net call 2025-09-10 16:03:43 +08:00
1044c60fe7 refactor(SeasonPatch): unify encoder and head initialization 2025-09-10 16:00:52 +08:00
908d3a7080 feat(mamba): add Mamba2 encoder option to SeasonPatch 2025-09-10 15:51:28 +08:00
96c40c6ab6 feat: add xpatch_sparsechannel test script 2025-09-10 10:54:50 +08:00
9f7fb24beb refactor(graphmixer): enhance channel graph attention with ST-Gumbel 2025-09-06 00:06:26 +08:00
ef307a57e9 feat: add mamba and dynamic chunking related code and test code 2025-09-04 01:32:13 +00:00
27 changed files with 5101 additions and 210 deletions

View File

@ -11,6 +11,7 @@ import warnings
import numpy as np import numpy as np
from utils.dtw_metric import dtw, accelerated_dtw from utils.dtw_metric import dtw, accelerated_dtw
from utils.augmentation import run_augmentation, run_augmentation_single from utils.augmentation import run_augmentation, run_augmentation_single
import inspect
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@ -18,9 +19,18 @@ warnings.filterwarnings('ignore')
class Exp_Long_Term_Forecast(Exp_Basic): class Exp_Long_Term_Forecast(Exp_Basic):
def __init__(self, args): def __init__(self, args):
super(Exp_Long_Term_Forecast, self).__init__(args) super(Exp_Long_Term_Forecast, self).__init__(args)
self._model_supports_training_flag = False
def _build_model(self): def _build_model(self):
model = self.model_dict[self.args.model].Model(self.args).float() model = self.model_dict[self.args.model].Model(self.args).float()
# 如果模型被 DataParallel 包装,我们需要检查原始模型
model_to_inspect = model
# inspect.signature() 可以获取函数或方法的参数信息
forward_signature = inspect.signature(model_to_inspect.forward)
# 检查'training'是否在参数列表中
if 'training' in forward_signature.parameters:
self._model_supports_training_flag = True
print("Model supports 'training' flag.")
if self.args.use_multi_gpu and self.args.use_gpu: if self.args.use_multi_gpu and self.args.use_gpu:
model = nn.DataParallel(model, device_ids=self.args.device_ids) model = nn.DataParallel(model, device_ids=self.args.device_ids)
@ -63,12 +73,16 @@ class Exp_Long_Term_Forecast(Exp_Basic):
# decoder input # decoder input
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder # --- 修改模型调用部分 ---
model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
model_kwargs = {}
if self._model_supports_training_flag:
model_kwargs['training'] = False # 验证阶段为 False
if self.args.use_amp: if self.args.use_amp:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
else: else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
f_dim = -1 if self.args.features == 'MS' else 0 f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, f_dim:] outputs = outputs[:, -self.args.pred_len:, f_dim:]
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
@ -130,19 +144,20 @@ class Exp_Long_Term_Forecast(Exp_Basic):
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
model_kwargs = {}
if self._model_supports_training_flag:
model_kwargs['training'] = True # 训练阶段为 True
if self.args.use_amp: if self.args.use_amp:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
f_dim = -1 if self.args.features == 'MS' else 0 f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, f_dim:] outputs = outputs[:, -self.args.pred_len:, f_dim:]
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
loss = criterion(outputs, batch_y) loss = criterion(outputs, batch_y)
train_loss.append(loss.item()) train_loss.append(loss.item())
else: else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
f_dim = -1 if self.args.features == 'MS' else 0 f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, f_dim:] outputs = outputs[:, -self.args.pred_len:, f_dim:]
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
@ -208,12 +223,15 @@ class Exp_Long_Term_Forecast(Exp_Basic):
# decoder input # decoder input
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder model_args = (batch_x, batch_x_mark, dec_inp, batch_y_mark)
model_kwargs = {}
if self._model_supports_training_flag:
model_kwargs['training'] = False # 测试阶段为 False
if self.args.use_amp: if self.args.use_amp:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
else: else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) outputs = self.model(*model_args, **model_kwargs)
f_dim = -1 if self.args.features == 'MS' else 0 f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, :] outputs = outputs[:, -self.args.pred_len:, :]

102
generate_sine_data.py Normal file
View File

@ -0,0 +1,102 @@
import numpy as np
import pandas as pd
import os
def generate_sine_wave_data(n_samples=10000, seq_len=200, n_channels=2, save_path='./data/sine_wave/'):
"""
生成双通道正弦波时序数据
Args:
n_samples: 总样本数
seq_len: 每个序列长度
n_channels: 通道数 (固定为2)
save_path: 保存路径
"""
if not os.path.exists(save_path):
os.makedirs(save_path)
# 生成时间轴
t = np.linspace(0, 4*np.pi, seq_len)
all_data = []
for i in range(n_samples):
# 为每个样本生成不同周期和相位的正弦波
# 通道1: 随机周期和相位
freq1 = np.random.uniform(0.5, 3.0) # 频率范围
phase1 = np.random.uniform(0, 2*np.pi) # 相位
amplitude1 = np.random.uniform(0.5, 2.0) # 幅度
# 通道2: 不同的随机周期和相位
freq2 = np.random.uniform(0.3, 2.5)
phase2 = np.random.uniform(0, 2*np.pi)
amplitude2 = np.random.uniform(0.8, 1.8)
# 生成正弦波数据
channel1 = amplitude1 * np.sin(freq1 * t + phase1)
channel2 = amplitude2 * np.sin(freq2 * t + phase2)
# 添加少量噪声
# noise1 = np.random.normal(0, 0.1, seq_len)
# noise2 = np.random.normal(0, 0.1, seq_len)
#
# channel1 += noise1
# channel2 += noise2
# 组合数据: [timestamp, channel1, channel2]
timestamp = np.arange(seq_len)
sample_data = np.column_stack([timestamp, channel1, channel2])
all_data.append(sample_data)
# 转换为连续的时间序列格式
continuous_data = []
current_time = 0
for sample in all_data:
sample[:, 0] = current_time + sample[:, 0] # 调整时间戳
continuous_data.append(sample)
current_time += seq_len
# 合并所有数据
full_data = np.vstack(continuous_data)
# 创建DataFrame
df = pd.DataFrame(full_data, columns=['timestamp', 'channel1', 'channel2'])
# 按 8:1:1 比例分割训练、验证、测试集
total_len = len(df)
train_end = int(0.8 * total_len)
val_end = int(0.9 * total_len)
train_df = df[:train_end]
val_df = df[train_end:val_end]
test_df = df[val_end:]
# 保存数据
train_df.to_csv(os.path.join(save_path, 'train.csv'), index=False)
val_df.to_csv(os.path.join(save_path, 'val.csv'), index=False)
test_df.to_csv(os.path.join(save_path, 'test.csv'), index=False)
# 保存完整数据
df.to_csv(os.path.join(save_path, 'sine_wave.csv'), index=False)
print(f"数据已生成并保存到 {save_path}")
print(f"训练集: {len(train_df)} 条记录")
print(f"验证集: {len(val_df)} 条记录")
print(f"测试集: {len(test_df)} 条记录")
print(f"总计: {len(df)} 条记录,{n_channels} 个通道")
return df
if __name__ == "__main__":
# 生成数据
data = generate_sine_wave_data(
n_samples=200, # 2000个不同的正弦波样本
seq_len=200, # 每个样本200个时间点
n_channels=2, # 双通道
save_path='./data/sine_wave/'
)
print("\n数据统计信息:")
print(data.describe())

440
layers/DynamicChunking.py Normal file
View File

@ -0,0 +1,440 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
@dataclass
class RoutingModuleOutput:
# 路由模块的前向输出:
# - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2)
# - boundary_mask: 基于最大概率得到的硬选择True=分隔点),形状与输入序列相同的前两维
# - selected_probs: 对应硬选择类别的概率,形状(..., 1)
boundary_prob: torch.Tensor
boundary_mask: torch.Tensor
selected_probs: torch.Tensor
@dataclass
class RoutingModuleState:
"""
路由模块的推理状态(用于增量/流式step
包含:
- has_seen_tokens: (batch_size,) 是否已经见过任意token用于首token强制边界
- last_hidden_state: (batch_size, d_model) 上一次的隐藏状态用于与当前token做相邻相似度
"""
has_seen_tokens: torch.Tensor # (batch_size,)
last_hidden_state: torch.Tensor # (batch_size, d_model)
@dataclass
class DeChunkState:
"""
DeChunk 的推理状态EMA的记忆值
包含:
- last_value: (batch_size, d_model) EMA反聚合的上一时刻值
"""
last_value: torch.Tensor # (batch_size, d_model)
def get_seq_idx(cu_seqlens, device=None):
seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device)
seq_idx[cu_seqlens[:-1]] = 1
seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int()
return seq_idx
class RoutingModule(nn.Module):
"""
路由模块:
用相邻token的余弦相似度构造“成为分隔点”的概率
p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1)
并强制首位置为边界(概率=1
支持:
- 常规batch掩码 mask 模式
- packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接)
- 流式推理 step()(维护状态)
"""
def __init__(self, d_model, device=None, dtype=None):
self.d_model = d_model
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# 相邻相似度计算前的线性投影(初始化为恒等)
self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
with torch.no_grad():
self.q_proj_layer.weight.copy_(torch.eye(d_model))
self.k_proj_layer.weight.copy_(torch.eye(d_model))
# 防止外部权重再初始化
self.q_proj_layer.weight._no_reinit = True
self.k_proj_layer.weight._no_reinit = True
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配推理cache用于step
return RoutingModuleState(
has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool),
last_hidden_state=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None):
"""
hidden_states:
- 若 cu_seqlens is None: (B, L, D)
- 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D)
cu_seqlens: packed模式下每条序列的前缀和下标形如 [0, len1, len1+len2, ...]
mask: (B, L) boolTrue=有效非packed
inference_params: RoutingModuleState用于prefill时的校验与状态维护
"""
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if inference_params is not None:
# prefill阶段必须提供mask且不允许之前已经见过token
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert (
~inference_params.has_seen_tokens
).all(), "Cannot have seen tokens when inference_params is not provided"
if cu_seqlens is not None:
# packed 模式:把 (T, D) 临时变为 (1, T, D)
hidden_states = hidden_states.unsqueeze(0)
# 计算相邻余弦相似度 cos(h_{t-1}, h_t)
cos_sim = torch.einsum(
"b l d, b l d -> b l",
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
)
# p = ((1 - cos) / 2) ∈ [0,1]
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 强制首位置为边界:首位概率=1补充后长度和输入序列长度想等
PAD_PROB = 1.0
boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB)
if cu_seqlens is not None:
# packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1]
boundary_prob = boundary_prob.squeeze(0)
boundary_prob[cu_seqlens[:-1]] = PAD_PROB
# 组装为二分类概率 [非边界, 边界]
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 取最大概率类别
selected_idx = torch.argmax(boundary_prob, dim=-1)
# 硬边界掩码
boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致
if mask is not None:
# 不允许选择到无效token
boundary_mask = boundary_mask & mask
if inference_params is not None:
# 维护路由状态是否见过token、最后一个有效token的隐藏状态
has_mask = mask.any(dim=-1)
inference_params.has_seen_tokens.copy_(
has_mask | inference_params.has_seen_tokens
)
last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0)
inference_params.last_hidden_state.copy_(
torch.where(
has_mask,
hidden_states[
torch.arange(
hidden_states.shape[0], device=hidden_states.device
),
last_mask,
],
inference_params.last_hidden_state,
)
)
# 取硬选择对应的概率(便于可视化/正则)
selected_probs = boundary_prob.gather(
dim=-1, index=selected_idx.unsqueeze(-1)
) # (..., 1)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (..., 2)
boundary_mask=boundary_mask, # (...)
selected_probs=selected_probs, # (..., 1)
)
def step(self, hidden_states, inference_params):
"""
流式单步:
hidden_states: (B, 1, D)
使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度得到当前步的边界概率
"""
# (B, D)
hidden_states = hidden_states.squeeze(1)
cos_sim = torch.einsum(
"b d, b d -> b",
F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1),
F.normalize(self.k_proj_layer(hidden_states), dim=-1),
)
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
# 更新最后隐藏状态
inference_params.last_hidden_state.copy_(hidden_states)
# 首个token前强制边界
boundary_prob = torch.where(
inference_params.has_seen_tokens,
boundary_prob,
torch.ones_like(boundary_prob),
)
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
# 标记为已见token
inference_params.has_seen_tokens.copy_(
torch.ones_like(inference_params.has_seen_tokens)
)
return RoutingModuleOutput(
boundary_prob=boundary_prob, # (B, 2)
boundary_mask=boundary_prob[..., 1] > 0.5, # (B,)
selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1)
)
class ChunkLayer(nn.Module):
"""
Chunk层根据 boundary_mask 将被选中的“边界token”抽取出来形成下一层序列。
支持两种模式:
- packedcu_seqlens 非 None直接在拼接后的序列上索引
- 非packedmask 非 None通过排序 trick 把True位置排到前面并生成 next_mask
返回:
- next_hidden_states: 选中的token序列packed: shape=(#selected, D)非packed: (B, M, D)
- next_cu_seqlens: packed模式下新序列的cu_seqlens否则None
- next_max_seqlen: packed模式下选中的最大长度非packed模式返回None
- next_mask: 非packed模式下的右侧pad掩码packed模式下None
"""
def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None):
assert (mask is not None) or (
cu_seqlens is not None
), "Either mask or cu_seqlens must be provided"
if cu_seqlens is not None:
# packed直接选择True的行得到拼接后的 selected
next_hidden_states = hidden_states[boundary_mask]
# 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数再前置0
next_cu_seqlens = F.pad(
boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0)
)
# 新序列的最大段长(仅用于内核/优化)
next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max())
next_mask = None
else:
# 非packed对每个batch内把True位置排到前面False放到靠后
next_cu_seqlens = None
num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量
next_max_seqlen = int(num_tokens.max())
device = hidden_states.device
L = hidden_states.shape[1]
# trick用 (~boundary_mask)*L 把False加大从而 argsort 后 True 的下标排在前面
token_idx = (
torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 收集前 next_max_seqlen 个不足的样本右侧pad
next_hidden_states = torch.gather(
hidden_states,
dim=1,
index=seq_sorted_indices[:, :next_max_seqlen, None].expand(
-1, -1, hidden_states.shape[-1]
),
)
# 下游的有效mask右侧pad无效
next_mask = (
torch.arange(next_max_seqlen, device=device)[None, :]
< num_tokens[:, None]
)
# 非packed模式下不再需要 max_seqlen返回None
next_max_seqlen = None
return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask
def step(self, hidden_states, boundary_mask):
# 流式step仅返回当前步被选中的token用于下一层
return hidden_states[boundary_mask]
class DeChunkLayer(nn.Module):
"""
DeChunk层把“被选中的边界token序列”反聚合EMA回原始等长序列。
实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined
- 将 d_model 切分为 nheads * headdim
- 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描
- 最终把扫描输出根据分段索引映射回原位置plug back
支持:
- packed 模式cu_seqlens
- 非packedbatch+右侧pad
- 流式 stepEMA递推
"""
def __init__(
self,
d_model,
dtype=torch.bfloat16,
block_size=256,
headdim=32,
):
super().__init__()
self.d_model = d_model
# 仅为内核要求:使用 bfloat16块大小与头维拆分
self.dtype = dtype
self.block_size = block_size
self.headdim = headdim
assert d_model % self.headdim == 0
self.nheads = d_model // self.headdim
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
# 分配EMA的last_value缓存
return DeChunkState(
last_value=torch.zeros(
batch_size, self.d_model, device=device, dtype=dtype
),
)
def forward(
self,
hidden_states, # 被选中的token序列packed: (M, D)非packed: (B, M, D)
boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L)
boundary_prob, # 原序列上的二分类概率((..., 2)
cu_seqlens=None,
inference_params=None,
mask=None,
):
"""
核心思路:
1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4)
2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt
3) 用 mamba_chunk_scan_combined 扫描A=-1, b=p, c=1对 (B, M, H, P) 进行块扫描
4) 将结果根据 cumulative boundary index 回填到原序列位置
"""
if inference_params is not None:
# prefill时必须有mask且首token必须是边界保证EMA初始化
assert (
mask is not None
), "Mask must be provided if inference_params is provided"
assert boundary_mask[
:, 0
].all(), "First token must be a boundary if running prefill"
# 取边界概率的“边界类”概率 p并限制在(1e-4, 1-1e-4)内,避免数值不稳
p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4))
if cu_seqlens is not None:
# packed从原序列p中取出被选中的位置对应的概率形状(B=1, M)
p = p[boundary_mask].unsqueeze(0)
# 为triton核准备packed序列的索引映射
seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device)
else:
B, L = boundary_mask.shape
seq_idx = None
# 与ChunkLayer一致的排序 trick得到选中的顺序True在前
token_idx = (
torch.arange(L, device=hidden_states.device)[None, :]
+ (~boundary_mask).long() * L
)
seq_sorted_indices = torch.argsort(token_idx, dim=1)
# 取出与 hidden_states 对应长度的 p(B, M)
p = torch.gather(
p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]]
) # (B, M)
original_dtype = hidden_states.dtype
# 构造 EMA 扫描所需变量
dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M)
x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1)
# A, b, c 分别对应一阶状态空间/EMA的参数
A = -torch.ones(
(self.nheads,), device=hidden_states.device, dtype=torch.float32
)
b = p.to(self.dtype)
c = torch.ones_like(b)
# 调用triton核进行块扫描
out = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P)
repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H)
A, # (H,)
rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1)
rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1)
chunk_size=self.block_size,
seq_idx=seq_idx, # packed时提供
)
out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D)
# 将扫描结果回填plug back到原序列位置
if cu_seqlens is not None:
out = out.squeeze(0) # (M, D)
plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,)
out = torch.gather(
out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model)
) # (T, D)
else:
plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L)
out = torch.gather(
out,
dim=1,
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model),
) # (B, L, D)
# 更新流式缓存
if inference_params is not None:
inference_params.last_value.copy_(out[:, -1])
return out.to(original_dtype)
def step(self, hidden_states, boundary_mask, boundary_prob, inference_params):
"""
流式单步 EMA 反聚合:
hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量boundary_mask.sum()
boundary_mask: (B,) 当前batch哪些位置被选中为边界
boundary_prob: (B, 2) 当前batch各位置的边界概率
输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值
"""
B = boundary_mask.shape[0]
D = hidden_states.shape[-1]
# 构造当前步每个位置的 p未被选中的位置 p=0
p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype)
p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp(
min=1e-4, max=1 - (1e-4)
)
# 构造当前被选中的隐藏状态未选中为0
current_hidden_states = torch.zeros(
B, D, device=hidden_states.device, dtype=hidden_states.dtype
)
current_hidden_states[boundary_mask] = hidden_states.squeeze(1)
# EMAresult = p * x + (1 - p) * last
result = p * current_hidden_states + (1 - p) * inference_params.last_value
inference_params.last_value.copy_(result)
return result.unsqueeze(1)

View File

@ -1,23 +1,94 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
class HardConcreteGate(nn.Module):
"""
Hard-Concrete gate for L0-style sparsity (Louizos et al., 2017).
Produces z in [0,1] without row-wise normalization.
"""
def __init__(self, shape, temperature=2./3., gamma=-0.1, zeta=1.1, init_log_alpha=-2.0):
super().__init__()
self.log_alpha = nn.Parameter(torch.full(shape, init_log_alpha))
self.temperature = temperature
self.gamma = gamma
self.zeta = zeta
def sample(self, training=True):
if training:
u = torch.rand_like(self.log_alpha)
s = torch.sigmoid((self.log_alpha + torch.log(u) - torch.log(1 - u)) / self.temperature)
else:
# deterministic mean gate at eval
s = torch.sigmoid(self.log_alpha)
s_bar = s * (self.zeta - self.gamma) + self.gamma
z = torch.clamp(s_bar, 0., 1.)
return z
def expected_l0(self):
"""
E[1_{z>0}] closed-form for hard-concrete.
Useful for L0 penalty: lambda * expected_l0.sum()
"""
# s > t0 => z > 0, where t0 = -gamma / (zeta - gamma)
t0 = -self.gamma / (self.zeta - self.gamma)
# logit(t0)
logit_t0 = math.log(t0) - math.log(1 - t0)
# P(x > logit_t0) with x ~ Logistic(loc=log_alpha, scale=temperature)
p_open = torch.sigmoid((self.log_alpha - logit_t0) / self.temperature)
return p_open
class HierarchicalGraphMixer(nn.Module): class HierarchicalGraphMixer(nn.Module):
""" """
分层图混合器,同时考虑宏观通道关系和微观 Patch 级别注意力。 使用 Hard-Concrete 边门控的分层图混合器:
输入 z : 形状为 [B, C, N, D] 的张量 - Level 1: 非归一化、可阈值、可为空的通道图
输出 z_out : 形状同输入 - Level 2: 仅在被选中的边上做 Patch 级别交叉注意力
输入: z [B, C, N, D]
输出: z_out 同形状
""" """
def __init__(self, n_channel: int, dim: int, k: int = 5, tau: float = 0.2): def __init__(
self,
n_channel: int,
dim: int,
max_degree: int = None, # 可选:限制每行最多边数
thr: float = 0.5, # 保留边阈值,例如 0.5/0.7
thr_min: float = None, # 动态阈值起点,不传则用 thr
thr_max: float = None, # 动态阈值终点,不传则用 thr
thr_steps: int = 0, # 从 thr_min -> thr_max 的步数,>0 时启用动态调度
thr_schedule: str = "linear", # "linear" | "cosine" | "exp"
temperature: float = 2./3.,
tau_attn: float = 1.0, # Patch attention 温度(可选)
symmetric: bool = True, # 是否对称化通道图
degree_rescale: str = "none", # "none" | "count" | "count-sqrt" | "sum"
init_log_alpha: float = -2.0
):
super().__init__() super().__init__()
self.k = k self.C = n_channel
self.tau = tau self.dim = dim
self.max_degree = max_degree
self.thr = thr
self.tau_attn = tau_attn
self.symmetric = symmetric
self.degree_rescale = degree_rescale
self.thr_min = thr if (thr_min is None) else float(thr_min)
self.thr_max = thr if (thr_max is None) else float(thr_max)
self.thr_steps = int(thr_steps) if thr_steps is not None else 0
self.thr_schedule = thr_schedule
self._use_dynamic_thr = (self.thr_steps > 0) and (abs(self.thr_max - self.thr_min) > 1e-12)
# 用 buffer 记录已步进次数(不保存到权重里)
self.register_buffer("_thr_step", torch.zeros((), dtype=torch.long), persistent=False)
# Level 1: Channel Graph # Level 1: 非归一化门控
self.A = nn.Parameter(torch.zeros(n_channel, n_channel)) self.gate = HardConcreteGate(
shape=(n_channel, n_channel),
temperature=temperature,
init_log_alpha=init_log_alpha
)
# 可选 SE你原来的 se 可以用来生成样本相关的通道优先级,但这里先保留接口)
self.se = nn.Sequential( self.se = nn.Sequential(
nn.Linear(dim, dim // 4, bias=False), nn.ReLU(), nn.Linear(dim, dim // 4, bias=False), nn.SiLU(),
nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid() nn.Linear(dim // 4, 1, bias=False), nn.Sigmoid()
) )
@ -28,56 +99,133 @@ class HierarchicalGraphMixer(nn.Module):
self.out_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim)
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
def _row_sparse(self, logits: torch.Tensor) -> torch.Tensor: def _compute_thr_by_progress(self, progress: float) -> float:
"""Gumbel-Softmax based sparse attention""" # progress in [0,1]
g = -torch.empty_like(logits).exponential_().log() progress = max(0.0, min(1.0, float(progress)))
y = (logits + g) / self.tau if self.thr_schedule == "linear":
probs = F.softmax(y, dim=-1) g = progress
elif self.thr_schedule == "cosine":
# 慢起步,后期加速
import math
g = 0.5 - 0.5 * math.cos(math.pi * progress)
elif self.thr_schedule == "exp":
# 更快从 thr_min 过渡到 thr_max指数式
import math
k = 5.0
g = (math.exp(k * progress) - 1.0) / (math.exp(k) - 1.0)
else:
g = progress
return self.thr_min + (self.thr_max - self.thr_min) * g
def _maybe_update_thr(self, training):
if training and self._use_dynamic_thr:
step = int(self._thr_step.item())
progress = step / float(self.thr_steps)
self.thr = float(self._compute_thr_by_progress(progress))
self._thr_step += 1
# Ensure k doesn't exceed the dimension size def _build_sparse_neighbors(self, z_gate):
k_actual = min(self.k, probs.size(-1)) """
if k_actual <= 0: 基于 z_gate 构造每行的邻接列表按阈值与可选top-k
return torch.zeros_like(probs) 返回:
- idx_list: 长度C的list每项是LongTensor[idx_j]
- w_list: 长度C的list每项是FloatTensor[w_j](非归一化)
"""
C = z_gate.size(0)
# 去对角
z_gate = z_gate.clone()
z_gate.fill_diagonal_(0.0)
topk_val, _ = torch.topk(probs, k_actual, dim=-1) if self.symmetric:
thr = topk_val[..., -1].unsqueeze(-1) z_gate = 0.5 * (z_gate + z_gate.t())
sparse = torch.where(probs >= thr, probs, torch.zeros_like(probs)) z_gate.fill_diagonal_(0.0)
return sparse.detach() + probs - probs.detach()
def forward(self, z): idx_list, w_list = [], []
# z 的形状: [B, C, N, D] for i in range(C):
row = z_gate[i] # [C]
# 阈值筛选
mask = row > self.thr
if mask.any():
vals = row[mask]
idxs = torch.nonzero(mask, as_tuple=False).squeeze(-1)
# 可选最多度数限制
if (self.max_degree is not None) and (idxs.numel() > self.max_degree):
topk = torch.topk(vals, k=self.max_degree, dim=0)
vals = topk.values
idxs = idxs[topk.indices]
else:
idxs = torch.empty((0,), dtype=torch.long, device=row.device)
vals = torch.empty((0,), dtype=row.dtype, device=row.device)
idx_list.append(idxs)
w_list.append(vals)
return idx_list, w_list
def _degree_rescale(self, ctx, w_sel):
"""
非归一化聚合的稳定性处理。可选对聚合值做degree归一化以稳定数值。
ctx: [B, k, N, D]
w_sel: [k]
"""
if self.degree_rescale == "none":
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
elif self.degree_rescale == "count":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / float(k)
elif self.degree_rescale == "count-sqrt":
k = max(1, w_sel.numel())
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / math.sqrt(k)
elif self.degree_rescale == "sum":
s = float(w_sel.sum().clamp(min=1e-6))
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1) / s
else:
return (ctx * w_sel.view(1, -1, 1, 1)).sum(dim=1)
def l0_loss(self, lam: float = 1e-4):
"""
期望L0正则鼓励稀疏邻接可调强度
"""
return lam * self.gate.expected_l0().sum()
def forward(self, z, is_training):
self._maybe_update_thr(training=is_training)
# z: [B, C, N, D]
B, C, N, D = z.shape B, C, N, D = z.shape
assert C == self.C and D == self.dim
# --- Level 1: 计算宏观权重 --- # Level 1: 采样非归一化门 z_gate ∈ [0,1]
A_sparse = self._row_sparse(self.A) # 通道连接稀疏图 A_sparse: [C, C] z_gate = self.gate.sample(training=is_training) # [C, C]
# --- Level 2: 跨通道 Patch 交互 --- # 构建稀疏邻居(阈值 + 可选 top-k
idx_list, w_list = self._build_sparse_neighbors(z_gate)
# Level 2: 仅对被保留的边做跨通道 Patch 交互
out_z = torch.zeros_like(z) out_z = torch.zeros_like(z)
for i in range(C): # 遍历每个目标通道 i
for i in range(C):
target_z = z[:, i, :, :] # [B, N, D] target_z = z[:, i, :, :] # [B, N, D]
idx = idx_list[i]
if idx.numel() == 0:
# 空邻域:允许“没有相关通道”,仅残差/归一化
out_z[:, i, :, :] = self.norm(target_z)
continue
# 准备聚合来自其他通道的 patch 级别上下文 w_sel = w_list[i] # [k], 非归一化权重,范围[0,1]
aggregated_context = torch.zeros_like(target_z) k_i = idx.numel()
for j in range(C): # 遍历每个源通道 j source_z = z[:, idx, :, :] # [B, k, N, D]
if A_sparse[i, j] != 0:
source_z = z[:, j, :, :] # [B, N, D]
# --- 执行交叉注意力 --- Q = self.q_proj(target_z) # [B, N, D]
Q = self.q_proj(target_z) # Query 来自目标通道 i K = self.k_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
K = self.k_proj(source_z) # Key 来自源通道 j V = self.v_proj(source_z.reshape(B * k_i, N, D)).reshape(B, k_i, N, D)
V = self.v_proj(source_z) # Value 来自源通道 j
attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(D) # 跨通道 patch 注意力
attn_probs = F.softmax(attn_scores, dim=-1) # [B, N, N] attn_scores = torch.einsum('bnd,bkmd->bknm', Q, K) / math.sqrt(D)
if self.tau_attn != 1.0:
attn_scores = attn_scores / self.tau_attn
attn_probs = F.softmax(attn_scores, dim=-1) # [B, k, N, N]
context = torch.einsum('bknm,bkmd->bknd', attn_probs, V) # [B, k, N, D]
context = torch.bmm(attn_probs, V) # [B, N, D], 从 j 聚合到 i 的上下文 # 非归一化通道权重聚合 + 可选度归一化(仅数值稳定,不改变“非归一化”的语义)
aggregated_context = self._degree_rescale(context, w_sel) # [B, N, D]
# 加权上下文
weighted_context = A_sparse[i, j] * context
aggregated_context = aggregated_context + weighted_context
# 将聚合后的上下文通过输出层,并与原始目标表示相加(残差连接)
out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context)) out_z[:, i, :, :] = self.norm(target_z + self.out_proj(aggregated_context))
return out_z return out_z

69
layers/MambaSeries.py Normal file
View File

@ -0,0 +1,69 @@
import torch
import torch.nn as nn
from mamba_ssm import Mamba2
class Mamba2Encoder(nn.Module):
"""
使用 Mamba2 对 patch 维度进行序列建模:
输入: [bs, nvars, patch_num, patch_len]
映射: patch_len -> d_model
建模: 在 patch_num 维度上用 Mamba2可堆叠多层层间加残差
输出: [bs, nvars, d_model] (仅返回 Mamba 输出的最后一个时间步)
"""
def __init__(
self,
c_in,
patch_num,
patch_len,
d_model=128,
# Mamba2 超参
d_state=64,
d_conv=4,
expand=2,
headdim=64,
# 堆叠层数
n_layers=1,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
self.d_model = d_model
self.n_layers = n_layers
# 将 patch_len 投影到 d_model
self.W_P = nn.Linear(patch_len, d_model) # 映射 patch_len -> d_model
# 堆叠 n_layers 层 Mamba2
self.mambas = nn.ModuleList([
Mamba2(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=headdim,
)
for _ in range(n_layers)
])
def forward(self, x):
# x: [bs, nvars, patch_num, patch_len]
bs, n_vars, patch_num, patch_len = x.shape # bs, n_vars, patch_num, patch_len
# 1) 线性映射: patch_len -> d_model
x = self.W_P(x) # x: [bs, nvars, patch_num, d_model]
# 2) 合并 batch 与通道维度,作为 Mamba 的 batch
u = x.reshape(bs * n_vars, patch_num, self.d_model) # u: [bs*nvars, patch_num, d_model]
# 3) 通过 n_layers 层 Mamba2 进行建模(在 patch_num 维度上),并加残差连接
for m in self.mambas:
u = u + m(u) # 残差连接,形状保持 [bs*nvars, patch_num, d_model]
# 4) 仅取最后一个时间步
y_last = u[:, -1, :] # y_last: [bs*nvars, d_model]
# 5) 还原回 (bs, nvars, d_model)
y_last = y_last.view(bs, n_vars, self.d_model) # y_last: [bs, nvars, d_model]
return y_last # [bs, nvars, d_model]

View File

@ -1,11 +1,15 @@
""" """
SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head SeasonPatch = PatchTST (CI) + ChannelGraphMixer + Linear prediction head
Adapted for Time-Series-Library-main style 支持两种编码器:
- Transformer 编码器路径PatchTST + GraphMixer + Head
- Mamba2 编码器路径Mamba2Encoder不使用mixer直接用最后得到的 d_model 走 Head
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from layers.TSTEncoder import TSTiEncoder from layers.TSTEncoder import TSTiEncoder
from layers.GraphMixer import HierarchicalGraphMixer from layers.GraphMixer import HierarchicalGraphMixer
from layers.MambaSeries import Mamba2Encoder
class SeasonPatch(nn.Module): class SeasonPatch(nn.Module):
def __init__(self, def __init__(self,
@ -15,53 +19,118 @@ class SeasonPatch(nn.Module):
patch_len: int, patch_len: int,
stride: int, stride: int,
k_graph: int = 8, k_graph: int = 8,
encoder_type: str = "Transformer",
d_model: int = 128, d_model: int = 128,
n_layers: int = 3, n_layers: int = 3,
n_heads: int = 16): n_heads: int = 16,
# Mamba2 相关可选超参
d_state: int = 64,
d_conv: int = 4,
expand: int = 2,
headdim: int = 64,
# Mixergraph 可选超参数
thr_graph: float = 0.5,
thr_graph_min: float = None,
thr_graph_max: float = None,
thr_graph_steps: int = 0,
thr_graph_schedule: str = "linear",
symmetric_graph: bool = True,
degree_rescale: str = "count-sqrt", # "none" | "count" | "count-sqrt" | "sum"
gate_temperature: float = 2./3.,
tau_attn: float = 1.0,
l0_lambda: float = 1e-4):
super().__init__() super().__init__()
# ===== 新增:保存 l0_lambda防止 reg_loss 访问报错 =====
self.l0_lambda = l0_lambda
# Store patch parameters # Store patch parameters
self.patch_len = patch_len self.patch_len = patch_len # patch 长度
self.stride = stride self.stride = stride # patch 步幅
# Calculate patch number # Calculate patch number
patch_num = (seq_len - patch_len) // stride + 1 patch_num = (seq_len - patch_len) // stride + 1 # patch_num: int
# PatchTST encoder (channel independent) self.encoder_type = encoder_type
# 只初始化需要的encoder
if encoder_type == "Transformer":
# Transformer (PatchTST) 编码器channel independent
self.encoder = TSTiEncoder( self.encoder = TSTiEncoder(
c_in=c_in, c_in=c_in, patch_num=patch_num, patch_len=patch_len,
patch_num=patch_num, d_model=d_model, n_layers=n_layers, n_heads=n_heads
patch_len=patch_len,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads
) )
# 集成新 HierarchicalGraphMixer非归一化
self.mixer = HierarchicalGraphMixer(
n_channel=c_in,
dim=d_model,
max_degree=k_graph,
thr=thr_graph,
thr_min=thr_graph_min,
thr_max=thr_graph_max,
thr_steps=thr_graph_steps,
thr_schedule=thr_graph_schedule,
temperature=gate_temperature,
tau_attn=tau_attn,
symmetric=symmetric_graph,
degree_rescale=degree_rescale
)
# Prediction headTransformer 路径用到,输入维度为 patch_num * d_model
self.head = nn.Sequential(
nn.Linear(patch_num * d_model, patch_num * d_model),
nn.SiLU(),
nn.Linear(patch_num * d_model, pred_len)
)
elif encoder_type == "Mamba2":
# Mamba2 编码器channel independent返回 [B, C, d_model]
self.encoder = Mamba2Encoder(
c_in=c_in, patch_num=patch_num, patch_len=patch_len,
d_model=d_model, d_state=d_state, d_conv=d_conv,
expand=expand, headdim=headdim, n_layers=n_layers
)
# Prediction headMamba2 路径用到,输入维度为 d_model
self.head = nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, pred_len)
)
else:
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
# Cross-channel mixer def forward(self, x, training):
self.mixer = HierarchicalGraphMixer(c_in, dim=d_model, k=k_graph)
# Prediction head
self.head = nn.Linear(patch_num * d_model, pred_len)
def forward(self, x):
# x: [B, L, C] # x: [B, L, C]
x = x.permute(0, 2, 1) # [B, C, L] x = x.permute(0, 2, 1) # x: [B, C, L]
# Patch the input # Patch the input
x_patch = x.unfold(-1, self.patch_len, self.stride) # [B, C, patch_num, patch_len] x_patch = x.unfold(-1, self.patch_len, self.stride) # x_patch: [B, C, patch_num, patch_len]
# Encode patches if self.encoder_type == "Transformer":
z = self.encoder(x_patch) # [B, C, d_model, patch_num] # Encode patches (PatchTST)
z = self.encoder(x_patch) # z: [B, C, d_model, patch_num]
# z: [B, C, d_model, patch_num] [B, C, patch_num, d_model] # [B, C, d_model, patch_num] -> [B, C, patch_num, d_model]
B, C, D, N = z.shape B, C, D, N = z.shape # B: batch, C: channels, D: d_model, N: patch_num
z = z.permute(0, 1, 3, 2) # [B, C, patch_num, d_model] z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing # Cross-channel mixing
z_mix = self.mixer(z) # [B, C, patch_num, d_model] z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict # Flatten and predict
z_mix = z_mix.view(B, C, N * D) # [B, C, patch_num * d_model] z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]
y_pred = self.head(z_mix) # [B, C, pred_len] y_pred = self.head(z_mix) # y_pred: [B, C, pred_len]
return y_pred elif self.encoder_type == "Mamba2":
# 使用 Mamba2 编码器(不使用 mixer
z_last = self.encoder(x_patch) # z_last: [B, C, d_model](仅最后一个时间步)
y_pred = self.head(z_last) # y_pred: [B, C, pred_len]
return y_pred # [B, C, pred_len]
def reg_loss(self):
"""
可选:把 L0 正则暴露出去训练时加到总loss。
"""
if self.encoder_type == "Transformer" and hasattr(self, "mixer"):
return self.mixer.l0_loss(self.l0_lambda)
return torch.tensor(0.0, device=self.head[0].weight.device)

View File

@ -17,25 +17,30 @@ class DSAttention(nn.Module):
self.output_attention = output_attention self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout) self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask: (B, S) bool, True=valid, False=pad可选忽略或由上层应用
"""
B, L, H, E = queries.shape B, L, H, E = queries.shape
_, S, _, D = values.shape _, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E) scale = self.scale or 1. / sqrt(E)
tau = 1.0 if tau is None else tau.unsqueeze( tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
1).unsqueeze(1) # B x 1 x 1 x 1 delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S
delta = 0.0 if delta is None else delta.unsqueeze(
1).unsqueeze(1) # B x 1 x 1 x S
# De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta # (B,H,L,S)
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
if self.mask_flag: if self.mask_flag:
if attn_mask is None: if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device) attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf) scores.masked_fill_(attn_mask.mask, -np.inf)
# 可选基于key_padding_mask的无效键屏蔽不改变原行为默认None
if key_padding_mask is not None:
# key_padding_mask: True 表示有效False为padding
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
scores = scores.masked_fill(invalid_k, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1)) A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values) V = torch.einsum("bhls,bshd->blhd", A, values)
@ -46,6 +51,12 @@ class DSAttention(nn.Module):
class FullAttention(nn.Module): class FullAttention(nn.Module):
"""
修正点:
- 新增 key_padding_mask 支持用于屏蔽批内右侧pad的键向量与DC变长对齐
- key_padding_mask 约定shape=(B, S)True=有效False=padding
- 其余行为与原实现保持一致
"""
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__() super(FullAttention, self).__init__()
self.scale = scale self.scale = scale
@ -53,21 +64,33 @@ class FullAttention(nn.Module):
self.output_attention = output_attention self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout) self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
queries: (B, L, H, E)
keys: (B, S, H, E)
values: (B, S, H, D)
attn_mask: TriangularCausalMask 或 None
key_padding_mask: (B, S) boolTrue=有效False=padding可选
"""
B, L, H, E = queries.shape B, L, H, E = queries.shape
_, S, _, D = values.shape _, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E) scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys) scores = torch.einsum("blhe,bshe->bhls", queries, keys) # (B,H,L,S)
if self.mask_flag: if self.mask_flag:
if attn_mask is None: if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device) attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf) scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # 基于key_padding_mask屏蔽无效键padding位置不参与注意力
V = torch.einsum("bhls,bshd->blhd", A, values) if key_padding_mask is not None:
# key_padding_mask: True=有效False=padding
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
scores = scores.masked_fill(invalid_k, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # (B,H,L,S)
V = torch.einsum("bhls,bshd->blhd", A, values) # (B,L,H,D)
if self.output_attention: if self.output_attention:
return V.contiguous(), A return V.contiguous(), A
@ -85,100 +108,86 @@ class ProbAttention(nn.Module):
self.dropout = nn.Dropout(attention_dropout) self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D] # Q [B, H, L_q, D], K [B, H, L_k, D]
B, H, L_K, E = K.shape B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape _, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor*ln(L_k))*L_q index_sample = torch.randint(L_K, (L_Q, sample_k), device=Q.device)
index_sample = torch.randint(L_K, (L_Q, sample_k)) K_sample = K_expand[:, :, torch.arange(L_Q, device=Q.device).unsqueeze(1), index_sample, :]
K_sample = K_expand[:, :, torch.arange( Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # (B,H,L_Q,sample_k)
L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(
Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) # (B,H,L_Q)
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) M_top = M.topk(n_top, sorted=False)[1] # indices
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[torch.arange(B)[:, None, None], Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None], torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q) M_top, :] # (B,H,n_top,D)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # (B,H,n_top,L_K)
return Q_K, M_top return Q_K, M_top
def _get_initial_context(self, V, L_Q): def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape B, H, L_V, D = V.shape
if not self.mask_flag: if not self.mask_flag:
# V_sum = V.sum(dim=-2) V_mean = V.mean(dim=-2) # (B,H,D)
V_sum = V.mean(dim=-2) context = V_mean.unsqueeze(-2).expand(B, H, L_Q, D).clone()
contex = V_sum.unsqueeze(-2).expand(B, H, else:
L_Q, V_sum.shape[-1]).clone() assert L_Q == L_V
else: # use mask context = V.cumsum(dim=-2)
# requires that L_Q == L_V, i.e. for self-attention only return context
assert (L_Q == L_V)
contex = V.cumsum(dim=-2)
return contex
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape B, H, L_V, D = V.shape
if self.mask_flag: if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf) scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[torch.arange(B)[:, None, None], context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None], torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in) index, :] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention: if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) / attns = (torch.ones([B, H, L_V, L_V], device=attn.device, dtype=attn.dtype) / L_V)
L_V).type_as(attn).to(attn.device) attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
attns[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], index, :] = attn
return context_in, attns return context_in, attns
else: else:
return context_in, None return context_in, None
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask 目前未集成到 ProbAttention如需可在scores处对无效键置 -inf
"""
B, L_Q, H, D = queries.shape B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape _, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1) queries = queries.transpose(2, 1) # (B,H,L_Q,D)
keys = keys.transpose(2, 1) keys = keys.transpose(2, 1) # (B,H,L_K,D)
values = values.transpose(2, 1) values = values.transpose(2, 1) # (B,H,L_K,D)
U_part = self.factor * \ U_part = self.factor * int(np.ceil(np.log(L_K)))
np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) u = self.factor * int(np.ceil(np.log(L_Q)))
u = self.factor * \
np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
U_part = U_part if U_part < L_K else L_K U_part = min(U_part, L_K)
u = u if u < L_Q else L_Q u = min(u, L_Q)
scores_top, index = self._prob_QK( scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
queries, keys, sample_k=U_part, n_top=u)
# add scale factor
scale = self.scale or 1. / sqrt(D) scale = self.scale or 1. / sqrt(D)
if scale is not None:
scores_top = scores_top * scale scores_top = scores_top * scale
# get the context
context = self._get_initial_context(values, L_Q) context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask)
return context.contiguous(), attn return context.contiguous(), attn
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None, """
d_values=None): 修正点:
- forward 新增 key_padding_mask 参数,并向 inner_attention 透传
- 保持与旧调用兼容不传时默认None
"""
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__() super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads) d_keys = d_keys or (d_model // n_heads)
@ -191,7 +200,10 @@ class AttentionLayer(nn.Module):
self.out_projection = nn.Linear(d_values * n_heads, d_model) self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
"""
key_padding_mask: (B, S) bool, True=有效False=padding
"""
B, L, _ = queries.shape B, L, _ = queries.shape
_, S, _ = keys.shape _, S, _ = keys.shape
H = self.n_heads H = self.n_heads
@ -206,10 +218,10 @@ class AttentionLayer(nn.Module):
values, values,
attn_mask, attn_mask,
tau=tau, tau=tau,
delta=delta delta=delta,
key_padding_mask=key_padding_mask,
) )
out = out.view(B, L, -1) out = out.view(B, L, -1)
return self.out_projection(out), attn return self.out_projection(out), attn
@ -232,12 +244,11 @@ class ReformerLayer(nn.Module):
if N % (self.bucket_size * 2) == 0: if N % (self.bucket_size * 2) == 0:
return queries return queries
else: else:
# fill the time series
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
def forward(self, queries, keys, values, attn_mask, tau, delta): def forward(self, queries, keys, values, attn_mask, tau, delta, key_padding_mask=None):
# in Reformer: defalut queries=keys # queries=keys in Reformer
B, N, C = queries.shape B, N, C = queries.shape
queries = self.attn(self.fit_length(queries))[:, :N, :] queries = self.attn(self.fit_length(queries))[:, :N, :]
return queries, None return queries, None
@ -275,23 +286,23 @@ class TwoStageAttentionLayer(nn.Module):
nn.GELU(), nn.GELU(),
nn.Linear(d_ff, d_model)) nn.Linear(d_ff, d_model))
def forward(self, x, attn_mask=None, tau=None, delta=None): def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
# Cross Time Stage: Directly apply MSA to each dimension # Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0] batch = x.shape[0]
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
time_enc, attn = self.time_attention( time_enc, attn = self.time_attention(
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None time_in, time_in, time_in, attn_mask=None, tau=None, delta=None, key_padding_mask=key_padding_mask
) )
dim_in = time_in + self.dropout(time_enc) dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in) dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in)) dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in) dim_in = self.norm2(dim_in)
# Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection # Cross Dimension Stage
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch) dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch) batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None) dim_buffer, _ = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None) dim_receive, _ = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
dim_enc = dim_send + self.dropout(dim_receive) dim_enc = dim_send + self.dropout(dim_receive)
dim_enc = self.norm3(dim_enc) dim_enc = self.norm3(dim_enc)
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))

528
models/DC_PatchTST.py Normal file
View File

@ -0,0 +1,528 @@
import torch
from torch import nn
import torch.nn.functional as F
from layers.SelfAttention_Family import FullAttention, AttentionLayer
# 需要 Mamba2 作为外层编码器
from mamba_ssm.modules.mamba2 import Mamba2
# -------------------- Routing余弦路由和论文一致 --------------------
class RoutingModule(nn.Module):
def __init__(self, d_model):
super().__init__()
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
with torch.no_grad():
nn.init.eye_(self.q_proj.weight)
nn.init.eye_(self.k_proj.weight)
self.q_proj.weight._no_reinit = True
self.k_proj.weight._no_reinit = True
def forward(self, x, mask=None):
"""
x: (B, L, D)
mask: (B, L) bool, True=有效
返回:
boundary_prob: (B, L, 2)
boundary_mask: (B, L) bool
selected_probs: (B, L, 1)
"""
B, L, D = x.shape
q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D)
k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D)
cos_sim = (q * k).sum(dim=-1) # (B, L-1)
p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1)
p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界
if mask is not None:
p = p * mask.float()
p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0])
boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2)
selected_idx = boundary_prob.argmax(dim=-1)
boundary_mask = (selected_idx == 1)
if mask is not None:
boundary_mask = boundary_mask & mask
selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1)
return boundary_prob, boundary_mask, selected_probs
# -------------------- 选择并右侧零pad不丢弃、不重复填充 --------------------
def select_and_right_pad(x, boundary_mask):
"""
内存优化版本减少临时tensor创建
x: (B, L, D), boundary_mask: (B, L) bool
返回:
x_pad: (B, T_max, D)
key_padding_mask: (B, T_max) bool, True=有效
lengths: (B,)
"""
B, L, D = x.shape
device = x.device
lengths = boundary_mask.sum(dim=1) # (B,)
T_max = int(lengths.max().item()) if lengths.max() > 0 else 1
x_pad = x.new_zeros(B, T_max, D)
key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
# 预创建默认索引tensor避免重复创建
default_idx = torch.tensor([0], device=device)
for b in range(B):
mask_b = boundary_mask[b]
if mask_b.any():
idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero
t = idx.numel()
x_pad[b, :t] = x[b, idx]
key_padding_mask[b, :t] = True
else:
# 使用预创建的tensor
x_pad[b, 0] = x[b, default_idx]
key_padding_mask[b, 0] = True
return x_pad, key_padding_mask, lengths
# -------------------- Mamba2 堆叠(外层编码器) --------------------
class Mamba2Encoder(nn.Module):
def __init__(self, d_model, depth=4, dropout=0.0):
super().__init__()
self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = self.dropout(x)
return x
# -------------------- 两层Encoder + DC变长不丢信息以比率约束压缩 --------------------
class DCEmbedding2StageVarLen(nn.Module):
"""
- Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1
- Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择
输出:
enc_out: (B*nvars, T_max, D1)
key_padding_mask: (B*nvars, T_max)
n_vars: int
aux: dict含两层ratio loss与边界信息
"""
def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0,
target_ratio0=0.25, target_ratio1=0.5):
super().__init__()
assert d_model_out >= d_model_stage0, "要求 D0 <= D1"
self.d0 = d_model_stage0
self.d1 = d_model_out
# 标量 -> D0
self.input_proj = nn.Linear(1, self.d0)
# Stage 0
self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout)
self.router0 = RoutingModule(self.d0)
delta = self.d1 - self.d0
self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None
self.target_ratio0 = target_ratio0
# Stage 1
self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout)
self.router1 = RoutingModule(self.d1)
self.target_ratio1 = target_ratio1
def _expand_width(self, x):
if self.pad_vec is None:
return x
B, L, _ = x.shape
return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1)
@staticmethod
def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor:
eps = 1e-6
F_act = boundary_mask.float().mean(dim=1) # (B,)
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
N = 1.0 / max(target_ratio, eps)
loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob))
return loss.mean()
def forward(self, x):
"""
x: (B, nvars, L)
内存优化版本及时删除中间tensor
"""
B, nvars, L = x.shape
x = x.reshape(B * nvars, L, 1)
x = self.input_proj(x) # (B*nvars, L, D0)
# Stage 0
h0 = self.enc0(x) # (B*nvars, L, D0)
p0, bm0, _ = self.router0(h0)
h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0)
# 及时删除不需要的tensor
del h0
# h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1)
# Stage 1
#h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1)
#p1, bm1, _ = self.router1(h1)
#bm1 = bm1 & mask0
#h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1)
# 及时删除中间tensor
#del h1, h0_sel
# 计算ratio loss时使用detach避免保存计算图
ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0)
# ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1)
# 简化aux字典只保存必要信息
aux = {
"stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()},
# "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()},
"ratio_loss0": ratio_loss0,
# "ratio_loss1": ratio_loss1,
}
return h0_sel, mask0, nvars, aux
# -------------------- Encoder/EncoderLayer带 key_padding_mask 透传) --------------------
class EncoderLayerWithMask(nn.Module):
"""
与原EncoderLayer结构一致但 forward 增加 key_padding_mask并传入 AttentionLayer。
FFN 用简单的 MLP与常规Transformer一致
"""
def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"):
super().__init__()
self.attention = attention
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
if activation == "relu":
act = nn.ReLU()
elif activation == "gelu":
act = nn.GELU()
else:
raise ValueError(f"Unsupported activation: {activation}")
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
act,
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
# Multi-head attention with key padding mask
attn_out, attn = self.attention(
x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask
)
x = x + self.dropout(attn_out)
x = self.norm1(x)
# FFN
y = self.ffn(x)
x = x + self.dropout(y)
x = self.norm2(x)
return x, attn
class EncoderWithMask(nn.Module):
"""
与原Encoder类似但 forward 支持 key_padding_mask并传递给每一层的注意力。
"""
def __init__(self, attn_layers, norm_layer=None):
super().__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.norm = norm_layer
def forward(self, x, attn_mask=None, key_padding_mask=None):
attns = []
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
# -------------------- 门控注意力聚合 + 任务头不依赖token数保留信息 --------------------
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
logits: (..., T)
mask: (..., T) bool, True=有效
"""
neg_inf = torch.finfo(logits.dtype).min
logits = logits.masked_fill(~mask, neg_inf)
return torch.softmax(logits, dim=dim)
class GatedAttnAggregator(nn.Module):
"""
门控注意力聚合器(可学习查询 + mask softmax + 值端门控)
输入: x: (B*, T, D), mask: (B*, T) bool
输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置)
"""
def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.R = num_slots
self.d_att = d_att or d_model
# 可学习查询R个
self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5))
# 线性投影
self.key_proj = nn.Linear(d_model, self.d_att)
self.val_proj = nn.Linear(d_model, d_model)
# 值端门控逐token标量门
self.gate = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
self.dropout = nn.Dropout(dropout)
def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor:
"""
x_bt_t_d: (B*, T, D)
mask_bt: (B*, T) bool
return: slots (B*, R, D)
"""
BStar, T, D = x_bt_t_d.shape
K = self.key_proj(x_bt_t_d) # (B*, T, d_att)
V = self.val_proj(x_bt_t_d) # (B*, T, D)
g = self.gate(x_bt_t_d) # (B*, T, 1)
Vg = V * g # 门控后的值
Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att)
logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T)
attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T)
attn = masked_softmax(logits, attn_mask, dim=-1)
attn = self.dropout(attn)
slots = torch.matmul(attn, Vg) # (B*, R, D)
return slots
class AttnPoolHeadForecast(nn.Module):
"""
预测任务头:门控注意力聚合到 R 个slots再映射到 target_windowpred_len
输出:(B, pred_len, nvars)
"""
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.proj = nn.Sequential(
nn.LayerNorm(num_slots * d_model),
nn.Linear(num_slots * d_model, target_window),
)
self.dropout = nn.Dropout(dropout)
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
out = self.proj(self.dropout(slots)) # (B, nvars, pred_len)
return out.permute(0, 2, 1) # (B, pred_len, nvars)
class AttnPoolHeadSeq(nn.Module):
"""
序列重建头:门控注意力聚合后映射到 seq_len
输出:(B, seq_len, nvars)
"""
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.proj = nn.Sequential(
nn.LayerNorm(num_slots * d_model),
nn.Linear(num_slots * d_model, target_window),
)
self.dropout = nn.Dropout(dropout)
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
out = self.proj(self.dropout(slots)) # (B, nvars, seq_len)
return out.permute(0, 2, 1) # (B, seq_len, nvars)
class AttnPoolHeadCls(nn.Module):
"""
分类头:每变量先门控注意力聚合到 R 个slots拼接所有变量后线性分类。
输出:(B, num_class)
"""
def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1):
super().__init__()
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.proj = nn.Sequential(
nn.LayerNorm(n_vars * num_slots * d_model),
nn.Linear(n_vars * num_slots * d_model, num_class),
)
self.n_vars = n_vars
self.num_slots = num_slots
self.d_model = d_model
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D)
flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D)
return self.proj(flat)
# -------------------- 主模型两层DC(比率控制) + 带mask的Encoder + 门控聚合头 --------------------
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims)
class Model(nn.Module):
"""
PatchTST with DC and masked attention + gated heads:
- 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding
- DC 使用 ratio losstarget_ratio0/1控制压缩强度随层级加深序列变短d_model 变大D0->D1
- 注意力传入 key_padding_mask 屏蔽pad
- 头部使用门控注意力聚合不依赖token数信息保留更充分
"""
def __init__(
self, configs,
d_model_stage0=None, # D0默认= d_model // 2
depth_enc0=1, depth_enc1=1,
target_ratio0=0.25, # 约等于 1/N0
target_ratio1=0.5, # 约等于 1/N1
agg_slots=4, # 门控聚合的slot数
):
super().__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
# DC 嵌入
D1 = configs.d_model
D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2)
assert D1 >= D0, "要求 D0 <= D1"
self.dc_embedding = DCEmbedding2StageVarLen(
d_model_out=D1,
d_model_stage0=D0,
depth_enc0=depth_enc0,
depth_enc1=depth_enc1,
dropout=configs.dropout,
target_ratio0=target_ratio0,
target_ratio1=target_ratio1,
)
# 带mask的Encoder
attn_layers = [
EncoderLayerWithMask(
AttentionLayer(
FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
D1, configs.n_heads
),
d_model=D1,
d_ff=configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for _ in range(configs.e_layers)
]
self.encoder = EncoderWithMask(
attn_layers,
norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2))
)
# 门控聚合头与token数无关
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout)
elif self.task_name in ('imputation', 'anomaly_detection'):
self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout)
elif self.task_name == 'classification':
self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout)
# --------- 归一化/反归一化 ---------
def _pre_norm(self, x):
means = x.mean(1, keepdim=True).detach()
x = x - means
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
x = x / stdev
return x, means, stdev
def _denorm(self, y, means, stdev, length):
return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \
(means[:, 0, :].unsqueeze(1).repeat(1, length, 1))
# --------- DC + Transformer Encoder携带 key_padding_mask ----------
def _embed_and_encode(self, x_enc):
"""
x_enc: (B, L, C)
返回:
enc_out: (B*nvars, T_max, D1)
n_vars: int
key_padding_mask: (B*nvars, T_max)
aux: dict
"""
B, L, C = x_enc.shape
x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L)
enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars)
enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask)
return enc_out, n_vars, key_padding_mask, B, aux
# --------- 各任务前向 ---------
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
x_enc, means, stdev = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.pred_len)
return dec_out, aux
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x = x_enc - means
x = x.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x = x / stdev
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
return dec_out, aux
def anomaly_detection(self, x_enc):
x_enc, means, stdev = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
return dec_out, aux
def classification(self, x_enc, x_mark_enc):
x_enc, _, _ = self._pre_norm(x_enc)
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class)
return logits, aux
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses
if self.task_name == 'imputation':
dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out, aux
if self.task_name == 'anomaly_detection':
dec_out, aux = self.anomaly_detection(x_enc)
return dec_out, aux
if self.task_name == 'classification':
logits, aux = self.classification(x_enc, x_mark_enc)
return logits, aux
return None, None

339
models/DC_hnet.py Normal file
View File

@ -0,0 +1,339 @@
from dataclasses import dataclass
from typing import Optional, Literal, List
import torch
import torch.nn as nn
import torch.nn.functional as F
# 来自你的代码库(可直接使用)
from hnet.modules.dc import RoutingModule, ChunkLayer
from hnet.modules.isotropic import Isotropic
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
# -------------------- 辅助 --------------------
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
"""创建简化的Isotropic编码器"""
factory_kwargs = {"device": device, "dtype": dtype}
# 创建HNetConfig确保list字段有足够的元素
config = HNetConfig(
arch_layout=[f"{arch}{height}"],
d_model=[d_model],
d_intermediate=[d_model * 2],
ssm_cfg=SSMConfig(
d_conv=4,
expand=2,
d_state=128,
chunk_size=256
),
attn_cfg=AttnConfig(
num_heads=[8], # 确保有至少一个元素
rotary_emb_dim=[0], # 确保有至少一个元素
window_size=[-1] # 确保有至少一个元素
)
)
return Isotropic(
config=config,
pos_idx=0,
stage_idx=0,
**factory_kwargs
)
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
F_act = boundary_mask.float().mean(dim=1) # (B,)
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
N = float(target_N)
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
return loss.mean()
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
s = (x * mask_f).sum(dim=1) # (B, D)
denom = mask_f.sum(dim=1).clamp_min(1.0)
return s / denom
# -------------------- 多层Encoder金字塔每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
class PyramidEncoders_NoDechunk(nn.Module):
"""
层级结构(仅编码器逐层压缩;主网络只在最终一层):
输入 x0: (B, L0, 1)
- 线性升维 -> D0
For s = 0..S-1:
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
维度扩展 D_s -> D_{s+1}(拼接共享向量)
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
"""
def __init__(
self,
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
encoder_cfg_per_stage: List[dict], # S个编码器配置必须 arch='m'/'M'
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
fusion_dropout: float = 0.1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
assert len(d_models) >= 1
S = len(d_models) - 1
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
for i in range(S):
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
self.S = S
self.d_models = d_models
# 输入升维到 D0
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
# 每层编码器 + 路由 + 下采样 + 扩宽参数
self.encoders = nn.ModuleList()
self.routers = nn.ModuleList()
self.chunks = nn.ModuleList()
self.pad_vectors = nn.ParameterList()
for s in range(S):
self.encoders.append(
create_isotropic_encoder(
d_model=d_models[s],
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
**factory_kwargs
)
)
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
self.chunks.append(ChunkLayer())
delta = d_models[s+1] - d_models[s]
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
# 最终唯一的主网络:在 D_S & L_S 上运行
self.main_network = create_isotropic_encoder(
d_model=d_models[-1],
**{k: v for k, v in main_cfg.items() if k != "d_model"},
**factory_kwargs
)
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
self.fusion_head = nn.Sequential(
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
nn.GELU(),
nn.Dropout(fusion_dropout),
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
)
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
if pad_vec.numel() == 0:
return x
early = x.shape[:-1]
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
"""
x_scalar: (B, L) 或 (B, L, 1)
mask: (B, L) bool
返回:
fused_vec: (B, D_S)
debug: 可选
aux: 包含各层路由信息供ratio loss
"""
if x_scalar.dim() == 2:
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
B, L, _ = x_scalar.shape
device = x_scalar.device
if mask is None:
mask = torch.ones(B, L, dtype=torch.bool, device=device)
# 初始升维到 D0
x = self.input_proj(x_scalar) # (B, L0, D0)
cur_mask = mask
pooled_enc0 = None
aux_per_stage = []
seq_debug = [] if return_seq else None
# 逐层Encoder(Mamba2)->Routing->Chunk->Expand D
for s in range(self.S):
d_in = self.d_models[s]
# 细粒度编码(未压缩序列)
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
if s == 0:
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
# 路由 + 下采样(得到更短序列)
bpred = self.routers[s](h_enc, mask=cur_mask)
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
# 扩展宽度 D_s -> D_{s+1}
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
# 推进到下一层
x, cur_mask = x_next, mask_next
aux_per_stage.append({
"boundary_mask": bpred.boundary_mask,
"boundary_prob": bpred.boundary_prob,
"selected_probs": bpred.selected_probs,
})
if return_seq:
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
# 最终单一主网络在最压缩序列上
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
# 主网络池化
if cur_mask is None:
pooled_main = h_main.mean(dim=1) # (B, D_S)
else:
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
# 跨尺度融合E^0 全局池化 与 主网络池化
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
fused = self.fusion_head(fused) # (B, D_S)
aux = {"per_stage": aux_per_stage}
if return_seq:
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
else:
return fused, None, aux
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
@dataclass
class HierEncodersSingleMainConfig:
num_channels: int
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
num_classes: int
encoder_cfg_per_stage: List[dict] # S个编码器配置均为Mamba2, height≈4
main_cfg: dict # 单一主网络配置Transformer或Mamba2d_model自动用D_S
target_compression_N_per_stage: List[int]
share_channel: bool = True
fusion_across_channels: Literal["mean", "concat"] = "mean"
dropout: float = 0.1
class HierEncodersSingleMainClassifier(nn.Module):
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
super().__init__()
self.cfg = cfg
factory_kwargs = {"dtype": dtype, "device": device}
S = len(cfg.d_models) - 1
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
if cfg.share_channel:
self.channel_encoder = PyramidEncoders_NoDechunk(
d_models=cfg.d_models,
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
main_cfg=cfg.main_cfg,
**factory_kwargs,
)
else:
self.channel_encoder = nn.ModuleList([
PyramidEncoders_NoDechunk(
d_models=cfg.d_models,
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
main_cfg=cfg.main_cfg,
**factory_kwargs,
)
for _ in range(cfg.num_channels)
])
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
else cfg.d_models[-1]
self.dropout = nn.Dropout(cfg.dropout)
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
"""
x: (B, L, N) 多通道输入
mask: (B, L) 时序mask
"""
B, L, N = x.shape
assert N == self.cfg.num_channels
channel_vecs: List[torch.Tensor] = []
ratio_losses = []
seq_dbg_all = [] if return_seq else None
for c in range(N):
x_c = x[..., c] # (B, L)
if self.cfg.share_channel:
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
else:
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
# ratio loss 累加每个encoder stage一项
total_rl = 0.0
for s, aux_s in enumerate(aux["per_stage"]):
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
total_rl = total_rl + rl
ratio_losses.append(total_rl)
channel_vecs.append(vec)
if return_seq:
seq_dbg_all.append(seq_dbg)
if self.cfg.fusion_across_channels == "concat":
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
else:
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
fused = self.dropout(fused)
logits = self.head(fused)
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
if return_seq:
return logits, seq_dbg_all, aux_all
else:
return logits, None, aux_all
# -------------------- 使用示例 --------------------
if __name__ == "__main__":
"""
符合要求:
- 多层仅增加编码器数量每层Mamba2 + 动态分块),主网络只有最终一个
- 序列长度逐层缩短由DC决定通道维度 d_model 单调增大SpaceByte式共享向量拼接
- 不使用去分块dechunk跨尺度融合用 E^0 的全局池化 + 最终主网络池化
"""
B, L, N = 8, 1024, 6
num_classes = 7
d_models = [128, 256, 512] # D0 <= D1 <= D2
encoder_cfg_per_stage = [
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
]
main_cfg = dict(
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
)
target_compression_N_per_stage = [4, 4]
cfg = HierEncodersSingleMainConfig(
num_channels=N,
d_models=d_models,
num_classes=num_classes,
encoder_cfg_per_stage=encoder_cfg_per_stage,
main_cfg=main_cfg,
target_compression_N_per_stage=target_compression_N_per_stage,
share_channel=True,
fusion_across_channels="mean",
dropout=0.1,
)
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
x = torch.randn(B, L, N, device="cuda")
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
logits, _, aux = model(x, mask=mask, return_seq=False)
y = torch.randint(0, num_classes, (B,), device="cuda")
cls_loss = F.cross_entropy(logits, y)
ratio_reg = 0.03 * aux["ratio_loss"]
loss = cls_loss + ratio_reg
loss.backward()
print("logits:", logits.shape, "loss:", float(loss))

View File

@ -0,0 +1,138 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba2
class ValueEmbedding(nn.Module):
"""
对每个时间步的单通道标量做线性投影到 d_model并可选 Dropout。
不包含 temporal embedding 和 positional embedding。
"""
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.proj = nn.Linear(in_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, L, 1] -> [B, L, d_model]
return self.dropout(self.proj(x))
class ChannelMambaBlock(nn.Module):
"""
针对单个通道的两层 Mamba-2 处理块:
- 输入: [B, L, 1],先做投影到 d_model
- 两层 Mamba2且在第一层输出和第二层输出均添加残差连接
- 每层后接 LayerNorm
- 输出: [B, L, d_model]
"""
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
super().__init__()
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
# 两层 Mamba-2
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
# 每层后接的归一化
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
# x_ch: [B, L, 1]
x = self.embed(x_ch) # [B, L, d_model]
# 第一层 + 残差
y1 = self.mamba1(x) # [B, L, d_model]
y1 = self.ln1(x + y1) # 残差1 + LN
# 第二层 + 残差
y2 = self.mamba2(y1) # [B, L, d_model]
y2 = self.ln2(y1 + y2) # 残差2 + LN
return y2 # [B, L, d_model]
class Model(nn.Module):
"""
按通道独立处理的 Mamba-2 分类模型:
- 将输入的每个通道拆开,分别使用独立的两层 Mamba2含两处残差
- 每个通道得到 [B, L, d_model] 输出
- 取各通道最后时间步的表示拼接,接分类头
输入:
- x_enc: [B, L, D] 多变量时间序列
输出:
- logits: [B, num_class]
"""
def __init__(self, configs):
super().__init__()
self.task_name = getattr(configs, 'task_name', 'classification')
assert self.task_name == 'classification', "当前模型仅实现 classification 任务"
# 基本配置
self.enc_in = configs.enc_in # 通道数 D
self.d_model = configs.d_model # 每通道的模型维度
self.num_class = configs.num_class
self.dropout = getattr(configs, 'dropout', 0.1)
# Mamba-2 超参数(按需从 configs 读取)
# 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求
m2_kwargs = dict(
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
d_ssm=getattr(configs, 'd_ssm', None),
ngroups=getattr(configs, 'ngroups', 1),
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
D_has_hdim=getattr(configs, 'D_has_hdim', False),
rmsnorm=getattr(configs, 'rmsnorm', True),
norm_before_gate=getattr(configs, 'norm_before_gate', False),
dt_min=getattr(configs, 'dt_min', 0.001),
dt_max=getattr(configs, 'dt_max', 0.1),
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
bias=getattr(configs, 'bias', False),
conv_bias=getattr(configs, 'conv_bias', True),
chunk_size=getattr(configs, 'chunk_size', 256),
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
)
# 为每个通道构建独立的两层 Mamba2 处理块
self.channel_blocks = nn.ModuleList([
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
for _ in range(self.enc_in)
])
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
self.act = nn.GELU()
self.head = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(self.d_model * self.enc_in, self.num_class)
)
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
# x_enc: [B, L, D]
B, L, D = x_enc.shape
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
per_channel_last = []
for c in range(D):
# 取出单通道序列 [B, L] -> [B, L, 1]
x_ch = x_enc[:, :, c].unsqueeze(-1)
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
h_last = torch.cat(per_channel_last, dim=-1)
# 分类头
h_last = self.act(h_last)
logits = self.head(h_last) # [B, num_class]
return logits
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
return self.classification(x_enc)

203
models/vanillaMamba.py Normal file
View File

@ -0,0 +1,203 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba2
class ValueEmbedding(nn.Module):
"""
对每个时间步的单通道标量做线性投影到 d_model并可选 Dropout。
不包含 temporal embedding 和 positional embedding。
"""
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.proj = nn.Linear(in_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, L, 1] -> [B, L, d_model]
return self.dropout(self.proj(x))
class ChannelMambaBlock(nn.Module):
"""
针对单个通道的两层 Mamba-2 处理块:
- 输入: [B, L, 1],先做投影到 d_model
- 两层 Mamba2且在第一层输出和第二层输出均添加残差连接
- 每层后接 LayerNorm
- 输出: [B, L, d_model]
"""
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
super().__init__()
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
# 两层 Mamba-2
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
# 每层后接的归一化
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
# x_ch: [B, L, 1]
x = self.embed(x_ch) # [B, L, d_model]
# 第一层 + 残差
y1 = self.mamba1(x) # [B, L, d_model]
y1 = self.ln1(x + y1) # 残差1 + LN
# 第二层 + 残差
y2 = self.mamba2(y1) # [B, L, d_model]
y2 = self.ln2(y1 + y2) # 残差2 + LN
return y2 # [B, L, d_model]
class Model(nn.Module):
"""
按通道独立处理的 Mamba-2 模型,支持:
- 分类:各通道独立提取,取最后时刻拼接 -> 分类头
- 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接
注意:预测输出通道数与输入通道数严格相同(逐通道预测)。
输入:
- x_enc: [B, L, D] 多变量时间序列
- x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注)
输出:
- classification: logits [B, num_class]
- forecast: [B, pred_len, D]
"""
def __init__(self, configs):
super().__init__()
# 任务类型
self.task_name = getattr(configs, 'task_name', 'classification')
assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \
"只支持 classification / long_term_forecast / short_term_forecast"
# 基本配置
self.enc_in = configs.enc_in # 通道数 D
self.d_model = configs.d_model # 每通道的模型维度
self.num_class = getattr(configs, 'num_class', None)
self.dropout = getattr(configs, 'dropout', 0.1)
# 预测相关
self.seq_len = getattr(configs, 'seq_len', None)
self.pred_len = getattr(configs, 'pred_len', None)
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len"
# 输出通道必须与输入通道一致
self.c_out = getattr(configs, 'c_out', self.enc_in)
assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致"
# Mamba-2 超参数
m2_kwargs = dict(
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
d_ssm=getattr(configs, 'd_ssm', None),
ngroups=getattr(configs, 'ngroups', 1),
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
D_has_hdim=getattr(configs, 'D_has_hdim', False),
rmsnorm=getattr(configs, 'rmsnorm', True),
norm_before_gate=getattr(configs, 'norm_before_gate', False),
dt_min=getattr(configs, 'dt_min', 0.001),
dt_max=getattr(configs, 'dt_max', 0.1),
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
bias=getattr(configs, 'bias', False),
conv_bias=getattr(configs, 'conv_bias', True),
chunk_size=getattr(configs, 'chunk_size', 256),
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
)
# 为每个通道构建独立的两层 Mamba2 处理块
self.channel_blocks = nn.ModuleList([
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
for _ in range(self.enc_in)
])
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
if self.task_name == 'classification':
assert self.num_class is not None, "classification 需要提供 num_class"
self.act = nn.GELU()
self.head = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(self.d_model * self.enc_in, self.num_class)
)
# 预测头:
# - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model]
# - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1]
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
self.predict_linear = nn.Linear(self.seq_len, self.pred_len)
self.projection = nn.Linear(self.d_model, 1, bias=True)
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
# x_enc: [B, L, D]
B, L, D = x_enc.shape
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
per_channel_last = []
for c in range(D):
# 取出单通道序列 [B, L] -> [B, L, 1]
x_ch = x_enc[:, :, c].unsqueeze(-1)
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
h_last = torch.cat(per_channel_last, dim=-1)
# 分类头
logits = self.head(self.act(h_last)) # [B, num_class]
return logits
def forecast(self, x_enc: torch.Tensor) -> torch.Tensor:
"""
逐通道预测:
- 归一化(时间维),按通道独立提取
- 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量
- 反归一化
返回:
dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段
"""
B, L, D = x_enc.shape
assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致"
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
# Normalization (per Non-stationary Transformer)
means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D]
x = x_enc - means
stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D]
x = x / stdev
per_channel_seq = []
for c in range(D):
x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1]
h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
# 时间维映射到 L + pred_len
h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model]
# 投影回单通道
y_ch = self.projection(h_ch) # [B, L+pred_len, 1]
per_channel_seq.append(y_ch)
# 拼接通道
dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D]
# De-normalization
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
return dec_out
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
dec_out = self.forecast(x_enc) # [B, L+pred_len, D]
return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D]
elif self.task_name == 'classification':
return self.classification(x_enc)
else:
raise NotImplementedError(f"Unsupported task: {self.task_name}")

View File

@ -22,7 +22,7 @@ class Model(nn.Module):
self.pred_len = configs.pred_len self.pred_len = configs.pred_len
self.enc_in = configs.enc_in self.enc_in = configs.enc_in
# Model parameters # Patch parameters
self.patch_len = getattr(configs, 'patch_len', 16) self.patch_len = getattr(configs, 'patch_len', 16)
self.stride = getattr(configs, 'stride', 8) self.stride = getattr(configs, 'stride', 8)
@ -37,17 +37,37 @@ class Model(nn.Module):
beta = getattr(configs, 'beta', torch.tensor(0.1)) beta = getattr(configs, 'beta', torch.tensor(0.1))
self.decomp = DECOMP(ma_type, alpha, beta) self.decomp = DECOMP(ma_type, alpha, beta)
# Season network (PatchTST + Graph Mixer) # Season network (PatchTST/Mamba2 + Graph Mixer)
# 透传新版 SeasonPatch 的参数(其中 GraphMixer 替换为非归一化 Hard-Concrete 门控)
self.season_net = SeasonPatch( self.season_net = SeasonPatch(
c_in=self.enc_in, c_in=self.enc_in,
seq_len=self.seq_len, seq_len=self.seq_len,
pred_len=self.pred_len, pred_len=self.pred_len,
patch_len=self.patch_len, patch_len=self.patch_len,
stride=self.stride, stride=self.stride,
k_graph=getattr(configs, 'k_graph', 8), # 编码器类型:'Transformer' or 'Mamba2'
encoder_type=getattr(configs, 'season_encoder', 'Transformer'),
# Patch相关
d_model=getattr(configs, 'd_model', 128), d_model=getattr(configs, 'd_model', 128),
n_layers=getattr(configs, 'e_layers', 3), n_layers=getattr(configs, 'e_layers', 3),
n_heads=getattr(configs, 'n_heads', 16) n_heads=getattr(configs, 'n_heads', 16),
# GraphMixer相关非归一化
k_graph=getattr(configs, 'k_graph', 8), # -> max_degree
thr_graph=getattr(configs, 'thr_graph', 0.5),
thr_graph_min=getattr(configs, 'thr_graph_min', None),
thr_graph_max=getattr(configs, 'thr_graph_max', None),
thr_graph_steps=getattr(configs, 'thr_graph_steps', 0),
thr_graph_schedule=getattr(configs, 'thr_graph_schedule', 'linear'),
symmetric_graph=getattr(configs, 'symmetric_graph', True),
degree_rescale=getattr(configs, 'degree_rescale', 'count-sqrt'), # 'none' | 'count' | 'count-sqrt' | 'sum'
gate_temperature=getattr(configs, 'gate_temperature', 2.0/3.0),
tau_attn=getattr(configs, 'tau_attn', 1.0),
l0_lambda=getattr(configs, 'season_l0_lambda', 0.0),
# Mamba2相关
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
) )
# Trend network (MLP) # Trend network (MLP)
@ -80,7 +100,7 @@ class Model(nn.Module):
nn.Linear(128, configs.num_class) nn.Linear(128, configs.num_class)
) )
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, training):
"""Long-term forecasting""" """Long-term forecasting"""
# Normalization # Normalization
if self.revin: if self.revin:
@ -90,7 +110,7 @@ class Model(nn.Module):
seasonal_init, trend_init = self.decomp(x_enc) seasonal_init, trend_init = self.decomp(x_enc)
# Season stream # Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len] y_season = self.season_net(seasonal_init, training) # [B, C, pred_len]
# Trend stream # Trend stream
B, L, C = trend_init.shape B, L, C = trend_init.shape
@ -117,17 +137,12 @@ class Model(nn.Module):
def classification(self, x_enc, x_mark_enc): def classification(self, x_enc, x_mark_enc):
"""Classification task""" """Classification task"""
# Normalization # Decomposition分类任务通常可不做 RevIN如需可自行打开
#if self.revin:
# x_enc = self.revin_layer(x_enc, 'norm')
# Decomposition
seasonal_init, trend_init = self.decomp(x_enc) seasonal_init, trend_init = self.decomp(x_enc)
# Season stream # Season stream
y_season = self.season_net(seasonal_init) # [B, C, pred_len] y_season = self.season_net(seasonal_init) # [B, C, pred_len]
# print("shape:", trend_init.shape)
# Trend stream # Trend stream
B, L, C = trend_init.shape B, L, C = trend_init.shape
trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L] trend = trend_init.permute(0, 2, 1).reshape(B * C, L) # [B*C, L]
@ -144,7 +159,7 @@ class Model(nn.Module):
season_attn_weights = torch.softmax(y_season, dim=-1) season_attn_weights = torch.softmax(y_season, dim=-1)
season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C] season_pooled = (y_season * season_attn_weights).sum(dim=-1) # [B, C]
trend_attn_weights = torch.softmax(y_trend, dim=-1) # 时间维 trend_attn_weights = torch.softmax(y_trend, dim=-1)
trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C] trend_pooled = (y_trend * trend_attn_weights).sum(dim=-1) # [B, C]
# Combine features # Combine features
@ -154,13 +169,22 @@ class Model(nn.Module):
logits = self.classifier(features) # [B, num_classes] logits = self.classifier(features) # [B, num_classes]
return logits return logits
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None, training=True):
"""Forward pass dispatching to task-specific methods""" """Forward pass dispatching to task-specific methods"""
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, training)
return dec_out[:, -self.pred_len:, :] # [B, L, D] return dec_out[:, -self.pred_len:, :] # [B, L, D]
elif self.task_name == 'classification': elif self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc) dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N] return dec_out # [B, N]
else: else:
raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel') raise ValueError(f'Task {self.task_name} not supported by xPatch_SparseChannel')
def reg_loss(self):
"""
L0 正则项(仅在 Transformer 路径启用 GraphMixer 时非零)。
训练时total_loss = main_loss + model.reg_loss()
"""
if hasattr(self, "season_net") and hasattr(self.season_net, "reg_loss"):
return self.season_net.reg_loss()
return torch.tensor(0.0, device=next(self.parameters()).device)

4
run.py
View File

@ -7,6 +7,7 @@ from exp.exp_imputation import Exp_Imputation
from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
from exp.exp_anomaly_detection import Exp_Anomaly_Detection from exp.exp_anomaly_detection import Exp_Anomaly_Detection
from exp.exp_classification import Exp_Classification from exp.exp_classification import Exp_Classification
from exp.exp_dc_patchtst_classification import Exp_DC_PatchTST_Classification
from utils.print_args import print_args from utils.print_args import print_args
import random import random
import numpy as np import numpy as np
@ -191,6 +192,9 @@ if __name__ == '__main__':
elif args.task_name == 'anomaly_detection': elif args.task_name == 'anomaly_detection':
Exp = Exp_Anomaly_Detection Exp = Exp_Anomaly_Detection
elif args.task_name == 'classification': elif args.task_name == 'classification':
if args.model == 'DC_PatchTST':
Exp = Exp_DC_PatchTST_Classification
else:
Exp = Exp_Classification Exp = Exp_Classification
else: else:
Exp = Exp_Long_Term_Forecast Exp = Exp_Long_Term_Forecast

View File

@ -0,0 +1,142 @@
export CUDA_VISIBLE_DEVICES=0
model_name=DC_PatchTST
# DC_PatchTST specific parameters
d_model_stage0=64 # Stage 0 dimension (D0)
depth_enc0=1 # Stage 0 Mamba2 encoder depth
depth_enc1=1 # Stage 1 Mamba2 encoder depth
target_ratio0=0.25 # Target compression ratio for stage 0
target_ratio1=0.25 # Target compression ratio for stage 1
# EthanolConcentration dataset
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/EthanolConcentration/ \
--model_id EthanolConcentration \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 8 \
--dropout 0.1 \
--activation gelu \
--des 'DC_PatchTST_Exp' \
--itr 1 \
--learning_rate 0.0002 \
--train_epochs 100 \
--patience 10 \
--d_model_stage0 $d_model_stage0 \
--depth_enc0 $depth_enc0 \
--depth_enc1 $depth_enc1 \
--target_ratio0 $target_ratio0 \
--target_ratio1 $target_ratio1
# FaceDetection dataset
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/FaceDetection/ \
--model_id FaceDetection \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 8 \
--dropout 0.1 \
--activation gelu \
--des 'DC_PatchTST_Exp' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--d_model_stage0 $d_model_stage0 \
--depth_enc0 $depth_enc0 \
--depth_enc1 $depth_enc1 \
--target_ratio0 $target_ratio0 \
--target_ratio1 $target_ratio1
# Handwriting dataset
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Handwriting/ \
--model_id Handwriting \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 8 \
--dropout 0.1 \
--activation gelu \
--des 'DC_PatchTST_Exp' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--d_model_stage0 $d_model_stage0 \
--depth_enc0 $depth_enc0 \
--depth_enc1 $depth_enc1 \
--target_ratio0 $target_ratio0 \
--target_ratio1 $target_ratio1
# Heartbeat dataset
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Heartbeat/ \
--model_id Heartbeat \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 8 \
--dropout 0.1 \
--activation gelu \
--des 'DC_PatchTST_Exp' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--d_model_stage0 $d_model_stage0 \
--depth_enc0 $depth_enc0 \
--depth_enc1 $depth_enc1 \
--target_ratio0 $target_ratio0 \
--target_ratio1 $target_ratio1
# JapaneseVowels dataset
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/JapaneseVowels/ \
--model_id JapaneseVowels \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 8 \
--d_model 128 \
--d_ff 256 \
--n_heads 8 \
--dropout 0.1 \
--activation gelu \
--des 'DC_PatchTST_Exp' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--d_model_stage0 $d_model_stage0 \
--depth_enc0 $depth_enc0 \
--depth_enc1 $depth_enc1 \
--target_ratio0 $target_ratio0 \
--target_ratio1 $target_ratio1

View File

@ -0,0 +1,259 @@
#!/bin/bash
# vanillaMamba Classification Training Script for Multiple Datasets
export CUDA_VISIBLE_DEVICES=0
model_name=vanillaMamba
# Create results directory if it doesn't exist
mkdir -p ./results
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3) - use Copy1 config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/UWaveGestureLibrary/ \
--model_id UWaveGestureLibrary \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 315 \
--enc_in 3 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 128 \
--dropout 0.1 \
--des 'vanillaMamba_UWaveGestureLibrary' \
--itr 1 \
--learning_rate 0.002 \
--train_epochs 150 \
--patience 30 \
--revin 0 | tee ./results/vanillaMamba_UWaveGestureLibrary.log
# EthanolConcentration dataset (seq_len=1751, enc_in=3) - use Copy1 config
python -u run.py \
--task_name classification \
--is_training 3 \
--root_path ./dataset/EthanolConcentration/ \
--model_id EthanolConcentration \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 1751 \
--enc_in 4 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_EthanolConcentration' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 200 \
--patience 30 \
--revin 0 | tee ./results/vanillaMamba_EthanolConcentration.log
# Handwriting dataset (seq_len=152, enc_in=3) - use Copy1 config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Handwriting/ \
--model_id Handwriting \
--model $model_name \
--data UEA \
--e_layers 4 \
--batch_size 64 \
--seq_len 152 \
--enc_in 3 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_Handwriting' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 200 \
--patience 30 \
--revin 0 | tee ./results/vanillaMamba_Handwriting.log
# JapaneseVowels dataset (seq_len=29, enc_in=12) - use Copy1 config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/JapaneseVowels/ \
--model_id JapaneseVowels \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 29 \
--enc_in 12 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_JapaneseVowels' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 | tee ./results/vanillaMamba_JapaneseVowels.log
# PEMS-SF dataset (seq_len=144, enc_in=963) - use Copy1 config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/PEMS-SF/ \
--model_id PEMS-SF \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 16 \
--seq_len 144 \
--enc_in 963 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_PEMS-SF' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 150 \
--patience 30 \
--revin 0 | tee ./results/vanillaMamba_PEMS-SF.log
# Heartbeat dataset (seq_len=405, enc_in=61) - use original config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Heartbeat/ \
--model_id Heartbeat \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 405 \
--enc_in 61 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_Heartbeat' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 150 \
--patience 10 \
--revin 0 | tee ./results/vanillaMamba_Heartbeat.log
# FaceDetection dataset (seq_len=62, enc_in=144) - use original config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/FaceDetection/ \
--model_id FaceDetection \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 62 \
--enc_in 144 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_FaceDetection' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--revin 0 | tee ./results/vanillaMamba_FaceDetection.log
# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6) - use original config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SelfRegulationSCP1/ \
--model_id SelfRegulationSCP1 \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 896 \
--enc_in 6 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_SelfRegulationSCP1' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP1.log
# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7) - use original config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SelfRegulationSCP2/ \
--model_id SelfRegulationSCP2 \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 1152 \
--enc_in 7 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_SelfRegulationSCP2' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP2.log
# SpokenArabicDigits dataset (seq_len=93, enc_in=13) - use original config
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SpokenArabicDigits/ \
--model_id SpokenArabicDigits \
--model $model_name \
--data UEA \
--e_layers 3 \
--batch_size 64 \
--seq_len 93 \
--enc_in 13 \
--d_model 128 \
--d_state 64 \
--d_conv 4 \
--expand 2 \
--headdim 64 \
--dropout 0.1 \
--des 'vanillaMamba_SpokenArabicDigits' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 10 \
--revin 0 | tee ./results/vanillaMamba_SpokenArabicDigits.log

View File

@ -0,0 +1,145 @@
#!/bin/bash
# xPatch_SparseChannel Classification Training Script for Multiple Datasets
export CUDA_VISIBLE_DEVICES=0
model_name=xPatch_SparseChannel
# Create results directory if it doesn't exist
mkdir -p ./results
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/UWaveGestureLibrary/ \
--model_id UWaveGestureLibrary \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 315 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_UWaveGestureLibrary' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log
# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/EthanolConcentration/ \
--model_id EthanolConcentration \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 1751 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_EthanolConcentration' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log
# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Handwriting/ \
--model_id Handwriting \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 152 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_Handwriting' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log
# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/JapaneseVowels/ \
--model_id JapaneseVowels \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 29 \
--enc_in 12 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_JapaneseVowels' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log
# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/PEMS-SF/ \
--model_id PEMS-SF \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 16 \
--seq_len 144 \
--enc_in 963 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_PEMS-SF' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log

View File

@ -1,10 +1,66 @@
#!/bin/bash #!/bin/bash
# xPatch_SparseChannel Classification Training Script for FaceDetection Dataset # xPatch_SparseChannel Classification Training Script for Multiple Datasets
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
model_name=xPatch_SparseChannel model_name=xPatch_SparseChannel
# Create results directory if it doesn't exist
mkdir -p ./results
# Heartbeat dataset (seq_len=405, enc_in=61, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Heartbeat/ \
--model_id Heartbeat \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 405 \
--enc_in 61 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_Heartbeat' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 5 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_Heartbeat.log
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/UWaveGestureLibrary/ \
--model_id UWaveGestureLibrary \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 315 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_UWaveGestureLibrary' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log
# FaceDetection dataset (seq_len=62, enc_in=144, k_graph=8)
python -u run.py \ python -u run.py \
--task_name classification \ --task_name classification \
--is_training 1 \ --is_training 1 \
@ -12,21 +68,202 @@ python -u run.py \
--model_id FaceDetection \ --model_id FaceDetection \
--model $model_name \ --model $model_name \
--data UEA \ --data UEA \
--e_layers 3 \ --e_layers 2 \
--batch_size 64 \ --batch_size 64 \
--seq_len 62 \ --seq_len 62 \
--enc_in 144 \ --enc_in 144 \
--d_model 128 \ --d_model 128 \
--d_ff 256 \ --d_ff 256 \
--n_heads 8 \ --n_heads 16 \
--patch_len 16 \ --patch_len 16 \
--stride 8 \ --stride 8 \
--moving_avg 25 \
--dropout 0.1 \ --dropout 0.1 \
--des 'xPatch_SparseChannel_FaceDetection' \ --des 'xPatch_SparseChannel_FaceDetection' \
--itr 1 \ --itr 1 \
--learning_rate 0.0005 \ --learning_rate 0.0005 \
--train_epochs 100 \ --train_epochs 100 \
--patience 5 \ --patience 5 \
--revin 1 \ --revin 0 \
--k_graph 8 --k_graph 8 | tee ./results/xPatch_SparseChannel_FaceDetection.log
# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/EthanolConcentration/ \
--model_id EthanolConcentration \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 1751 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_EthanolConcentration' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log
# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/Handwriting/ \
--model_id Handwriting \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 152 \
--enc_in 3 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_Handwriting' \
--itr 1 \
--learning_rate 0.001 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log
# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/JapaneseVowels/ \
--model_id JapaneseVowels \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 29 \
--enc_in 12 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_JapaneseVowels' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log
# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6, k_graph=6)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SelfRegulationSCP1/ \
--model_id SelfRegulationSCP1 \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 896 \
--enc_in 6 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_SelfRegulationSCP1' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 5 \
--revin 0 \
--k_graph 6 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP1.log
# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7, k_graph=7)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SelfRegulationSCP2/ \
--model_id SelfRegulationSCP2 \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 1152 \
--enc_in 7 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_SelfRegulationSCP2' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 5 \
--revin 0 \
--k_graph 7 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP2.log
# SpokenArabicDigits dataset (seq_len=93, enc_in=13, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/SpokenArabicDigits/ \
--model_id SpokenArabicDigits \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 64 \
--seq_len 93 \
--enc_in 13 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_SpokenArabicDigits' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 5 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_SpokenArabicDigits.log
# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8)
python -u run.py \
--task_name classification \
--is_training 1 \
--root_path ./dataset/PEMS-SF/ \
--model_id PEMS-SF \
--model $model_name \
--data UEA \
--e_layers 2 \
--batch_size 16 \
--seq_len 144 \
--enc_in 963 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--dropout 0.1 \
--des 'xPatch_SparseChannel_PEMS-SF' \
--itr 1 \
--learning_rate 0.0005 \
--train_epochs 100 \
--patience 30 \
--revin 0 \
--k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log

View File

@ -0,0 +1,40 @@
#!/bin/bash
model_name=xPatch_SparseChannel
# Traffic dataset testing
for pred_len in 96 192 336 720
do
echo "Testing Traffic dataset with prediction length: $pred_len"
python -u run.py \
--task_name long_term_forecast \
--is_training 0 \
--root_path ./dataset/traffic/ \
--data_path traffic.csv \
--model_id traffic_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 862 \
--batch_size 256 \
--c_out 862 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
echo "Finished testing for prediction length: $pred_len"
echo "================================"
done
echo "All Traffic dataset testing completed!"

View File

@ -0,0 +1,251 @@
#!/bin/bash
model_name=vanillaMamba
# ETTm1 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTm1.csv \
--model_id ETTm1_$pred_len'_'$pred_len \
--model $model_name \
--data ETTm1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# ETTm2 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTm2.csv \
--model_id ETTm2_$pred_len'_'$pred_len \
--model $model_name \
--data ETTm2 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# ETTh1 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_$pred_len'_'$pred_len \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# ETTh2 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh2.csv \
--model_id ETTh2_$pred_len'_'$pred_len \
--model $model_name \
--data ETTh2 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# Weather dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/weather/ \
--data_path weather.csv \
--model_id weather_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 21 \
--c_out 21 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# ECL dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 321 \
--c_out 321 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# Traffic dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/traffic/ \
--data_path traffic.csv \
--model_id traffic_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 862 \
--c_out 862 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done
# Exchange dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/exchange_rate/ \
--data_path exchange_rate.csv \
--model_id Exchange_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 8 \
--c_out 8 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1
done

View File

@ -0,0 +1,150 @@
#!/bin/bash
#export CUDA_VISIBLE_DEVICES=0
model_name=xPatch_SparseChannel
seq_len=96
pred_len=12
learning_rate=0.003
d_model=128
d_ff=256
batch_size=128
train_epochs=10
patience=10
# PEMS03 dataset
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/PEMS/ \
--data_path PEMS03.npz \
--model_id PEMS03 \
--model $model_name \
--data PEMS \
--features M \
--seq_len $seq_len \
--label_len 0 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 358 \
--dec_in 358 \
--c_out 358 \
--lradj 'sigmoid' \
--d_model $d_model \
--d_ff $d_ff \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--batch_size $batch_size \
--learning_rate $learning_rate \
--train_epochs $train_epochs \
--patience $patience \
--des 'Exp' \
--itr 1
# PEMS04 dataset
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/PEMS/ \
--data_path PEMS04.npz \
--model_id PEMS04 \
--model $model_name \
--data PEMS \
--features M \
--seq_len $seq_len \
--label_len 0 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 307 \
--dec_in 307 \
--c_out 307 \
--lradj 'sigmoid' \
--d_model $d_model \
--d_ff $d_ff \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--batch_size $batch_size \
--learning_rate $learning_rate \
--train_epochs $train_epochs \
--patience $patience \
--des 'Exp' \
--itr 1
# PEMS07 dataset
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/PEMS/ \
--data_path PEMS07.npz \
--model_id PEMS07 \
--model $model_name \
--data PEMS \
--features M \
--seq_len $seq_len \
--label_len 0 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 883 \
--dec_in 883 \
--c_out 883 \
--lradj 'sigmoid' \
--d_model $d_model \
--d_ff $d_ff \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--batch_size $batch_size \
--learning_rate $learning_rate \
--train_epochs $train_epochs \
--patience $patience \
--des 'Exp' \
--itr 1
# PEMS08 dataset
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/PEMS/ \
--data_path PEMS08.npz \
--model_id PEMS08 \
--model $model_name \
--data PEMS \
--features M \
--seq_len $seq_len \
--label_len 0 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 170 \
--dec_in 170 \
--c_out 170 \
--lradj 'sigmoid' \
--d_model $d_model \
--d_ff $d_ff \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--batch_size $batch_size \
--learning_rate $learning_rate \
--train_epochs $train_epochs \
--patience $patience \
--des 'Exp' \
--itr 1

View File

@ -0,0 +1,302 @@
#!/bin/bash
model_name=xPatch_SparseChannel
# Weather dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/weather/ \
--data_path weather.csv \
--model_id weather_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 21 \
--c_out 21 \
--d_model 128 \
--lradj 'sigmoid' \
--train_epochs 20 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--season_encoder 'Transformer' \
--thr_graph 0.6 \
--symmetric_graph 1 \
--degree_rescale 'none' \
--gate_temperature 0.6667 \
--tau_attn 1.0 \
--season_l0_lambda 0.0000 \
--thr_graph_min 0.1 \
--thr_graph_max 0.6 \
--thr_graph_steps 1000 \
--thr_graph_schedule 'cosine'
done
# Exchange dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/exchange_rate/ \
--data_path exchange_rate.csv \
--model_id Exchange_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 8 \
--c_out 8 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 7 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--season_encoder 'Transformer' \
--thr_graph 0.6 \
--symmetric_graph 1 \
--degree_rescale 'none' \
--gate_temperature 0.6667 \
--tau_attn 1.0 \
--season_l0_lambda 0.0000 \
--thr_graph_min 0.1 \
--thr_graph_max 0.6 \
--thr_graph_steps 1000 \
--thr_graph_schedule 'cosine'
done
# ETTm1 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTm1.csv \
--model_id ETTm1_$pred_len'_'$pred_len \
--model $model_name \
--data ETTm1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 5 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--season_encoder 'Transformer' \
--thr_graph 0.6 \
--symmetric_graph 1 \
--degree_rescale 'none' \
--gate_temperature 0.6667 \
--tau_attn 1.0 \
--season_l0_lambda 0.0000 \
--thr_graph_min 0.1 \
--thr_graph_max 0.6 \
--thr_graph_steps 1000 \
--thr_graph_schedule 'cosine'
done
# ETTm2 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTm2.csv \
--model_id ETTm2_$pred_len'_'$pred_len \
--model $model_name \
--data ETTm2 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 7 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--season_encoder 'Transformer' \
--thr_graph 0.6 \
--symmetric_graph 1 \
--degree_rescale 'none' \
--gate_temperature 0.6667 \
--tau_attn 1.0 \
--season_l0_lambda 0.0000
done
# ETTh1 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_$pred_len'_'$pred_len \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 7 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# ETTh2 dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh2.csv \
--model_id ETTh2_$pred_len'_'$pred_len \
--model $model_name \
--data ETTh2 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--c_out 7 \
--d_model 128 \
--lradj 'sigmoid' \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 7 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# ECL dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 321 \
--c_out 321 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done
# Traffic dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/traffic/ \
--data_path traffic.csv \
--model_id traffic_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 862 \
--c_out 862 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done

View File

@ -0,0 +1,85 @@
#!/bin/bash
model_name=xPatch_SparseChannel
# Traffic dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/traffic/ \
--data_path traffic.csv \
--model_id traffic_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 862 \
--batch_size 256 \
--learning_rate 0.0005 \
--use_multi_gpu True \
--lradj 'sigmoid' \
--train_epochs 50 \
--patience 5 \
--c_out 862 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 10 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--season_encoder 'Transformer' \
--thr_graph 0.6 \
--symmetric_graph 1 \
--degree_rescale 'none' \
--gate_temperature 0.6667 \
--tau_attn 1.0 \
--season_l0_lambda 0.0000
done
# ECL dataset
for pred_len in 96 192 336 720
do
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 321 \
--batch_size 512 \
--learning_rate 0.001 \
--use_multi_gpu True \
--lradj 'sigmoid' \
--train_epochs 50 \
--patience 5 \
--c_out 321 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 8 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1
done

View File

@ -0,0 +1,165 @@
#!/bin/bash
model_name=vanillaMamba
# M4 Monthly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Monthly' \
--model_id m4_Monthly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Yearly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Yearly' \
--model_id m4_Yearly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Quarterly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Quarterly' \
--model_id m4_Quarterly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Weekly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Weekly' \
--model_id m4_Weekly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Daily
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Daily' \
--model_id m4_Daily \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Hourly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Hourly' \
--model_id m4_Hourly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--expand 2 \
--d_conv 4 \
--d_state 64 \
--headdim 64 \
--ngroups 1 \
--chunk_size 256 \
--dropout 0.1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'

View File

@ -0,0 +1,165 @@
#!/bin/bash
model_name=xPatch_SparseChannel
# M4 Monthly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Monthly' \
--model_id m4_Monthly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Yearly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Yearly' \
--model_id m4_Yearly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Quarterly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Quarterly' \
--model_id m4_Quarterly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Weekly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Weekly' \
--model_id m4_Weekly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Daily
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Daily' \
--model_id m4_Daily \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'
# M4 Hourly
python -u run.py \
--task_name short_term_forecast \
--is_training 1 \
--root_path ./dataset/m4 \
--seasonal_patterns 'Hourly' \
--model_id m4_Hourly \
--model $model_name \
--data m4 \
--features M \
--e_layers 2 \
--enc_in 1 \
--c_out 1 \
--batch_size 16 \
--d_model 128 \
--d_ff 256 \
--n_heads 16 \
--patch_len 16 \
--stride 8 \
--k_graph 1 \
--dropout 0.1 \
--revin 1 \
--des 'Exp' \
--itr 1 \
--learning_rate 0.001 \
--loss 'SMAPE'

209
test_DC_hnet.py Normal file
View File

@ -0,0 +1,209 @@
#!/usr/bin/env python3
"""
测试DC_hnet模型的脚本
用于验证时间序列分类模型能否正常运行并得到期望的输出形状
"""
import torch
import torch.nn.functional as F
import sys
import os
# 添加当前目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from models.DC_hnet import HierEncodersSingleMainConfig, HierEncodersSingleMainClassifier
def test_dc_hnet_model():
"""测试DC_hnet时间序列分类模型"""
print("=" * 60)
print("测试DC_hnet时间序列分类模型")
print("=" * 60)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 模型参数配置
B, L, N = 8, 512, 6 # batch_size, seq_length, num_channels
num_classes = 10 # 分类数
d_models = [64, 128, 256] # 各层维度,单调递增
print(f"输入形状: (B={B}, L={L}, N={N})")
print(f"分类数: {num_classes}")
print(f"模型维度: {d_models}")
# 编码器配置每层都是Mamba
encoder_cfg_per_stage = [
dict(arch="m", height=2), # stage 0: Mamba2, 2层
dict(arch="m", height=3), # stage 1: Mamba2, 3层
]
# 主网络配置使用Transformer
main_cfg = dict(
arch="T", height=4 # Transformer, 4层
)
# 压缩目标
target_compression_N_per_stage = [2, 3] # 每层压缩比例
# 创建配置
cfg = HierEncodersSingleMainConfig(
num_channels=N,
d_models=d_models,
num_classes=num_classes,
encoder_cfg_per_stage=encoder_cfg_per_stage,
main_cfg=main_cfg,
target_compression_N_per_stage=target_compression_N_per_stage,
share_channel=True, # 通道间共享编码器
fusion_across_channels="mean", # 通道融合方式
dropout=0.1,
)
print("配置创建完成")
try:
# 创建模型 - 设置正确的dtype以兼容flash attention
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
model = HierEncodersSingleMainClassifier(cfg, device=device, dtype=dtype)
model = model.to(device)
print(f"模型创建成功,参数量: {sum(p.numel() for p in model.parameters()):,}")
# 创建随机输入数据 - 使用bfloat16以兼容flash attention
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
x = torch.randn(B, L, N, device=device, dtype=dtype)
mask = torch.ones(B, L, dtype=torch.bool, device=device)
print(f"输入数据形状: {x.shape}, 数据类型: {x.dtype}")
# 前向传播测试
print("\n开始前向传播测试...")
model.eval()
with torch.no_grad():
logits, seq_debug, aux = model(x, mask=mask, return_seq=False)
print(f"✅ 前向传播成功!")
print(f"输出logits形状: {logits.shape}") # 应该是 (B, num_classes)
print(f"ratio_loss: {aux['ratio_loss']:.4f}")
# 验证输出形状
expected_shape = (B, num_classes)
if logits.shape == expected_shape:
print(f"✅ 输出形状正确: {logits.shape}")
else:
print(f"❌ 输出形状错误: 期望 {expected_shape}, 实际 {logits.shape}")
return False
# 测试训练模式
print("\n开始训练模式测试...")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 创建目标标签
y = torch.randint(0, num_classes, (B,), device=device)
# 前向传播
logits, _, aux = model(x, mask=mask, return_seq=False)
# 计算损失
cls_loss = F.cross_entropy(logits, y)
ratio_reg = 0.01 * aux["ratio_loss"] # ratio loss正则化
total_loss = cls_loss + ratio_reg
print(f"分类损失: {cls_loss:.4f}")
print(f"ratio损失: {ratio_reg:.4f}")
print(f"总损失: {total_loss:.4f}")
# 反向传播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
print("✅ 训练步骤成功!")
# 测试序列返回功能
print("\n测试序列调试信息返回...")
with torch.no_grad():
logits, seq_debug, aux = model(x, mask=mask, return_seq=True)
if seq_debug is not None:
print(f"✅ 序列调试信息获取成功,包含 {len(seq_debug)} 个通道的信息")
else:
print("❌ 序列调试信息获取失败")
print("\n" + "=" * 60)
print("🎉 DC_hnet模型测试全部通过!")
print("模型可以正常进行时间序列分类任务")
print("=" * 60)
return True
except Exception as e:
print(f"❌ 测试失败: {str(e)}")
import traceback
traceback.print_exc()
return False
def test_different_configurations():
"""测试不同的模型配置"""
print("\n测试不同配置...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 测试配置1: 不共享通道
cfg1 = HierEncodersSingleMainConfig(
num_channels=3,
d_models=[32, 64],
num_classes=5,
encoder_cfg_per_stage=[dict(arch="m", height=2)],
main_cfg=dict(arch="m", height=3),
target_compression_N_per_stage=[2],
share_channel=False, # 不共享通道
fusion_across_channels="concat", # 连接融合
dropout=0.1,
)
try:
dtype1 = torch.bfloat16 if device.type == "cuda" else torch.float32
model1 = HierEncodersSingleMainClassifier(cfg1, device=device, dtype=dtype1)
x1 = torch.randn(4, 256, 3, device=device, dtype=dtype1)
logits1, _, _ = model1(x1)
print(f"✅ 配置1 (不共享通道, concat融合): 输出形状 {logits1.shape}")
except Exception as e:
print(f"❌ 配置1测试失败: {str(e)}")
# 测试配置2: 单层模型
cfg2 = HierEncodersSingleMainConfig(
num_channels=2,
d_models=[128], # 只有一层,没有编码器阶段
num_classes=3,
encoder_cfg_per_stage=[], # 空的编码器阶段
main_cfg=dict(arch="T", height=2),
target_compression_N_per_stage=[],
share_channel=True,
fusion_across_channels="mean",
dropout=0.1,
)
try:
dtype2 = torch.bfloat16 if device.type == "cuda" else torch.float32
model2 = HierEncodersSingleMainClassifier(cfg2, device=device, dtype=dtype2)
x2 = torch.randn(2, 128, 2, device=device, dtype=dtype2)
logits2, _, _ = model2(x2)
print(f"✅ 配置2 (单层模型): 输出形状 {logits2.shape}")
except Exception as e:
print(f"❌ 配置2测试失败: {str(e)}")
if __name__ == "__main__":
# 主测试
success = test_dc_hnet_model()
# 额外配置测试
test_different_configurations()
if success:
print("\n🎊 所有测试完成!")
sys.exit(0)
else:
print("\n💥 测试失败!")
sys.exit(1)

313
test_dc_patchtst.py Normal file
View File

@ -0,0 +1,313 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
from models.DC_PatchTST import Model
class Config:
"""Configuration class"""
def __init__(self):
# Basic configuration
self.task_name = 'long_term_forecast'
self.model = 'DC_PatchTST'
# Data configuration
self.seq_len = 96 # Input sequence length
self.pred_len = 24 # Prediction sequence length
self.label_len = 48 # Label length
self.enc_in = 2 # Input feature dimension (dual channel)
self.dec_in = 2 # Decoder input dimension
self.c_out = 2 # Output dimension
# Model configuration
self.d_model = 128 # Model dimension
self.n_heads = 8 # Number of attention heads
self.e_layers = 2 # Number of encoder layers
self.d_layers = 1 # Number of decoder layers
self.d_ff = 256 # Feed forward dimension
self.factor = 1 # Attention factor
self.dropout = 0.1 # Dropout rate
self.activation = 'gelu'
# Training configuration
self.batch_size = 32
self.learning_rate = 0.001
self.train_epochs = 50
self.patience = 5
# Other configuration
self.use_amp = False
self.num_class = 0
# GPU configuration
self.use_gpu = torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
class SineWaveDataset(Dataset):
"""Sine wave dataset"""
def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'):
self.seq_len = seq_len
self.pred_len = pred_len
self.mode = mode
# Load data
if mode == 'train':
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
elif mode == 'val':
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
else: # test
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
# Extract feature columns (except timestamp)
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
# Calculate available sample count
self.total_len = len(self.data)
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples")
def __len__(self):
return self.samples_num
def __getitem__(self, idx):
# Input sequence
s_begin = idx
s_end = s_begin + self.seq_len
# Prediction target
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
# Time marks (simple positional encoding)
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
return seq_x, seq_y, seq_x_mark, seq_y_mark, idx
def load_model(model_path):
"""Load saved model"""
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
config = checkpoint['config']
# Create model
model = Model(config)
model.load_state_dict(checkpoint['model_state_dict'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}")
print(f"Using device: {device}")
return model, config, device
def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5):
"""Predict and visualize results, marking chunk points"""
model.eval()
# Create save directory
vis_dir = os.path.join(save_path, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
sample_count = 0
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader:
if sample_count >= num_samples:
break
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
# Construct decoder input
dec_inp = torch.zeros_like(batch_y).to(device)
# Predict
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# Convert to CPU
batch_x_cpu = batch_x.cpu().numpy()
batch_y_cpu = batch_y.cpu().numpy()
outputs_cpu = outputs.cpu().numpy()
# Process each sample in batch
for i in range(min(batch_x.size(0), num_samples - sample_count)):
input_seq = batch_x_cpu[i] # (seq_len, 2)
true_pred = batch_y_cpu[i] # (pred_len, 2)
pred_seq = outputs_cpu[i] # (pred_len, 2)
# Get first layer chunk information
boundary_mask_stage0 = None
if aux is not None and 'stage0' in aux:
# aux['stage0']['boundary_mask'] shape: (B*nvars, L)
# We need to reshape to (B, nvars, L) then take i-th sample
stage0_info = aux['stage0']
boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L)
B = batch_x.size(0)
nvars = batch_x.size(2) # should be 2
L = boundary_mask.shape[1]
# Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L)
boundary_mask = boundary_mask.reshape(B, nvars, L)
boundary_mask_stage0 = boundary_mask[i] # (nvars, L)
# Create visualization for each channel
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
for ch in range(2): # dual channel
ax = axes[ch]
# Time axis
input_time = np.arange(len(input_seq))
pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred))
# Plot input sequence
ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5)
# Plot ground truth prediction and model prediction
ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2)
ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2)
# Mark first layer chunk points
if boundary_mask_stage0 is not None:
chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel
for point in chunk_points:
if point < len(input_seq): # Only mark points within input sequence range
ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1)
# Add chunk points explanation in legend
if len(chunk_points) > 0:
ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7,
linewidth=1, label=f'Chunk Points (Stage 0)')
ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}')
ax.set_xlabel('Time Steps')
ax.set_ylabel('Value')
ax.legend()
ax.grid(True, alpha=0.3)
# Add boundary line marking input and prediction boundary
ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1)
ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start',
rotation=90, verticalalignment='top', fontsize=8)
plt.tight_layout()
# Save figure
sample_filename = f'sample_{sample_count + 1}_with_chunks.png'
sample_path = os.path.join(vis_dir, sample_filename)
plt.savefig(sample_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Sample {sample_count + 1} visualization saved to: {sample_path}")
# Print chunk statistics
if boundary_mask_stage0 is not None:
for ch in range(2):
chunk_count = np.sum(boundary_mask_stage0[ch])
chunk_ratio = chunk_count / len(boundary_mask_stage0[ch])
print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)")
sample_count += 1
if sample_count >= num_samples:
break
if sample_count >= num_samples:
break
def evaluate_model(model, test_loader, device):
"""Evaluate model performance"""
model.eval()
total_mse = 0.0
total_mae = 0.0
batch_count = 0
all_predictions = []
all_ground_truths = []
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# Calculate loss
mse = torch.mean((outputs - batch_y) ** 2)
mae = torch.mean(torch.abs(outputs - batch_y))
total_mse += mse.item()
total_mae += mae.item()
batch_count += 1
# Collect for overall statistics
all_predictions.append(outputs.cpu().numpy())
all_ground_truths.append(batch_y.cpu().numpy())
avg_mse = total_mse / batch_count
avg_mae = total_mae / batch_count
# Calculate overall RMSE
all_predictions = np.concatenate(all_predictions, axis=0)
all_ground_truths = np.concatenate(all_ground_truths, axis=0)
rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2))
print(f"\nTest set evaluation results:")
print(f"MSE: {avg_mse:.6f}")
print(f"MAE: {avg_mae:.6f}")
print(f"RMSE: {rmse:.6f}")
return avg_mse, avg_mae, rmse
def main():
# Configuration parameters
model_path = './results/dc_patchtst_sine_wave/best_model.pth'
data_path = './data/sine_wave/'
save_path = './results/dc_patchtst_sine_wave/'
if not os.path.exists(model_path):
print(f"Error: Model file not found {model_path}")
print("Please run the training script train_dc_patchtst.py first")
return
# Load model
model, config, device = load_model(model_path)
# Load test data
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization
print(f"\nConfiguration info:")
print(f"Sequence length: {config.seq_len}")
print(f"Prediction length: {config.pred_len}")
print(f"Feature dimension: {config.enc_in}")
# Evaluate model
mse, mae, rmse = evaluate_model(model, test_loader, device)
# Visualize prediction results and mark chunk points
print(f"\nGenerating visualization results...")
visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5)
print(f"\nAll results saved to: {save_path}")
print(f"Visualization files located at: {save_path}/visualizations/")
if __name__ == "__main__":
main()

335
train_dc_patchtst.py Normal file
View File

@ -0,0 +1,335 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import argparse
from models.DC_PatchTST import Model
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
# 检查并创建结果文件夹
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
class SineWaveDataset(Dataset):
"""正弦波数据集"""
def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'):
self.seq_len = seq_len
self.pred_len = pred_len
self.mode = mode
# 加载数据
if mode == 'train':
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
elif mode == 'val':
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
else: # test
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
# 提取特征列除timestamp外
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
# 计算可用样本数量
self.total_len = len(self.data)
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本")
def __len__(self):
return self.samples_num
def __getitem__(self, idx):
# 输入序列
s_begin = idx
s_end = s_begin + self.seq_len
# 预测目标
r_begin = s_end
r_end = r_begin + self.pred_len
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
# 时间标记(简单的位置编码)
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
return seq_x, seq_y, seq_x_mark, seq_y_mark
class Config:
"""配置类"""
def __init__(self):
# 基础配置
self.task_name = 'long_term_forecast'
self.model = 'DC_PatchTST'
# 数据配置
self.seq_len = 96 # 输入序列长度
self.pred_len = 24 # 预测序列长度
self.label_len = 48 # 标签长度
self.enc_in = 2 # 输入特征维度(双通道)
self.dec_in = 2 # 解码器输入维度
self.c_out = 2 # 输出维度
# 模型配置
self.d_model = 128 # 模型维度
self.n_heads = 8 # 注意力头数
self.e_layers = 2 # 编码器层数
self.d_layers = 1 # 解码器层数
self.d_ff = 256 # 前向网络维度
self.factor = 1 # 注意力因子
self.dropout = 0 # Dropout率
self.activation = 'gelu'
# 训练配置
self.batch_size = 1024
self.learning_rate = 0.001
self.train_epochs = 50
self.patience = 5
# 其他配置
self.use_amp = False
self.num_class = 0
# GPU配置
self.use_gpu = torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False):
"""训练一个epoch"""
model.train()
total_loss = 0.0
batch_count = 0
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
# 构造解码器输入
dec_inp = torch.zeros_like(batch_y).to(device)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
else:
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
# 添加DC的ratio loss
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
loss = loss + 0.0 * ratio_loss # ratio loss权重
loss.backward()
optimizer.step()
total_loss += loss.item()
batch_count += 1
if i % 100 == 0:
print(f'Batch {i}, Loss: {loss.item():.6f}')
return total_loss / batch_count
def validate(model, val_loader, criterion, device):
"""验证模型"""
model.eval()
total_loss = 0.0
batch_count = 0
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
loss = criterion(outputs, batch_y)
total_loss += loss.item()
batch_count += 1
return total_loss / batch_count
def test_model(model, test_loader, device, save_path):
"""测试模型并可视化结果"""
model.eval()
predictions = []
ground_truths = []
with torch.no_grad():
for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
batch_x_mark = batch_x_mark.to(device)
batch_y_mark = batch_y_mark.to(device)
dec_inp = torch.zeros_like(batch_y).to(device)
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
predictions.append(outputs.cpu().numpy())
ground_truths.append(batch_y.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
ground_truths = np.concatenate(ground_truths, axis=0)
# 计算指标
mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1))
mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1))
print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}")
# 可视化前几个样本
plt.figure(figsize=(15, 10))
for i in range(min(4, len(predictions))):
for ch in range(2): # 双通道
plt.subplot(4, 2, i*2 + ch + 1)
plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue')
plt.plot(predictions[i, :, ch], label='Prediction', color='red')
plt.title(f'Sample {i+1}, Channel {ch+1}')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'predictions.png'))
print(f"预测结果可视化保存到: {save_path}/predictions.png")
return mse, mae
def main():
# 配置参数
config = Config()
# 创建结果目录
results_dir = './results/dc_patchtst_sine_wave'
ensure_dir(results_dir)
print(f"使用设备: {config.device}")
print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}")
# 加载数据
data_path = './data/sine_wave/'
train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train')
val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val')
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
# 创建模型
model = Model(config).to(config.device)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
# 优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
criterion = nn.MSELoss()
# 训练循环
best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []
print("开始训练...")
start_time = time.time()
for epoch in range(config.train_epochs):
epoch_start = time.time()
# 训练
train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp)
# 验证
val_loss = validate(model, val_loader, criterion, config.device)
train_losses.append(train_loss)
val_losses.append(val_loss)
epoch_time = time.time() - epoch_start
print(f'Epoch {epoch+1:2d}/{config.train_epochs} | '
f'Train Loss: {train_loss:.6f} | '
f'Val Loss: {val_loss:.6f} | '
f'Time: {epoch_time:.2f}s')
# 早停检查
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# 保存最佳模型
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'epoch': epoch,
'val_loss': val_loss
}, os.path.join(results_dir, 'best_model.pth'))
print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})')
else:
patience_counter += 1
if patience_counter >= config.patience:
print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}')
break
total_time = time.time() - start_time
print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟')
# 加载最佳模型进行测试
checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
print("\n测试最佳模型...")
test_mse, test_mae = test_model(model, test_loader, config.device, results_dir)
# 保存训练历史
history = {
'train_losses': train_losses,
'val_losses': val_losses,
'test_mse': test_mse,
'test_mae': test_mae
}
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training History')
plt.subplot(1, 2, 2)
plt.bar(['MSE', 'MAE'], [test_mse, test_mae])
plt.title('Test Metrics')
plt.ylabel('Error')
plt.tight_layout()
plt.savefig(os.path.join(results_dir, 'training_history.png'))
print(f"\n结果保存在: {results_dir}")
print(f"最佳模型: {results_dir}/best_model.pth")
print(f"训练历史: {results_dir}/training_history.png")
print(f"预测可视化: {results_dir}/predictions.png")
if __name__ == "__main__":
main()