feat: 架构引入DiT
This commit is contained in:
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _load_gr00t_dit():
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
dit_path = repo_root / "gr00t" / "models" / "dit.py"
|
||||
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Unable to load DiT from {dit_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module.DiT
|
||||
|
||||
|
||||
DiT = _load_gr00t_dit()
|
||||
|
||||
|
||||
class Gr00tDiT1D(nn.Module):
|
||||
"""
|
||||
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
|
||||
|
||||
Expected forward interface:
|
||||
- sample: (B, T_action, input_dim)
|
||||
- timestep: (B,) or scalar diffusion timestep
|
||||
- cond: (B, T_obs, cond_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int,
|
||||
cond_dim: int,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
hidden_dim: int = 256,
|
||||
mlp_ratio: int = 4,
|
||||
dropout: float = 0.1,
|
||||
add_action_pos_emb: bool = True,
|
||||
add_cond_pos_emb: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if cond_dim <= 0:
|
||||
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
|
||||
|
||||
self.horizon = horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
|
||||
self.input_proj = nn.Linear(input_dim, n_emb)
|
||||
self.cond_proj = nn.Linear(cond_dim, n_emb)
|
||||
self.output_proj = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
self.action_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, horizon, n_emb))
|
||||
if add_action_pos_emb
|
||||
else None
|
||||
)
|
||||
self.cond_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
|
||||
if add_cond_pos_emb
|
||||
else None
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
embed_dim=n_emb,
|
||||
nheads=n_head,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dropout=dropout,
|
||||
num_layers=n_layer,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.dit = DiT(args, cross_attention_dim=n_emb)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
if self.action_pos_emb is not None:
|
||||
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
|
||||
def _normalize_timesteps(
|
||||
self,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if not torch.is_tensor(timestep):
|
||||
timesteps = torch.tensor([timestep], device=device)
|
||||
else:
|
||||
timesteps = timestep.to(device)
|
||||
|
||||
if timesteps.ndim == 0:
|
||||
timesteps = timesteps[None]
|
||||
if timesteps.shape[0] != batch_size:
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
|
||||
return timesteps.long()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cond is None:
|
||||
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
|
||||
|
||||
bsz, t_act, _ = sample.shape
|
||||
if t_act > self.horizon:
|
||||
raise ValueError(
|
||||
f"sample length {t_act} exceeds configured horizon {self.horizon}"
|
||||
)
|
||||
|
||||
hidden_states = self.input_proj(sample)
|
||||
if self.action_pos_emb is not None:
|
||||
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
|
||||
|
||||
encoder_hidden_states = self.cond_proj(cond)
|
||||
if self.cond_pos_emb is not None:
|
||||
t_obs = encoder_hidden_states.shape[1]
|
||||
if t_obs > self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
|
||||
)
|
||||
|
||||
timesteps = self._normalize_timesteps(
|
||||
timestep, batch_size=bsz, device=sample.device
|
||||
)
|
||||
dit_output = self.dit(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return self.output_proj(dit_output)
|
||||
Reference in New Issue
Block a user