feat(model): introduce dynamic training flag for model forward pass

This commit is contained in:
gameloader
2025-09-13 00:04:01 +08:00
parent 172328a4e6
commit 93f14077da
4 changed files with 40 additions and 22 deletions

View File

@ -98,7 +98,7 @@ class SeasonPatch(nn.Module):
else:
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
def forward(self, x ):
def forward(self, x, training):
# x: [B, L, C]
x = x.permute(0, 2, 1) # x: [B, C, L]
@ -114,7 +114,7 @@ class SeasonPatch(nn.Module):
z = z.permute(0, 1, 3, 2) # z: [B, C, patch_num, d_model]
# Cross-channel mixing
z_mix = self.mixer(z) # z_mix: [B, C, patch_num, d_model]
z_mix = self.mixer(z, training) # z_mix: [B, C, patch_num, d_model]
# Flatten and predict
z_mix = z_mix.view(B, C, N * D) # z_mix: [B, C, patch_num * d_model]