fix: remove dt clamping and use raw softplus for step size
This commit is contained in:
@@ -152,7 +152,7 @@ class ASMamba(nn.Module):
|
|||||||
feats, h = self.backbone(x, cond_vec, h)
|
feats, h = self.backbone(x, cond_vec, h)
|
||||||
delta = self.delta_head(feats)
|
delta = self.delta_head(feats)
|
||||||
dt_raw = self.dt_head(feats).squeeze(-1)
|
dt_raw = self.dt_head(feats).squeeze(-1)
|
||||||
dt = torch.clamp(F.softplus(dt_raw), min=self.dt_min, max=self.dt_max)
|
dt = F.softplus(dt_raw)
|
||||||
return delta, dt, h
|
return delta, dt, h
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
|
|||||||
Reference in New Issue
Block a user