feat: add mamba and dynamic chunking related code and test code
This commit is contained in:
209
test_DC_hnet.py
Normal file
209
test_DC_hnet.py
Normal file
@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试DC_hnet模型的脚本
|
||||
用于验证时间序列分类模型能否正常运行并得到期望的输出形状
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加当前目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from models.DC_hnet import HierEncodersSingleMainConfig, HierEncodersSingleMainClassifier
|
||||
|
||||
def test_dc_hnet_model():
|
||||
"""测试DC_hnet时间序列分类模型"""
|
||||
print("=" * 60)
|
||||
print("测试DC_hnet时间序列分类模型")
|
||||
print("=" * 60)
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 模型参数配置
|
||||
B, L, N = 8, 512, 6 # batch_size, seq_length, num_channels
|
||||
num_classes = 10 # 分类数
|
||||
d_models = [64, 128, 256] # 各层维度,单调递增
|
||||
|
||||
print(f"输入形状: (B={B}, L={L}, N={N})")
|
||||
print(f"分类数: {num_classes}")
|
||||
print(f"模型维度: {d_models}")
|
||||
|
||||
# 编码器配置(每层都是Mamba)
|
||||
encoder_cfg_per_stage = [
|
||||
dict(arch="m", height=2), # stage 0: Mamba2, 2层
|
||||
dict(arch="m", height=3), # stage 1: Mamba2, 3层
|
||||
]
|
||||
|
||||
# 主网络配置(使用Transformer)
|
||||
main_cfg = dict(
|
||||
arch="T", height=4 # Transformer, 4层
|
||||
)
|
||||
|
||||
# 压缩目标
|
||||
target_compression_N_per_stage = [2, 3] # 每层压缩比例
|
||||
|
||||
# 创建配置
|
||||
cfg = HierEncodersSingleMainConfig(
|
||||
num_channels=N,
|
||||
d_models=d_models,
|
||||
num_classes=num_classes,
|
||||
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||||
main_cfg=main_cfg,
|
||||
target_compression_N_per_stage=target_compression_N_per_stage,
|
||||
share_channel=True, # 通道间共享编码器
|
||||
fusion_across_channels="mean", # 通道融合方式
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
print("配置创建完成")
|
||||
|
||||
try:
|
||||
# 创建模型 - 设置正确的dtype以兼容flash attention
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
model = HierEncodersSingleMainClassifier(cfg, device=device, dtype=dtype)
|
||||
model = model.to(device)
|
||||
print(f"模型创建成功,参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 创建随机输入数据 - 使用bfloat16以兼容flash attention
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
x = torch.randn(B, L, N, device=device, dtype=dtype)
|
||||
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||||
|
||||
print(f"输入数据形状: {x.shape}, 数据类型: {x.dtype}")
|
||||
|
||||
# 前向传播测试
|
||||
print("\n开始前向传播测试...")
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits, seq_debug, aux = model(x, mask=mask, return_seq=False)
|
||||
|
||||
print(f"✅ 前向传播成功!")
|
||||
print(f"输出logits形状: {logits.shape}") # 应该是 (B, num_classes)
|
||||
print(f"ratio_loss: {aux['ratio_loss']:.4f}")
|
||||
|
||||
# 验证输出形状
|
||||
expected_shape = (B, num_classes)
|
||||
if logits.shape == expected_shape:
|
||||
print(f"✅ 输出形状正确: {logits.shape}")
|
||||
else:
|
||||
print(f"❌ 输出形状错误: 期望 {expected_shape}, 实际 {logits.shape}")
|
||||
return False
|
||||
|
||||
# 测试训练模式
|
||||
print("\n开始训练模式测试...")
|
||||
model.train()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
# 创建目标标签
|
||||
y = torch.randint(0, num_classes, (B,), device=device)
|
||||
|
||||
# 前向传播
|
||||
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||||
|
||||
# 计算损失
|
||||
cls_loss = F.cross_entropy(logits, y)
|
||||
ratio_reg = 0.01 * aux["ratio_loss"] # ratio loss正则化
|
||||
total_loss = cls_loss + ratio_reg
|
||||
|
||||
print(f"分类损失: {cls_loss:.4f}")
|
||||
print(f"ratio损失: {ratio_reg:.4f}")
|
||||
print(f"总损失: {total_loss:.4f}")
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print("✅ 训练步骤成功!")
|
||||
|
||||
# 测试序列返回功能
|
||||
print("\n测试序列调试信息返回...")
|
||||
with torch.no_grad():
|
||||
logits, seq_debug, aux = model(x, mask=mask, return_seq=True)
|
||||
|
||||
if seq_debug is not None:
|
||||
print(f"✅ 序列调试信息获取成功,包含 {len(seq_debug)} 个通道的信息")
|
||||
else:
|
||||
print("❌ 序列调试信息获取失败")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 DC_hnet模型测试全部通过!")
|
||||
print("模型可以正常进行时间序列分类任务")
|
||||
print("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_different_configurations():
|
||||
"""测试不同的模型配置"""
|
||||
print("\n测试不同配置...")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 测试配置1: 不共享通道
|
||||
cfg1 = HierEncodersSingleMainConfig(
|
||||
num_channels=3,
|
||||
d_models=[32, 64],
|
||||
num_classes=5,
|
||||
encoder_cfg_per_stage=[dict(arch="m", height=2)],
|
||||
main_cfg=dict(arch="m", height=3),
|
||||
target_compression_N_per_stage=[2],
|
||||
share_channel=False, # 不共享通道
|
||||
fusion_across_channels="concat", # 连接融合
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
try:
|
||||
dtype1 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
model1 = HierEncodersSingleMainClassifier(cfg1, device=device, dtype=dtype1)
|
||||
x1 = torch.randn(4, 256, 3, device=device, dtype=dtype1)
|
||||
logits1, _, _ = model1(x1)
|
||||
print(f"✅ 配置1 (不共享通道, concat融合): 输出形状 {logits1.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ 配置1测试失败: {str(e)}")
|
||||
|
||||
# 测试配置2: 单层模型
|
||||
cfg2 = HierEncodersSingleMainConfig(
|
||||
num_channels=2,
|
||||
d_models=[128], # 只有一层,没有编码器阶段
|
||||
num_classes=3,
|
||||
encoder_cfg_per_stage=[], # 空的编码器阶段
|
||||
main_cfg=dict(arch="T", height=2),
|
||||
target_compression_N_per_stage=[],
|
||||
share_channel=True,
|
||||
fusion_across_channels="mean",
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
try:
|
||||
dtype2 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
model2 = HierEncodersSingleMainClassifier(cfg2, device=device, dtype=dtype2)
|
||||
x2 = torch.randn(2, 128, 2, device=device, dtype=dtype2)
|
||||
logits2, _, _ = model2(x2)
|
||||
print(f"✅ 配置2 (单层模型): 输出形状 {logits2.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ 配置2测试失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 主测试
|
||||
success = test_dc_hnet_model()
|
||||
|
||||
# 额外配置测试
|
||||
test_different_configurations()
|
||||
|
||||
if success:
|
||||
print("\n🎊 所有测试完成!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n💥 测试失败!")
|
||||
sys.exit(1)
|
Reference in New Issue
Block a user