chore: 删除unet里的local_cond(未使用)

This commit is contained in:
gouhanke
2026-02-28 10:42:16 +08:00
parent 1d33db0ef0
commit abb4f501e3

View File

@@ -124,7 +124,6 @@ class ConditionalResidualBlock1D(nn.Module):
class ConditionalUnet1D(nn.Module): class ConditionalUnet1D(nn.Module):
def __init__(self, def __init__(self,
input_dim, input_dim,
local_cond_dim=None,
global_cond_dim=None, global_cond_dim=None,
diffusion_step_embed_dim=256, diffusion_step_embed_dim=256,
down_dims=[256,512,1024], down_dims=[256,512,1024],
@@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module):
in_out = list(zip(all_dims[:-1], all_dims[1:])) in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList([
# down encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
# up encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale)
])
mid_dim = all_dims[-1] mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([ self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D( ConditionalResidualBlock1D(
@@ -216,7 +198,6 @@ class ConditionalUnet1D(nn.Module):
) )
self.diffusion_step_encoder = diffusion_step_encoder self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules self.up_modules = up_modules
self.down_modules = down_modules self.down_modules = down_modules
self.final_conv = final_conv self.final_conv = final_conv
@@ -225,12 +206,11 @@ class ConditionalUnet1D(nn.Module):
def forward(self, def forward(self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None, global_cond=None,
**kwargs): **kwargs):
""" """
x: (B,T,input_dim) x: (B,T,input_dim)
timestep: (B,) or int, diffusion step timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim) global_cond: (B,global_cond_dim)
output: (B,T,input_dim) output: (B,T,input_dim)
""" """
@@ -253,22 +233,10 @@ class ConditionalUnet1D(nn.Module):
global_feature, global_cond global_feature, global_cond
], axis=-1) ], axis=-1)
# encode local features
h_local = list()
if local_cond is not None:
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
x = sample x = sample
h = [] h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature) x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
x = resnet2(x, global_feature) x = resnet2(x, global_feature)
h.append(x) h.append(x)
x = downsample(x) x = downsample(x)
@@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module):
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1) x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature) x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature) x = resnet2(x, global_feature)
x = upsample(x) x = upsample(x)