209 lines
7.1 KiB
Python
209 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试DC_hnet模型的脚本
|
||
用于验证时间序列分类模型能否正常运行并得到期望的输出形状
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import sys
|
||
import os
|
||
|
||
# 添加当前目录到Python路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from models.DC_hnet import HierEncodersSingleMainConfig, HierEncodersSingleMainClassifier
|
||
|
||
def test_dc_hnet_model():
|
||
"""测试DC_hnet时间序列分类模型"""
|
||
print("=" * 60)
|
||
print("测试DC_hnet时间序列分类模型")
|
||
print("=" * 60)
|
||
|
||
# 设置设备
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"使用设备: {device}")
|
||
|
||
# 模型参数配置
|
||
B, L, N = 8, 512, 6 # batch_size, seq_length, num_channels
|
||
num_classes = 10 # 分类数
|
||
d_models = [64, 128, 256] # 各层维度,单调递增
|
||
|
||
print(f"输入形状: (B={B}, L={L}, N={N})")
|
||
print(f"分类数: {num_classes}")
|
||
print(f"模型维度: {d_models}")
|
||
|
||
# 编码器配置(每层都是Mamba)
|
||
encoder_cfg_per_stage = [
|
||
dict(arch="m", height=2), # stage 0: Mamba2, 2层
|
||
dict(arch="m", height=3), # stage 1: Mamba2, 3层
|
||
]
|
||
|
||
# 主网络配置(使用Transformer)
|
||
main_cfg = dict(
|
||
arch="T", height=4 # Transformer, 4层
|
||
)
|
||
|
||
# 压缩目标
|
||
target_compression_N_per_stage = [2, 3] # 每层压缩比例
|
||
|
||
# 创建配置
|
||
cfg = HierEncodersSingleMainConfig(
|
||
num_channels=N,
|
||
d_models=d_models,
|
||
num_classes=num_classes,
|
||
encoder_cfg_per_stage=encoder_cfg_per_stage,
|
||
main_cfg=main_cfg,
|
||
target_compression_N_per_stage=target_compression_N_per_stage,
|
||
share_channel=True, # 通道间共享编码器
|
||
fusion_across_channels="mean", # 通道融合方式
|
||
dropout=0.1,
|
||
)
|
||
|
||
print("配置创建完成")
|
||
|
||
try:
|
||
# 创建模型 - 设置正确的dtype以兼容flash attention
|
||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||
model = HierEncodersSingleMainClassifier(cfg, device=device, dtype=dtype)
|
||
model = model.to(device)
|
||
print(f"模型创建成功,参数量: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# 创建随机输入数据 - 使用bfloat16以兼容flash attention
|
||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||
x = torch.randn(B, L, N, device=device, dtype=dtype)
|
||
mask = torch.ones(B, L, dtype=torch.bool, device=device)
|
||
|
||
print(f"输入数据形状: {x.shape}, 数据类型: {x.dtype}")
|
||
|
||
# 前向传播测试
|
||
print("\n开始前向传播测试...")
|
||
model.eval()
|
||
with torch.no_grad():
|
||
logits, seq_debug, aux = model(x, mask=mask, return_seq=False)
|
||
|
||
print(f"✅ 前向传播成功!")
|
||
print(f"输出logits形状: {logits.shape}") # 应该是 (B, num_classes)
|
||
print(f"ratio_loss: {aux['ratio_loss']:.4f}")
|
||
|
||
# 验证输出形状
|
||
expected_shape = (B, num_classes)
|
||
if logits.shape == expected_shape:
|
||
print(f"✅ 输出形状正确: {logits.shape}")
|
||
else:
|
||
print(f"❌ 输出形状错误: 期望 {expected_shape}, 实际 {logits.shape}")
|
||
return False
|
||
|
||
# 测试训练模式
|
||
print("\n开始训练模式测试...")
|
||
model.train()
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||
|
||
# 创建目标标签
|
||
y = torch.randint(0, num_classes, (B,), device=device)
|
||
|
||
# 前向传播
|
||
logits, _, aux = model(x, mask=mask, return_seq=False)
|
||
|
||
# 计算损失
|
||
cls_loss = F.cross_entropy(logits, y)
|
||
ratio_reg = 0.01 * aux["ratio_loss"] # ratio loss正则化
|
||
total_loss = cls_loss + ratio_reg
|
||
|
||
print(f"分类损失: {cls_loss:.4f}")
|
||
print(f"ratio损失: {ratio_reg:.4f}")
|
||
print(f"总损失: {total_loss:.4f}")
|
||
|
||
# 反向传播
|
||
optimizer.zero_grad()
|
||
total_loss.backward()
|
||
optimizer.step()
|
||
|
||
print("✅ 训练步骤成功!")
|
||
|
||
# 测试序列返回功能
|
||
print("\n测试序列调试信息返回...")
|
||
with torch.no_grad():
|
||
logits, seq_debug, aux = model(x, mask=mask, return_seq=True)
|
||
|
||
if seq_debug is not None:
|
||
print(f"✅ 序列调试信息获取成功,包含 {len(seq_debug)} 个通道的信息")
|
||
else:
|
||
print("❌ 序列调试信息获取失败")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎉 DC_hnet模型测试全部通过!")
|
||
print("模型可以正常进行时间序列分类任务")
|
||
print("=" * 60)
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def test_different_configurations():
|
||
"""测试不同的模型配置"""
|
||
print("\n测试不同配置...")
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 测试配置1: 不共享通道
|
||
cfg1 = HierEncodersSingleMainConfig(
|
||
num_channels=3,
|
||
d_models=[32, 64],
|
||
num_classes=5,
|
||
encoder_cfg_per_stage=[dict(arch="m", height=2)],
|
||
main_cfg=dict(arch="m", height=3),
|
||
target_compression_N_per_stage=[2],
|
||
share_channel=False, # 不共享通道
|
||
fusion_across_channels="concat", # 连接融合
|
||
dropout=0.1,
|
||
)
|
||
|
||
try:
|
||
dtype1 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||
model1 = HierEncodersSingleMainClassifier(cfg1, device=device, dtype=dtype1)
|
||
x1 = torch.randn(4, 256, 3, device=device, dtype=dtype1)
|
||
logits1, _, _ = model1(x1)
|
||
print(f"✅ 配置1 (不共享通道, concat融合): 输出形状 {logits1.shape}")
|
||
except Exception as e:
|
||
print(f"❌ 配置1测试失败: {str(e)}")
|
||
|
||
# 测试配置2: 单层模型
|
||
cfg2 = HierEncodersSingleMainConfig(
|
||
num_channels=2,
|
||
d_models=[128], # 只有一层,没有编码器阶段
|
||
num_classes=3,
|
||
encoder_cfg_per_stage=[], # 空的编码器阶段
|
||
main_cfg=dict(arch="T", height=2),
|
||
target_compression_N_per_stage=[],
|
||
share_channel=True,
|
||
fusion_across_channels="mean",
|
||
dropout=0.1,
|
||
)
|
||
|
||
try:
|
||
dtype2 = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||
model2 = HierEncodersSingleMainClassifier(cfg2, device=device, dtype=dtype2)
|
||
x2 = torch.randn(2, 128, 2, device=device, dtype=dtype2)
|
||
logits2, _, _ = model2(x2)
|
||
print(f"✅ 配置2 (单层模型): 输出形状 {logits2.shape}")
|
||
except Exception as e:
|
||
print(f"❌ 配置2测试失败: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
# 主测试
|
||
success = test_dc_hnet_model()
|
||
|
||
# 额外配置测试
|
||
test_different_configurations()
|
||
|
||
if success:
|
||
print("\n🎊 所有测试完成!")
|
||
sys.exit(0)
|
||
else:
|
||
print("\n💥 测试失败!")
|
||
sys.exit(1) |