import torch import torch.nn as nn import torch.nn.functional as F from mamba_ssm import Mamba2 class ValueEmbedding(nn.Module): """ 对每个时间步的单通道标量做线性投影到 d_model,并可选 Dropout。 不包含 temporal embedding 和 positional embedding。 """ def __init__(self, in_dim: int, d_model: int, dropout: float = 0.0, bias: bool = True): super().__init__() self.proj = nn.Linear(in_dim, d_model, bias=bias) self.dropout = nn.Dropout(dropout) if dropout and dropout > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, L, 1] -> [B, L, d_model] return self.dropout(self.proj(x)) class ChannelMambaBlock(nn.Module): """ 针对单个通道的两层 Mamba-2 处理块: - 输入: [B, L, 1],先做投影到 d_model - 两层 Mamba2,且在第一层输出和第二层输出均添加残差连接 - 每层后接 LayerNorm - 输出: [B, L, d_model] """ def __init__(self, d_model: int, dropout: float, m2_kwargs: dict): super().__init__() self.embed = ValueEmbedding(in_dim=1, d_model=d_model, dropout=dropout, bias=True) # 两层 Mamba-2 self.mamba1 = Mamba2(d_model=d_model, **m2_kwargs) self.mamba2 = Mamba2(d_model=d_model, **m2_kwargs) # 每层后接的归一化 self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) def forward(self, x_ch: torch.Tensor) -> torch.Tensor: # x_ch: [B, L, 1] x = self.embed(x_ch) # [B, L, d_model] # 第一层 + 残差 y1 = self.mamba1(x) # [B, L, d_model] y1 = self.ln1(x + y1) # 残差1 + LN # 第二层 + 残差 y2 = self.mamba2(y1) # [B, L, d_model] y2 = self.ln2(y1 + y2) # 残差2 + LN return y2 # [B, L, d_model] class Model(nn.Module): """ 按通道独立处理的 Mamba-2 模型,支持: - 分类:各通道独立提取,取最后时刻拼接 -> 分类头 - 长/短期预测:各通道独立提取,保留整段序列,经时间维线性映射到目标长度,再投影回标量并拼接 注意:预测输出通道数与输入通道数严格相同(逐通道预测)。 输入: - x_enc: [B, L, D] 多变量时间序列 - x_mark_enc, x_dec, x_mark_dec, mask: 兼容接口参数(本模型在分类/预测中未使用这些标注) 输出: - classification: logits [B, num_class] - forecast: [B, pred_len, D] """ def __init__(self, configs): super().__init__() # 任务类型 self.task_name = getattr(configs, 'task_name', 'classification') assert self.task_name in ['classification', 'long_term_forecast', 'short_term_forecast'], \ "只支持 classification / long_term_forecast / short_term_forecast" # 基本配置 self.enc_in = configs.enc_in # 通道数 D self.d_model = configs.d_model # 每通道的模型维度 self.num_class = getattr(configs, 'num_class', None) self.dropout = getattr(configs, 'dropout', 0.1) # 预测相关 self.seq_len = getattr(configs, 'seq_len', None) self.pred_len = getattr(configs, 'pred_len', None) if self.task_name in ['long_term_forecast', 'short_term_forecast']: assert self.seq_len is not None and self.pred_len is not None, "预测任务需要 seq_len 与 pred_len" # 输出通道必须与输入通道一致 self.c_out = getattr(configs, 'c_out', self.enc_in) assert self.c_out == self.enc_in, "预测任务要求输出通道 c_out 与输入通道 enc_in 一致" # Mamba-2 超参数 m2_kwargs = dict( d_state=getattr(configs, 'd_state', 64), d_conv=getattr(configs, 'd_conv', 4), expand=getattr(configs, 'expand', 2), headdim=getattr(configs, 'headdim', 64), d_ssm=getattr(configs, 'd_ssm', None), ngroups=getattr(configs, 'ngroups', 1), A_init_range=getattr(configs, 'A_init_range', (1, 16)), D_has_hdim=getattr(configs, 'D_has_hdim', False), rmsnorm=getattr(configs, 'rmsnorm', True), norm_before_gate=getattr(configs, 'norm_before_gate', False), dt_min=getattr(configs, 'dt_min', 0.001), dt_max=getattr(configs, 'dt_max', 0.1), dt_init_floor=getattr(configs, 'dt_init_floor', 1e-4), dt_limit=getattr(configs, 'dt_limit', (0.0, float("inf"))), bias=getattr(configs, 'bias', False), conv_bias=getattr(configs, 'conv_bias', True), chunk_size=getattr(configs, 'chunk_size', 256), use_mem_eff_path=getattr(configs, 'use_mem_eff_path', True), ) # 为每个通道构建独立的两层 Mamba2 处理块 self.channel_blocks = nn.ModuleList([ ChannelMambaBlock(d_model=self.d_model, dropout=self.dropout, m2_kwargs=m2_kwargs) for _ in range(self.enc_in) ]) # 分类头:将各通道最后时间步的表示拼接后 -> GELU -> Dropout -> Linear if self.task_name == 'classification': assert self.num_class is not None, "classification 需要提供 num_class" self.act = nn.GELU() self.head = nn.Sequential( nn.Dropout(self.dropout), nn.Linear(self.d_model * self.enc_in, self.num_class) ) # 预测头: # - 先对时间维做线性映射: [B, L, d_model] -> [B, pred_len, d_model] # - 再将 d_model 投影为单通道标量: [B, pred_len, d_model] -> [B, pred_len, 1] if self.task_name in ['long_term_forecast', 'short_term_forecast']: self.predict_linear = nn.Linear(self.seq_len, self.pred_len) self.projection = nn.Linear(self.d_model, 1, bias=True) def classification(self, x_enc: torch.Tensor) -> torch.Tensor: # x_enc: [B, L, D] B, L, D = x_enc.shape assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致" per_channel_last = [] for c in range(D): # 取出单通道序列 [B, L] -> [B, L, 1] x_ch = x_enc[:, :, c].unsqueeze(-1) y_ch = self.channel_blocks[c](x_ch) # [B, L, d_model] per_channel_last.append(y_ch[:, -1, :]) # [B, d_model] # 拼接各通道最后时刻的表示 -> [B, D * d_model] h_last = torch.cat(per_channel_last, dim=-1) # 分类头 logits = self.head(self.act(h_last)) # [B, num_class] return logits def forecast(self, x_enc: torch.Tensor) -> torch.Tensor: """ 逐通道预测: - 归一化(时间维),按通道独立提取 - 使用整段 Mamba 输出序列,经时间维线性映射到目标长度,再投影为标量 - 反归一化 返回: dec_out: [B, L+pred_len, D],在 forward 中会取最后 pred_len 段 """ B, L, D = x_enc.shape assert L == self.seq_len, f"输入长度 {L} 与配置 seq_len {self.seq_len} 不一致" assert D == self.enc_in, f"输入通道数 {D} 与 enc_in {self.enc_in} 不一致" # Normalization (per Non-stationary Transformer) means = x_enc.mean(1, keepdim=True).detach() # [B, 1, D] x = x_enc - means stdev = torch.sqrt(x.var(dim=1, keepdim=True, unbiased=False) + 1e-5) # [B, 1, D] x = x / stdev per_channel_seq = [] for c in range(D): x_ch = x[:, :, c].unsqueeze(-1) # [B, L, 1] h_ch = self.channel_blocks[c](x_ch) # [B, L, d_model] # 时间维映射到 L + pred_len h_ch = self.predict_linear(h_ch.permute(0, 2, 1)).permute(0, 2, 1) # [B, L+pred_len, d_model] # 投影回单通道 y_ch = self.projection(h_ch) # [B, L+pred_len, 1] per_channel_seq.append(y_ch) # 拼接通道 dec_out = torch.cat(per_channel_seq, dim=-1) # [B, pred_len, D] # De-normalization dec_out = dec_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1) return dec_out # 与 TimesNet 的 forward 签名保持一致;忽略 x_mark_enc / x_dec / x_mark_dec / mask def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): if self.task_name in ['long_term_forecast', 'short_term_forecast']: dec_out = self.forecast(x_enc) # [B, L+pred_len, D] return dec_out[:, -self.pred_len:, :] # 仅返回预测部分 [B, pred_len, D] elif self.task_name == 'classification': return self.classification(x_enc) else: raise NotImplementedError(f"Unsupported task: {self.task_name}")