chore: 导入gr00t

This commit is contained in:
gouhanke
2026-03-06 11:17:28 +08:00
parent 642d41dd8f
commit ca1716c67f
9 changed files with 922 additions and 166 deletions

142
gr00t/models/dit.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
import torch
from torch import nn
import torch.nn.functional as F
class TimestepEncoder(nn.Module):
def __init__(self, args):
super().__init__()
embedding_dim = args.embed_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
super().__init__()
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
super().__init__()
dim = args.embed_dim
num_heads = args.nheads
mlp_ratio = args.mlp_ratio
dropout = args.dropout
self.norm1 = AdaLayerNorm(dim)
if not use_self_attn:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
kdim=crosss_attention_dim,
vdim=crosss_attention_dim,
batch_first=True,
)
else:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mlp_ratio, dim),
nn.Dropout(dropout)
)
def forward(self, hidden_states, temb, context=None):
norm_hidden_states = self.norm1(hidden_states, temb)
attn_output = self.attn(
norm_hidden_states,
context if context is not None else norm_hidden_states,
context if context is not None else norm_hidden_states,
)[0]
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
def __init__(self, args, cross_attention_dim):
super().__init__()
inner_dim = args.embed_dim
num_layers = args.num_layers
output_dim = args.hidden_dim
self.timestep_encoder = TimestepEncoder(args)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1
if use_self_attn:
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
else:
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
all_blocks.append(block)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
def forward(self, hidden_states, timestep, encoder_hidden_states):
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(hidden_states, temb)
else:
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
def build_dit(args, cross_attention_dim):
return DiT(args, cross_attention_dim)