chore: 导入gr00t
This commit is contained in:
142
gr00t/models/dit.py
Normal file
142
gr00t/models/dit.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
embedding_dim = args.embed_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
||||
super().__init__()
|
||||
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
||||
super().__init__()
|
||||
dim = args.embed_dim
|
||||
num_heads = args.nheads
|
||||
mlp_ratio = args.mlp_ratio
|
||||
dropout = args.dropout
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
|
||||
if not use_self_attn:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
kdim=crosss_attention_dim,
|
||||
vdim=crosss_attention_dim,
|
||||
batch_first=True,
|
||||
)
|
||||
else:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * mlp_ratio),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mlp_ratio, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb, context=None):
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
|
||||
attn_output = self.attn(
|
||||
norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
ff_output = self.mlp(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(self, args, cross_attention_dim):
|
||||
super().__init__()
|
||||
inner_dim = args.embed_dim
|
||||
num_layers = args.num_layers
|
||||
output_dim = args.hidden_dim
|
||||
|
||||
self.timestep_encoder = TimestepEncoder(args)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(num_layers):
|
||||
use_self_attn = idx % 2 == 1
|
||||
if use_self_attn:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
||||
else:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
||||
all_blocks.append(block)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(hidden_states, temb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
def build_dit(args, cross_attention_dim):
|
||||
return DiT(args, cross_attention_dim)
|
||||
Reference in New Issue
Block a user