feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
138
models/vanillaMamba-Copy1.py
Normal file
138
models/vanillaMamba-Copy1.py
Normal file
@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mamba_ssm import Mamba2
|
||||
|
||||
|
||||
class ValueEmbedding(nn.Module):
|
||||
"""
|
||||
对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。
|
||||
不包含 temporal embedding 和 positional embedding。
|
||||
"""
|
||||
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(in_dim, d_model, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, L, 1] -> [B, L, d_model]
|
||||
return self.dropout(self.proj(x))
|
||||
|
||||
|
||||
class ChannelMambaBlock(nn.Module):
|
||||
"""
|
||||
针对单个通道的两层 Mamba-2 处理块:
|
||||
- 输入: [B, L, 1],先做投影到 d_model
|
||||
- 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接
|
||||
- 每层后接 LayerNorm
|
||||
- 输出: [B, L, d_model]
|
||||
"""
|
||||
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
|
||||
super().__init__()
|
||||
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
|
||||
|
||||
# 两层 Mamba-2
|
||||
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
|
||||
|
||||
# 每层后接的归一化
|
||||
self.ln1 = nn.LayerNorm(d_model)
|
||||
self.ln2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
|
||||
# x_ch: [B, L, 1]
|
||||
x = self.embed(x_ch) # [B, L, d_model]
|
||||
|
||||
# 第一层 + 残差
|
||||
y1 = self.mamba1(x) # [B, L, d_model]
|
||||
y1 = self.ln1(x + y1) # 残差1 + LN
|
||||
|
||||
# 第二层 + 残差
|
||||
y2 = self.mamba2(y1) # [B, L, d_model]
|
||||
y2 = self.ln2(y1 + y2) # 残差2 + LN
|
||||
|
||||
return y2 # [B, L, d_model]
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
按通道独立处理的 Mamba-2 分类模型:
|
||||
- 将输入的每个通道拆开,分别使用独立的两层 Mamba2(含两处残差)
|
||||
- 每个通道得到 [B, L, d_model] 输出
|
||||
- 取各通道最后时间步的表示拼接,接分类头
|
||||
输入:
|
||||
- x_enc: [B, L, D] 多变量时间序列
|
||||
输出:
|
||||
- logits: [B, num_class]
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super().__init__()
|
||||
self.task_name = getattr(configs, 'task_name', 'classification')
|
||||
assert self.task_name == 'classification', "当前模型仅实现 classification 任务"
|
||||
|
||||
# 基本配置
|
||||
self.enc_in = configs.enc_in # 通道数 D
|
||||
self.d_model = configs.d_model # 每通道的模型维度
|
||||
self.num_class = configs.num_class
|
||||
self.dropout = getattr(configs, 'dropout', 0.1)
|
||||
|
||||
# Mamba-2 超参数(按需从 configs 读取)
|
||||
# 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求
|
||||
m2_kwargs = dict(
|
||||
d_state=getattr(configs, 'd_state', 64),
|
||||
d_conv=getattr(configs, 'd_conv', 4),
|
||||
expand=getattr(configs, 'expand', 2),
|
||||
headdim=getattr(configs, 'headdim', 64),
|
||||
d_ssm=getattr(configs, 'd_ssm', None),
|
||||
ngroups=getattr(configs, 'ngroups', 1),
|
||||
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
|
||||
D_has_hdim=getattr(configs, 'D_has_hdim', False),
|
||||
rmsnorm=getattr(configs, 'rmsnorm', True),
|
||||
norm_before_gate=getattr(configs, 'norm_before_gate', False),
|
||||
dt_min=getattr(configs, 'dt_min', 0.001),
|
||||
dt_max=getattr(configs, 'dt_max', 0.1),
|
||||
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
|
||||
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
|
||||
bias=getattr(configs, 'bias', False),
|
||||
conv_bias=getattr(configs, 'conv_bias', True),
|
||||
chunk_size=getattr(configs, 'chunk_size', 256),
|
||||
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
|
||||
)
|
||||
|
||||
# 为每个通道构建独立的两层 Mamba2 处理块
|
||||
self.channel_blocks = nn.ModuleList([
|
||||
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
|
||||
for _ in range(self.enc_in)
|
||||
])
|
||||
|
||||
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
|
||||
self.act = nn.GELU()
|
||||
self.head = nn.Sequential(
|
||||
nn.Dropout(self.dropout),
|
||||
nn.Linear(self.d_model * self.enc_in, self.num_class)
|
||||
)
|
||||
|
||||
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
|
||||
# x_enc: [B, L, D]
|
||||
B, L, D = x_enc.shape
|
||||
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
|
||||
|
||||
per_channel_last = []
|
||||
for c in range(D):
|
||||
# 取出单通道序列 [B, L] -> [B, L, 1]
|
||||
x_ch = x_enc[:, :, c].unsqueeze(-1)
|
||||
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
|
||||
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
|
||||
|
||||
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
|
||||
h_last = torch.cat(per_channel_last, dim=-1)
|
||||
|
||||
# 分类头
|
||||
h_last = self.act(h_last)
|
||||
logits = self.head(h_last) # [B, num_class]
|
||||
return logits
|
||||
|
||||
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
|
||||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
||||
return self.classification(x_enc)
|
Reference in New Issue
Block a user