chore: 删除unet里的local_cond(未使用)
This commit is contained in:
@@ -122,9 +122,8 @@ class ConditionalResidualBlock1D(nn.Module):
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256,512,1024],
|
||||
@@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module):
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
|
||||
local_cond_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList([
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in, dim_out, cond_dim=cond_dim,
|
||||
kernel_size=kernel_size, n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale)
|
||||
])
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
@@ -216,21 +198,19 @@ class ConditionalUnet1D(nn.Module):
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None, global_cond=None,
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
global_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
@@ -252,23 +232,11 @@ class ConditionalUnet1D(nn.Module):
|
||||
global_feature = torch.cat([
|
||||
global_feature, global_cond
|
||||
], axis=-1)
|
||||
|
||||
# encode local features
|
||||
h_local = list()
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
@@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module):
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user