chore: 删除unet里的local_cond(未使用)
This commit is contained in:
@@ -124,7 +124,6 @@ class ConditionalResidualBlock1D(nn.Module):
|
|||||||
class ConditionalUnet1D(nn.Module):
|
class ConditionalUnet1D(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_dim,
|
input_dim,
|
||||||
local_cond_dim=None,
|
|
||||||
global_cond_dim=None,
|
global_cond_dim=None,
|
||||||
diffusion_step_embed_dim=256,
|
diffusion_step_embed_dim=256,
|
||||||
down_dims=[256,512,1024],
|
down_dims=[256,512,1024],
|
||||||
@@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||||
|
|
||||||
local_cond_encoder = None
|
|
||||||
if local_cond_dim is not None:
|
|
||||||
_, dim_out = in_out[0]
|
|
||||||
dim_in = local_cond_dim
|
|
||||||
local_cond_encoder = nn.ModuleList([
|
|
||||||
# down encoder
|
|
||||||
ConditionalResidualBlock1D(
|
|
||||||
dim_in, dim_out, cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size, n_groups=n_groups,
|
|
||||||
cond_predict_scale=cond_predict_scale),
|
|
||||||
# up encoder
|
|
||||||
ConditionalResidualBlock1D(
|
|
||||||
dim_in, dim_out, cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size, n_groups=n_groups,
|
|
||||||
cond_predict_scale=cond_predict_scale)
|
|
||||||
])
|
|
||||||
|
|
||||||
mid_dim = all_dims[-1]
|
mid_dim = all_dims[-1]
|
||||||
self.mid_modules = nn.ModuleList([
|
self.mid_modules = nn.ModuleList([
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
@@ -216,7 +198,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.diffusion_step_encoder = diffusion_step_encoder
|
self.diffusion_step_encoder = diffusion_step_encoder
|
||||||
self.local_cond_encoder = local_cond_encoder
|
|
||||||
self.up_modules = up_modules
|
self.up_modules = up_modules
|
||||||
self.down_modules = down_modules
|
self.down_modules = down_modules
|
||||||
self.final_conv = final_conv
|
self.final_conv = final_conv
|
||||||
@@ -225,12 +206,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
local_cond=None, global_cond=None,
|
global_cond=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
x: (B,T,input_dim)
|
x: (B,T,input_dim)
|
||||||
timestep: (B,) or int, diffusion step
|
timestep: (B,) or int, diffusion step
|
||||||
local_cond: (B,T,local_cond_dim)
|
|
||||||
global_cond: (B,global_cond_dim)
|
global_cond: (B,global_cond_dim)
|
||||||
output: (B,T,input_dim)
|
output: (B,T,input_dim)
|
||||||
"""
|
"""
|
||||||
@@ -253,22 +233,10 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
global_feature, global_cond
|
global_feature, global_cond
|
||||||
], axis=-1)
|
], axis=-1)
|
||||||
|
|
||||||
# encode local features
|
|
||||||
h_local = list()
|
|
||||||
if local_cond is not None:
|
|
||||||
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
|
|
||||||
resnet, resnet2 = self.local_cond_encoder
|
|
||||||
x = resnet(local_cond, global_feature)
|
|
||||||
h_local.append(x)
|
|
||||||
x = resnet2(local_cond, global_feature)
|
|
||||||
h_local.append(x)
|
|
||||||
|
|
||||||
x = sample
|
x = sample
|
||||||
h = []
|
h = []
|
||||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
if idx == 0 and len(h_local) > 0:
|
|
||||||
x = x + h_local[0]
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
@@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
# The correct condition should be:
|
|
||||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
|
||||||
# However this change will break compatibility with published checkpoints.
|
|
||||||
# Therefore it is left as a comment.
|
|
||||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
|
||||||
x = x + h_local[1]
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user