feat: add mamba and dynamic chunking related code and test code

This commit is contained in:
gameloader
2025-09-04 01:32:13 +00:00
parent 12cb7652cf
commit ef307a57e9
21 changed files with 4550 additions and 86 deletions

203
models/vanillaMamba.py Normal file
View File

@ -0,0 +1,203 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba2
class ValueEmbedding(nn.Module):
"""
对每个时间步的单通道标量做线性投影到 d_model并可选 Dropout。
不包含 temporal embedding 和 positional embedding。
"""
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.proj = nn.Linear(in_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, L, 1] -> [B, L, d_model]
return self.dropout(self.proj(x))
class ChannelMambaBlock(nn.Module):
"""
针对单个通道的两层 Mamba-2 处理块:
- 输入: [B, L, 1],先做投影到 d_model
- 两层 Mamba2且在第一层输出和第二层输出均添加残差连接
- 每层后接 LayerNorm
- 输出: [B, L, d_model]
"""
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
super().__init__()
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
# 两层 Mamba-2
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
# 每层后接的归一化
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
# x_ch: [B, L, 1]
x = self.embed(x_ch) # [B, L, d_model]
# 第一层 + 残差
y1 = self.mamba1(x) # [B, L, d_model]
y1 = self.ln1(x + y1) # 残差1 + LN
# 第二层 + 残差
y2 = self.mamba2(y1) # [B, L, d_model]
y2 = self.ln2(y1 + y2) # 残差2 + LN
return y2 # [B, L, d_model]
class Model(nn.Module):
"""
按通道独立处理的 Mamba-2 模型,支持:
- 分类:各通道独立提取,取最后时刻拼接 -> 分类头
- 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接
注意:预测输出通道数与输入通道数严格相同(逐通道预测)。
输入:
- x_enc: [B, L, D] 多变量时间序列
- x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注)
输出:
- classification: logits [B, num_class]
- forecast: [B, pred_len, D]
"""
def __init__(self, configs):
super().__init__()
# 任务类型
self.task_name = getattr(configs, 'task_name', 'classification')
assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \
"只支持 classification / long_term_forecast / short_term_forecast"
# 基本配置
self.enc_in = configs.enc_in # 通道数 D
self.d_model = configs.d_model # 每通道的模型维度
self.num_class = getattr(configs, 'num_class', None)
self.dropout = getattr(configs, 'dropout', 0.1)
# 预测相关
self.seq_len = getattr(configs, 'seq_len', None)
self.pred_len = getattr(configs, 'pred_len', None)
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len"
# 输出通道必须与输入通道一致
self.c_out = getattr(configs, 'c_out', self.enc_in)
assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致"
# Mamba-2 超参数
m2_kwargs = dict(
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
d_ssm=getattr(configs, 'd_ssm', None),
ngroups=getattr(configs, 'ngroups', 1),
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
D_has_hdim=getattr(configs, 'D_has_hdim', False),
rmsnorm=getattr(configs, 'rmsnorm', True),
norm_before_gate=getattr(configs, 'norm_before_gate', False),
dt_min=getattr(configs, 'dt_min', 0.001),
dt_max=getattr(configs, 'dt_max', 0.1),
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
bias=getattr(configs, 'bias', False),
conv_bias=getattr(configs, 'conv_bias', True),
chunk_size=getattr(configs, 'chunk_size', 256),
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
)
# 为每个通道构建独立的两层 Mamba2 处理块
self.channel_blocks = nn.ModuleList([
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
for _ in range(self.enc_in)
])
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
if self.task_name == 'classification':
assert self.num_class is not None, "classification 需要提供 num_class"
self.act = nn.GELU()
self.head = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(self.d_model * self.enc_in, self.num_class)
)
# 预测头:
# - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model]
# - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1]
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
self.predict_linear = nn.Linear(self.seq_len, self.pred_len)
self.projection = nn.Linear(self.d_model, 1, bias=True)
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
# x_enc: [B, L, D]
B, L, D = x_enc.shape
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
per_channel_last = []
for c in range(D):
# 取出单通道序列 [B, L] -> [B, L, 1]
x_ch = x_enc[:, :, c].unsqueeze(-1)
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
h_last = torch.cat(per_channel_last, dim=-1)
# 分类头
logits = self.head(self.act(h_last)) # [B, num_class]
return logits
def forecast(self, x_enc: torch.Tensor) -> torch.Tensor:
"""
逐通道预测:
- 归一化(时间维),按通道独立提取
- 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量
- 反归一化
返回:
dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段
"""
B, L, D = x_enc.shape
assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致"
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
# Normalization (per Non-stationary Transformer)
means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D]
x = x_enc - means
stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D]
x = x / stdev
per_channel_seq = []
for c in range(D):
x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1]
h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
# 时间维映射到 L + pred_len
h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model]
# 投影回单通道
y_ch = self.projection(h_ch) # [B, L+pred_len, 1]
per_channel_seq.append(y_ch)
# 拼接通道
dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D]
# De-normalization
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
return dec_out
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
dec_out = self.forecast(x_enc) # [B, L+pred_len, D]
return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D]
elif self.task_name == 'classification':
return self.classification(x_enc)
else:
raise NotImplementedError(f"Unsupported task: {self.task_name}")