diff --git a/as_mamba.py b/as_mamba.py index 9574a6f..c6edbe6 100644 --- a/as_mamba.py +++ b/as_mamba.py @@ -152,7 +152,7 @@ class ASMamba(nn.Module): feats, h = self.backbone(x, cond_vec, h) delta = self.delta_head(feats) dt_raw = self.dt_head(feats).squeeze(-1) - dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max) + dt = F.softplus(dt_raw) return delta, dt, h def step(