Files
TSlib/models/vanillaMamba.py

204 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba2
class ValueEmbedding(nn.Module):
"""
对每个时间步的单通道标量做线性投影到 d_model并可选 Dropout。
不包含 temporal embedding 和 positional embedding。
"""
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.proj = nn.Linear(in_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, L, 1] -> [B, L, d_model]
return self.dropout(self.proj(x))
class ChannelMambaBlock(nn.Module):
"""
针对单个通道的两层 Mamba-2 处理块:
- 输入: [B, L, 1],先做投影到 d_model
- 两层 Mamba2且在第一层输出和第二层输出均添加残差连接
- 每层后接 LayerNorm
- 输出: [B, L, d_model]
"""
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
super().__init__()
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
# 两层 Mamba-2
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
# 每层后接的归一化
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
# x_ch: [B, L, 1]
x = self.embed(x_ch) # [B, L, d_model]
# 第一层 + 残差
y1 = self.mamba1(x) # [B, L, d_model]
y1 = self.ln1(x + y1) # 残差1 + LN
# 第二层 + 残差
y2 = self.mamba2(y1) # [B, L, d_model]
y2 = self.ln2(y1 + y2) # 残差2 + LN
return y2 # [B, L, d_model]
class Model(nn.Module):
"""
按通道独立处理的 Mamba-2 模型,支持:
- 分类:各通道独立提取,取最后时刻拼接 -> 分类头
- 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接
注意:预测输出通道数与输入通道数严格相同(逐通道预测)。
输入:
- x_enc: [B, L, D] 多变量时间序列
- x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注)
输出:
- classification: logits [B, num_class]
- forecast: [B, pred_len, D]
"""
def __init__(self, configs):
super().__init__()
# 任务类型
self.task_name = getattr(configs, 'task_name', 'classification')
assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \
"只支持 classification / long_term_forecast / short_term_forecast"
# 基本配置
self.enc_in = configs.enc_in # 通道数 D
self.d_model = configs.d_model # 每通道的模型维度
self.num_class = getattr(configs, 'num_class', None)
self.dropout = getattr(configs, 'dropout', 0.1)
# 预测相关
self.seq_len = getattr(configs, 'seq_len', None)
self.pred_len = getattr(configs, 'pred_len', None)
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len"
# 输出通道必须与输入通道一致
self.c_out = getattr(configs, 'c_out', self.enc_in)
assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致"
# Mamba-2 超参数
m2_kwargs = dict(
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
d_ssm=getattr(configs, 'd_ssm', None),
ngroups=getattr(configs, 'ngroups', 1),
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
D_has_hdim=getattr(configs, 'D_has_hdim', False),
rmsnorm=getattr(configs, 'rmsnorm', True),
norm_before_gate=getattr(configs, 'norm_before_gate', False),
dt_min=getattr(configs, 'dt_min', 0.001),
dt_max=getattr(configs, 'dt_max', 0.1),
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
bias=getattr(configs, 'bias', False),
conv_bias=getattr(configs, 'conv_bias', True),
chunk_size=getattr(configs, 'chunk_size', 256),
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
)
# 为每个通道构建独立的两层 Mamba2 处理块
self.channel_blocks = nn.ModuleList([
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
for _ in range(self.enc_in)
])
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
if self.task_name == 'classification':
assert self.num_class is not None, "classification 需要提供 num_class"
self.act = nn.GELU()
self.head = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(self.d_model * self.enc_in, self.num_class)
)
# 预测头:
# - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model]
# - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1]
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
self.predict_linear = nn.Linear(self.seq_len, self.pred_len)
self.projection = nn.Linear(self.d_model, 1, bias=True)
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
# x_enc: [B, L, D]
B, L, D = x_enc.shape
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
per_channel_last = []
for c in range(D):
# 取出单通道序列 [B, L] -> [B, L, 1]
x_ch = x_enc[:, :, c].unsqueeze(-1)
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
h_last = torch.cat(per_channel_last, dim=-1)
# 分类头
logits = self.head(self.act(h_last)) # [B, num_class]
return logits
def forecast(self, x_enc: torch.Tensor) -> torch.Tensor:
"""
逐通道预测:
- 归一化(时间维),按通道独立提取
- 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量
- 反归一化
返回:
dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段
"""
B, L, D = x_enc.shape
assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致"
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
# Normalization (per Non-stationary Transformer)
means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D]
x = x_enc - means
stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D]
x = x / stdev
per_channel_seq = []
for c in range(D):
x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1]
h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
# 时间维映射到 L + pred_len
h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model]
# 投影回单通道
y_ch = self.projection(h_ch) # [B, L+pred_len, 1]
per_channel_seq.append(y_ch)
# 拼接通道
dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D]
# De-normalization
dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
return dec_out
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
if self.task_name in ['long_term_forecast', 'short_term_forecast']:
dec_out = self.forecast(x_enc) # [B, L+pred_len, D]
return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D]
elif self.task_name == 'classification':
return self.classification(x_enc)
else:
raise NotImplementedError(f"Unsupported task: {self.task_name}")