feat: 架构引入DiT
This commit is contained in:
217
roboimi/vla/agent_gr00t_dit.py
Normal file
217
roboimi/vla/agent_gr00t_dit.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import deque
|
||||
from typing import Dict
|
||||
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
from roboimi.vla.models.normalization import NormalizationModule
|
||||
|
||||
|
||||
class VLAAgentGr00tDiT(nn.Module):
|
||||
"""
|
||||
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
|
||||
Other components (backbone/encoders/scheduler/queue logic) stay aligned
|
||||
with the existing VLAAgent implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
head,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
pred_horizon=16,
|
||||
obs_horizon=4,
|
||||
diffusion_steps=100,
|
||||
inference_steps=10,
|
||||
num_cams=3,
|
||||
dataset_stats=None,
|
||||
normalization_type="min_max",
|
||||
num_action_steps=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.num_cams = num_cams
|
||||
self.num_action_steps = num_action_steps
|
||||
self.inference_steps = inference_steps
|
||||
|
||||
self.normalization = NormalizationModule(
|
||||
stats=dataset_stats,
|
||||
normalization_type=normalization_type,
|
||||
)
|
||||
|
||||
self.vision_encoder = vision_backbone
|
||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
self.infer_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=diffusion_steps,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
if isinstance(head, nn.Module):
|
||||
self.noise_pred_net = head
|
||||
else:
|
||||
self.noise_pred_net = head(
|
||||
input_dim=action_dim,
|
||||
output_dim=action_dim,
|
||||
horizon=pred_horizon,
|
||||
n_obs_steps=obs_horizon,
|
||||
cond_dim=self.per_step_cond_dim,
|
||||
)
|
||||
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
self.reset()
|
||||
|
||||
def _get_model_device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def _move_to_device(self, data, device: torch.device):
|
||||
if torch.is_tensor(data):
|
||||
return data.to(device)
|
||||
if isinstance(data, dict):
|
||||
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [self._move_to_device(v, device) for v in data]
|
||||
if isinstance(data, tuple):
|
||||
return tuple(self._move_to_device(v, device) for v in data)
|
||||
return data
|
||||
|
||||
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||
visual_features = self.vision_encoder(images)
|
||||
state_features = self.state_encoder(states)
|
||||
return torch.cat([visual_features, state_features], dim=-1)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
actions, states, images = batch["action"], batch["qpos"], batch["images"]
|
||||
action_is_pad = batch.get("action_is_pad", None)
|
||||
bsz = actions.shape[0]
|
||||
|
||||
states = self.normalization.normalize_qpos(states)
|
||||
actions = self.normalization.normalize_action(actions)
|
||||
|
||||
action_features = self.action_encoder(actions)
|
||||
cond = self._build_cond(images, states)
|
||||
|
||||
noise = torch.randn_like(action_features)
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=action_features.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
|
||||
|
||||
pred_noise = self.noise_pred_net(
|
||||
sample=noisy_actions,
|
||||
timestep=timesteps,
|
||||
cond=cond,
|
||||
)
|
||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
|
||||
|
||||
if action_is_pad is not None:
|
||||
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||
valid_count = mask.sum() * loss.shape[-1]
|
||||
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
def reset(self):
|
||||
self._queues = {
|
||||
"qpos": deque(maxlen=self.obs_horizon),
|
||||
"images": deque(maxlen=self.obs_horizon),
|
||||
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
|
||||
}
|
||||
|
||||
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||
if "qpos" in observation:
|
||||
self._queues["qpos"].append(observation["qpos"].clone())
|
||||
if "images" in observation:
|
||||
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
|
||||
|
||||
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||
qpos_list = list(self._queues["qpos"])
|
||||
if len(qpos_list) == 0:
|
||||
raise ValueError("observation queue is empty.")
|
||||
while len(qpos_list) < self.obs_horizon:
|
||||
qpos_list.append(qpos_list[-1])
|
||||
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||
|
||||
images_list = list(self._queues["images"])
|
||||
if len(images_list) == 0:
|
||||
raise ValueError("image queue is empty.")
|
||||
while len(images_list) < self.obs_horizon:
|
||||
images_list.append(images_list[-1])
|
||||
|
||||
batch_images = {}
|
||||
for cam_name in images_list[0].keys():
|
||||
batch_images[cam_name] = torch.stack(
|
||||
[img[cam_name] for img in images_list], dim=0
|
||||
).unsqueeze(0)
|
||||
|
||||
return {"qpos": batch_qpos, "images": batch_images}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
device = self._get_model_device()
|
||||
observation = self._move_to_device(observation, device)
|
||||
self._populate_queues(observation)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = self._prepare_observation_batch()
|
||||
actions = self.predict_action_chunk(batch)
|
||||
start = self.obs_horizon - 1
|
||||
end = start + self.num_action_steps
|
||||
executable_actions = actions[:, start:end]
|
||||
for i in range(executable_actions.shape[1]):
|
||||
self._queues["action"].append(executable_actions[:, i].squeeze(0))
|
||||
|
||||
return self._queues["action"].popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return self.predict_action(batch["images"], batch["qpos"])
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, images, proprioception):
|
||||
bsz = proprioception.shape[0]
|
||||
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||
cond = self._build_cond(images, proprioception)
|
||||
|
||||
device = cond.device
|
||||
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
|
||||
self.infer_scheduler.set_timesteps(self.inference_steps)
|
||||
|
||||
for t in self.infer_scheduler.timesteps:
|
||||
noise_pred = self.noise_pred_net(
|
||||
sample=current_actions,
|
||||
timestep=t,
|
||||
cond=cond,
|
||||
)
|
||||
current_actions = self.infer_scheduler.step(
|
||||
noise_pred, t, current_actions
|
||||
).prev_sample
|
||||
|
||||
return self.normalization.denormalize_action(current_actions)
|
||||
|
||||
def get_normalization_stats(self):
|
||||
return self.normalization.get_stats()
|
||||
|
||||
37
roboimi/vla/conf/agent/resnet_gr00t_dit.yaml
Normal file
37
roboimi/vla/conf/agent/resnet_gr00t_dit.yaml
Normal file
@@ -0,0 +1,37 @@
|
||||
# @package agent
|
||||
defaults:
|
||||
- /backbone@vision_backbone: resnet_diffusion
|
||||
- /modules@state_encoder: identity_state_encoder
|
||||
- /modules@action_encoder: identity_action_encoder
|
||||
- /head: gr00t_dit1d
|
||||
- _self_
|
||||
|
||||
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
|
||||
|
||||
# Model dimensions
|
||||
action_dim: 16
|
||||
obs_dim: 16
|
||||
|
||||
# Normalization
|
||||
normalization_type: "min_max"
|
||||
|
||||
# Horizons
|
||||
pred_horizon: 16
|
||||
obs_horizon: 2
|
||||
num_action_steps: 8
|
||||
|
||||
# Cameras
|
||||
num_cams: 3
|
||||
|
||||
# Diffusion
|
||||
diffusion_steps: 100
|
||||
inference_steps: 10
|
||||
|
||||
# Head overrides
|
||||
head:
|
||||
input_dim: ${agent.action_dim}
|
||||
output_dim: ${agent.action_dim}
|
||||
horizon: ${agent.pred_horizon}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
cond_dim: 208
|
||||
|
||||
22
roboimi/vla/conf/head/gr00t_dit1d.yaml
Normal file
22
roboimi/vla/conf/head/gr00t_dit1d.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
|
||||
_partial_: true
|
||||
|
||||
# DiT architecture
|
||||
n_layer: 6
|
||||
n_head: 8
|
||||
n_emb: 256
|
||||
hidden_dim: 256
|
||||
mlp_ratio: 4
|
||||
dropout: 0.1
|
||||
|
||||
# Positional embeddings
|
||||
add_action_pos_emb: true
|
||||
add_cond_pos_emb: true
|
||||
|
||||
# Supplied by agent interpolation:
|
||||
# - input_dim
|
||||
# - output_dim
|
||||
# - horizon
|
||||
# - n_obs_steps
|
||||
# - cond_dim
|
||||
|
||||
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _load_gr00t_dit():
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
dit_path = repo_root / "gr00t" / "models" / "dit.py"
|
||||
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Unable to load DiT from {dit_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module.DiT
|
||||
|
||||
|
||||
DiT = _load_gr00t_dit()
|
||||
|
||||
|
||||
class Gr00tDiT1D(nn.Module):
|
||||
"""
|
||||
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
|
||||
|
||||
Expected forward interface:
|
||||
- sample: (B, T_action, input_dim)
|
||||
- timestep: (B,) or scalar diffusion timestep
|
||||
- cond: (B, T_obs, cond_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int,
|
||||
cond_dim: int,
|
||||
n_layer: int = 8,
|
||||
n_head: int = 8,
|
||||
n_emb: int = 256,
|
||||
hidden_dim: int = 256,
|
||||
mlp_ratio: int = 4,
|
||||
dropout: float = 0.1,
|
||||
add_action_pos_emb: bool = True,
|
||||
add_cond_pos_emb: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if cond_dim <= 0:
|
||||
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
|
||||
|
||||
self.horizon = horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
|
||||
self.input_proj = nn.Linear(input_dim, n_emb)
|
||||
self.cond_proj = nn.Linear(cond_dim, n_emb)
|
||||
self.output_proj = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
self.action_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, horizon, n_emb))
|
||||
if add_action_pos_emb
|
||||
else None
|
||||
)
|
||||
self.cond_pos_emb = (
|
||||
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
|
||||
if add_cond_pos_emb
|
||||
else None
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
embed_dim=n_emb,
|
||||
nheads=n_head,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dropout=dropout,
|
||||
num_layers=n_layer,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.dit = DiT(args, cross_attention_dim=n_emb)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
if self.action_pos_emb is not None:
|
||||
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
|
||||
if self.cond_pos_emb is not None:
|
||||
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||
|
||||
def _normalize_timesteps(
|
||||
self,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if not torch.is_tensor(timestep):
|
||||
timesteps = torch.tensor([timestep], device=device)
|
||||
else:
|
||||
timesteps = timestep.to(device)
|
||||
|
||||
if timesteps.ndim == 0:
|
||||
timesteps = timesteps[None]
|
||||
if timesteps.shape[0] != batch_size:
|
||||
timesteps = timesteps.expand(batch_size)
|
||||
|
||||
return timesteps.long()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cond is None:
|
||||
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
|
||||
|
||||
bsz, t_act, _ = sample.shape
|
||||
if t_act > self.horizon:
|
||||
raise ValueError(
|
||||
f"sample length {t_act} exceeds configured horizon {self.horizon}"
|
||||
)
|
||||
|
||||
hidden_states = self.input_proj(sample)
|
||||
if self.action_pos_emb is not None:
|
||||
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
|
||||
|
||||
encoder_hidden_states = self.cond_proj(cond)
|
||||
if self.cond_pos_emb is not None:
|
||||
t_obs = encoder_hidden_states.shape[1]
|
||||
if t_obs > self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
|
||||
)
|
||||
|
||||
timesteps = self._normalize_timesteps(
|
||||
timestep, batch_size=bsz, device=sample.device
|
||||
)
|
||||
dit_output = self.dit(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return self.output_proj(dit_output)
|
||||
Reference in New Issue
Block a user