feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
102
generate_sine_data.py
Normal file
102
generate_sine_data.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
|
||||||
|
def generate_sine_wave_data(n_samples=10000, seq_len=200, n_channels=2, save_path='./data/sine_wave/'):
|
||||||
|
"""
|
||||||
|
生成双通道正弦波时序数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: 总样本数
|
||||||
|
seq_len: 每个序列长度
|
||||||
|
n_channels: 通道数 (固定为2)
|
||||||
|
save_path: 保存路径
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
os.makedirs(save_path)
|
||||||
|
|
||||||
|
# 生成时间轴
|
||||||
|
t = np.linspace(0, 4*np.pi, seq_len)
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
|
||||||
|
for i in range(n_samples):
|
||||||
|
# 为每个样本生成不同周期和相位的正弦波
|
||||||
|
# 通道1: 随机周期和相位
|
||||||
|
freq1 = np.random.uniform(0.5, 3.0) # 频率范围
|
||||||
|
phase1 = np.random.uniform(0, 2*np.pi) # 相位
|
||||||
|
amplitude1 = np.random.uniform(0.5, 2.0) # 幅度
|
||||||
|
|
||||||
|
# 通道2: 不同的随机周期和相位
|
||||||
|
freq2 = np.random.uniform(0.3, 2.5)
|
||||||
|
phase2 = np.random.uniform(0, 2*np.pi)
|
||||||
|
amplitude2 = np.random.uniform(0.8, 1.8)
|
||||||
|
|
||||||
|
# 生成正弦波数据
|
||||||
|
channel1 = amplitude1 * np.sin(freq1 * t + phase1)
|
||||||
|
channel2 = amplitude2 * np.sin(freq2 * t + phase2)
|
||||||
|
|
||||||
|
# 添加少量噪声
|
||||||
|
# noise1 = np.random.normal(0, 0.1, seq_len)
|
||||||
|
# noise2 = np.random.normal(0, 0.1, seq_len)
|
||||||
|
#
|
||||||
|
# channel1 += noise1
|
||||||
|
# channel2 += noise2
|
||||||
|
|
||||||
|
# 组合数据: [timestamp, channel1, channel2]
|
||||||
|
timestamp = np.arange(seq_len)
|
||||||
|
sample_data = np.column_stack([timestamp, channel1, channel2])
|
||||||
|
all_data.append(sample_data)
|
||||||
|
|
||||||
|
# 转换为连续的时间序列格式
|
||||||
|
continuous_data = []
|
||||||
|
current_time = 0
|
||||||
|
|
||||||
|
for sample in all_data:
|
||||||
|
sample[:, 0] = current_time + sample[:, 0] # 调整时间戳
|
||||||
|
continuous_data.append(sample)
|
||||||
|
current_time += seq_len
|
||||||
|
|
||||||
|
# 合并所有数据
|
||||||
|
full_data = np.vstack(continuous_data)
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
df = pd.DataFrame(full_data, columns=['timestamp', 'channel1', 'channel2'])
|
||||||
|
|
||||||
|
# 按 8:1:1 比例分割训练、验证、测试集
|
||||||
|
total_len = len(df)
|
||||||
|
train_end = int(0.8 * total_len)
|
||||||
|
val_end = int(0.9 * total_len)
|
||||||
|
|
||||||
|
train_df = df[:train_end]
|
||||||
|
val_df = df[train_end:val_end]
|
||||||
|
test_df = df[val_end:]
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
train_df.to_csv(os.path.join(save_path, 'train.csv'), index=False)
|
||||||
|
val_df.to_csv(os.path.join(save_path, 'val.csv'), index=False)
|
||||||
|
test_df.to_csv(os.path.join(save_path, 'test.csv'), index=False)
|
||||||
|
|
||||||
|
# 保存完整数据
|
||||||
|
df.to_csv(os.path.join(save_path, 'sine_wave.csv'), index=False)
|
||||||
|
|
||||||
|
print(f"数据已生成并保存到 {save_path}")
|
||||||
|
print(f"训练集: {len(train_df)} 条记录")
|
||||||
|
print(f"验证集: {len(val_df)} 条记录")
|
||||||
|
print(f"测试集: {len(test_df)} 条记录")
|
||||||
|
print(f"总计: {len(df)} 条记录,{n_channels} 个通道")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 生成数据
|
||||||
|
data = generate_sine_wave_data(
|
||||||
|
n_samples=200, # 2000个不同的正弦波样本
|
||||||
|
seq_len=200, # 每个样本200个时间点
|
||||||
|
n_channels=2, # 双通道
|
||||||
|
save_path='./data/sine_wave/'
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n数据统计信息:")
|
||||||
|
print(data.describe())
|
440
layers/DynamicChunking.py
Normal file
440
layers/DynamicChunking.py
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
|
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingModuleOutput:
|
||||||
|
# 路由模块的前向输出:
|
||||||
|
# - boundary_prob: 每个位置为“分隔点/非分隔点”的二分类概率,形状(..., 2)
|
||||||
|
# - boundary_mask: 基于最大概率得到的硬选择(True=分隔点),形状与输入序列相同的前两维
|
||||||
|
# - selected_probs: 对应硬选择类别的概率,形状(..., 1)
|
||||||
|
boundary_prob: torch.Tensor
|
||||||
|
boundary_mask: torch.Tensor
|
||||||
|
selected_probs: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingModuleState:
|
||||||
|
"""
|
||||||
|
路由模块的推理状态(用于增量/流式step)
|
||||||
|
|
||||||
|
包含:
|
||||||
|
- has_seen_tokens: (batch_size,) 是否已经见过任意token(用于首token强制边界)
|
||||||
|
- last_hidden_state: (batch_size, d_model) 上一次的隐藏状态(用于与当前token做相邻相似度)
|
||||||
|
"""
|
||||||
|
|
||||||
|
has_seen_tokens: torch.Tensor # (batch_size,)
|
||||||
|
last_hidden_state: torch.Tensor # (batch_size, d_model)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeChunkState:
|
||||||
|
"""
|
||||||
|
DeChunk 的推理状态(EMA的记忆值)
|
||||||
|
|
||||||
|
包含:
|
||||||
|
- last_value: (batch_size, d_model) EMA反聚合的上一时刻值
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_value: torch.Tensor # (batch_size, d_model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_idx(cu_seqlens, device=None):
|
||||||
|
seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.long, device=device)
|
||||||
|
seq_idx[cu_seqlens[:-1]] = 1
|
||||||
|
seq_idx = (torch.cumsum(seq_idx, dim=0) - 1).unsqueeze(0).int()
|
||||||
|
|
||||||
|
return seq_idx
|
||||||
|
|
||||||
|
class RoutingModule(nn.Module):
|
||||||
|
"""
|
||||||
|
路由模块:
|
||||||
|
用相邻token的余弦相似度构造“成为分隔点”的概率:
|
||||||
|
p_t = clamp((1 - cos(h_{t-1}, h_t)) / 2, 0, 1)
|
||||||
|
并强制首位置为边界(概率=1)。
|
||||||
|
支持:
|
||||||
|
- 常规batch掩码 mask 模式
|
||||||
|
- packed序列 cu_seqlens 模式(把多序列打包成单条序列的拼接)
|
||||||
|
- 流式推理 step()(维护状态)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, device=None, dtype=None):
|
||||||
|
self.d_model = d_model
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
# 相邻相似度计算前的线性投影(初始化为恒等)
|
||||||
|
self.q_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
|
||||||
|
self.k_proj_layer = nn.Linear(d_model, d_model, bias=False, **factory_kwargs)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_proj_layer.weight.copy_(torch.eye(d_model))
|
||||||
|
self.k_proj_layer.weight.copy_(torch.eye(d_model))
|
||||||
|
# 防止外部权重再初始化
|
||||||
|
self.q_proj_layer.weight._no_reinit = True
|
||||||
|
self.k_proj_layer.weight._no_reinit = True
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
|
||||||
|
# 分配推理cache(用于step)
|
||||||
|
return RoutingModuleState(
|
||||||
|
has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool),
|
||||||
|
last_hidden_state=torch.zeros(
|
||||||
|
batch_size, self.d_model, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None):
|
||||||
|
"""
|
||||||
|
hidden_states:
|
||||||
|
- 若 cu_seqlens is None: (B, L, D)
|
||||||
|
- 若 cu_seqlens 非 None: 期望 packed 模式 (T, D),这里会临时扩维成 (1, T, D)
|
||||||
|
cu_seqlens: packed模式下每条序列的前缀和下标,形如 [0, len1, len1+len2, ...]
|
||||||
|
mask: (B, L) bool,True=有效(非packed)
|
||||||
|
inference_params: RoutingModuleState,用于prefill时的校验与状态维护
|
||||||
|
"""
|
||||||
|
assert (mask is not None) or (
|
||||||
|
cu_seqlens is not None
|
||||||
|
), "Either mask or cu_seqlens must be provided"
|
||||||
|
|
||||||
|
if inference_params is not None:
|
||||||
|
# prefill阶段必须提供mask,且不允许之前已经见过token
|
||||||
|
assert (
|
||||||
|
mask is not None
|
||||||
|
), "Mask must be provided if inference_params is provided"
|
||||||
|
assert (
|
||||||
|
~inference_params.has_seen_tokens
|
||||||
|
).all(), "Cannot have seen tokens when inference_params is not provided"
|
||||||
|
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
# packed 模式:把 (T, D) 临时变为 (1, T, D)
|
||||||
|
hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
|
||||||
|
# 计算相邻余弦相似度 cos(h_{t-1}, h_t)
|
||||||
|
cos_sim = torch.einsum(
|
||||||
|
"b l d, b l d -> b l",
|
||||||
|
F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1),
|
||||||
|
F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1),
|
||||||
|
)
|
||||||
|
# p = ((1 - cos) / 2) ∈ [0,1]
|
||||||
|
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
|
||||||
|
|
||||||
|
# 强制首位置为边界:首位概率=1,补充后长度和输入序列长度想等
|
||||||
|
PAD_PROB = 1.0
|
||||||
|
boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB)
|
||||||
|
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
# packed 模式下,每段序列的第一个位置是 cu_seqlens[:-1]
|
||||||
|
boundary_prob = boundary_prob.squeeze(0)
|
||||||
|
boundary_prob[cu_seqlens[:-1]] = PAD_PROB
|
||||||
|
|
||||||
|
# 组装为二分类概率 [非边界, 边界]
|
||||||
|
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
|
||||||
|
|
||||||
|
# 取最大概率类别
|
||||||
|
selected_idx = torch.argmax(boundary_prob, dim=-1)
|
||||||
|
|
||||||
|
# 硬边界掩码
|
||||||
|
boundary_mask = selected_idx == 1 # 形状与 hidden_states 的前两维一致
|
||||||
|
if mask is not None:
|
||||||
|
# 不允许选择到无效token
|
||||||
|
boundary_mask = boundary_mask & mask
|
||||||
|
|
||||||
|
if inference_params is not None:
|
||||||
|
# 维护路由状态:是否见过token、最后一个有效token的隐藏状态
|
||||||
|
has_mask = mask.any(dim=-1)
|
||||||
|
inference_params.has_seen_tokens.copy_(
|
||||||
|
has_mask | inference_params.has_seen_tokens
|
||||||
|
)
|
||||||
|
last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0)
|
||||||
|
inference_params.last_hidden_state.copy_(
|
||||||
|
torch.where(
|
||||||
|
has_mask,
|
||||||
|
hidden_states[
|
||||||
|
torch.arange(
|
||||||
|
hidden_states.shape[0], device=hidden_states.device
|
||||||
|
),
|
||||||
|
last_mask,
|
||||||
|
],
|
||||||
|
inference_params.last_hidden_state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取硬选择对应的概率(便于可视化/正则)
|
||||||
|
selected_probs = boundary_prob.gather(
|
||||||
|
dim=-1, index=selected_idx.unsqueeze(-1)
|
||||||
|
) # (..., 1)
|
||||||
|
|
||||||
|
return RoutingModuleOutput(
|
||||||
|
boundary_prob=boundary_prob, # (..., 2)
|
||||||
|
boundary_mask=boundary_mask, # (...)
|
||||||
|
selected_probs=selected_probs, # (..., 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(self, hidden_states, inference_params):
|
||||||
|
"""
|
||||||
|
流式单步:
|
||||||
|
hidden_states: (B, 1, D)
|
||||||
|
使用上一步缓存的 last_hidden_state 与当前token计算相邻相似度,得到当前步的边界概率
|
||||||
|
"""
|
||||||
|
# (B, D)
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
cos_sim = torch.einsum(
|
||||||
|
"b d, b d -> b",
|
||||||
|
F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1),
|
||||||
|
F.normalize(self.k_proj_layer(hidden_states), dim=-1),
|
||||||
|
)
|
||||||
|
boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0)
|
||||||
|
# 更新最后隐藏状态
|
||||||
|
inference_params.last_hidden_state.copy_(hidden_states)
|
||||||
|
# 首个token前,强制边界
|
||||||
|
boundary_prob = torch.where(
|
||||||
|
inference_params.has_seen_tokens,
|
||||||
|
boundary_prob,
|
||||||
|
torch.ones_like(boundary_prob),
|
||||||
|
)
|
||||||
|
boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1)
|
||||||
|
|
||||||
|
# 标记为已见token
|
||||||
|
inference_params.has_seen_tokens.copy_(
|
||||||
|
torch.ones_like(inference_params.has_seen_tokens)
|
||||||
|
)
|
||||||
|
return RoutingModuleOutput(
|
||||||
|
boundary_prob=boundary_prob, # (B, 2)
|
||||||
|
boundary_mask=boundary_prob[..., 1] > 0.5, # (B,)
|
||||||
|
selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Chunk层:根据 boundary_mask 将被选中的“边界token”抽取出来,形成下一层序列。
|
||||||
|
支持两种模式:
|
||||||
|
- packed(cu_seqlens 非 None):直接在拼接后的序列上索引
|
||||||
|
- 非packed(mask 非 None):通过排序 trick 把True位置排到前面,并生成 next_mask
|
||||||
|
返回:
|
||||||
|
- next_hidden_states: 选中的token序列(packed: shape=(#selected, D);非packed: (B, M, D))
|
||||||
|
- next_cu_seqlens: packed模式下新序列的cu_seqlens;否则None
|
||||||
|
- next_max_seqlen: packed模式下选中的最大长度;非packed模式返回None
|
||||||
|
- next_mask: 非packed模式下的右侧pad掩码;packed模式下None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None):
|
||||||
|
assert (mask is not None) or (
|
||||||
|
cu_seqlens is not None
|
||||||
|
), "Either mask or cu_seqlens must be provided"
|
||||||
|
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
# packed:直接选择True的行,得到拼接后的 selected
|
||||||
|
next_hidden_states = hidden_states[boundary_mask]
|
||||||
|
# 新的cu_seqlens = 对每段最后一个位置(=cu_seqlens[1:]-1)累计True的计数,再前置0
|
||||||
|
next_cu_seqlens = F.pad(
|
||||||
|
boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0)
|
||||||
|
)
|
||||||
|
# 新序列的最大段长(仅用于内核/优化)
|
||||||
|
next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max())
|
||||||
|
next_mask = None
|
||||||
|
else:
|
||||||
|
# 非packed:对每个batch内,把True位置排到前面(False放到靠后)
|
||||||
|
next_cu_seqlens = None
|
||||||
|
num_tokens = boundary_mask.sum(dim=-1) # 每个样本被选中的数量
|
||||||
|
next_max_seqlen = int(num_tokens.max())
|
||||||
|
|
||||||
|
device = hidden_states.device
|
||||||
|
L = hidden_states.shape[1]
|
||||||
|
# trick:用 (~boundary_mask)*L 把False加大,从而 argsort 后 True 的下标排在前面
|
||||||
|
token_idx = (
|
||||||
|
torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L
|
||||||
|
)
|
||||||
|
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
||||||
|
|
||||||
|
# 收集前 next_max_seqlen 个(不足的样本右侧pad)
|
||||||
|
next_hidden_states = torch.gather(
|
||||||
|
hidden_states,
|
||||||
|
dim=1,
|
||||||
|
index=seq_sorted_indices[:, :next_max_seqlen, None].expand(
|
||||||
|
-1, -1, hidden_states.shape[-1]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 下游的有效mask(右侧pad无效)
|
||||||
|
next_mask = (
|
||||||
|
torch.arange(next_max_seqlen, device=device)[None, :]
|
||||||
|
< num_tokens[:, None]
|
||||||
|
)
|
||||||
|
# 非packed模式下,不再需要 max_seqlen(返回None)
|
||||||
|
next_max_seqlen = None
|
||||||
|
|
||||||
|
return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask
|
||||||
|
|
||||||
|
def step(self, hidden_states, boundary_mask):
|
||||||
|
# 流式step:仅返回当前步被选中的token(用于下一层)
|
||||||
|
return hidden_states[boundary_mask]
|
||||||
|
|
||||||
|
|
||||||
|
class DeChunkLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
DeChunk层:把“被选中的边界token序列”反聚合(EMA)回原始等长序列。
|
||||||
|
实现上复用 Mamba2 的 Triton 扫描核 mamba_chunk_scan_combined:
|
||||||
|
- 将 d_model 切分为 nheads * headdim
|
||||||
|
- 使用参数 A=-1, b=p, c=1 的一阶状态空间/EMA形式进行前向扫描
|
||||||
|
- 最终把扫描输出根据分段索引映射回原位置(plug back)
|
||||||
|
支持:
|
||||||
|
- packed 模式(cu_seqlens)
|
||||||
|
- 非packed(batch+右侧pad)
|
||||||
|
- 流式 step(EMA递推)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
block_size=256,
|
||||||
|
headdim=32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
# 仅为内核要求:使用 bfloat16,块大小与头维拆分
|
||||||
|
self.dtype = dtype
|
||||||
|
self.block_size = block_size
|
||||||
|
self.headdim = headdim
|
||||||
|
assert d_model % self.headdim == 0
|
||||||
|
self.nheads = d_model // self.headdim
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None):
|
||||||
|
# 分配EMA的last_value缓存
|
||||||
|
return DeChunkState(
|
||||||
|
last_value=torch.zeros(
|
||||||
|
batch_size, self.d_model, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states, # 被选中的token序列(packed: (M, D);非packed: (B, M, D))
|
||||||
|
boundary_mask, # 原序列上的边界掩码((T,) 或 (B, L))
|
||||||
|
boundary_prob, # 原序列上的二分类概率((..., 2))
|
||||||
|
cu_seqlens=None,
|
||||||
|
inference_params=None,
|
||||||
|
mask=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
核心思路:
|
||||||
|
1) 从 boundary_prob 得到 p = P(boundary) ∈ (1e-4, 1-1e-4)
|
||||||
|
2) 构造 dt = log(1 / (1-p)),并对输入做缩放 x = h / dt
|
||||||
|
3) 用 mamba_chunk_scan_combined 扫描:A=-1, b=p, c=1,对 (B, M, H, P) 进行块扫描
|
||||||
|
4) 将结果根据 cumulative boundary index 回填到原序列位置
|
||||||
|
"""
|
||||||
|
if inference_params is not None:
|
||||||
|
# prefill时必须有mask,且首token必须是边界(保证EMA初始化)
|
||||||
|
assert (
|
||||||
|
mask is not None
|
||||||
|
), "Mask must be provided if inference_params is provided"
|
||||||
|
assert boundary_mask[
|
||||||
|
:, 0
|
||||||
|
].all(), "First token must be a boundary if running prefill"
|
||||||
|
|
||||||
|
# 取边界概率的“边界类”概率 p,并限制在(1e-4, 1-1e-4)内,避免数值不稳
|
||||||
|
p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4))
|
||||||
|
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
# packed:从原序列p中取出被选中的位置对应的概率,形状(B=1, M)
|
||||||
|
p = p[boundary_mask].unsqueeze(0)
|
||||||
|
# 为triton核准备packed序列的索引映射
|
||||||
|
seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device)
|
||||||
|
else:
|
||||||
|
B, L = boundary_mask.shape
|
||||||
|
seq_idx = None
|
||||||
|
# 与ChunkLayer一致的排序 trick,得到选中的顺序(True在前)
|
||||||
|
token_idx = (
|
||||||
|
torch.arange(L, device=hidden_states.device)[None, :]
|
||||||
|
+ (~boundary_mask).long() * L
|
||||||
|
)
|
||||||
|
seq_sorted_indices = torch.argsort(token_idx, dim=1)
|
||||||
|
|
||||||
|
# 取出与 hidden_states 对应长度的 p((B, M))
|
||||||
|
p = torch.gather(
|
||||||
|
p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]]
|
||||||
|
) # (B, M)
|
||||||
|
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
# 构造 EMA 扫描所需变量
|
||||||
|
dt = torch.log(1 / (1 - p)).to(self.dtype) # (B, M)
|
||||||
|
x = (hidden_states / dt[..., None]).to(self.dtype) # (B, M, D) / (B, M, 1)
|
||||||
|
|
||||||
|
# A, b, c 分别对应一阶状态空间/EMA的参数
|
||||||
|
A = -torch.ones(
|
||||||
|
(self.nheads,), device=hidden_states.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
b = p.to(self.dtype)
|
||||||
|
c = torch.ones_like(b)
|
||||||
|
|
||||||
|
# 调用triton核进行块扫描
|
||||||
|
out = mamba_chunk_scan_combined(
|
||||||
|
rearrange(x, "b l (h p) -> b l h p", p=self.headdim), # (B, M, H, P)
|
||||||
|
repeat(dt, "b l -> b l h", h=self.nheads), # (B, M, H)
|
||||||
|
A, # (H,)
|
||||||
|
rearrange(b, "b l -> b l 1 1"), # (B, M, 1, 1)
|
||||||
|
rearrange(c, "b l -> b l 1 1"), # (B, M, 1, 1)
|
||||||
|
chunk_size=self.block_size,
|
||||||
|
seq_idx=seq_idx, # packed时提供
|
||||||
|
)
|
||||||
|
out = rearrange(out, "b l h p -> b l (h p)") # (B, M, D)
|
||||||
|
|
||||||
|
# 将扫描结果回填(plug back)到原序列位置
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
out = out.squeeze(0) # (M, D)
|
||||||
|
plug_back_idx = boundary_mask.cumsum(dim=0) - 1 # (T,)
|
||||||
|
out = torch.gather(
|
||||||
|
out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model)
|
||||||
|
) # (T, D)
|
||||||
|
else:
|
||||||
|
plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L)
|
||||||
|
out = torch.gather(
|
||||||
|
out,
|
||||||
|
dim=1,
|
||||||
|
index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model),
|
||||||
|
) # (B, L, D)
|
||||||
|
|
||||||
|
# 更新流式缓存
|
||||||
|
if inference_params is not None:
|
||||||
|
inference_params.last_value.copy_(out[:, -1])
|
||||||
|
|
||||||
|
return out.to(original_dtype)
|
||||||
|
|
||||||
|
def step(self, hidden_states, boundary_mask, boundary_prob, inference_params):
|
||||||
|
"""
|
||||||
|
流式单步 EMA 反聚合:
|
||||||
|
hidden_states: (B', 1, D),其中 B' = 当前步被选中的数量(boundary_mask.sum())
|
||||||
|
boundary_mask: (B,) 当前batch哪些位置被选中为边界
|
||||||
|
boundary_prob: (B, 2) 当前batch各位置的边界概率
|
||||||
|
输出:(B, 1, D),对应对所有位置做了一步 EMA 更新后的值
|
||||||
|
"""
|
||||||
|
B = boundary_mask.shape[0]
|
||||||
|
D = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
# 构造当前步每个位置的 p(未被选中的位置 p=0)
|
||||||
|
p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
|
p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp(
|
||||||
|
min=1e-4, max=1 - (1e-4)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构造当前被选中的隐藏状态(未选中为0)
|
||||||
|
current_hidden_states = torch.zeros(
|
||||||
|
B, D, device=hidden_states.device, dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
current_hidden_states[boundary_mask] = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# EMA:result = p * x + (1 - p) * last
|
||||||
|
result = p * current_hidden_states + (1 - p) * inference_params.last_value
|
||||||
|
inference_params.last_value.copy_(result)
|
||||||
|
|
||||||
|
return result.unsqueeze(1)
|
@ -17,25 +17,30 @@ class DSAttention(nn.Module):
|
|||||||
self.output_attention = output_attention
|
self.output_attention = output_attention
|
||||||
self.dropout = nn.Dropout(attention_dropout)
|
self.dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||||||
|
"""
|
||||||
|
key_padding_mask: (B, S) bool, True=valid, False=pad(可选,忽略或由上层应用)
|
||||||
|
"""
|
||||||
B, L, H, E = queries.shape
|
B, L, H, E = queries.shape
|
||||||
_, S, _, D = values.shape
|
_, S, _, D = values.shape
|
||||||
scale = self.scale or 1. / sqrt(E)
|
scale = self.scale or 1. / sqrt(E)
|
||||||
|
|
||||||
tau = 1.0 if tau is None else tau.unsqueeze(
|
tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1
|
||||||
1).unsqueeze(1) # B x 1 x 1 x 1
|
delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S
|
||||||
delta = 0.0 if delta is None else delta.unsqueeze(
|
|
||||||
1).unsqueeze(1) # B x 1 x 1 x S
|
|
||||||
|
|
||||||
# De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta # (B,H,L,S)
|
||||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta
|
|
||||||
|
|
||||||
if self.mask_flag:
|
if self.mask_flag:
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||||
|
|
||||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||||
|
|
||||||
|
# 可选:基于key_padding_mask的无效键屏蔽(不改变原行为,默认None)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# key_padding_mask: True 表示有效,False为padding
|
||||||
|
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
|
||||||
|
scores = scores.masked_fill(invalid_k, -np.inf)
|
||||||
|
|
||||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||||
|
|
||||||
@ -46,6 +51,12 @@ class DSAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FullAttention(nn.Module):
|
class FullAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
修正点:
|
||||||
|
- 新增 key_padding_mask 支持,用于屏蔽批内右侧pad的键向量(与DC变长对齐)
|
||||||
|
- key_padding_mask 约定:shape=(B, S),True=有效,False=padding
|
||||||
|
- 其余行为与原实现保持一致
|
||||||
|
"""
|
||||||
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
|
||||||
super(FullAttention, self).__init__()
|
super(FullAttention, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -53,21 +64,33 @@ class FullAttention(nn.Module):
|
|||||||
self.output_attention = output_attention
|
self.output_attention = output_attention
|
||||||
self.dropout = nn.Dropout(attention_dropout)
|
self.dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||||||
|
"""
|
||||||
|
queries: (B, L, H, E)
|
||||||
|
keys: (B, S, H, E)
|
||||||
|
values: (B, S, H, D)
|
||||||
|
attn_mask: TriangularCausalMask 或 None
|
||||||
|
key_padding_mask: (B, S) bool,True=有效,False=padding(可选)
|
||||||
|
"""
|
||||||
B, L, H, E = queries.shape
|
B, L, H, E = queries.shape
|
||||||
_, S, _, D = values.shape
|
_, S, _, D = values.shape
|
||||||
scale = self.scale or 1. / sqrt(E)
|
scale = self.scale or 1. / sqrt(E)
|
||||||
|
|
||||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys) # (B,H,L,S)
|
||||||
|
|
||||||
if self.mask_flag:
|
if self.mask_flag:
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||||
|
|
||||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||||
|
|
||||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
# 基于key_padding_mask屏蔽无效键(padding位置不参与注意力)
|
||||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
if key_padding_mask is not None:
|
||||||
|
# key_padding_mask: True=有效,False=padding
|
||||||
|
invalid_k = (~key_padding_mask).unsqueeze(1).unsqueeze(1) # (B,1,1,S)
|
||||||
|
scores = scores.masked_fill(invalid_k, -np.inf)
|
||||||
|
|
||||||
|
A = self.dropout(torch.softmax(scale * scores, dim=-1)) # (B,H,L,S)
|
||||||
|
V = torch.einsum("bhls,bshd->blhd", A, values) # (B,L,H,D)
|
||||||
|
|
||||||
if self.output_attention:
|
if self.output_attention:
|
||||||
return V.contiguous(), A
|
return V.contiguous(), A
|
||||||
@ -85,100 +108,86 @@ class ProbAttention(nn.Module):
|
|||||||
self.dropout = nn.Dropout(attention_dropout)
|
self.dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
||||||
# Q [B, H, L, D]
|
# Q [B, H, L_q, D], K [B, H, L_k, D]
|
||||||
B, H, L_K, E = K.shape
|
B, H, L_K, E = K.shape
|
||||||
_, _, L_Q, _ = Q.shape
|
_, _, L_Q, _ = Q.shape
|
||||||
|
|
||||||
# calculate the sampled Q_K
|
|
||||||
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
||||||
# real U = U_part(factor*ln(L_k))*L_q
|
index_sample = torch.randint(L_K, (L_Q, sample_k), device=Q.device)
|
||||||
index_sample = torch.randint(L_K, (L_Q, sample_k))
|
K_sample = K_expand[:, :, torch.arange(L_Q, device=Q.device).unsqueeze(1), index_sample, :]
|
||||||
K_sample = K_expand[:, :, torch.arange(
|
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # (B,H,L_Q,sample_k)
|
||||||
L_Q).unsqueeze(1), index_sample, :]
|
|
||||||
Q_K_sample = torch.matmul(
|
|
||||||
Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
|
|
||||||
|
|
||||||
# find the Top_k query with sparisty measurement
|
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) # (B,H,L_Q)
|
||||||
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
|
M_top = M.topk(n_top, sorted=False)[1] # indices
|
||||||
M_top = M.topk(n_top, sorted=False)[1]
|
|
||||||
|
|
||||||
# use the reduced Q to calculate Q_K
|
|
||||||
Q_reduce = Q[torch.arange(B)[:, None, None],
|
Q_reduce = Q[torch.arange(B)[:, None, None],
|
||||||
torch.arange(H)[None, :, None],
|
torch.arange(H)[None, :, None],
|
||||||
M_top, :] # factor*ln(L_q)
|
M_top, :] # (B,H,n_top,D)
|
||||||
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
|
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # (B,H,n_top,L_K)
|
||||||
|
|
||||||
return Q_K, M_top
|
return Q_K, M_top
|
||||||
|
|
||||||
def _get_initial_context(self, V, L_Q):
|
def _get_initial_context(self, V, L_Q):
|
||||||
B, H, L_V, D = V.shape
|
B, H, L_V, D = V.shape
|
||||||
if not self.mask_flag:
|
if not self.mask_flag:
|
||||||
# V_sum = V.sum(dim=-2)
|
V_mean = V.mean(dim=-2) # (B,H,D)
|
||||||
V_sum = V.mean(dim=-2)
|
context = V_mean.unsqueeze(-2).expand(B, H, L_Q, D).clone()
|
||||||
contex = V_sum.unsqueeze(-2).expand(B, H,
|
else:
|
||||||
L_Q, V_sum.shape[-1]).clone()
|
assert L_Q == L_V
|
||||||
else: # use mask
|
context = V.cumsum(dim=-2)
|
||||||
# requires that L_Q == L_V, i.e. for self-attention only
|
return context
|
||||||
assert (L_Q == L_V)
|
|
||||||
contex = V.cumsum(dim=-2)
|
|
||||||
return contex
|
|
||||||
|
|
||||||
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
||||||
B, H, L_V, D = V.shape
|
B, H, L_V, D = V.shape
|
||||||
|
|
||||||
if self.mask_flag:
|
if self.mask_flag:
|
||||||
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
||||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||||
|
attn = torch.softmax(scores, dim=-1)
|
||||||
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
|
||||||
|
|
||||||
context_in[torch.arange(B)[:, None, None],
|
context_in[torch.arange(B)[:, None, None],
|
||||||
torch.arange(H)[None, :, None],
|
torch.arange(H)[None, :, None],
|
||||||
index, :] = torch.matmul(attn, V).type_as(context_in)
|
index, :] = torch.matmul(attn, V).type_as(context_in)
|
||||||
if self.output_attention:
|
if self.output_attention:
|
||||||
attns = (torch.ones([B, H, L_V, L_V]) /
|
attns = (torch.ones([B, H, L_V, L_V], device=attn.device, dtype=attn.dtype) / L_V)
|
||||||
L_V).type_as(attn).to(attn.device)
|
attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
|
||||||
attns[torch.arange(B)[:, None, None], torch.arange(H)[
|
|
||||||
None, :, None], index, :] = attn
|
|
||||||
return context_in, attns
|
return context_in, attns
|
||||||
else:
|
else:
|
||||||
return context_in, None
|
return context_in, None
|
||||||
|
|
||||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||||||
|
"""
|
||||||
|
key_padding_mask 目前未集成到 ProbAttention(如需,可在scores处对无效键置 -inf)
|
||||||
|
"""
|
||||||
B, L_Q, H, D = queries.shape
|
B, L_Q, H, D = queries.shape
|
||||||
_, L_K, _, _ = keys.shape
|
_, L_K, _, _ = keys.shape
|
||||||
|
|
||||||
queries = queries.transpose(2, 1)
|
queries = queries.transpose(2, 1) # (B,H,L_Q,D)
|
||||||
keys = keys.transpose(2, 1)
|
keys = keys.transpose(2, 1) # (B,H,L_K,D)
|
||||||
values = values.transpose(2, 1)
|
values = values.transpose(2, 1) # (B,H,L_K,D)
|
||||||
|
|
||||||
U_part = self.factor * \
|
U_part = self.factor * int(np.ceil(np.log(L_K)))
|
||||||
np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
|
u = self.factor * int(np.ceil(np.log(L_Q)))
|
||||||
u = self.factor * \
|
|
||||||
np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
|
|
||||||
|
|
||||||
U_part = U_part if U_part < L_K else L_K
|
U_part = min(U_part, L_K)
|
||||||
u = u if u < L_Q else L_Q
|
u = min(u, L_Q)
|
||||||
|
|
||||||
scores_top, index = self._prob_QK(
|
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
|
||||||
queries, keys, sample_k=U_part, n_top=u)
|
|
||||||
|
|
||||||
# add scale factor
|
|
||||||
scale = self.scale or 1. / sqrt(D)
|
scale = self.scale or 1. / sqrt(D)
|
||||||
if scale is not None:
|
scores_top = scores_top * scale
|
||||||
scores_top = scores_top * scale
|
|
||||||
# get the context
|
|
||||||
context = self._get_initial_context(values, L_Q)
|
context = self._get_initial_context(values, L_Q)
|
||||||
# update the context with selected top_k queries
|
context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
|
||||||
context, attn = self._update_context(
|
|
||||||
context, values, scores_top, index, L_Q, attn_mask)
|
|
||||||
|
|
||||||
return context.contiguous(), attn
|
return context.contiguous(), attn
|
||||||
|
|
||||||
|
|
||||||
class AttentionLayer(nn.Module):
|
class AttentionLayer(nn.Module):
|
||||||
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
"""
|
||||||
d_values=None):
|
修正点:
|
||||||
|
- forward 新增 key_padding_mask 参数,并向 inner_attention 透传
|
||||||
|
- 保持与旧调用兼容(不传时默认None)
|
||||||
|
"""
|
||||||
|
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
|
||||||
super(AttentionLayer, self).__init__()
|
super(AttentionLayer, self).__init__()
|
||||||
|
|
||||||
d_keys = d_keys or (d_model // n_heads)
|
d_keys = d_keys or (d_model // n_heads)
|
||||||
@ -191,7 +200,10 @@ class AttentionLayer(nn.Module):
|
|||||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None, key_padding_mask=None):
|
||||||
|
"""
|
||||||
|
key_padding_mask: (B, S) bool, True=有效,False=padding
|
||||||
|
"""
|
||||||
B, L, _ = queries.shape
|
B, L, _ = queries.shape
|
||||||
_, S, _ = keys.shape
|
_, S, _ = keys.shape
|
||||||
H = self.n_heads
|
H = self.n_heads
|
||||||
@ -206,10 +218,10 @@ class AttentionLayer(nn.Module):
|
|||||||
values,
|
values,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
tau=tau,
|
tau=tau,
|
||||||
delta=delta
|
delta=delta,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
)
|
)
|
||||||
out = out.view(B, L, -1)
|
out = out.view(B, L, -1)
|
||||||
|
|
||||||
return self.out_projection(out), attn
|
return self.out_projection(out), attn
|
||||||
|
|
||||||
|
|
||||||
@ -232,12 +244,11 @@ class ReformerLayer(nn.Module):
|
|||||||
if N % (self.bucket_size * 2) == 0:
|
if N % (self.bucket_size * 2) == 0:
|
||||||
return queries
|
return queries
|
||||||
else:
|
else:
|
||||||
# fill the time series
|
|
||||||
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
|
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
|
||||||
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
|
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
|
||||||
|
|
||||||
def forward(self, queries, keys, values, attn_mask, tau, delta):
|
def forward(self, queries, keys, values, attn_mask, tau, delta, key_padding_mask=None):
|
||||||
# in Reformer: defalut queries=keys
|
# queries=keys in Reformer
|
||||||
B, N, C = queries.shape
|
B, N, C = queries.shape
|
||||||
queries = self.attn(self.fit_length(queries))[:, :N, :]
|
queries = self.attn(self.fit_length(queries))[:, :N, :]
|
||||||
return queries, None
|
return queries, None
|
||||||
@ -275,23 +286,23 @@ class TwoStageAttentionLayer(nn.Module):
|
|||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(d_ff, d_model))
|
nn.Linear(d_ff, d_model))
|
||||||
|
|
||||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
|
||||||
# Cross Time Stage: Directly apply MSA to each dimension
|
# Cross Time Stage: Directly apply MSA to each dimension
|
||||||
batch = x.shape[0]
|
batch = x.shape[0]
|
||||||
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
|
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
|
||||||
time_enc, attn = self.time_attention(
|
time_enc, attn = self.time_attention(
|
||||||
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None
|
time_in, time_in, time_in, attn_mask=None, tau=None, delta=None, key_padding_mask=key_padding_mask
|
||||||
)
|
)
|
||||||
dim_in = time_in + self.dropout(time_enc)
|
dim_in = time_in + self.dropout(time_enc)
|
||||||
dim_in = self.norm1(dim_in)
|
dim_in = self.norm1(dim_in)
|
||||||
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
|
||||||
dim_in = self.norm2(dim_in)
|
dim_in = self.norm2(dim_in)
|
||||||
|
|
||||||
# Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
|
# Cross Dimension Stage
|
||||||
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
|
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b=batch)
|
||||||
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
|
batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat=batch)
|
||||||
dim_buffer, attn = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
|
dim_buffer, _ = self.dim_sender(batch_router, dim_send, dim_send, attn_mask=None, tau=None, delta=None)
|
||||||
dim_receive, attn = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
|
dim_receive, _ = self.dim_receiver(dim_send, dim_buffer, dim_buffer, attn_mask=None, tau=None, delta=None)
|
||||||
dim_enc = dim_send + self.dropout(dim_receive)
|
dim_enc = dim_send + self.dropout(dim_receive)
|
||||||
dim_enc = self.norm3(dim_enc)
|
dim_enc = self.norm3(dim_enc)
|
||||||
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
|
||||||
|
528
models/DC_PatchTST.py
Normal file
528
models/DC_PatchTST.py
Normal file
@ -0,0 +1,528 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
||||||
|
|
||||||
|
# 需要 Mamba2 作为外层编码器
|
||||||
|
from mamba_ssm.modules.mamba2 import Mamba2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- Routing(余弦路由,和论文一致) --------------------
|
||||||
|
class RoutingModule(nn.Module):
|
||||||
|
def __init__(self, d_model):
|
||||||
|
super().__init__()
|
||||||
|
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.eye_(self.q_proj.weight)
|
||||||
|
nn.init.eye_(self.k_proj.weight)
|
||||||
|
self.q_proj.weight._no_reinit = True
|
||||||
|
self.k_proj.weight._no_reinit = True
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
x: (B, L, D)
|
||||||
|
mask: (B, L) bool, True=有效
|
||||||
|
返回:
|
||||||
|
boundary_prob: (B, L, 2)
|
||||||
|
boundary_mask: (B, L) bool
|
||||||
|
selected_probs: (B, L, 1)
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
q = F.normalize(self.q_proj(x[:, :-1]), dim=-1) # (B, L-1, D)
|
||||||
|
k = F.normalize(self.k_proj(x[:, 1:]), dim=-1) # (B, L-1, D)
|
||||||
|
cos_sim = (q * k).sum(dim=-1) # (B, L-1)
|
||||||
|
p = torch.clamp((1 - cos_sim) / 2, 0.0, 1.0) # (B, L-1)
|
||||||
|
p = F.pad(p, (1, 0), value=1.0) # 强制首位是边界
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
p = p * mask.float()
|
||||||
|
p[:, 0] = torch.where(mask[:, 0], torch.ones_like(p[:, 0]), p[:, 0])
|
||||||
|
|
||||||
|
boundary_prob = torch.stack([1 - p, p], dim=-1) # (B, L, 2)
|
||||||
|
selected_idx = boundary_prob.argmax(dim=-1)
|
||||||
|
boundary_mask = (selected_idx == 1)
|
||||||
|
if mask is not None:
|
||||||
|
boundary_mask = boundary_mask & mask
|
||||||
|
selected_probs = boundary_prob.gather(-1, selected_idx.unsqueeze(-1)) # (B, L, 1)
|
||||||
|
return boundary_prob, boundary_mask, selected_probs
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- 选择并右侧零pad(不丢弃、不重复填充) --------------------
|
||||||
|
def select_and_right_pad(x, boundary_mask):
|
||||||
|
"""
|
||||||
|
内存优化版本:减少临时tensor创建
|
||||||
|
x: (B, L, D), boundary_mask: (B, L) bool
|
||||||
|
返回:
|
||||||
|
x_pad: (B, T_max, D)
|
||||||
|
key_padding_mask: (B, T_max) bool, True=有效
|
||||||
|
lengths: (B,)
|
||||||
|
"""
|
||||||
|
B, L, D = x.shape
|
||||||
|
device = x.device
|
||||||
|
lengths = boundary_mask.sum(dim=1) # (B,)
|
||||||
|
T_max = int(lengths.max().item()) if lengths.max() > 0 else 1
|
||||||
|
|
||||||
|
x_pad = x.new_zeros(B, T_max, D)
|
||||||
|
key_padding_mask = torch.zeros(B, T_max, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# 预创建默认索引tensor避免重复创建
|
||||||
|
default_idx = torch.tensor([0], device=device)
|
||||||
|
|
||||||
|
for b in range(B):
|
||||||
|
mask_b = boundary_mask[b]
|
||||||
|
if mask_b.any():
|
||||||
|
idx = mask_b.nonzero(as_tuple=True)[0] # 更高效的nonzero
|
||||||
|
t = idx.numel()
|
||||||
|
x_pad[b, :t] = x[b, idx]
|
||||||
|
key_padding_mask[b, :t] = True
|
||||||
|
else:
|
||||||
|
# 使用预创建的tensor
|
||||||
|
x_pad[b, 0] = x[b, default_idx]
|
||||||
|
key_padding_mask[b, 0] = True
|
||||||
|
|
||||||
|
return x_pad, key_padding_mask, lengths
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- Mamba2 堆叠(外层编码器) --------------------
|
||||||
|
class Mamba2Encoder(nn.Module):
|
||||||
|
def __init__(self, d_model, depth=4, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([Mamba2(d_model=d_model) for _ in range(depth)])
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- 两层Encoder + DC(变长,不丢信息;以比率约束压缩) --------------------
|
||||||
|
class DCEmbedding2StageVarLen(nn.Module):
|
||||||
|
"""
|
||||||
|
- Stage 0: (B*nvars, L, 1) -> Linear(D0) -> Mamba2(D0) -> Routing -> 选择 -> 扩宽到 D1
|
||||||
|
- Stage 1: (B*nvars, L0_sel, D1) -> Mamba2(D1) -> Routing -> 选择
|
||||||
|
输出:
|
||||||
|
enc_out: (B*nvars, T_max, D1)
|
||||||
|
key_padding_mask: (B*nvars, T_max)
|
||||||
|
n_vars: int
|
||||||
|
aux: dict(含两层ratio loss与边界信息)
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model_out, d_model_stage0, depth_enc0=4, depth_enc1=4, dropout=0.0,
|
||||||
|
target_ratio0=0.25, target_ratio1=0.5):
|
||||||
|
super().__init__()
|
||||||
|
assert d_model_out >= d_model_stage0, "要求 D0 <= D1"
|
||||||
|
self.d0 = d_model_stage0
|
||||||
|
self.d1 = d_model_out
|
||||||
|
|
||||||
|
# 标量 -> D0
|
||||||
|
self.input_proj = nn.Linear(1, self.d0)
|
||||||
|
|
||||||
|
# Stage 0
|
||||||
|
self.enc0 = Mamba2Encoder(self.d0, depth=depth_enc0, dropout=dropout)
|
||||||
|
self.router0 = RoutingModule(self.d0)
|
||||||
|
delta = self.d1 - self.d0
|
||||||
|
self.pad_vec = nn.Parameter(torch.zeros(delta)) if delta > 0 else None
|
||||||
|
self.target_ratio0 = target_ratio0
|
||||||
|
|
||||||
|
# Stage 1
|
||||||
|
self.enc1 = Mamba2Encoder(self.d1, depth=depth_enc1, dropout=dropout)
|
||||||
|
self.router1 = RoutingModule(self.d1)
|
||||||
|
self.target_ratio1 = target_ratio1
|
||||||
|
|
||||||
|
def _expand_width(self, x):
|
||||||
|
if self.pad_vec is None:
|
||||||
|
return x
|
||||||
|
B, L, _ = x.shape
|
||||||
|
return torch.cat([x, self.pad_vec.view(1, 1, -1).expand(B, L, -1)], dim=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_ratio: float) -> torch.Tensor:
|
||||||
|
eps = 1e-6
|
||||||
|
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||||||
|
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||||||
|
N = 1.0 / max(target_ratio, eps)
|
||||||
|
loss = N / (N - 1.0 + eps) * (((N - 1.0) * F_act * G_prob) + (1.0 - F_act) * (1.0 - G_prob))
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: (B, nvars, L)
|
||||||
|
内存优化版本:及时删除中间tensor
|
||||||
|
"""
|
||||||
|
B, nvars, L = x.shape
|
||||||
|
x = x.reshape(B * nvars, L, 1)
|
||||||
|
x = self.input_proj(x) # (B*nvars, L, D0)
|
||||||
|
|
||||||
|
# Stage 0
|
||||||
|
h0 = self.enc0(x) # (B*nvars, L, D0)
|
||||||
|
p0, bm0, _ = self.router0(h0)
|
||||||
|
h0_sel, mask0, len0 = select_and_right_pad(h0, bm0) # (B*nvars, L0_max, D0)
|
||||||
|
|
||||||
|
# 及时删除不需要的tensor
|
||||||
|
del h0
|
||||||
|
|
||||||
|
# h0_sel = self._expand_width(h0_sel) # (B*nvars, L0_max, D1)
|
||||||
|
|
||||||
|
# Stage 1
|
||||||
|
#h1 = self.enc1(h0_sel) # (B*nvars, L0_max, D1)
|
||||||
|
#p1, bm1, _ = self.router1(h1)
|
||||||
|
#bm1 = bm1 & mask0
|
||||||
|
#h1_sel, mask1, len1 = select_and_right_pad(h1, bm1) # (B*nvars, L1_max, D1)
|
||||||
|
|
||||||
|
# 及时删除中间tensor
|
||||||
|
#del h1, h0_sel
|
||||||
|
|
||||||
|
# 计算ratio loss时使用detach避免保存计算图
|
||||||
|
ratio_loss0 = self._ratio_loss(bm0, p0, target_ratio=self.target_ratio0)
|
||||||
|
# ratio_loss1 = self._ratio_loss(bm1, p1, target_ratio=self.target_ratio1)
|
||||||
|
|
||||||
|
# 简化aux字典,只保存必要信息
|
||||||
|
aux = {
|
||||||
|
"stage0": {"boundary_mask": bm0.detach(), "boundary_prob": p0.detach(), "lengths": len0.detach()},
|
||||||
|
# "stage1": {"boundary_mask": bm1.detach(), "boundary_prob": p1.detach(), "lengths": len1.detach()},
|
||||||
|
"ratio_loss0": ratio_loss0,
|
||||||
|
# "ratio_loss1": ratio_loss1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return h0_sel, mask0, nvars, aux
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- Encoder/EncoderLayer(带 key_padding_mask 透传) --------------------
|
||||||
|
class EncoderLayerWithMask(nn.Module):
|
||||||
|
"""
|
||||||
|
与原EncoderLayer结构一致,但 forward 增加 key_padding_mask,并传入 AttentionLayer。
|
||||||
|
FFN 用简单的 MLP(与常规Transformer一致)。
|
||||||
|
"""
|
||||||
|
def __init__(self, attention: AttentionLayer, d_model, d_ff, dropout=0.1, activation="gelu"):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = attention
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
act = nn.ReLU()
|
||||||
|
elif activation == "gelu":
|
||||||
|
act = nn.GELU()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {activation}")
|
||||||
|
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_ff),
|
||||||
|
act,
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(d_ff, d_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None, tau=None, delta=None, key_padding_mask=None):
|
||||||
|
# Multi-head attention with key padding mask
|
||||||
|
attn_out, attn = self.attention(
|
||||||
|
x, x, x, attn_mask, tau=tau, delta=delta, key_padding_mask=key_padding_mask
|
||||||
|
)
|
||||||
|
x = x + self.dropout(attn_out)
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
y = self.ffn(x)
|
||||||
|
x = x + self.dropout(y)
|
||||||
|
x = self.norm2(x)
|
||||||
|
return x, attn
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderWithMask(nn.Module):
|
||||||
|
"""
|
||||||
|
与原Encoder类似,但 forward 支持 key_padding_mask,并传递给每一层的注意力。
|
||||||
|
"""
|
||||||
|
def __init__(self, attn_layers, norm_layer=None):
|
||||||
|
super().__init__()
|
||||||
|
self.attn_layers = nn.ModuleList(attn_layers)
|
||||||
|
self.norm = norm_layer
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
||||||
|
attns = []
|
||||||
|
for attn_layer in self.attn_layers:
|
||||||
|
x, attn = attn_layer(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
attns.append(attn)
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
return x, attns
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- 门控注意力聚合 + 任务头(不依赖token数;保留信息) --------------------
|
||||||
|
def masked_softmax(logits: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
logits: (..., T)
|
||||||
|
mask: (..., T) bool, True=有效
|
||||||
|
"""
|
||||||
|
neg_inf = torch.finfo(logits.dtype).min
|
||||||
|
logits = logits.masked_fill(~mask, neg_inf)
|
||||||
|
return torch.softmax(logits, dim=dim)
|
||||||
|
|
||||||
|
class GatedAttnAggregator(nn.Module):
|
||||||
|
"""
|
||||||
|
门控注意力聚合器(可学习查询 + mask softmax + 值端门控)
|
||||||
|
输入: x: (B*, T, D), mask: (B*, T) bool
|
||||||
|
输出: slots: (B*, R, D) 其中 R 为聚合插槽数(可配置)
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, num_slots: int = 4, d_att: int = None, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.R = num_slots
|
||||||
|
self.d_att = d_att or d_model
|
||||||
|
|
||||||
|
# 可学习查询(R个)
|
||||||
|
self.query = nn.Parameter(torch.randn(self.R, self.d_att) / (self.d_att ** 0.5))
|
||||||
|
|
||||||
|
# 线性投影
|
||||||
|
self.key_proj = nn.Linear(d_model, self.d_att)
|
||||||
|
self.val_proj = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
# 值端门控(逐token标量门)
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model // 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_model // 2, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x_bt_t_d: torch.Tensor, mask_bt: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x_bt_t_d: (B*, T, D)
|
||||||
|
mask_bt: (B*, T) bool
|
||||||
|
return: slots (B*, R, D)
|
||||||
|
"""
|
||||||
|
BStar, T, D = x_bt_t_d.shape
|
||||||
|
K = self.key_proj(x_bt_t_d) # (B*, T, d_att)
|
||||||
|
V = self.val_proj(x_bt_t_d) # (B*, T, D)
|
||||||
|
g = self.gate(x_bt_t_d) # (B*, T, 1)
|
||||||
|
Vg = V * g # 门控后的值
|
||||||
|
|
||||||
|
Q = self.query.unsqueeze(0).expand(BStar, -1, -1) # (B*, R, d_att)
|
||||||
|
|
||||||
|
logits = torch.matmul(Q, K.transpose(1, 2)) / (self.d_att ** 0.5) # (B*, R, T)
|
||||||
|
attn_mask = mask_bt.unsqueeze(1) # (B*, 1, T)
|
||||||
|
attn = masked_softmax(logits, attn_mask, dim=-1)
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
slots = torch.matmul(attn, Vg) # (B*, R, D)
|
||||||
|
return slots
|
||||||
|
|
||||||
|
class AttnPoolHeadForecast(nn.Module):
|
||||||
|
"""
|
||||||
|
预测任务头:门控注意力聚合到 R 个slots,再映射到 target_window(pred_len)
|
||||||
|
输出:(B, pred_len, nvars)
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.LayerNorm(num_slots * d_model),
|
||||||
|
nn.Linear(num_slots * d_model, target_window),
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||||
|
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||||
|
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||||||
|
out = self.proj(self.dropout(slots)) # (B, nvars, pred_len)
|
||||||
|
return out.permute(0, 2, 1) # (B, pred_len, nvars)
|
||||||
|
|
||||||
|
class AttnPoolHeadSeq(nn.Module):
|
||||||
|
"""
|
||||||
|
序列重建头:门控注意力聚合后映射到 seq_len
|
||||||
|
输出:(B, seq_len, nvars)
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, target_window: int, num_slots: int = 4, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.LayerNorm(num_slots * d_model),
|
||||||
|
nn.Linear(num_slots * d_model, target_window),
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||||
|
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||||
|
slots = slots.reshape(B, n_vars, -1) # (B, nvars, R*D)
|
||||||
|
out = self.proj(self.dropout(slots)) # (B, nvars, seq_len)
|
||||||
|
return out.permute(0, 2, 1) # (B, seq_len, nvars)
|
||||||
|
|
||||||
|
class AttnPoolHeadCls(nn.Module):
|
||||||
|
"""
|
||||||
|
分类头:每变量先门控注意力聚合到 R 个slots,拼接所有变量后线性分类。
|
||||||
|
输出:(B, num_class)
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, n_vars: int, num_class: int, num_slots: int = 4, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.agg = GatedAttnAggregator(d_model, num_slots=num_slots, dropout=dropout)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.LayerNorm(n_vars * num_slots * d_model),
|
||||||
|
nn.Linear(n_vars * num_slots * d_model, num_class),
|
||||||
|
)
|
||||||
|
self.n_vars = n_vars
|
||||||
|
self.num_slots = num_slots
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
def forward(self, enc_out_bt_t_d: torch.Tensor, key_padding_mask_bt: torch.Tensor, n_vars: int, B: int):
|
||||||
|
slots = self.agg(enc_out_bt_t_d, key_padding_mask_bt) # (B*, R, D)
|
||||||
|
slots = slots.reshape(B, n_vars, self.num_slots * self.d_model) # (B, nvars, R*D)
|
||||||
|
flat = self.dropout(slots.reshape(B, -1)) # (B, nvars*R*D)
|
||||||
|
return self.proj(flat)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- 主模型:两层DC(比率控制) + 带mask的Encoder + 门控聚合头 --------------------
|
||||||
|
class Transpose(nn.Module):
|
||||||
|
def __init__(self, *dims, contiguous=False):
|
||||||
|
super().__init__()
|
||||||
|
self.dims, self.contiguous = dims, contiguous
|
||||||
|
def forward(self, x):
|
||||||
|
return x.transpose(*self.dims).contiguous() if self.contiguous else x.transpose(*self.dims)
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
PatchTST with DC and masked attention + gated heads:
|
||||||
|
- 用两层 Mamba2 编码器 + 动态分块 替代 PatchEmbedding
|
||||||
|
- DC 使用 ratio loss(target_ratio0/1)控制压缩强度;随层级加深,序列变短,d_model 变大(D0->D1)
|
||||||
|
- 注意力传入 key_padding_mask 屏蔽pad
|
||||||
|
- 头部使用门控注意力聚合(不依赖token数,信息保留更充分)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, configs,
|
||||||
|
d_model_stage0=None, # D0,默认= d_model // 2
|
||||||
|
depth_enc0=1, depth_enc1=1,
|
||||||
|
target_ratio0=0.25, # 约等于 1/N0
|
||||||
|
target_ratio1=0.5, # 约等于 1/N1
|
||||||
|
agg_slots=4, # 门控聚合的slot数
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.task_name = configs.task_name
|
||||||
|
self.seq_len = configs.seq_len
|
||||||
|
self.pred_len = configs.pred_len
|
||||||
|
self.enc_in = configs.enc_in
|
||||||
|
|
||||||
|
# DC 嵌入
|
||||||
|
D1 = configs.d_model
|
||||||
|
D0 = d_model_stage0 if d_model_stage0 is not None else max(16, D1 // 2)
|
||||||
|
assert D1 >= D0, "要求 D0 <= D1"
|
||||||
|
self.dc_embedding = DCEmbedding2StageVarLen(
|
||||||
|
d_model_out=D1,
|
||||||
|
d_model_stage0=D0,
|
||||||
|
depth_enc0=depth_enc0,
|
||||||
|
depth_enc1=depth_enc1,
|
||||||
|
dropout=configs.dropout,
|
||||||
|
target_ratio0=target_ratio0,
|
||||||
|
target_ratio1=target_ratio1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 带mask的Encoder
|
||||||
|
attn_layers = [
|
||||||
|
EncoderLayerWithMask(
|
||||||
|
AttentionLayer(
|
||||||
|
FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
||||||
|
D1, configs.n_heads
|
||||||
|
),
|
||||||
|
d_model=D1,
|
||||||
|
d_ff=configs.d_ff,
|
||||||
|
dropout=configs.dropout,
|
||||||
|
activation=configs.activation
|
||||||
|
) for _ in range(configs.e_layers)
|
||||||
|
]
|
||||||
|
self.encoder = EncoderWithMask(
|
||||||
|
attn_layers,
|
||||||
|
norm_layer=nn.Sequential(Transpose(1, 2), nn.BatchNorm1d(D1), Transpose(1, 2))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 门控聚合头(与token数无关)
|
||||||
|
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||||||
|
self.head = AttnPoolHeadForecast(D1, self.pred_len, num_slots=agg_slots, dropout=configs.dropout)
|
||||||
|
elif self.task_name in ('imputation', 'anomaly_detection'):
|
||||||
|
self.head = AttnPoolHeadSeq(D1, self.seq_len, num_slots=agg_slots, dropout=configs.dropout)
|
||||||
|
elif self.task_name == 'classification':
|
||||||
|
self.head_cls = AttnPoolHeadCls(D1, n_vars=self.enc_in, num_class=configs.num_class, num_slots=agg_slots, dropout=configs.dropout)
|
||||||
|
|
||||||
|
# --------- 归一化/反归一化 ---------
|
||||||
|
def _pre_norm(self, x):
|
||||||
|
means = x.mean(1, keepdim=True).detach()
|
||||||
|
x = x - means
|
||||||
|
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
||||||
|
x = x / stdev
|
||||||
|
return x, means, stdev
|
||||||
|
|
||||||
|
def _denorm(self, y, means, stdev, length):
|
||||||
|
return y * (stdev[:, 0, :].unsqueeze(1).repeat(1, length, 1)) + \
|
||||||
|
(means[:, 0, :].unsqueeze(1).repeat(1, length, 1))
|
||||||
|
|
||||||
|
# --------- DC + Transformer Encoder(携带 key_padding_mask) ----------
|
||||||
|
def _embed_and_encode(self, x_enc):
|
||||||
|
"""
|
||||||
|
x_enc: (B, L, C)
|
||||||
|
返回:
|
||||||
|
enc_out: (B*nvars, T_max, D1)
|
||||||
|
n_vars: int
|
||||||
|
key_padding_mask: (B*nvars, T_max)
|
||||||
|
aux: dict
|
||||||
|
"""
|
||||||
|
B, L, C = x_enc.shape
|
||||||
|
x_vars = x_enc.permute(0, 2, 1) # (B, nvars, L)
|
||||||
|
enc_out, key_padding_mask, n_vars, aux = self.dc_embedding(x_vars)
|
||||||
|
enc_out, _ = self.encoder(enc_out, attn_mask=None, key_padding_mask=key_padding_mask)
|
||||||
|
return enc_out, n_vars, key_padding_mask, B, aux
|
||||||
|
|
||||||
|
# --------- 各任务前向 ---------
|
||||||
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
||||||
|
x_enc, means, stdev = self._pre_norm(x_enc)
|
||||||
|
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||||
|
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, pred_len, nvars)
|
||||||
|
dec_out = self._denorm(dec_out, means, stdev, self.pred_len)
|
||||||
|
return dec_out, aux
|
||||||
|
|
||||||
|
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
||||||
|
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
||||||
|
means = means.unsqueeze(1).detach()
|
||||||
|
x = x_enc - means
|
||||||
|
x = x.masked_fill(mask == 0, 0)
|
||||||
|
stdev = torch.sqrt(torch.sum(x * x, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
|
||||||
|
stdev = stdev.unsqueeze(1).detach()
|
||||||
|
x = x / stdev
|
||||||
|
|
||||||
|
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x)
|
||||||
|
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||||||
|
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||||||
|
return dec_out, aux
|
||||||
|
|
||||||
|
def anomaly_detection(self, x_enc):
|
||||||
|
x_enc, means, stdev = self._pre_norm(x_enc)
|
||||||
|
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||||
|
dec_out = self.head(enc_out, key_padding_mask, n_vars, B) # (B, seq_len, nvars)
|
||||||
|
dec_out = self._denorm(dec_out, means, stdev, self.seq_len)
|
||||||
|
return dec_out, aux
|
||||||
|
|
||||||
|
def classification(self, x_enc, x_mark_enc):
|
||||||
|
x_enc, _, _ = self._pre_norm(x_enc)
|
||||||
|
enc_out, n_vars, key_padding_mask, B, aux = self._embed_and_encode(x_enc)
|
||||||
|
logits = self.head_cls(enc_out, key_padding_mask, n_vars, B) # (B, num_class)
|
||||||
|
return logits, aux
|
||||||
|
|
||||||
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
||||||
|
if self.task_name in ('long_term_forecast', 'short_term_forecast'):
|
||||||
|
dec_out, aux = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||||||
|
return dec_out[:, -self.pred_len:, :], aux # [B, L, D], aux含ratio losses
|
||||||
|
if self.task_name == 'imputation':
|
||||||
|
dec_out, aux = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
||||||
|
return dec_out, aux
|
||||||
|
if self.task_name == 'anomaly_detection':
|
||||||
|
dec_out, aux = self.anomaly_detection(x_enc)
|
||||||
|
return dec_out, aux
|
||||||
|
if self.task_name == 'classification':
|
||||||
|
logits, aux = self.classification(x_enc, x_mark_enc)
|
||||||
|
return logits, aux
|
||||||
|
return None, None
|
339
models/DC_hnet.py
Normal file
339
models/DC_hnet.py
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# 来自你的代码库(可直接使用)
|
||||||
|
from hnet.modules.dc import RoutingModule, ChunkLayer
|
||||||
|
from hnet.modules.isotropic import Isotropic
|
||||||
|
from hnet.models.config_hnet import HNetConfig, SSMConfig, AttnConfig
|
||||||
|
|
||||||
|
# -------------------- 辅助 --------------------
|
||||||
|
def create_isotropic_encoder(d_model, arch="m", height=4, device=None, dtype=None):
|
||||||
|
"""创建简化的Isotropic编码器"""
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
# 创建HNetConfig,确保list字段有足够的元素
|
||||||
|
config = HNetConfig(
|
||||||
|
arch_layout=[f"{arch}{height}"],
|
||||||
|
d_model=[d_model],
|
||||||
|
d_intermediate=[d_model * 2],
|
||||||
|
ssm_cfg=SSMConfig(
|
||||||
|
d_conv=4,
|
||||||
|
expand=2,
|
||||||
|
d_state=128,
|
||||||
|
chunk_size=256
|
||||||
|
),
|
||||||
|
attn_cfg=AttnConfig(
|
||||||
|
num_heads=[8], # 确保有至少一个元素
|
||||||
|
rotary_emb_dim=[0], # 确保有至少一个元素
|
||||||
|
window_size=[-1] # 确保有至少一个元素
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Isotropic(
|
||||||
|
config=config,
|
||||||
|
pos_idx=0,
|
||||||
|
stage_idx=0,
|
||||||
|
**factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def ratio_loss(boundary_mask: torch.Tensor, boundary_prob: torch.Tensor, target_N: int) -> torch.Tensor:
|
||||||
|
F_act = boundary_mask.float().mean(dim=1) # (B,)
|
||||||
|
G_prob = boundary_prob[..., 1].mean(dim=1) # (B,)
|
||||||
|
N = float(target_N)
|
||||||
|
loss = N / (N - 1.0) * (((N - 1.0) * F_act) + (1.0 - F_act) * (1.0 - G_prob))
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
mask_f = mask.float().unsqueeze(-1) # (B, L, 1)
|
||||||
|
s = (x * mask_f).sum(dim=1) # (B, D)
|
||||||
|
denom = mask_f.sum(dim=1).clamp_min(1.0)
|
||||||
|
return s / denom
|
||||||
|
|
||||||
|
# -------------------- 多层Encoder(金字塔):每层Mamba2 + 路由下采样,只有最终有主网络 --------------------
|
||||||
|
class PyramidEncoders_NoDechunk(nn.Module):
|
||||||
|
"""
|
||||||
|
层级结构(仅编码器逐层压缩;主网络只在最终一层):
|
||||||
|
输入 x0: (B, L0, 1)
|
||||||
|
- 线性升维 -> D0
|
||||||
|
For s = 0..S-1:
|
||||||
|
Es(Mamba2, D_s) -> h_s (B, L_s, D_s)
|
||||||
|
路由 + 下采样 -> x_{s+1} (B, L_{s+1}, D_s), mask_{s+1}
|
||||||
|
维度扩展 D_s -> D_{s+1}(拼接共享向量)
|
||||||
|
最终 x_S: (B, L_S, D_S) 送入单一主网络 M (Transformer/Mamba)
|
||||||
|
跨尺度融合(不去分块):融合 E^0 的 pooled_enc0 与 主网络 pooled_main
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_models: List[int], # [D0, D1, ..., D_S] 单调非降
|
||||||
|
encoder_cfg_per_stage: List[dict], # S个编码器配置(必须 arch='m'/'M')
|
||||||
|
main_cfg: dict, # 单一主网络配置(在最压缩序列上工作)
|
||||||
|
fusion_dropout: float = 0.1,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
assert len(d_models) >= 1
|
||||||
|
S = len(d_models) - 1
|
||||||
|
assert S == len(encoder_cfg_per_stage), "stage数等于encoder配置数"
|
||||||
|
for i in range(S):
|
||||||
|
assert d_models[i+1] >= d_models[i], "需满足 D_s <= D_{s+1}(宽度单调增加)"
|
||||||
|
assert encoder_cfg_per_stage[i].get("arch", "m") in ("m", "M"), "Encoder必须为Mamba2"
|
||||||
|
|
||||||
|
self.S = S
|
||||||
|
self.d_models = d_models
|
||||||
|
|
||||||
|
# 输入升维到 D0
|
||||||
|
self.input_proj = nn.Linear(1, d_models[0], **factory_kwargs)
|
||||||
|
|
||||||
|
# 每层编码器 + 路由 + 下采样 + 扩宽参数
|
||||||
|
self.encoders = nn.ModuleList()
|
||||||
|
self.routers = nn.ModuleList()
|
||||||
|
self.chunks = nn.ModuleList()
|
||||||
|
self.pad_vectors = nn.ParameterList()
|
||||||
|
for s in range(S):
|
||||||
|
self.encoders.append(
|
||||||
|
create_isotropic_encoder(
|
||||||
|
d_model=d_models[s],
|
||||||
|
**{k: v for k, v in encoder_cfg_per_stage[s].items() if k != "d_model"},
|
||||||
|
**factory_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.routers.append(RoutingModule(d_models[s], **factory_kwargs))
|
||||||
|
self.chunks.append(ChunkLayer())
|
||||||
|
delta = d_models[s+1] - d_models[s]
|
||||||
|
self.pad_vectors.append(nn.Parameter(torch.zeros(delta, **factory_kwargs)) if delta > 0 else nn.Parameter(torch.empty(0, **factory_kwargs)))
|
||||||
|
|
||||||
|
# 最终唯一的主网络:在 D_S & L_S 上运行
|
||||||
|
self.main_network = create_isotropic_encoder(
|
||||||
|
d_model=d_models[-1],
|
||||||
|
**{k: v for k, v in main_cfg.items() if k != "d_model"},
|
||||||
|
**factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 跨尺度融合:将 pooled_enc0(D0) 投到 D_S 并与 pooled_main(D_S) 融合 -> D_S
|
||||||
|
self.proj_enc0_to_DS = nn.Linear(d_models[0], d_models[-1], **factory_kwargs)
|
||||||
|
self.fusion_head = nn.Sequential(
|
||||||
|
nn.Linear(d_models[-1] + d_models[-1], d_models[-1], **factory_kwargs),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(fusion_dropout),
|
||||||
|
nn.Linear(d_models[-1], d_models[-1], **factory_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_width(self, x: torch.Tensor, pad_vec: nn.Parameter) -> torch.Tensor:
|
||||||
|
if pad_vec.numel() == 0:
|
||||||
|
return x
|
||||||
|
early = x.shape[:-1]
|
||||||
|
return torch.cat([x, pad_vec.expand(*early, -1)], dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x_scalar: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||||
|
"""
|
||||||
|
x_scalar: (B, L) 或 (B, L, 1)
|
||||||
|
mask: (B, L) bool
|
||||||
|
返回:
|
||||||
|
fused_vec: (B, D_S)
|
||||||
|
debug: 可选
|
||||||
|
aux: 包含各层路由信息(供ratio loss)
|
||||||
|
"""
|
||||||
|
if x_scalar.dim() == 2:
|
||||||
|
x_scalar = x_scalar.unsqueeze(-1) # (B, L, 1)
|
||||||
|
B, L, _ = x_scalar.shape
|
||||||
|
device = x_scalar.device
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# 初始升维到 D0
|
||||||
|
x = self.input_proj(x_scalar) # (B, L0, D0)
|
||||||
|
cur_mask = mask
|
||||||
|
|
||||||
|
pooled_enc0 = None
|
||||||
|
aux_per_stage = []
|
||||||
|
seq_debug = [] if return_seq else None
|
||||||
|
|
||||||
|
# 逐层:Encoder(Mamba2)->Routing->Chunk->Expand D
|
||||||
|
for s in range(self.S):
|
||||||
|
d_in = self.d_models[s]
|
||||||
|
# 细粒度编码(未压缩序列)
|
||||||
|
h_enc = self.encoders[s](x, mask=cur_mask) # (B, L_s, D_s)
|
||||||
|
|
||||||
|
if s == 0:
|
||||||
|
pooled_enc0 = masked_mean(h_enc, cur_mask) # (B, D0)
|
||||||
|
|
||||||
|
# 路由 + 下采样(得到更短序列)
|
||||||
|
bpred = self.routers[s](h_enc, mask=cur_mask)
|
||||||
|
x_next, _, _, mask_next = self.chunks[s](h_enc, bpred.boundary_mask, mask=cur_mask) # (B, L_{s+1}, D_s)
|
||||||
|
|
||||||
|
# 扩展宽度 D_s -> D_{s+1}
|
||||||
|
x_next = self._expand_width(x_next, self.pad_vectors[s]) # (B, L_{s+1}, D_{s+1})
|
||||||
|
|
||||||
|
# 推进到下一层
|
||||||
|
x, cur_mask = x_next, mask_next
|
||||||
|
|
||||||
|
aux_per_stage.append({
|
||||||
|
"boundary_mask": bpred.boundary_mask,
|
||||||
|
"boundary_prob": bpred.boundary_prob,
|
||||||
|
"selected_probs": bpred.selected_probs,
|
||||||
|
})
|
||||||
|
if return_seq:
|
||||||
|
seq_debug.append({"stage": s, "seq": x, "mask": cur_mask})
|
||||||
|
|
||||||
|
# 现在 x: (B, L_S, D_S), cur_mask: (B, L_S)
|
||||||
|
# 最终单一主网络在最压缩序列上
|
||||||
|
h_main = self.main_network(x, mask=cur_mask) # (B, L_S, D_S)
|
||||||
|
|
||||||
|
# 主网络池化
|
||||||
|
if cur_mask is None:
|
||||||
|
pooled_main = h_main.mean(dim=1) # (B, D_S)
|
||||||
|
else:
|
||||||
|
pooled_main = (h_main * cur_mask.float().unsqueeze(-1)).sum(dim=1) / \
|
||||||
|
cur_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)
|
||||||
|
|
||||||
|
# 跨尺度融合:E^0 全局池化 与 主网络池化
|
||||||
|
pooled_enc0_in_DS = self.proj_enc0_to_DS(pooled_enc0) # (B, D_S)
|
||||||
|
fused = torch.cat([pooled_enc0_in_DS, pooled_main], dim=-1) # (B, 2*D_S)
|
||||||
|
fused = self.fusion_head(fused) # (B, D_S)
|
||||||
|
|
||||||
|
aux = {"per_stage": aux_per_stage}
|
||||||
|
if return_seq:
|
||||||
|
return fused, {"stages": seq_debug, "main_seq": h_main, "main_mask": cur_mask}, aux
|
||||||
|
else:
|
||||||
|
return fused, None, aux
|
||||||
|
|
||||||
|
# -------------------- 顶层:多通道融合 + 分类头(仅一个主网络) --------------------
|
||||||
|
@dataclass
|
||||||
|
class HierEncodersSingleMainConfig:
|
||||||
|
num_channels: int
|
||||||
|
d_models: List[int] # [D0, D1, ..., D_S] 单调非降
|
||||||
|
num_classes: int
|
||||||
|
encoder_cfg_per_stage: List[dict] # S个编码器配置(均为Mamba2, height≈4)
|
||||||
|
main_cfg: dict # 单一主网络配置(Transformer或Mamba2),d_model自动用D_S
|
||||||
|
target_compression_N_per_stage: List[int]
|
||||||
|
share_channel: bool = True
|
||||||
|
fusion_across_channels: Literal["mean", "concat"] = "mean"
|
||||||
|
dropout: float = 0.1
|
||||||
|
|
||||||
|
class HierEncodersSingleMainClassifier(nn.Module):
|
||||||
|
def __init__(self, cfg: HierEncodersSingleMainConfig, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
|
||||||
|
S = len(cfg.d_models) - 1
|
||||||
|
assert S == len(cfg.encoder_cfg_per_stage) == len(cfg.target_compression_N_per_stage), "stage数不一致"
|
||||||
|
|
||||||
|
if cfg.share_channel:
|
||||||
|
self.channel_encoder = PyramidEncoders_NoDechunk(
|
||||||
|
d_models=cfg.d_models,
|
||||||
|
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||||
|
main_cfg=cfg.main_cfg,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.channel_encoder = nn.ModuleList([
|
||||||
|
PyramidEncoders_NoDechunk(
|
||||||
|
d_models=cfg.d_models,
|
||||||
|
encoder_cfg_per_stage=cfg.encoder_cfg_per_stage,
|
||||||
|
main_cfg=cfg.main_cfg,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(cfg.num_channels)
|
||||||
|
])
|
||||||
|
|
||||||
|
fusion_dim = (cfg.num_channels * cfg.d_models[-1]) if cfg.fusion_across_channels == "concat" \
|
||||||
|
else cfg.d_models[-1]
|
||||||
|
self.dropout = nn.Dropout(cfg.dropout)
|
||||||
|
self.head = nn.Linear(fusion_dim, cfg.num_classes, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_seq: bool = False):
|
||||||
|
"""
|
||||||
|
x: (B, L, N) 多通道输入
|
||||||
|
mask: (B, L) 时序mask
|
||||||
|
"""
|
||||||
|
B, L, N = x.shape
|
||||||
|
assert N == self.cfg.num_channels
|
||||||
|
|
||||||
|
channel_vecs: List[torch.Tensor] = []
|
||||||
|
ratio_losses = []
|
||||||
|
seq_dbg_all = [] if return_seq else None
|
||||||
|
|
||||||
|
for c in range(N):
|
||||||
|
x_c = x[..., c] # (B, L)
|
||||||
|
if self.cfg.share_channel:
|
||||||
|
vec, seq_dbg, aux = self.channel_encoder(x_c, mask=mask, return_seq=return_seq)
|
||||||
|
else:
|
||||||
|
vec, seq_dbg, aux = self.channel_encoder[c](x_c, mask=mask, return_seq=return_seq)
|
||||||
|
|
||||||
|
# ratio loss 累加(每个encoder stage一项)
|
||||||
|
total_rl = 0.0
|
||||||
|
for s, aux_s in enumerate(aux["per_stage"]):
|
||||||
|
rl = ratio_loss(aux_s["boundary_mask"], aux_s["boundary_prob"], self.cfg.target_compression_N_per_stage[s])
|
||||||
|
total_rl = total_rl + rl
|
||||||
|
ratio_losses.append(total_rl)
|
||||||
|
|
||||||
|
channel_vecs.append(vec)
|
||||||
|
if return_seq:
|
||||||
|
seq_dbg_all.append(seq_dbg)
|
||||||
|
|
||||||
|
if self.cfg.fusion_across_channels == "concat":
|
||||||
|
fused = torch.cat(channel_vecs, dim=-1) # (B, N*D_S)
|
||||||
|
else:
|
||||||
|
fused = torch.stack(channel_vecs, dim=1).mean(dim=1) # (B, D_S)
|
||||||
|
|
||||||
|
fused = self.dropout(fused)
|
||||||
|
logits = self.head(fused)
|
||||||
|
|
||||||
|
aux_all = {"ratio_loss": torch.stack(ratio_losses).mean()}
|
||||||
|
if return_seq:
|
||||||
|
return logits, seq_dbg_all, aux_all
|
||||||
|
else:
|
||||||
|
return logits, None, aux_all
|
||||||
|
|
||||||
|
# -------------------- 使用示例 --------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
符合要求:
|
||||||
|
- 多层仅增加编码器数量(每层Mamba2 + 动态分块),主网络只有最终一个
|
||||||
|
- 序列长度逐层缩短(由DC决定),通道维度 d_model 单调增大(SpaceByte式共享向量拼接)
|
||||||
|
- 不使用去分块(dechunk);跨尺度融合用 E^0 的全局池化 + 最终主网络池化
|
||||||
|
"""
|
||||||
|
B, L, N = 8, 1024, 6
|
||||||
|
num_classes = 7
|
||||||
|
d_models = [128, 256, 512] # D0 <= D1 <= D2
|
||||||
|
|
||||||
|
encoder_cfg_per_stage = [
|
||||||
|
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 0 encoder (Mamba2)
|
||||||
|
dict(arch="m", height=4, ssm_cfg=dict(), attn_cfg=dict()), # stage 1 encoder (Mamba2)
|
||||||
|
]
|
||||||
|
main_cfg = dict(
|
||||||
|
arch="T", height=12, ssm_cfg=dict(), attn_cfg=dict(num_heads=8) # 最终主网络(较重)
|
||||||
|
)
|
||||||
|
target_compression_N_per_stage = [4, 4]
|
||||||
|
|
||||||
|
cfg = HierEncodersSingleMainConfig(
|
||||||
|
num_channels=N,
|
||||||
|
d_models=d_models,
|
||||||
|
num_classes=num_classes,
|
||||||
|
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||||||
|
main_cfg=main_cfg,
|
||||||
|
target_compression_N_per_stage=target_compression_N_per_stage,
|
||||||
|
share_channel=True,
|
||||||
|
fusion_across_channels="mean",
|
||||||
|
dropout=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = HierEncodersSingleMainClassifier(cfg).cuda().train()
|
||||||
|
x = torch.randn(B, L, N, device="cuda")
|
||||||
|
mask = torch.ones(B, L, dtype=torch.bool, device="cuda")
|
||||||
|
|
||||||
|
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||||||
|
y = torch.randint(0, num_classes, (B,), device="cuda")
|
||||||
|
cls_loss = F.cross_entropy(logits, y)
|
||||||
|
ratio_reg = 0.03 * aux["ratio_loss"]
|
||||||
|
loss = cls_loss + ratio_reg
|
||||||
|
loss.backward()
|
||||||
|
print("logits:", logits.shape, "loss:", float(loss))
|
138
models/vanillaMamba-Copy1.py
Normal file
138
models/vanillaMamba-Copy1.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mamba_ssm import Mamba2
|
||||||
|
|
||||||
|
|
||||||
|
class ValueEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。
|
||||||
|
不包含 temporal embedding 和 positional embedding。
|
||||||
|
"""
|
||||||
|
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(in_dim, d_model, bias=bias)
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: [B, L, 1] -> [B, L, d_model]
|
||||||
|
return self.dropout(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelMambaBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
针对单个通道的两层 Mamba-2 处理块:
|
||||||
|
- 输入: [B, L, 1],先做投影到 d_model
|
||||||
|
- 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接
|
||||||
|
- 每层后接 LayerNorm
|
||||||
|
- 输出: [B, L, d_model]
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
|
||||||
|
|
||||||
|
# 两层 Mamba-2
|
||||||
|
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||||
|
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||||
|
|
||||||
|
# 每层后接的归一化
|
||||||
|
self.ln1 = nn.LayerNorm(d_model)
|
||||||
|
self.ln2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x_ch: [B, L, 1]
|
||||||
|
x = self.embed(x_ch) # [B, L, d_model]
|
||||||
|
|
||||||
|
# 第一层 + 残差
|
||||||
|
y1 = self.mamba1(x) # [B, L, d_model]
|
||||||
|
y1 = self.ln1(x + y1) # 残差1 + LN
|
||||||
|
|
||||||
|
# 第二层 + 残差
|
||||||
|
y2 = self.mamba2(y1) # [B, L, d_model]
|
||||||
|
y2 = self.ln2(y1 + y2) # 残差2 + LN
|
||||||
|
|
||||||
|
return y2 # [B, L, d_model]
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
按通道独立处理的 Mamba-2 分类模型:
|
||||||
|
- 将输入的每个通道拆开,分别使用独立的两层 Mamba2(含两处残差)
|
||||||
|
- 每个通道得到 [B, L, d_model] 输出
|
||||||
|
- 取各通道最后时间步的表示拼接,接分类头
|
||||||
|
输入:
|
||||||
|
- x_enc: [B, L, D] 多变量时间序列
|
||||||
|
输出:
|
||||||
|
- logits: [B, num_class]
|
||||||
|
"""
|
||||||
|
def __init__(self, configs):
|
||||||
|
super().__init__()
|
||||||
|
self.task_name = getattr(configs, 'task_name', 'classification')
|
||||||
|
assert self.task_name == 'classification', "当前模型仅实现 classification 任务"
|
||||||
|
|
||||||
|
# 基本配置
|
||||||
|
self.enc_in = configs.enc_in # 通道数 D
|
||||||
|
self.d_model = configs.d_model # 每通道的模型维度
|
||||||
|
self.num_class = configs.num_class
|
||||||
|
self.dropout = getattr(configs, 'dropout', 0.1)
|
||||||
|
|
||||||
|
# Mamba-2 超参数(按需从 configs 读取)
|
||||||
|
# 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求
|
||||||
|
m2_kwargs = dict(
|
||||||
|
d_state=getattr(configs, 'd_state', 64),
|
||||||
|
d_conv=getattr(configs, 'd_conv', 4),
|
||||||
|
expand=getattr(configs, 'expand', 2),
|
||||||
|
headdim=getattr(configs, 'headdim', 64),
|
||||||
|
d_ssm=getattr(configs, 'd_ssm', None),
|
||||||
|
ngroups=getattr(configs, 'ngroups', 1),
|
||||||
|
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
|
||||||
|
D_has_hdim=getattr(configs, 'D_has_hdim', False),
|
||||||
|
rmsnorm=getattr(configs, 'rmsnorm', True),
|
||||||
|
norm_before_gate=getattr(configs, 'norm_before_gate', False),
|
||||||
|
dt_min=getattr(configs, 'dt_min', 0.001),
|
||||||
|
dt_max=getattr(configs, 'dt_max', 0.1),
|
||||||
|
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
|
||||||
|
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
|
||||||
|
bias=getattr(configs, 'bias', False),
|
||||||
|
conv_bias=getattr(configs, 'conv_bias', True),
|
||||||
|
chunk_size=getattr(configs, 'chunk_size', 256),
|
||||||
|
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为每个通道构建独立的两层 Mamba2 处理块
|
||||||
|
self.channel_blocks = nn.ModuleList([
|
||||||
|
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
|
||||||
|
for _ in range(self.enc_in)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Linear(self.d_model * self.enc_in, self.num_class)
|
||||||
|
)
|
||||||
|
|
||||||
|
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x_enc: [B, L, D]
|
||||||
|
B, L, D = x_enc.shape
|
||||||
|
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||||
|
|
||||||
|
per_channel_last = []
|
||||||
|
for c in range(D):
|
||||||
|
# 取出单通道序列 [B, L] -> [B, L, 1]
|
||||||
|
x_ch = x_enc[:, :, c].unsqueeze(-1)
|
||||||
|
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||||
|
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
|
||||||
|
|
||||||
|
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
|
||||||
|
h_last = torch.cat(per_channel_last, dim=-1)
|
||||||
|
|
||||||
|
# 分类头
|
||||||
|
h_last = self.act(h_last)
|
||||||
|
logits = self.head(h_last) # [B, num_class]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
|
||||||
|
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||||
|
return self.classification(x_enc)
|
203
models/vanillaMamba.py
Normal file
203
models/vanillaMamba.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mamba_ssm import Mamba2
|
||||||
|
|
||||||
|
|
||||||
|
class ValueEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。
|
||||||
|
不包含 temporal embedding 和 positional embedding。
|
||||||
|
"""
|
||||||
|
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(in_dim, d_model, bias=bias)
|
||||||
|
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: [B, L, 1] -> [B, L, d_model]
|
||||||
|
return self.dropout(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelMambaBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
针对单个通道的两层 Mamba-2 处理块:
|
||||||
|
- 输入: [B, L, 1],先做投影到 d_model
|
||||||
|
- 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接
|
||||||
|
- 每层后接 LayerNorm
|
||||||
|
- 输出: [B, L, d_model]
|
||||||
|
"""
|
||||||
|
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
|
||||||
|
|
||||||
|
# 两层 Mamba-2
|
||||||
|
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||||
|
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||||
|
|
||||||
|
# 每层后接的归一化
|
||||||
|
self.ln1 = nn.LayerNorm(d_model)
|
||||||
|
self.ln2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x_ch: [B, L, 1]
|
||||||
|
x = self.embed(x_ch) # [B, L, d_model]
|
||||||
|
|
||||||
|
# 第一层 + 残差
|
||||||
|
y1 = self.mamba1(x) # [B, L, d_model]
|
||||||
|
y1 = self.ln1(x + y1) # 残差1 + LN
|
||||||
|
|
||||||
|
# 第二层 + 残差
|
||||||
|
y2 = self.mamba2(y1) # [B, L, d_model]
|
||||||
|
y2 = self.ln2(y1 + y2) # 残差2 + LN
|
||||||
|
|
||||||
|
return y2 # [B, L, d_model]
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""
|
||||||
|
按通道独立处理的 Mamba-2 模型,支持:
|
||||||
|
- 分类:各通道独立提取,取最后时刻拼接 -> 分类头
|
||||||
|
- 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接
|
||||||
|
注意:预测输出通道数与输入通道数严格相同(逐通道预测)。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
- x_enc: [B, L, D] 多变量时间序列
|
||||||
|
- x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注)
|
||||||
|
|
||||||
|
输出:
|
||||||
|
- classification: logits [B, num_class]
|
||||||
|
- forecast: [B, pred_len, D]
|
||||||
|
"""
|
||||||
|
def __init__(self, configs):
|
||||||
|
super().__init__()
|
||||||
|
# 任务类型
|
||||||
|
self.task_name = getattr(configs, 'task_name', 'classification')
|
||||||
|
assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \
|
||||||
|
"只支持 classification / long_term_forecast / short_term_forecast"
|
||||||
|
|
||||||
|
# 基本配置
|
||||||
|
self.enc_in = configs.enc_in # 通道数 D
|
||||||
|
self.d_model = configs.d_model # 每通道的模型维度
|
||||||
|
self.num_class = getattr(configs, 'num_class', None)
|
||||||
|
self.dropout = getattr(configs, 'dropout', 0.1)
|
||||||
|
|
||||||
|
# 预测相关
|
||||||
|
self.seq_len = getattr(configs, 'seq_len', None)
|
||||||
|
self.pred_len = getattr(configs, 'pred_len', None)
|
||||||
|
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||||
|
assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len"
|
||||||
|
# 输出通道必须与输入通道一致
|
||||||
|
self.c_out = getattr(configs, 'c_out', self.enc_in)
|
||||||
|
assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致"
|
||||||
|
|
||||||
|
# Mamba-2 超参数
|
||||||
|
m2_kwargs = dict(
|
||||||
|
d_state=getattr(configs, 'd_state', 64),
|
||||||
|
d_conv=getattr(configs, 'd_conv', 4),
|
||||||
|
expand=getattr(configs, 'expand', 2),
|
||||||
|
headdim=getattr(configs, 'headdim', 64),
|
||||||
|
d_ssm=getattr(configs, 'd_ssm', None),
|
||||||
|
ngroups=getattr(configs, 'ngroups', 1),
|
||||||
|
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
|
||||||
|
D_has_hdim=getattr(configs, 'D_has_hdim', False),
|
||||||
|
rmsnorm=getattr(configs, 'rmsnorm', True),
|
||||||
|
norm_before_gate=getattr(configs, 'norm_before_gate', False),
|
||||||
|
dt_min=getattr(configs, 'dt_min', 0.001),
|
||||||
|
dt_max=getattr(configs, 'dt_max', 0.1),
|
||||||
|
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
|
||||||
|
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
|
||||||
|
bias=getattr(configs, 'bias', False),
|
||||||
|
conv_bias=getattr(configs, 'conv_bias', True),
|
||||||
|
chunk_size=getattr(configs, 'chunk_size', 256),
|
||||||
|
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为每个通道构建独立的两层 Mamba2 处理块
|
||||||
|
self.channel_blocks = nn.ModuleList([
|
||||||
|
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
|
||||||
|
for _ in range(self.enc_in)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
|
||||||
|
if self.task_name == 'classification':
|
||||||
|
assert self.num_class is not None, "classification 需要提供 num_class"
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Linear(self.d_model * self.enc_in, self.num_class)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测头:
|
||||||
|
# - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model]
|
||||||
|
# - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1]
|
||||||
|
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||||
|
self.predict_linear = nn.Linear(self.seq_len, self.pred_len)
|
||||||
|
self.projection = nn.Linear(self.d_model, 1, bias=True)
|
||||||
|
|
||||||
|
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x_enc: [B, L, D]
|
||||||
|
B, L, D = x_enc.shape
|
||||||
|
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||||
|
|
||||||
|
per_channel_last = []
|
||||||
|
for c in range(D):
|
||||||
|
# 取出单通道序列 [B, L] -> [B, L, 1]
|
||||||
|
x_ch = x_enc[:, :, c].unsqueeze(-1)
|
||||||
|
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||||
|
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
|
||||||
|
|
||||||
|
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
|
||||||
|
h_last = torch.cat(per_channel_last, dim=-1)
|
||||||
|
|
||||||
|
# 分类头
|
||||||
|
logits = self.head(self.act(h_last)) # [B, num_class]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forecast(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
逐通道预测:
|
||||||
|
- 归一化(时间维),按通道独立提取
|
||||||
|
- 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量
|
||||||
|
- 反归一化
|
||||||
|
返回:
|
||||||
|
dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段
|
||||||
|
"""
|
||||||
|
B, L, D = x_enc.shape
|
||||||
|
assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致"
|
||||||
|
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||||
|
|
||||||
|
# Normalization (per Non-stationary Transformer)
|
||||||
|
means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D]
|
||||||
|
x = x_enc - means
|
||||||
|
stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D]
|
||||||
|
x = x / stdev
|
||||||
|
|
||||||
|
per_channel_seq = []
|
||||||
|
for c in range(D):
|
||||||
|
x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1]
|
||||||
|
h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||||
|
# 时间维映射到 L + pred_len
|
||||||
|
h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model]
|
||||||
|
# 投影回单通道
|
||||||
|
y_ch = self.projection(h_ch) # [B, L+pred_len, 1]
|
||||||
|
per_channel_seq.append(y_ch)
|
||||||
|
|
||||||
|
# 拼接通道
|
||||||
|
dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D]
|
||||||
|
|
||||||
|
# De-normalization
|
||||||
|
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
|
||||||
|
|
||||||
|
return dec_out
|
||||||
|
|
||||||
|
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
|
||||||
|
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||||
|
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
|
||||||
|
dec_out = self.forecast(x_enc) # [B, L+pred_len, D]
|
||||||
|
return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D]
|
||||||
|
elif self.task_name == 'classification':
|
||||||
|
return self.classification(x_enc)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported task: {self.task_name}")
|
6
run.py
6
run.py
@ -7,6 +7,7 @@ from exp.exp_imputation import Exp_Imputation
|
|||||||
from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
|
from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
|
||||||
from exp.exp_anomaly_detection import Exp_Anomaly_Detection
|
from exp.exp_anomaly_detection import Exp_Anomaly_Detection
|
||||||
from exp.exp_classification import Exp_Classification
|
from exp.exp_classification import Exp_Classification
|
||||||
|
from exp.exp_dc_patchtst_classification import Exp_DC_PatchTST_Classification
|
||||||
from utils.print_args import print_args
|
from utils.print_args import print_args
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -191,7 +192,10 @@ if __name__ == '__main__':
|
|||||||
elif args.task_name == 'anomaly_detection':
|
elif args.task_name == 'anomaly_detection':
|
||||||
Exp = Exp_Anomaly_Detection
|
Exp = Exp_Anomaly_Detection
|
||||||
elif args.task_name == 'classification':
|
elif args.task_name == 'classification':
|
||||||
Exp = Exp_Classification
|
if args.model == 'DC_PatchTST':
|
||||||
|
Exp = Exp_DC_PatchTST_Classification
|
||||||
|
else:
|
||||||
|
Exp = Exp_Classification
|
||||||
else:
|
else:
|
||||||
Exp = Exp_Long_Term_Forecast
|
Exp = Exp_Long_Term_Forecast
|
||||||
|
|
||||||
|
142
scripts/classification/DC_PatchTST.sh
Executable file
142
scripts/classification/DC_PatchTST.sh
Executable file
@ -0,0 +1,142 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
model_name=DC_PatchTST
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# DC_PatchTST specific parameters
|
||||||
|
d_model_stage0=64 # Stage 0 dimension (D0)
|
||||||
|
depth_enc0=1 # Stage 0 Mamba2 encoder depth
|
||||||
|
depth_enc1=1 # Stage 1 Mamba2 encoder depth
|
||||||
|
target_ratio0=0.25 # Target compression ratio for stage 0
|
||||||
|
target_ratio1=0.25 # Target compression ratio for stage 1
|
||||||
|
|
||||||
|
# EthanolConcentration dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/EthanolConcentration/ \
|
||||||
|
--model_id EthanolConcentration \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--activation gelu \
|
||||||
|
--des 'DC_PatchTST_Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0002 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--d_model_stage0 $d_model_stage0 \
|
||||||
|
--depth_enc0 $depth_enc0 \
|
||||||
|
--depth_enc1 $depth_enc1 \
|
||||||
|
--target_ratio0 $target_ratio0 \
|
||||||
|
--target_ratio1 $target_ratio1
|
||||||
|
|
||||||
|
# FaceDetection dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/FaceDetection/ \
|
||||||
|
--model_id FaceDetection \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--activation gelu \
|
||||||
|
--des 'DC_PatchTST_Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--d_model_stage0 $d_model_stage0 \
|
||||||
|
--depth_enc0 $depth_enc0 \
|
||||||
|
--depth_enc1 $depth_enc1 \
|
||||||
|
--target_ratio0 $target_ratio0 \
|
||||||
|
--target_ratio1 $target_ratio1
|
||||||
|
|
||||||
|
# Handwriting dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Handwriting/ \
|
||||||
|
--model_id Handwriting \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--activation gelu \
|
||||||
|
--des 'DC_PatchTST_Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--d_model_stage0 $d_model_stage0 \
|
||||||
|
--depth_enc0 $depth_enc0 \
|
||||||
|
--depth_enc1 $depth_enc1 \
|
||||||
|
--target_ratio0 $target_ratio0 \
|
||||||
|
--target_ratio1 $target_ratio1
|
||||||
|
|
||||||
|
# Heartbeat dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Heartbeat/ \
|
||||||
|
--model_id Heartbeat \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--activation gelu \
|
||||||
|
--des 'DC_PatchTST_Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--d_model_stage0 $d_model_stage0 \
|
||||||
|
--depth_enc0 $depth_enc0 \
|
||||||
|
--depth_enc1 $depth_enc1 \
|
||||||
|
--target_ratio0 $target_ratio0 \
|
||||||
|
--target_ratio1 $target_ratio1
|
||||||
|
|
||||||
|
# JapaneseVowels dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/JapaneseVowels/ \
|
||||||
|
--model_id JapaneseVowels \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--activation gelu \
|
||||||
|
--des 'DC_PatchTST_Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--d_model_stage0 $d_model_stage0 \
|
||||||
|
--depth_enc0 $depth_enc0 \
|
||||||
|
--depth_enc1 $depth_enc1 \
|
||||||
|
--target_ratio0 $target_ratio0 \
|
||||||
|
--target_ratio1 $target_ratio1
|
259
scripts/classification/vanillaMamba_classification.sh
Normal file
259
scripts/classification/vanillaMamba_classification.sh
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# vanillaMamba Classification Training Script for Multiple Datasets
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
model_name=vanillaMamba
|
||||||
|
|
||||||
|
# Create results directory if it doesn't exist
|
||||||
|
mkdir -p ./results
|
||||||
|
|
||||||
|
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3) - use Copy1 config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/UWaveGestureLibrary/ \
|
||||||
|
--model_id UWaveGestureLibrary \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 315 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 128 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_UWaveGestureLibrary' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.002 \
|
||||||
|
--train_epochs 150 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_UWaveGestureLibrary.log
|
||||||
|
|
||||||
|
# EthanolConcentration dataset (seq_len=1751, enc_in=3) - use Copy1 config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 3 \
|
||||||
|
--root_path ./dataset/EthanolConcentration/ \
|
||||||
|
--model_id EthanolConcentration \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 1751 \
|
||||||
|
--enc_in 4 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_EthanolConcentration' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 200 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_EthanolConcentration.log
|
||||||
|
|
||||||
|
# Handwriting dataset (seq_len=152, enc_in=3) - use Copy1 config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Handwriting/ \
|
||||||
|
--model_id Handwriting \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 4 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 152 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_Handwriting' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 200 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_Handwriting.log
|
||||||
|
|
||||||
|
# JapaneseVowels dataset (seq_len=29, enc_in=12) - use Copy1 config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/JapaneseVowels/ \
|
||||||
|
--model_id JapaneseVowels \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 29 \
|
||||||
|
--enc_in 12 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_JapaneseVowels' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_JapaneseVowels.log
|
||||||
|
|
||||||
|
# PEMS-SF dataset (seq_len=144, enc_in=963) - use Copy1 config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS-SF/ \
|
||||||
|
--model_id PEMS-SF \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--seq_len 144 \
|
||||||
|
--enc_in 963 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_PEMS-SF' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 150 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_PEMS-SF.log
|
||||||
|
|
||||||
|
# Heartbeat dataset (seq_len=405, enc_in=61) - use original config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Heartbeat/ \
|
||||||
|
--model_id Heartbeat \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 405 \
|
||||||
|
--enc_in 61 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_Heartbeat' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 150 \
|
||||||
|
--patience 10 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_Heartbeat.log
|
||||||
|
|
||||||
|
# FaceDetection dataset (seq_len=62, enc_in=144) - use original config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/FaceDetection/ \
|
||||||
|
--model_id FaceDetection \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 62 \
|
||||||
|
--enc_in 144 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_FaceDetection' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_FaceDetection.log
|
||||||
|
|
||||||
|
# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6) - use original config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SelfRegulationSCP1/ \
|
||||||
|
--model_id SelfRegulationSCP1 \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 896 \
|
||||||
|
--enc_in 6 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_SelfRegulationSCP1' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP1.log
|
||||||
|
|
||||||
|
# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7) - use original config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SelfRegulationSCP2/ \
|
||||||
|
--model_id SelfRegulationSCP2 \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 1152 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_SelfRegulationSCP2' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_SelfRegulationSCP2.log
|
||||||
|
|
||||||
|
# SpokenArabicDigits dataset (seq_len=93, enc_in=13) - use original config
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SpokenArabicDigits/ \
|
||||||
|
--model_id SpokenArabicDigits \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 3 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 93 \
|
||||||
|
--enc_in 13 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_state 64 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--expand 2 \
|
||||||
|
--headdim 64 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'vanillaMamba_SpokenArabicDigits' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 10 \
|
||||||
|
--revin 0 | tee ./results/vanillaMamba_SpokenArabicDigits.log
|
145
scripts/classification/xPatch_SparseChannel-Copy1.sh
Normal file
145
scripts/classification/xPatch_SparseChannel-Copy1.sh
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# xPatch_SparseChannel Classification Training Script for Multiple Datasets
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# Create results directory if it doesn't exist
|
||||||
|
mkdir -p ./results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/UWaveGestureLibrary/ \
|
||||||
|
--model_id UWaveGestureLibrary \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 315 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_UWaveGestureLibrary' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/EthanolConcentration/ \
|
||||||
|
--model_id EthanolConcentration \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 1751 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_EthanolConcentration' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log
|
||||||
|
|
||||||
|
# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Handwriting/ \
|
||||||
|
--model_id Handwriting \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 152 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_Handwriting' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log
|
||||||
|
|
||||||
|
# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/JapaneseVowels/ \
|
||||||
|
--model_id JapaneseVowels \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 29 \
|
||||||
|
--enc_in 12 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_JapaneseVowels' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS-SF/ \
|
||||||
|
--model_id PEMS-SF \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--seq_len 144 \
|
||||||
|
--enc_in 963 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_PEMS-SF' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log
|
@ -1,10 +1,66 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# xPatch_SparseChannel Classification Training Script for FaceDetection Dataset
|
# xPatch_SparseChannel Classification Training Script for Multiple Datasets
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
model_name=xPatch_SparseChannel
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# Create results directory if it doesn't exist
|
||||||
|
mkdir -p ./results
|
||||||
|
|
||||||
|
# Heartbeat dataset (seq_len=405, enc_in=61, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Heartbeat/ \
|
||||||
|
--model_id Heartbeat \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 405 \
|
||||||
|
--enc_in 61 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_Heartbeat' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 5 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_Heartbeat.log
|
||||||
|
|
||||||
|
# UWaveGestureLibrary dataset (seq_len=315, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/UWaveGestureLibrary/ \
|
||||||
|
--model_id UWaveGestureLibrary \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 315 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_UWaveGestureLibrary' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_UWaveGestureLibrary.log
|
||||||
|
|
||||||
|
# FaceDetection dataset (seq_len=62, enc_in=144, k_graph=8)
|
||||||
python -u run.py \
|
python -u run.py \
|
||||||
--task_name classification \
|
--task_name classification \
|
||||||
--is_training 1 \
|
--is_training 1 \
|
||||||
@ -12,21 +68,202 @@ python -u run.py \
|
|||||||
--model_id FaceDetection \
|
--model_id FaceDetection \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data UEA \
|
--data UEA \
|
||||||
--e_layers 3 \
|
--e_layers 2 \
|
||||||
--batch_size 64 \
|
--batch_size 64 \
|
||||||
--seq_len 62 \
|
--seq_len 62 \
|
||||||
--enc_in 144 \
|
--enc_in 144 \
|
||||||
--d_model 128 \
|
--d_model 128 \
|
||||||
--d_ff 256 \
|
--d_ff 256 \
|
||||||
--n_heads 8 \
|
--n_heads 16 \
|
||||||
--patch_len 16 \
|
--patch_len 16 \
|
||||||
--stride 8 \
|
--stride 8 \
|
||||||
--moving_avg 25 \
|
|
||||||
--dropout 0.1 \
|
--dropout 0.1 \
|
||||||
--des 'xPatch_SparseChannel_FaceDetection' \
|
--des 'xPatch_SparseChannel_FaceDetection' \
|
||||||
--itr 1 \
|
--itr 1 \
|
||||||
--learning_rate 0.0005 \
|
--learning_rate 0.0005 \
|
||||||
--train_epochs 100 \
|
--train_epochs 100 \
|
||||||
--patience 5 \
|
--patience 5 \
|
||||||
--revin 1 \
|
--revin 0 \
|
||||||
--k_graph 8
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_FaceDetection.log
|
||||||
|
|
||||||
|
# EthanolConcentration dataset (seq_len=1751, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/EthanolConcentration/ \
|
||||||
|
--model_id EthanolConcentration \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 1751 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_EthanolConcentration' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_EthanolConcentration.log
|
||||||
|
|
||||||
|
# Handwriting dataset (seq_len=152, enc_in=3, k_graph=3)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/Handwriting/ \
|
||||||
|
--model_id Handwriting \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 152 \
|
||||||
|
--enc_in 3 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_Handwriting' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 3 | tee ./results/xPatch_SparseChannel_Handwriting.log
|
||||||
|
|
||||||
|
# JapaneseVowels dataset (seq_len=29, enc_in=12, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/JapaneseVowels/ \
|
||||||
|
--model_id JapaneseVowels \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 29 \
|
||||||
|
--enc_in 12 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_JapaneseVowels' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_JapaneseVowels.log
|
||||||
|
|
||||||
|
# SelfRegulationSCP1 dataset (seq_len=896, enc_in=6, k_graph=6)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SelfRegulationSCP1/ \
|
||||||
|
--model_id SelfRegulationSCP1 \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 896 \
|
||||||
|
--enc_in 6 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_SelfRegulationSCP1' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 5 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 6 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP1.log
|
||||||
|
|
||||||
|
# SelfRegulationSCP2 dataset (seq_len=1152, enc_in=7, k_graph=7)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SelfRegulationSCP2/ \
|
||||||
|
--model_id SelfRegulationSCP2 \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 1152 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_SelfRegulationSCP2' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 5 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 7 | tee ./results/xPatch_SparseChannel_SelfRegulationSCP2.log
|
||||||
|
|
||||||
|
# SpokenArabicDigits dataset (seq_len=93, enc_in=13, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/SpokenArabicDigits/ \
|
||||||
|
--model_id SpokenArabicDigits \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 64 \
|
||||||
|
--seq_len 93 \
|
||||||
|
--enc_in 13 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_SpokenArabicDigits' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 5 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_SpokenArabicDigits.log
|
||||||
|
|
||||||
|
# PEMS-SF dataset (seq_len=144, enc_in=963, k_graph=8)
|
||||||
|
python -u run.py \
|
||||||
|
--task_name classification \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS-SF/ \
|
||||||
|
--model_id PEMS-SF \
|
||||||
|
--model $model_name \
|
||||||
|
--data UEA \
|
||||||
|
--e_layers 2 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--seq_len 144 \
|
||||||
|
--enc_in 963 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'xPatch_SparseChannel_PEMS-SF' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--train_epochs 100 \
|
||||||
|
--patience 30 \
|
||||||
|
--revin 0 \
|
||||||
|
--k_graph 8 | tee ./results/xPatch_SparseChannel_PEMS-SF.log
|
||||||
|
251
scripts/long_term_forecast/vanillaMamba_all.sh
Normal file
251
scripts/long_term_forecast/vanillaMamba_all.sh
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=vanillaMamba
|
||||||
|
|
||||||
|
# ETTm1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm1.csv \
|
||||||
|
--model_id ETTm1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTm2 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm2.csv \
|
||||||
|
--model_id ETTm2_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm2 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTh1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTh1.csv \
|
||||||
|
--model_id ETTh1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTh1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTh2 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTh2.csv \
|
||||||
|
--model_id ETTh2_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTh2 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Weather dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/weather/ \
|
||||||
|
--data_path weather.csv \
|
||||||
|
--model_id weather_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 21 \
|
||||||
|
--c_out 21 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ECL dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/electricity/ \
|
||||||
|
--data_path electricity.csv \
|
||||||
|
--model_id ECL_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 321 \
|
||||||
|
--c_out 321 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Traffic dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/traffic/ \
|
||||||
|
--data_path traffic.csv \
|
||||||
|
--model_id traffic_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 862 \
|
||||||
|
--c_out 862 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Exchange dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/exchange_rate/ \
|
||||||
|
--data_path exchange_rate.csv \
|
||||||
|
--model_id Exchange_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 8 \
|
||||||
|
--c_out 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
150
scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh
Normal file
150
scripts/long_term_forecast/xPatch_SparseChannel_PEMS.sh
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
seq_len=96
|
||||||
|
pred_len=12
|
||||||
|
learning_rate=0.003
|
||||||
|
d_model=128
|
||||||
|
d_ff=256
|
||||||
|
batch_size=128
|
||||||
|
train_epochs=10
|
||||||
|
patience=10
|
||||||
|
|
||||||
|
# PEMS03 dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS/ \
|
||||||
|
--data_path PEMS03.npz \
|
||||||
|
--model_id PEMS03 \
|
||||||
|
--model $model_name \
|
||||||
|
--data PEMS \
|
||||||
|
--features M \
|
||||||
|
--seq_len $seq_len \
|
||||||
|
--label_len 0 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 358 \
|
||||||
|
--dec_in 358 \
|
||||||
|
--c_out 358 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_model $d_model \
|
||||||
|
--d_ff $d_ff \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--batch_size $batch_size \
|
||||||
|
--learning_rate $learning_rate \
|
||||||
|
--train_epochs $train_epochs \
|
||||||
|
--patience $patience \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
|
||||||
|
# PEMS04 dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS/ \
|
||||||
|
--data_path PEMS04.npz \
|
||||||
|
--model_id PEMS04 \
|
||||||
|
--model $model_name \
|
||||||
|
--data PEMS \
|
||||||
|
--features M \
|
||||||
|
--seq_len $seq_len \
|
||||||
|
--label_len 0 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 307 \
|
||||||
|
--dec_in 307 \
|
||||||
|
--c_out 307 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_model $d_model \
|
||||||
|
--d_ff $d_ff \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--batch_size $batch_size \
|
||||||
|
--learning_rate $learning_rate \
|
||||||
|
--train_epochs $train_epochs \
|
||||||
|
--patience $patience \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
|
||||||
|
# PEMS07 dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS/ \
|
||||||
|
--data_path PEMS07.npz \
|
||||||
|
--model_id PEMS07 \
|
||||||
|
--model $model_name \
|
||||||
|
--data PEMS \
|
||||||
|
--features M \
|
||||||
|
--seq_len $seq_len \
|
||||||
|
--label_len 0 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 883 \
|
||||||
|
--dec_in 883 \
|
||||||
|
--c_out 883 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_model $d_model \
|
||||||
|
--d_ff $d_ff \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--batch_size $batch_size \
|
||||||
|
--learning_rate $learning_rate \
|
||||||
|
--train_epochs $train_epochs \
|
||||||
|
--patience $patience \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
|
||||||
|
# PEMS08 dataset
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/PEMS/ \
|
||||||
|
--data_path PEMS08.npz \
|
||||||
|
--model_id PEMS08 \
|
||||||
|
--model $model_name \
|
||||||
|
--data PEMS \
|
||||||
|
--features M \
|
||||||
|
--seq_len $seq_len \
|
||||||
|
--label_len 0 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 170 \
|
||||||
|
--dec_in 170 \
|
||||||
|
--c_out 170 \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--d_model $d_model \
|
||||||
|
--d_ff $d_ff \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--batch_size $batch_size \
|
||||||
|
--learning_rate $learning_rate \
|
||||||
|
--train_epochs $train_epochs \
|
||||||
|
--patience $patience \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
251
scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh
Normal file
251
scripts/long_term_forecast/xPatch_SparseChannel_all-Copy1.sh
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# ETTm1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm1.csv \
|
||||||
|
--model_id ETTm1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 7 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTm2 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTm2.csv \
|
||||||
|
--model_id ETTm2_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTm2 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 7 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTh1 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTh1.csv \
|
||||||
|
--model_id ETTh1_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTh1 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 7 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ETTh2 dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/ETT-small/ \
|
||||||
|
--data_path ETTh2.csv \
|
||||||
|
--model_id ETTh2_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data ETTh2 \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 7 \
|
||||||
|
--c_out 7 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 7 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Weather dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/weather/ \
|
||||||
|
--data_path weather.csv \
|
||||||
|
--model_id weather_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 21 \
|
||||||
|
--c_out 21 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# ECL dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/electricity/ \
|
||||||
|
--data_path electricity.csv \
|
||||||
|
--model_id ECL_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 321 \
|
||||||
|
--c_out 321 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Traffic dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/traffic/ \
|
||||||
|
--data_path traffic.csv \
|
||||||
|
--model_id traffic_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 862 \
|
||||||
|
--c_out 862 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Exchange dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/exchange_rate/ \
|
||||||
|
--data_path exchange_rate.csv \
|
||||||
|
--model_id Exchange_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 8 \
|
||||||
|
--c_out 8 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
77
scripts/long_term_forecast/xPatch_SparseChannel_all.sh
Normal file
77
scripts/long_term_forecast/xPatch_SparseChannel_all.sh
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# ECL dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/electricity/ \
|
||||||
|
--data_path electricity.csv \
|
||||||
|
--model_id ECL_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 321 \
|
||||||
|
--batch_size 512 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--use_multi_gpu True \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--train_epochs 50 \
|
||||||
|
--patience 5 \
|
||||||
|
--c_out 321 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
||||||
|
|
||||||
|
# Traffic dataset
|
||||||
|
for pred_len in 96 192 336 720
|
||||||
|
do
|
||||||
|
python -u run.py \
|
||||||
|
--task_name long_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/traffic/ \
|
||||||
|
--data_path traffic.csv \
|
||||||
|
--model_id traffic_$pred_len'_'$pred_len \
|
||||||
|
--model $model_name \
|
||||||
|
--data custom \
|
||||||
|
--features M \
|
||||||
|
--seq_len 96 \
|
||||||
|
--label_len 48 \
|
||||||
|
--pred_len $pred_len \
|
||||||
|
--e_layers 2 \
|
||||||
|
--d_layers 1 \
|
||||||
|
--enc_in 862 \
|
||||||
|
--batch_size 256 \
|
||||||
|
--learning_rate 0.0005 \
|
||||||
|
--use_multi_gpu True \
|
||||||
|
--lradj 'sigmoid' \
|
||||||
|
--train_epochs 50 \
|
||||||
|
--patience 5 \
|
||||||
|
--c_out 862 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 8 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1
|
||||||
|
done
|
165
scripts/short_term_forecast/vanillaMamba_M4.sh
Normal file
165
scripts/short_term_forecast/vanillaMamba_M4.sh
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=vanillaMamba
|
||||||
|
|
||||||
|
# M4 Monthly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Monthly' \
|
||||||
|
--model_id m4_Monthly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Yearly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Yearly' \
|
||||||
|
--model_id m4_Yearly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Quarterly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Quarterly' \
|
||||||
|
--model_id m4_Quarterly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Weekly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Weekly' \
|
||||||
|
--model_id m4_Weekly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Daily
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Daily' \
|
||||||
|
--model_id m4_Daily \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Hourly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Hourly' \
|
||||||
|
--model_id m4_Hourly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--expand 2 \
|
||||||
|
--d_conv 4 \
|
||||||
|
--d_state 64 \
|
||||||
|
--headdim 64 \
|
||||||
|
--ngroups 1 \
|
||||||
|
--chunk_size 256 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
165
scripts/short_term_forecast/xPatch_SparseChannel_M4.sh
Normal file
165
scripts/short_term_forecast/xPatch_SparseChannel_M4.sh
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
model_name=xPatch_SparseChannel
|
||||||
|
|
||||||
|
# M4 Monthly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Monthly' \
|
||||||
|
--model_id m4_Monthly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Yearly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Yearly' \
|
||||||
|
--model_id m4_Yearly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Quarterly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Quarterly' \
|
||||||
|
--model_id m4_Quarterly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Weekly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Weekly' \
|
||||||
|
--model_id m4_Weekly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Daily
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Daily' \
|
||||||
|
--model_id m4_Daily \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
||||||
|
|
||||||
|
# M4 Hourly
|
||||||
|
python -u run.py \
|
||||||
|
--task_name short_term_forecast \
|
||||||
|
--is_training 1 \
|
||||||
|
--root_path ./dataset/m4 \
|
||||||
|
--seasonal_patterns 'Hourly' \
|
||||||
|
--model_id m4_Hourly \
|
||||||
|
--model $model_name \
|
||||||
|
--data m4 \
|
||||||
|
--features M \
|
||||||
|
--e_layers 2 \
|
||||||
|
--enc_in 1 \
|
||||||
|
--c_out 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--d_model 128 \
|
||||||
|
--d_ff 256 \
|
||||||
|
--n_heads 16 \
|
||||||
|
--patch_len 16 \
|
||||||
|
--stride 8 \
|
||||||
|
--k_graph 1 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--revin 1 \
|
||||||
|
--des 'Exp' \
|
||||||
|
--itr 1 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--loss 'SMAPE'
|
209
test_DC_hnet.py
Normal file
209
test_DC_hnet.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试DC_hnet模型的脚本
|
||||||
|
用于验证时间序列分类模型能否正常运行并得到期望的输出形状
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加当前目录到Python路径
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from models.DC_hnet import HierEncodersSingleMainConfig, HierEncodersSingleMainClassifier
|
||||||
|
|
||||||
|
def test_dc_hnet_model():
|
||||||
|
"""测试DC_hnet时间序列分类模型"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试DC_hnet时间序列分类模型")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 设置设备
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
|
||||||
|
# 模型参数配置
|
||||||
|
B, L, N = 8, 512, 6 # batch_size, seq_length, num_channels
|
||||||
|
num_classes = 10 # 分类数
|
||||||
|
d_models = [64, 128, 256] # 各层维度,单调递增
|
||||||
|
|
||||||
|
print(f"输入形状: (B={B}, L={L}, N={N})")
|
||||||
|
print(f"分类数: {num_classes}")
|
||||||
|
print(f"模型维度: {d_models}")
|
||||||
|
|
||||||
|
# 编码器配置(每层都是Mamba)
|
||||||
|
encoder_cfg_per_stage = [
|
||||||
|
dict(arch="m", height=2), # stage 0: Mamba2, 2层
|
||||||
|
dict(arch="m", height=3), # stage 1: Mamba2, 3层
|
||||||
|
]
|
||||||
|
|
||||||
|
# 主网络配置(使用Transformer)
|
||||||
|
main_cfg = dict(
|
||||||
|
arch="T", height=4 # Transformer, 4层
|
||||||
|
)
|
||||||
|
|
||||||
|
# 压缩目标
|
||||||
|
target_compression_N_per_stage = [2, 3] # 每层压缩比例
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
cfg = HierEncodersSingleMainConfig(
|
||||||
|
num_channels=N,
|
||||||
|
d_models=d_models,
|
||||||
|
num_classes=num_classes,
|
||||||
|
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||||||
|
main_cfg=main_cfg,
|
||||||
|
target_compression_N_per_stage=target_compression_N_per_stage,
|
||||||
|
share_channel=True, # 通道间共享编码器
|
||||||
|
fusion_across_channels="mean", # 通道融合方式
|
||||||
|
dropout=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("配置创建完成")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建模型 - 设置正确的dtype以兼容flash attention
|
||||||
|
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||||
|
model = HierEncodersSingleMainClassifier(cfg, device=device, dtype=dtype)
|
||||||
|
model = model.to(device)
|
||||||
|
print(f"模型创建成功,参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 创建随机输入数据 - 使用bfloat16以兼容flash attention
|
||||||
|
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||||
|
x = torch.randn(B, L, N, device=device, dtype=dtype)
|
||||||
|
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
print(f"输入数据形状: {x.shape}, 数据类型: {x.dtype}")
|
||||||
|
|
||||||
|
# 前向传播测试
|
||||||
|
print("\n开始前向传播测试...")
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits, seq_debug, aux = model(x, mask=mask, return_seq=False)
|
||||||
|
|
||||||
|
print(f"✅ 前向传播成功!")
|
||||||
|
print(f"输出logits形状: {logits.shape}") # 应该是 (B, num_classes)
|
||||||
|
print(f"ratio_loss: {aux['ratio_loss']:.4f}")
|
||||||
|
|
||||||
|
# 验证输出形状
|
||||||
|
expected_shape = (B, num_classes)
|
||||||
|
if logits.shape == expected_shape:
|
||||||
|
print(f"✅ 输出形状正确: {logits.shape}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 输出形状错误: 期望 {expected_shape}, 实际 {logits.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 测试训练模式
|
||||||
|
print("\n开始训练模式测试...")
|
||||||
|
model.train()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
# 创建目标标签
|
||||||
|
y = torch.randint(0, num_classes, (B,), device=device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||||||
|
|
||||||
|
# 计算损失
|
||||||
|
cls_loss = F.cross_entropy(logits, y)
|
||||||
|
ratio_reg = 0.01 * aux["ratio_loss"] # ratio loss正则化
|
||||||
|
total_loss = cls_loss + ratio_reg
|
||||||
|
|
||||||
|
print(f"分类损失: {cls_loss:.4f}")
|
||||||
|
print(f"ratio损失: {ratio_reg:.4f}")
|
||||||
|
print(f"总损失: {total_loss:.4f}")
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print("✅ 训练步骤成功!")
|
||||||
|
|
||||||
|
# 测试序列返回功能
|
||||||
|
print("\n测试序列调试信息返回...")
|
||||||
|
with torch.no_grad():
|
||||||
|
logits, seq_debug, aux = model(x, mask=mask, return_seq=True)
|
||||||
|
|
||||||
|
if seq_debug is not None:
|
||||||
|
print(f"✅ 序列调试信息获取成功,包含 {len(seq_debug)} 个通道的信息")
|
||||||
|
else:
|
||||||
|
print("❌ 序列调试信息获取失败")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎉 DC_hnet模型测试全部通过!")
|
||||||
|
print("模型可以正常进行时间序列分类任务")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 测试失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_different_configurations():
|
||||||
|
"""测试不同的模型配置"""
|
||||||
|
print("\n测试不同配置...")
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# 测试配置1: 不共享通道
|
||||||
|
cfg1 = HierEncodersSingleMainConfig(
|
||||||
|
num_channels=3,
|
||||||
|
d_models=[32, 64],
|
||||||
|
num_classes=5,
|
||||||
|
encoder_cfg_per_stage=[dict(arch="m", height=2)],
|
||||||
|
main_cfg=dict(arch="m", height=3),
|
||||||
|
target_compression_N_per_stage=[2],
|
||||||
|
share_channel=False, # 不共享通道
|
||||||
|
fusion_across_channels="concat", # 连接融合
|
||||||
|
dropout=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dtype1 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||||
|
model1 = HierEncodersSingleMainClassifier(cfg1, device=device, dtype=dtype1)
|
||||||
|
x1 = torch.randn(4, 256, 3, device=device, dtype=dtype1)
|
||||||
|
logits1, _, _ = model1(x1)
|
||||||
|
print(f"✅ 配置1 (不共享通道, concat融合): 输出形状 {logits1.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 配置1测试失败: {str(e)}")
|
||||||
|
|
||||||
|
# 测试配置2: 单层模型
|
||||||
|
cfg2 = HierEncodersSingleMainConfig(
|
||||||
|
num_channels=2,
|
||||||
|
d_models=[128], # 只有一层,没有编码器阶段
|
||||||
|
num_classes=3,
|
||||||
|
encoder_cfg_per_stage=[], # 空的编码器阶段
|
||||||
|
main_cfg=dict(arch="T", height=2),
|
||||||
|
target_compression_N_per_stage=[],
|
||||||
|
share_channel=True,
|
||||||
|
fusion_across_channels="mean",
|
||||||
|
dropout=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dtype2 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||||
|
model2 = HierEncodersSingleMainClassifier(cfg2, device=device, dtype=dtype2)
|
||||||
|
x2 = torch.randn(2, 128, 2, device=device, dtype=dtype2)
|
||||||
|
logits2, _, _ = model2(x2)
|
||||||
|
print(f"✅ 配置2 (单层模型): 输出形状 {logits2.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 配置2测试失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 主测试
|
||||||
|
success = test_dc_hnet_model()
|
||||||
|
|
||||||
|
# 额外配置测试
|
||||||
|
test_different_configurations()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\n🎊 所有测试完成!")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("\n💥 测试失败!")
|
||||||
|
sys.exit(1)
|
313
test_dc_patchtst.py
Normal file
313
test_dc_patchtst.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from models.DC_PatchTST import Model
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration class"""
|
||||||
|
def __init__(self):
|
||||||
|
# Basic configuration
|
||||||
|
self.task_name = 'long_term_forecast'
|
||||||
|
self.model = 'DC_PatchTST'
|
||||||
|
|
||||||
|
# Data configuration
|
||||||
|
self.seq_len = 96 # Input sequence length
|
||||||
|
self.pred_len = 24 # Prediction sequence length
|
||||||
|
self.label_len = 48 # Label length
|
||||||
|
self.enc_in = 2 # Input feature dimension (dual channel)
|
||||||
|
self.dec_in = 2 # Decoder input dimension
|
||||||
|
self.c_out = 2 # Output dimension
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
self.d_model = 128 # Model dimension
|
||||||
|
self.n_heads = 8 # Number of attention heads
|
||||||
|
self.e_layers = 2 # Number of encoder layers
|
||||||
|
self.d_layers = 1 # Number of decoder layers
|
||||||
|
self.d_ff = 256 # Feed forward dimension
|
||||||
|
self.factor = 1 # Attention factor
|
||||||
|
self.dropout = 0.1 # Dropout rate
|
||||||
|
self.activation = 'gelu'
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
self.batch_size = 32
|
||||||
|
self.learning_rate = 0.001
|
||||||
|
self.train_epochs = 50
|
||||||
|
self.patience = 5
|
||||||
|
|
||||||
|
# Other configuration
|
||||||
|
self.use_amp = False
|
||||||
|
self.num_class = 0
|
||||||
|
|
||||||
|
# GPU configuration
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
||||||
|
|
||||||
|
class SineWaveDataset(Dataset):
|
||||||
|
"""Sine wave dataset"""
|
||||||
|
def __init__(self, data_path, seq_len=96, pred_len=24, mode='test'):
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.pred_len = pred_len
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
if mode == 'train':
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
|
||||||
|
elif mode == 'val':
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
|
||||||
|
else: # test
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
|
||||||
|
|
||||||
|
# Extract feature columns (except timestamp)
|
||||||
|
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
|
||||||
|
|
||||||
|
# Calculate available sample count
|
||||||
|
self.total_len = len(self.data)
|
||||||
|
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
|
||||||
|
|
||||||
|
print(f"{mode} dataset: {self.total_len} records, {self.samples_num} samples")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.samples_num
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# Input sequence
|
||||||
|
s_begin = idx
|
||||||
|
s_end = s_begin + self.seq_len
|
||||||
|
|
||||||
|
# Prediction target
|
||||||
|
r_begin = s_end
|
||||||
|
r_end = r_begin + self.pred_len
|
||||||
|
|
||||||
|
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
|
||||||
|
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
|
||||||
|
|
||||||
|
# Time marks (simple positional encoding)
|
||||||
|
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
|
||||||
|
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
|
||||||
|
|
||||||
|
return seq_x, seq_y, seq_x_mark, seq_y_mark, idx
|
||||||
|
|
||||||
|
def load_model(model_path):
|
||||||
|
"""Load saved model"""
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||||
|
config = checkpoint['config']
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = Model(config)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print(f"Model loaded successfully - Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.6f}")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
return model, config, device
|
||||||
|
|
||||||
|
def visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5):
|
||||||
|
"""Predict and visualize results, marking chunk points"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Create save directory
|
||||||
|
vis_dir = os.path.join(save_path, 'visualizations')
|
||||||
|
os.makedirs(vis_dir, exist_ok=True)
|
||||||
|
|
||||||
|
sample_count = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_x, batch_y, batch_x_mark, batch_y_mark, batch_idx in test_loader:
|
||||||
|
if sample_count >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
batch_x_mark = batch_x_mark.to(device)
|
||||||
|
batch_y_mark = batch_y_mark.to(device)
|
||||||
|
|
||||||
|
# Construct decoder input
|
||||||
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
|
||||||
|
# Convert to CPU
|
||||||
|
batch_x_cpu = batch_x.cpu().numpy()
|
||||||
|
batch_y_cpu = batch_y.cpu().numpy()
|
||||||
|
outputs_cpu = outputs.cpu().numpy()
|
||||||
|
|
||||||
|
# Process each sample in batch
|
||||||
|
for i in range(min(batch_x.size(0), num_samples - sample_count)):
|
||||||
|
input_seq = batch_x_cpu[i] # (seq_len, 2)
|
||||||
|
true_pred = batch_y_cpu[i] # (pred_len, 2)
|
||||||
|
pred_seq = outputs_cpu[i] # (pred_len, 2)
|
||||||
|
|
||||||
|
# Get first layer chunk information
|
||||||
|
boundary_mask_stage0 = None
|
||||||
|
if aux is not None and 'stage0' in aux:
|
||||||
|
# aux['stage0']['boundary_mask'] shape: (B*nvars, L)
|
||||||
|
# We need to reshape to (B, nvars, L) then take i-th sample
|
||||||
|
stage0_info = aux['stage0']
|
||||||
|
boundary_mask = stage0_info['boundary_mask'].cpu().numpy() # (B*nvars, L)
|
||||||
|
|
||||||
|
B = batch_x.size(0)
|
||||||
|
nvars = batch_x.size(2) # should be 2
|
||||||
|
L = boundary_mask.shape[1]
|
||||||
|
|
||||||
|
# Reshape boundary_mask: (B*nvars, L) -> (B, nvars, L)
|
||||||
|
boundary_mask = boundary_mask.reshape(B, nvars, L)
|
||||||
|
boundary_mask_stage0 = boundary_mask[i] # (nvars, L)
|
||||||
|
|
||||||
|
# Create visualization for each channel
|
||||||
|
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
|
||||||
|
|
||||||
|
for ch in range(2): # dual channel
|
||||||
|
ax = axes[ch]
|
||||||
|
|
||||||
|
# Time axis
|
||||||
|
input_time = np.arange(len(input_seq))
|
||||||
|
pred_time = np.arange(len(input_seq), len(input_seq) + len(true_pred))
|
||||||
|
|
||||||
|
# Plot input sequence
|
||||||
|
ax.plot(input_time, input_seq[:, ch], 'b-', label='Input Sequence', linewidth=1.5)
|
||||||
|
|
||||||
|
# Plot ground truth prediction and model prediction
|
||||||
|
ax.plot(pred_time, true_pred[:, ch], 'g-', label='Ground Truth', linewidth=2)
|
||||||
|
ax.plot(pred_time, pred_seq[:, ch], 'r--', label='Prediction', linewidth=2)
|
||||||
|
|
||||||
|
# Mark first layer chunk points
|
||||||
|
if boundary_mask_stage0 is not None:
|
||||||
|
chunk_points = np.where(boundary_mask_stage0[ch])[0] # Get chunk points for this channel
|
||||||
|
for point in chunk_points:
|
||||||
|
if point < len(input_seq): # Only mark points within input sequence range
|
||||||
|
ax.axvline(x=point, color='orange', linestyle=':', alpha=0.7, linewidth=1)
|
||||||
|
|
||||||
|
# Add chunk points explanation in legend
|
||||||
|
if len(chunk_points) > 0:
|
||||||
|
ax.axvline(x=-1, color='orange', linestyle=':', alpha=0.7,
|
||||||
|
linewidth=1, label=f'Chunk Points (Stage 0)')
|
||||||
|
|
||||||
|
ax.set_title(f'Sample {sample_count + 1} - Channel {ch + 1}')
|
||||||
|
ax.set_xlabel('Time Steps')
|
||||||
|
ax.set_ylabel('Value')
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Add boundary line marking input and prediction boundary
|
||||||
|
ax.axvline(x=len(input_seq)-0.5, color='black', linestyle='-', alpha=0.5, linewidth=1)
|
||||||
|
ax.text(len(input_seq)-0.5, ax.get_ylim()[1]*0.9, 'Prediction Start',
|
||||||
|
rotation=90, verticalalignment='top', fontsize=8)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save figure
|
||||||
|
sample_filename = f'sample_{sample_count + 1}_with_chunks.png'
|
||||||
|
sample_path = os.path.join(vis_dir, sample_filename)
|
||||||
|
plt.savefig(sample_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"Sample {sample_count + 1} visualization saved to: {sample_path}")
|
||||||
|
|
||||||
|
# Print chunk statistics
|
||||||
|
if boundary_mask_stage0 is not None:
|
||||||
|
for ch in range(2):
|
||||||
|
chunk_count = np.sum(boundary_mask_stage0[ch])
|
||||||
|
chunk_ratio = chunk_count / len(boundary_mask_stage0[ch])
|
||||||
|
print(f" Channel {ch + 1}: {chunk_count} chunk points ({chunk_ratio:.2%} of sequence)")
|
||||||
|
|
||||||
|
sample_count += 1
|
||||||
|
if sample_count >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
if sample_count >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
def evaluate_model(model, test_loader, device):
|
||||||
|
"""Evaluate model performance"""
|
||||||
|
model.eval()
|
||||||
|
total_mse = 0.0
|
||||||
|
total_mae = 0.0
|
||||||
|
batch_count = 0
|
||||||
|
|
||||||
|
all_predictions = []
|
||||||
|
all_ground_truths = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_x, batch_y, batch_x_mark, batch_y_mark, _ in test_loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
batch_x_mark = batch_x_mark.to(device)
|
||||||
|
batch_y_mark = batch_y_mark.to(device)
|
||||||
|
|
||||||
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||||
|
|
||||||
|
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
mse = torch.mean((outputs - batch_y) ** 2)
|
||||||
|
mae = torch.mean(torch.abs(outputs - batch_y))
|
||||||
|
|
||||||
|
total_mse += mse.item()
|
||||||
|
total_mae += mae.item()
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
# Collect for overall statistics
|
||||||
|
all_predictions.append(outputs.cpu().numpy())
|
||||||
|
all_ground_truths.append(batch_y.cpu().numpy())
|
||||||
|
|
||||||
|
avg_mse = total_mse / batch_count
|
||||||
|
avg_mae = total_mae / batch_count
|
||||||
|
|
||||||
|
# Calculate overall RMSE
|
||||||
|
all_predictions = np.concatenate(all_predictions, axis=0)
|
||||||
|
all_ground_truths = np.concatenate(all_ground_truths, axis=0)
|
||||||
|
rmse = np.sqrt(np.mean((all_predictions - all_ground_truths) ** 2))
|
||||||
|
|
||||||
|
print(f"\nTest set evaluation results:")
|
||||||
|
print(f"MSE: {avg_mse:.6f}")
|
||||||
|
print(f"MAE: {avg_mae:.6f}")
|
||||||
|
print(f"RMSE: {rmse:.6f}")
|
||||||
|
|
||||||
|
return avg_mse, avg_mae, rmse
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration parameters
|
||||||
|
model_path = './results/dc_patchtst_sine_wave/best_model.pth'
|
||||||
|
data_path = './data/sine_wave/'
|
||||||
|
save_path = './results/dc_patchtst_sine_wave/'
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"Error: Model file not found {model_path}")
|
||||||
|
print("Please run the training script train_dc_patchtst.py first")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model, config, device = load_model(model_path)
|
||||||
|
|
||||||
|
# Load test data
|
||||||
|
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) # batch_size=1 for easier visualization
|
||||||
|
|
||||||
|
print(f"\nConfiguration info:")
|
||||||
|
print(f"Sequence length: {config.seq_len}")
|
||||||
|
print(f"Prediction length: {config.pred_len}")
|
||||||
|
print(f"Feature dimension: {config.enc_in}")
|
||||||
|
|
||||||
|
# Evaluate model
|
||||||
|
mse, mae, rmse = evaluate_model(model, test_loader, device)
|
||||||
|
|
||||||
|
# Visualize prediction results and mark chunk points
|
||||||
|
print(f"\nGenerating visualization results...")
|
||||||
|
visualize_predictions_with_chunks(model, test_loader, device, save_path, num_samples=5)
|
||||||
|
|
||||||
|
print(f"\nAll results saved to: {save_path}")
|
||||||
|
print(f"Visualization files located at: {save_path}/visualizations/")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
335
train_dc_patchtst.py
Normal file
335
train_dc_patchtst.py
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from models.DC_PatchTST import Model
|
||||||
|
import time
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
||||||
|
|
||||||
|
# 检查并创建结果文件夹
|
||||||
|
def ensure_dir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
class SineWaveDataset(Dataset):
|
||||||
|
"""正弦波数据集"""
|
||||||
|
def __init__(self, data_path, seq_len=96, pred_len=24, mode='train'):
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.pred_len = pred_len
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
if mode == 'train':
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'train.csv'))
|
||||||
|
elif mode == 'val':
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'val.csv'))
|
||||||
|
else: # test
|
||||||
|
df = pd.read_csv(os.path.join(data_path, 'test.csv'))
|
||||||
|
|
||||||
|
# 提取特征列(除timestamp外)
|
||||||
|
self.data = df[['channel1', 'channel2']].values.astype(np.float32)
|
||||||
|
|
||||||
|
# 计算可用样本数量
|
||||||
|
self.total_len = len(self.data)
|
||||||
|
self.samples_num = max(0, self.total_len - seq_len - pred_len + 1)
|
||||||
|
|
||||||
|
print(f"{mode} 数据集: {self.total_len} 条记录, {self.samples_num} 个样本")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.samples_num
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# 输入序列
|
||||||
|
s_begin = idx
|
||||||
|
s_end = s_begin + self.seq_len
|
||||||
|
|
||||||
|
# 预测目标
|
||||||
|
r_begin = s_end
|
||||||
|
r_end = r_begin + self.pred_len
|
||||||
|
|
||||||
|
seq_x = self.data[s_begin:s_end] # (seq_len, n_vars)
|
||||||
|
seq_y = self.data[r_begin:r_end] # (pred_len, n_vars)
|
||||||
|
|
||||||
|
# 时间标记(简单的位置编码)
|
||||||
|
seq_x_mark = np.arange(self.seq_len).reshape(-1, 1).astype(np.float32)
|
||||||
|
seq_y_mark = np.arange(self.pred_len).reshape(-1, 1).astype(np.float32)
|
||||||
|
|
||||||
|
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""配置类"""
|
||||||
|
def __init__(self):
|
||||||
|
# 基础配置
|
||||||
|
self.task_name = 'long_term_forecast'
|
||||||
|
self.model = 'DC_PatchTST'
|
||||||
|
|
||||||
|
# 数据配置
|
||||||
|
self.seq_len = 96 # 输入序列长度
|
||||||
|
self.pred_len = 24 # 预测序列长度
|
||||||
|
self.label_len = 48 # 标签长度
|
||||||
|
self.enc_in = 2 # 输入特征维度(双通道)
|
||||||
|
self.dec_in = 2 # 解码器输入维度
|
||||||
|
self.c_out = 2 # 输出维度
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
self.d_model = 128 # 模型维度
|
||||||
|
self.n_heads = 8 # 注意力头数
|
||||||
|
self.e_layers = 2 # 编码器层数
|
||||||
|
self.d_layers = 1 # 解码器层数
|
||||||
|
self.d_ff = 256 # 前向网络维度
|
||||||
|
self.factor = 1 # 注意力因子
|
||||||
|
self.dropout = 0 # Dropout率
|
||||||
|
self.activation = 'gelu'
|
||||||
|
|
||||||
|
# 训练配置
|
||||||
|
self.batch_size = 1024
|
||||||
|
self.learning_rate = 0.001
|
||||||
|
self.train_epochs = 50
|
||||||
|
self.patience = 5
|
||||||
|
|
||||||
|
# 其他配置
|
||||||
|
self.use_amp = False
|
||||||
|
self.num_class = 0
|
||||||
|
|
||||||
|
# GPU配置
|
||||||
|
self.use_gpu = torch.cuda.is_available()
|
||||||
|
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
||||||
|
|
||||||
|
def train_epoch(model, train_loader, criterion, optimizer, device, use_amp=False):
|
||||||
|
"""训练一个epoch"""
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
batch_count = 0
|
||||||
|
|
||||||
|
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
batch_x_mark = batch_x_mark.to(device)
|
||||||
|
batch_y_mark = batch_y_mark.to(device)
|
||||||
|
|
||||||
|
# 构造解码器输入
|
||||||
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if use_amp:
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
loss = criterion(outputs, batch_y)
|
||||||
|
|
||||||
|
# 添加DC的ratio loss
|
||||||
|
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
|
||||||
|
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
|
||||||
|
loss = loss + 0.0 * ratio_loss # ratio loss权重
|
||||||
|
else:
|
||||||
|
outputs, aux = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
loss = criterion(outputs, batch_y)
|
||||||
|
|
||||||
|
# 添加DC的ratio loss
|
||||||
|
if aux is not None and 'ratio_loss0' in aux and 'ratio_loss1' in aux:
|
||||||
|
ratio_loss = aux['ratio_loss0'] + aux['ratio_loss1']
|
||||||
|
loss = loss + 0.0 * ratio_loss # ratio loss权重
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
if i % 100 == 0:
|
||||||
|
print(f'Batch {i}, Loss: {loss.item():.6f}')
|
||||||
|
|
||||||
|
return total_loss / batch_count
|
||||||
|
|
||||||
|
def validate(model, val_loader, criterion, device):
|
||||||
|
"""验证模型"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
batch_count = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
batch_x_mark = batch_x_mark.to(device)
|
||||||
|
batch_y_mark = batch_y_mark.to(device)
|
||||||
|
|
||||||
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||||
|
|
||||||
|
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
loss = criterion(outputs, batch_y)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
batch_count += 1
|
||||||
|
|
||||||
|
return total_loss / batch_count
|
||||||
|
|
||||||
|
def test_model(model, test_loader, device, save_path):
|
||||||
|
"""测试模型并可视化结果"""
|
||||||
|
model.eval()
|
||||||
|
predictions = []
|
||||||
|
ground_truths = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_x, batch_y, batch_x_mark, batch_y_mark in test_loader:
|
||||||
|
batch_x = batch_x.to(device)
|
||||||
|
batch_y = batch_y.to(device)
|
||||||
|
batch_x_mark = batch_x_mark.to(device)
|
||||||
|
batch_y_mark = batch_y_mark.to(device)
|
||||||
|
|
||||||
|
dec_inp = torch.zeros_like(batch_y).to(device)
|
||||||
|
|
||||||
|
outputs, _ = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
|
|
||||||
|
predictions.append(outputs.cpu().numpy())
|
||||||
|
ground_truths.append(batch_y.cpu().numpy())
|
||||||
|
|
||||||
|
predictions = np.concatenate(predictions, axis=0)
|
||||||
|
ground_truths = np.concatenate(ground_truths, axis=0)
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
mse = mean_squared_error(ground_truths.reshape(-1), predictions.reshape(-1))
|
||||||
|
mae = mean_absolute_error(ground_truths.reshape(-1), predictions.reshape(-1))
|
||||||
|
|
||||||
|
print(f"测试结果 - MSE: {mse:.6f}, MAE: {mae:.6f}")
|
||||||
|
|
||||||
|
# 可视化前几个样本
|
||||||
|
plt.figure(figsize=(15, 10))
|
||||||
|
for i in range(min(4, len(predictions))):
|
||||||
|
for ch in range(2): # 双通道
|
||||||
|
plt.subplot(4, 2, i*2 + ch + 1)
|
||||||
|
plt.plot(ground_truths[i, :, ch], label='Ground Truth', color='blue')
|
||||||
|
plt.plot(predictions[i, :, ch], label='Prediction', color='red')
|
||||||
|
plt.title(f'Sample {i+1}, Channel {ch+1}')
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(os.path.join(save_path, 'predictions.png'))
|
||||||
|
print(f"预测结果可视化保存到: {save_path}/predictions.png")
|
||||||
|
|
||||||
|
return mse, mae
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 配置参数
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
# 创建结果目录
|
||||||
|
results_dir = './results/dc_patchtst_sine_wave'
|
||||||
|
ensure_dir(results_dir)
|
||||||
|
|
||||||
|
print(f"使用设备: {config.device}")
|
||||||
|
print(f"模型配置: seq_len={config.seq_len}, pred_len={config.pred_len}")
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
data_path = './data/sine_wave/'
|
||||||
|
train_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'train')
|
||||||
|
val_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'val')
|
||||||
|
test_dataset = SineWaveDataset(data_path, config.seq_len, config.pred_len, 'test')
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
|
||||||
|
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = Model(config).to(config.device)
|
||||||
|
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
|
||||||
|
|
||||||
|
# 优化器和损失函数
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
train_losses = []
|
||||||
|
val_losses = []
|
||||||
|
|
||||||
|
print("开始训练...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(config.train_epochs):
|
||||||
|
epoch_start = time.time()
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device, config.use_amp)
|
||||||
|
|
||||||
|
# 验证
|
||||||
|
val_loss = validate(model, val_loader, criterion, config.device)
|
||||||
|
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
val_losses.append(val_loss)
|
||||||
|
|
||||||
|
epoch_time = time.time() - epoch_start
|
||||||
|
|
||||||
|
print(f'Epoch {epoch+1:2d}/{config.train_epochs} | '
|
||||||
|
f'Train Loss: {train_loss:.6f} | '
|
||||||
|
f'Val Loss: {val_loss:.6f} | '
|
||||||
|
f'Time: {epoch_time:.2f}s')
|
||||||
|
|
||||||
|
# 早停检查
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
# 保存最佳模型
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'config': config,
|
||||||
|
'epoch': epoch,
|
||||||
|
'val_loss': val_loss
|
||||||
|
}, os.path.join(results_dir, 'best_model.pth'))
|
||||||
|
print(f' -> 保存最佳模型 (val_loss: {val_loss:.6f})')
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
if patience_counter >= config.patience:
|
||||||
|
print(f'早停触发! 最佳验证损失: {best_val_loss:.6f}')
|
||||||
|
break
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f'\n训练完成! 总时间: {total_time/60:.2f} 分钟')
|
||||||
|
|
||||||
|
# 加载最佳模型进行测试
|
||||||
|
checkpoint = torch.load(os.path.join(results_dir, 'best_model.pth'))
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
print("\n测试最佳模型...")
|
||||||
|
test_mse, test_mae = test_model(model, test_loader, config.device, results_dir)
|
||||||
|
|
||||||
|
# 保存训练历史
|
||||||
|
history = {
|
||||||
|
'train_losses': train_losses,
|
||||||
|
'val_losses': val_losses,
|
||||||
|
'test_mse': test_mse,
|
||||||
|
'test_mae': test_mae
|
||||||
|
}
|
||||||
|
|
||||||
|
# 绘制训练曲线
|
||||||
|
plt.figure(figsize=(12, 4))
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.plot(train_losses, label='Train Loss')
|
||||||
|
plt.plot(val_losses, label='Validation Loss')
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.title('Training History')
|
||||||
|
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.bar(['MSE', 'MAE'], [test_mse, test_mae])
|
||||||
|
plt.title('Test Metrics')
|
||||||
|
plt.ylabel('Error')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(os.path.join(results_dir, 'training_history.png'))
|
||||||
|
|
||||||
|
print(f"\n结果保存在: {results_dir}")
|
||||||
|
print(f"最佳模型: {results_dir}/best_model.pth")
|
||||||
|
print(f"训练历史: {results_dir}/training_history.png")
|
||||||
|
print(f"预测可视化: {results_dir}/predictions.png")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Reference in New Issue
Block a user