feat: 架构引入DiT

This commit is contained in:
gouhanke
2026-03-06 11:17:54 +08:00
parent ca1716c67f
commit 23088e5e33
4 changed files with 422 additions and 0 deletions

View File

@@ -0,0 +1,146 @@
import torch
import torch.nn as nn
from types import SimpleNamespace
from typing import Optional, Union
from pathlib import Path
import importlib.util
def _load_gr00t_dit():
repo_root = Path(__file__).resolve().parents[4]
dit_path = repo_root / "gr00t" / "models" / "dit.py"
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load DiT from {dit_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.DiT
DiT = _load_gr00t_dit()
class Gr00tDiT1D(nn.Module):
"""
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
Expected forward interface:
- sample: (B, T_action, input_dim)
- timestep: (B,) or scalar diffusion timestep
- cond: (B, T_obs, cond_dim)
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
hidden_dim: int = 256,
mlp_ratio: int = 4,
dropout: float = 0.1,
add_action_pos_emb: bool = True,
add_cond_pos_emb: bool = True,
):
super().__init__()
if cond_dim <= 0:
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.input_proj = nn.Linear(input_dim, n_emb)
self.cond_proj = nn.Linear(cond_dim, n_emb)
self.output_proj = nn.Linear(hidden_dim, output_dim)
self.action_pos_emb = (
nn.Parameter(torch.zeros(1, horizon, n_emb))
if add_action_pos_emb
else None
)
self.cond_pos_emb = (
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
if add_cond_pos_emb
else None
)
args = SimpleNamespace(
embed_dim=n_emb,
nheads=n_head,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_layers=n_layer,
hidden_dim=hidden_dim,
)
self.dit = DiT(args, cross_attention_dim=n_emb)
self._init_weights()
def _init_weights(self):
if self.action_pos_emb is not None:
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def _normalize_timesteps(
self,
timestep: Union[torch.Tensor, float, int],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
if not torch.is_tensor(timestep):
timesteps = torch.tensor([timestep], device=device)
else:
timesteps = timestep.to(device)
if timesteps.ndim == 0:
timesteps = timesteps[None]
if timesteps.shape[0] != batch_size:
timesteps = timesteps.expand(batch_size)
return timesteps.long()
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if cond is None:
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
bsz, t_act, _ = sample.shape
if t_act > self.horizon:
raise ValueError(
f"sample length {t_act} exceeds configured horizon {self.horizon}"
)
hidden_states = self.input_proj(sample)
if self.action_pos_emb is not None:
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
encoder_hidden_states = self.cond_proj(cond)
if self.cond_pos_emb is not None:
t_obs = encoder_hidden_states.shape[1]
if t_obs > self.n_obs_steps:
raise ValueError(
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
)
encoder_hidden_states = (
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
)
timesteps = self._normalize_timesteps(
timestep, batch_size=bsz, device=sample.device
)
dit_output = self.dit(
hidden_states=hidden_states,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)
return self.output_proj(dit_output)