fix: remove dt clamping and use raw softplus for step size

This commit is contained in:
gameloader
2026-01-22 14:41:02 +08:00
parent 444f5fc109
commit 913740266b

View File

@@ -152,7 +152,7 @@ class ASMamba(nn.Module):
feats, h = self.backbone(x, cond_vec, h) feats, h = self.backbone(x, cond_vec, h)
delta = self.delta_head(feats) delta = self.delta_head(feats)
dt_raw = self.dt_head(feats).squeeze(-1) dt_raw = self.dt_head(feats).squeeze(-1)
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) dt = F.softplus(dt_raw)
return delta, dt, h return delta, dt, h
def step( def step(