Files
TSlib/models/vanillaMamba-Copy1.py

139 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba2
class ValueEmbedding(nn.Module):
"""
对每个时间步的单通道标量做线性投影到 d_model并可选 Dropout。
不包含 temporal embedding 和 positional embedding。
"""
def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
self.proj = nn.Linear(in_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, L, 1] -> [B, L, d_model]
return self.dropout(self.proj(x))
class ChannelMambaBlock(nn.Module):
"""
针对单个通道的两层 Mamba-2 处理块:
- 输入: [B, L, 1],先做投影到 d_model
- 两层 Mamba2且在第一层输出和第二层输出均添加残差连接
- 每层后接 LayerNorm
- 输出: [B, L, d_model]
"""
def __init__(self, d_model: int, dropout: float, m2_kwargs: dict):
super().__init__()
self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True)
# 两层 Mamba-2
self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs)
self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs)
# 每层后接的归一化
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x_ch: torch.Tensor) -> torch.Tensor:
# x_ch: [B, L, 1]
x = self.embed(x_ch) # [B, L, d_model]
# 第一层 + 残差
y1 = self.mamba1(x) # [B, L, d_model]
y1 = self.ln1(x + y1) # 残差1 + LN
# 第二层 + 残差
y2 = self.mamba2(y1) # [B, L, d_model]
y2 = self.ln2(y1 + y2) # 残差2 + LN
return y2 # [B, L, d_model]
class Model(nn.Module):
"""
按通道独立处理的 Mamba-2 分类模型:
- 将输入的每个通道拆开,分别使用独立的两层 Mamba2含两处残差
- 每个通道得到 [B, L, d_model] 输出
- 取各通道最后时间步的表示拼接,接分类头
输入:
- x_enc: [B, L, D] 多变量时间序列
输出:
- logits: [B, num_class]
"""
def __init__(self, configs):
super().__init__()
self.task_name = getattr(configs, 'task_name', 'classification')
assert self.task_name == 'classification', "当前模型仅实现 classification 任务"
# 基本配置
self.enc_in = configs.enc_in # 通道数 D
self.d_model = configs.d_model # 每通道的模型维度
self.num_class = configs.num_class
self.dropout = getattr(configs, 'dropout', 0.1)
# Mamba-2 超参数(按需从 configs 读取)
# 注意:此处不再使用 e_layers 的堆叠,而是固定每通道两层以满足“在第一层和第二层输出处添加残差”的要求
m2_kwargs = dict(
d_state=getattr(configs, 'd_state', 64),
d_conv=getattr(configs, 'd_conv', 4),
expand=getattr(configs, 'expand', 2),
headdim=getattr(configs, 'headdim', 64),
d_ssm=getattr(configs, 'd_ssm', None),
ngroups=getattr(configs, 'ngroups', 1),
A_init_range=getattr(configs, 'A_init_range', (1, 16)),
D_has_hdim=getattr(configs, 'D_has_hdim', False),
rmsnorm=getattr(configs, 'rmsnorm', True),
norm_before_gate=getattr(configs, 'norm_before_gate', False),
dt_min=getattr(configs, 'dt_min', 0.001),
dt_max=getattr(configs, 'dt_max', 0.1),
dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4),
dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))),
bias=getattr(configs, 'bias', False),
conv_bias=getattr(configs, 'conv_bias', True),
chunk_size=getattr(configs, 'chunk_size', 256),
use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True),
)
# 为每个通道构建独立的两层 Mamba2 处理块
self.channel_blocks = nn.ModuleList([
ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs)
for _ in range(self.enc_in)
])
# 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear
self.act = nn.GELU()
self.head = nn.Sequential(
nn.Dropout(self.dropout),
nn.Linear(self.d_model * self.enc_in, self.num_class)
)
def classification(self, x_enc: torch.Tensor) -> torch.Tensor:
# x_enc: [B, L, D]
B, L, D = x_enc.shape
assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致"
per_channel_last = []
for c in range(D):
# 取出单通道序列 [B, L] -> [B, L, 1]
x_ch = x_enc[:, :, c].unsqueeze(-1)
y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model]
per_channel_last.append(y_ch[:, -1, :]) # [B, d_model]
# 拼接各通道最后时刻的表示 -> [B, D * d_model]
h_last = torch.cat(per_channel_last, dim=-1)
# 分类头
h_last = self.act(h_last)
logits = self.head(h_last) # [B, num_class]
return logits
# 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
return self.classification(x_enc)