142 lines
4.9 KiB
Python
142 lines
4.9 KiB
Python
from typing import Optional
|
|
|
|
from diffusers import ConfigMixin, ModelMixin
|
|
from diffusers.configuration_utils import register_to_config
|
|
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
class TimestepEncoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
embedding_dim = args.embed_dim
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
|
|
|
def forward(self, timesteps):
|
|
dtype = next(self.parameters()).dtype
|
|
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
|
return timesteps_emb
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
|
super().__init__()
|
|
|
|
output_dim = embedding_dim * 2
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
temb: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
temb = self.linear(self.silu(temb))
|
|
scale, shift = temb.chunk(2, dim=1)
|
|
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
|
return x
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
|
super().__init__()
|
|
dim = args.embed_dim
|
|
num_heads = args.nheads
|
|
mlp_ratio = args.mlp_ratio
|
|
dropout = args.dropout
|
|
self.norm1 = AdaLayerNorm(dim)
|
|
|
|
if not use_self_attn:
|
|
self.attn = nn.MultiheadAttention(
|
|
embed_dim=dim,
|
|
num_heads=num_heads,
|
|
dropout=dropout,
|
|
kdim=crosss_attention_dim,
|
|
vdim=crosss_attention_dim,
|
|
batch_first=True,
|
|
)
|
|
else:
|
|
self.attn = nn.MultiheadAttention(
|
|
embed_dim=dim,
|
|
num_heads=num_heads,
|
|
dropout=dropout,
|
|
batch_first=True,
|
|
)
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(dim, dim * mlp_ratio),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(dim * mlp_ratio, dim),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, hidden_states, temb, context=None):
|
|
norm_hidden_states = self.norm1(hidden_states, temb)
|
|
|
|
attn_output = self.attn(
|
|
norm_hidden_states,
|
|
context if context is not None else norm_hidden_states,
|
|
context if context is not None else norm_hidden_states,
|
|
)[0]
|
|
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
norm_hidden_states = self.norm2(hidden_states)
|
|
|
|
ff_output = self.mlp(norm_hidden_states)
|
|
|
|
hidden_states = ff_output + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
class DiT(nn.Module):
|
|
def __init__(self, args, cross_attention_dim):
|
|
super().__init__()
|
|
inner_dim = args.embed_dim
|
|
num_layers = args.num_layers
|
|
output_dim = args.hidden_dim
|
|
|
|
self.timestep_encoder = TimestepEncoder(args)
|
|
|
|
all_blocks = []
|
|
for idx in range(num_layers):
|
|
use_self_attn = idx % 2 == 1
|
|
if use_self_attn:
|
|
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
|
else:
|
|
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
|
all_blocks.append(block)
|
|
|
|
self.transformer_blocks = nn.ModuleList(all_blocks)
|
|
|
|
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
|
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
|
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
|
|
|
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
|
temb = self.timestep_encoder(timestep)
|
|
|
|
hidden_states = hidden_states.contiguous()
|
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
|
|
for idx, block in enumerate(self.transformer_blocks):
|
|
if idx % 2 == 1:
|
|
hidden_states = block(hidden_states, temb)
|
|
else:
|
|
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
|
|
|
conditioning = temb
|
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
return self.proj_out_2(hidden_states)
|
|
|
|
|
|
def build_dit(args, cross_attention_dim):
|
|
return DiT(args, cross_attention_dim) |