feat: 架构引入DiT

This commit is contained in:
gouhanke
2026-03-06 11:17:54 +08:00
parent ca1716c67f
commit 23088e5e33
4 changed files with 422 additions and 0 deletions

View File

@@ -0,0 +1,217 @@
import torch
import torch.nn as nn
from collections import deque
from typing import Dict
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgentGr00tDiT(nn.Module):
"""
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
Other components (backbone/encoders/scheduler/queue logic) stay aligned
with the existing VLAAgent implementation.
"""
def __init__(
self,
vision_backbone,
state_encoder,
action_encoder,
head,
action_dim,
obs_dim,
pred_horizon=16,
obs_horizon=4,
diffusion_steps=100,
inference_steps=10,
num_cams=3,
dataset_stats=None,
normalization_type="min_max",
num_action_steps=8,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type,
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
if isinstance(head, nn.Module):
self.noise_pred_net = head
else:
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim,
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.reset()
def _get_model_device(self) -> torch.device:
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(states)
return torch.cat([visual_features, state_features], dim=-1)
def compute_loss(self, batch):
actions, states, images = batch["action"], batch["qpos"], batch["images"]
action_is_pad = batch.get("action_is_pad", None)
bsz = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
action_features = self.action_encoder(actions)
cond = self._build_cond(images, states)
noise = torch.randn_like(action_features)
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=action_features.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond,
)
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
def reset(self):
self._queues = {
"qpos": deque(maxlen=self.obs_horizon),
"images": deque(maxlen=self.obs_horizon),
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
if "qpos" in observation:
self._queues["qpos"].append(observation["qpos"].clone())
if "images" in observation:
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
qpos_list = list(self._queues["qpos"])
if len(qpos_list) == 0:
raise ValueError("observation queue is empty.")
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
images_list = list(self._queues["images"])
if len(images_list) == 0:
raise ValueError("image queue is empty.")
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack(
[img[cam_name] for img in images_list], dim=0
).unsqueeze(0)
return {"qpos": batch_qpos, "images": batch_images}
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
device = self._get_model_device()
observation = self._move_to_device(observation, device)
self._populate_queues(observation)
if len(self._queues["action"]) == 0:
batch = self._prepare_observation_batch()
actions = self.predict_action_chunk(batch)
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end]
for i in range(executable_actions.shape[1]):
self._queues["action"].append(executable_actions[:, i].squeeze(0))
return self._queues["action"].popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(batch["images"], batch["qpos"])
@torch.no_grad()
def predict_action(self, images, proprioception):
bsz = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
device = cond.device
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
self.infer_scheduler.set_timesteps(self.inference_steps)
for t in self.infer_scheduler.timesteps:
noise_pred = self.noise_pred_net(
sample=current_actions,
timestep=t,
cond=cond,
)
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
return self.normalization.denormalize_action(current_actions)
def get_normalization_stats(self):
return self.normalization.get_stats()

View File

@@ -0,0 +1,37 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: gr00t_dit1d
- _self_
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
# Model dimensions
action_dim: 16
obs_dim: 16
# Normalization
normalization_type: "min_max"
# Horizons
pred_horizon: 16
obs_horizon: 2
num_action_steps: 8
# Cameras
num_cams: 3
# Diffusion
diffusion_steps: 100
inference_steps: 10
# Head overrides
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: 208

View File

@@ -0,0 +1,22 @@
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
_partial_: true
# DiT architecture
n_layer: 6
n_head: 8
n_emb: 256
hidden_dim: 256
mlp_ratio: 4
dropout: 0.1
# Positional embeddings
add_action_pos_emb: true
add_cond_pos_emb: true
# Supplied by agent interpolation:
# - input_dim
# - output_dim
# - horizon
# - n_obs_steps
# - cond_dim

View File

@@ -0,0 +1,146 @@
import torch
import torch.nn as nn
from types import SimpleNamespace
from typing import Optional, Union
from pathlib import Path
import importlib.util
def _load_gr00t_dit():
repo_root = Path(__file__).resolve().parents[4]
dit_path = repo_root / "gr00t" / "models" / "dit.py"
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load DiT from {dit_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.DiT
DiT = _load_gr00t_dit()
class Gr00tDiT1D(nn.Module):
"""
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
Expected forward interface:
- sample: (B, T_action, input_dim)
- timestep: (B,) or scalar diffusion timestep
- cond: (B, T_obs, cond_dim)
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
hidden_dim: int = 256,
mlp_ratio: int = 4,
dropout: float = 0.1,
add_action_pos_emb: bool = True,
add_cond_pos_emb: bool = True,
):
super().__init__()
if cond_dim <= 0:
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.input_proj = nn.Linear(input_dim, n_emb)
self.cond_proj = nn.Linear(cond_dim, n_emb)
self.output_proj = nn.Linear(hidden_dim, output_dim)
self.action_pos_emb = (
nn.Parameter(torch.zeros(1, horizon, n_emb))
if add_action_pos_emb
else None
)
self.cond_pos_emb = (
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
if add_cond_pos_emb
else None
)
args = SimpleNamespace(
embed_dim=n_emb,
nheads=n_head,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_layers=n_layer,
hidden_dim=hidden_dim,
)
self.dit = DiT(args, cross_attention_dim=n_emb)
self._init_weights()
def _init_weights(self):
if self.action_pos_emb is not None:
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def _normalize_timesteps(
self,
timestep: Union[torch.Tensor, float, int],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
if not torch.is_tensor(timestep):
timesteps = torch.tensor([timestep], device=device)
else:
timesteps = timestep.to(device)
if timesteps.ndim == 0:
timesteps = timesteps[None]
if timesteps.shape[0] != batch_size:
timesteps = timesteps.expand(batch_size)
return timesteps.long()
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if cond is None:
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
bsz, t_act, _ = sample.shape
if t_act > self.horizon:
raise ValueError(
f"sample length {t_act} exceeds configured horizon {self.horizon}"
)
hidden_states = self.input_proj(sample)
if self.action_pos_emb is not None:
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
encoder_hidden_states = self.cond_proj(cond)
if self.cond_pos_emb is not None:
t_obs = encoder_hidden_states.shape[1]
if t_obs > self.n_obs_steps:
raise ValueError(
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
)
encoder_hidden_states = (
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
)
timesteps = self._normalize_timesteps(
timestep, batch_size=bsz, device=sample.device
)
dit_output = self.dit(
hidden_states=hidden_states,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)
return self.output_proj(dit_output)