from typing import Optional from diffusers import ConfigMixin, ModelMixin from diffusers.configuration_utils import register_to_config from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps import torch from torch import nn import torch.nn.functional as F class TimestepEncoder(nn.Module): def __init__(self, args): super().__init__() embedding_dim = args.embed_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) def forward(self, timesteps): dtype = next(self.parameters()).dtype timesteps_proj = self.time_proj(timesteps).to(dtype) timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) return timesteps_emb class AdaLayerNorm(nn.Module): def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False): super().__init__() output_dim = embedding_dim * 2 self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( self, x: torch.Tensor, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: temb = self.linear(self.silu(temb)) scale, shift = temb.chunk(2, dim=1) x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] return x class BasicTransformerBlock(nn.Module): def __init__(self, args, crosss_attention_dim, use_self_attn=False): super().__init__() dim = args.embed_dim num_heads = args.nheads mlp_ratio = args.mlp_ratio dropout = args.dropout self.norm1 = AdaLayerNorm(dim) if not use_self_attn: self.attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, dropout=dropout, kdim=crosss_attention_dim, vdim=crosss_attention_dim, batch_first=True, ) else: self.attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False) self.mlp = nn.Sequential( nn.Linear(dim, dim * mlp_ratio), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * mlp_ratio, dim), nn.Dropout(dropout) ) def forward(self, hidden_states, temb, context=None): norm_hidden_states = self.norm1(hidden_states, temb) attn_output = self.attn( norm_hidden_states, context if context is not None else norm_hidden_states, context if context is not None else norm_hidden_states, )[0] hidden_states = attn_output + hidden_states norm_hidden_states = self.norm2(hidden_states) ff_output = self.mlp(norm_hidden_states) hidden_states = ff_output + hidden_states return hidden_states class DiT(nn.Module): def __init__(self, args, cross_attention_dim): super().__init__() inner_dim = args.embed_dim num_layers = args.num_layers output_dim = args.hidden_dim self.timestep_encoder = TimestepEncoder(args) all_blocks = [] for idx in range(num_layers): use_self_attn = idx % 2 == 1 if use_self_attn: block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True) else: block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False) all_blocks.append(block) self.transformer_blocks = nn.ModuleList(all_blocks) self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, output_dim) def forward(self, hidden_states, timestep, encoder_hidden_states): temb = self.timestep_encoder(timestep) hidden_states = hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous() for idx, block in enumerate(self.transformer_blocks): if idx % 2 == 1: hidden_states = block(hidden_states, temb) else: hidden_states = block(hidden_states, temb, context=encoder_hidden_states) conditioning = temb shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] return self.proj_out_2(hidden_states) def build_dit(args, cross_attention_dim): return DiT(args, cross_attention_dim)