diff --git a/roboimi/vla/agent_gr00t_dit.py b/roboimi/vla/agent_gr00t_dit.py new file mode 100644 index 0000000..eadfad8 --- /dev/null +++ b/roboimi/vla/agent_gr00t_dit.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +from collections import deque +from typing import Dict + +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler + +from roboimi.vla.models.normalization import NormalizationModule + + +class VLAAgentGr00tDiT(nn.Module): + """ + VLA Agent variant that swaps Transformer1D head with gr00t DiT head. + Other components (backbone/encoders/scheduler/queue logic) stay aligned + with the existing VLAAgent implementation. + """ + + def __init__( + self, + vision_backbone, + state_encoder, + action_encoder, + head, + action_dim, + obs_dim, + pred_horizon=16, + obs_horizon=4, + diffusion_steps=100, + inference_steps=10, + num_cams=3, + dataset_stats=None, + normalization_type="min_max", + num_action_steps=8, + ): + super().__init__() + self.action_dim = action_dim + self.obs_dim = obs_dim + self.pred_horizon = pred_horizon + self.obs_horizon = obs_horizon + self.num_cams = num_cams + self.num_action_steps = num_action_steps + self.inference_steps = inference_steps + + self.normalization = NormalizationModule( + stats=dataset_stats, + normalization_type=normalization_type, + ) + + self.vision_encoder = vision_backbone + single_cam_feat_dim = self.vision_encoder.output_dim + self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim + + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=diffusion_steps, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + prediction_type="epsilon", + ) + self.infer_scheduler = DDIMScheduler( + num_train_timesteps=diffusion_steps, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + prediction_type="epsilon", + ) + + if isinstance(head, nn.Module): + self.noise_pred_net = head + else: + self.noise_pred_net = head( + input_dim=action_dim, + output_dim=action_dim, + horizon=pred_horizon, + n_obs_steps=obs_horizon, + cond_dim=self.per_step_cond_dim, + ) + + self.state_encoder = state_encoder + self.action_encoder = action_encoder + self.reset() + + def _get_model_device(self) -> torch.device: + return next(self.parameters()).device + + def _move_to_device(self, data, device: torch.device): + if torch.is_tensor(data): + return data.to(device) + if isinstance(data, dict): + return {k: self._move_to_device(v, device) for k, v in data.items()} + if isinstance(data, list): + return [self._move_to_device(v, device) for v in data] + if isinstance(data, tuple): + return tuple(self._move_to_device(v, device) for v in data) + return data + + def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor: + visual_features = self.vision_encoder(images) + state_features = self.state_encoder(states) + return torch.cat([visual_features, state_features], dim=-1) + + def compute_loss(self, batch): + actions, states, images = batch["action"], batch["qpos"], batch["images"] + action_is_pad = batch.get("action_is_pad", None) + bsz = actions.shape[0] + + states = self.normalization.normalize_qpos(states) + actions = self.normalization.normalize_action(actions) + + action_features = self.action_encoder(actions) + cond = self._build_cond(images, states) + + noise = torch.randn_like(action_features) + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=action_features.device, + ).long() + noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps) + + pred_noise = self.noise_pred_net( + sample=noisy_actions, + timestep=timesteps, + cond=cond, + ) + loss = nn.functional.mse_loss(pred_noise, noise, reduction="none") + + if action_is_pad is not None: + mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) + valid_count = mask.sum() * loss.shape[-1] + loss = (loss * mask).sum() / valid_count.clamp_min(1.0) + else: + loss = loss.mean() + + return loss + + def reset(self): + self._queues = { + "qpos": deque(maxlen=self.obs_horizon), + "images": deque(maxlen=self.obs_horizon), + "action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1), + } + + def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None: + if "qpos" in observation: + self._queues["qpos"].append(observation["qpos"].clone()) + if "images" in observation: + self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()}) + + def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]: + qpos_list = list(self._queues["qpos"]) + if len(qpos_list) == 0: + raise ValueError("observation queue is empty.") + while len(qpos_list) < self.obs_horizon: + qpos_list.append(qpos_list[-1]) + batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0) + + images_list = list(self._queues["images"]) + if len(images_list) == 0: + raise ValueError("image queue is empty.") + while len(images_list) < self.obs_horizon: + images_list.append(images_list[-1]) + + batch_images = {} + for cam_name in images_list[0].keys(): + batch_images[cam_name] = torch.stack( + [img[cam_name] for img in images_list], dim=0 + ).unsqueeze(0) + + return {"qpos": batch_qpos, "images": batch_images} + + @torch.no_grad() + def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor: + device = self._get_model_device() + observation = self._move_to_device(observation, device) + self._populate_queues(observation) + + if len(self._queues["action"]) == 0: + batch = self._prepare_observation_batch() + actions = self.predict_action_chunk(batch) + start = self.obs_horizon - 1 + end = start + self.num_action_steps + executable_actions = actions[:, start:end] + for i in range(executable_actions.shape[1]): + self._queues["action"].append(executable_actions[:, i].squeeze(0)) + + return self._queues["action"].popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + return self.predict_action(batch["images"], batch["qpos"]) + + @torch.no_grad() + def predict_action(self, images, proprioception): + bsz = proprioception.shape[0] + proprioception = self.normalization.normalize_qpos(proprioception) + cond = self._build_cond(images, proprioception) + + device = cond.device + current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device) + self.infer_scheduler.set_timesteps(self.inference_steps) + + for t in self.infer_scheduler.timesteps: + noise_pred = self.noise_pred_net( + sample=current_actions, + timestep=t, + cond=cond, + ) + current_actions = self.infer_scheduler.step( + noise_pred, t, current_actions + ).prev_sample + + return self.normalization.denormalize_action(current_actions) + + def get_normalization_stats(self): + return self.normalization.get_stats() + diff --git a/roboimi/vla/conf/agent/resnet_gr00t_dit.yaml b/roboimi/vla/conf/agent/resnet_gr00t_dit.yaml new file mode 100644 index 0000000..e21f39f --- /dev/null +++ b/roboimi/vla/conf/agent/resnet_gr00t_dit.yaml @@ -0,0 +1,37 @@ +# @package agent +defaults: + - /backbone@vision_backbone: resnet_diffusion + - /modules@state_encoder: identity_state_encoder + - /modules@action_encoder: identity_action_encoder + - /head: gr00t_dit1d + - _self_ + +_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT + +# Model dimensions +action_dim: 16 +obs_dim: 16 + +# Normalization +normalization_type: "min_max" + +# Horizons +pred_horizon: 16 +obs_horizon: 2 +num_action_steps: 8 + +# Cameras +num_cams: 3 + +# Diffusion +diffusion_steps: 100 +inference_steps: 10 + +# Head overrides +head: + input_dim: ${agent.action_dim} + output_dim: ${agent.action_dim} + horizon: ${agent.pred_horizon} + n_obs_steps: ${agent.obs_horizon} + cond_dim: 208 + diff --git a/roboimi/vla/conf/head/gr00t_dit1d.yaml b/roboimi/vla/conf/head/gr00t_dit1d.yaml new file mode 100644 index 0000000..acd0ba7 --- /dev/null +++ b/roboimi/vla/conf/head/gr00t_dit1d.yaml @@ -0,0 +1,22 @@ +_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D +_partial_: true + +# DiT architecture +n_layer: 6 +n_head: 8 +n_emb: 256 +hidden_dim: 256 +mlp_ratio: 4 +dropout: 0.1 + +# Positional embeddings +add_action_pos_emb: true +add_cond_pos_emb: true + +# Supplied by agent interpolation: +# - input_dim +# - output_dim +# - horizon +# - n_obs_steps +# - cond_dim + diff --git a/roboimi/vla/models/heads/gr00t_dit1d.py b/roboimi/vla/models/heads/gr00t_dit1d.py new file mode 100644 index 0000000..f6d6a85 --- /dev/null +++ b/roboimi/vla/models/heads/gr00t_dit1d.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +from types import SimpleNamespace +from typing import Optional, Union +from pathlib import Path +import importlib.util + + +def _load_gr00t_dit(): + repo_root = Path(__file__).resolve().parents[4] + dit_path = repo_root / "gr00t" / "models" / "dit.py" + spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load DiT from {dit_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.DiT + + +DiT = _load_gr00t_dit() + + +class Gr00tDiT1D(nn.Module): + """ + Adapter that wraps gr00t DiT with the same call signature used by VLA heads. + + Expected forward interface: + - sample: (B, T_action, input_dim) + - timestep: (B,) or scalar diffusion timestep + - cond: (B, T_obs, cond_dim) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: int, + cond_dim: int, + n_layer: int = 8, + n_head: int = 8, + n_emb: int = 256, + hidden_dim: int = 256, + mlp_ratio: int = 4, + dropout: float = 0.1, + add_action_pos_emb: bool = True, + add_cond_pos_emb: bool = True, + ): + super().__init__() + if cond_dim <= 0: + raise ValueError("Gr00tDiT1D requires cond_dim > 0.") + + self.horizon = horizon + self.n_obs_steps = n_obs_steps + + self.input_proj = nn.Linear(input_dim, n_emb) + self.cond_proj = nn.Linear(cond_dim, n_emb) + self.output_proj = nn.Linear(hidden_dim, output_dim) + + self.action_pos_emb = ( + nn.Parameter(torch.zeros(1, horizon, n_emb)) + if add_action_pos_emb + else None + ) + self.cond_pos_emb = ( + nn.Parameter(torch.zeros(1, n_obs_steps, n_emb)) + if add_cond_pos_emb + else None + ) + + args = SimpleNamespace( + embed_dim=n_emb, + nheads=n_head, + mlp_ratio=mlp_ratio, + dropout=dropout, + num_layers=n_layer, + hidden_dim=hidden_dim, + ) + self.dit = DiT(args, cross_attention_dim=n_emb) + + self._init_weights() + + def _init_weights(self): + if self.action_pos_emb is not None: + nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02) + if self.cond_pos_emb is not None: + nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02) + + def _normalize_timesteps( + self, + timestep: Union[torch.Tensor, float, int], + batch_size: int, + device: torch.device, + ) -> torch.Tensor: + if not torch.is_tensor(timestep): + timesteps = torch.tensor([timestep], device=device) + else: + timesteps = timestep.to(device) + + if timesteps.ndim == 0: + timesteps = timesteps[None] + if timesteps.shape[0] != batch_size: + timesteps = timesteps.expand(batch_size) + + return timesteps.long() + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + cond: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if cond is None: + raise ValueError("`cond` is required for Gr00tDiT1D forward.") + + bsz, t_act, _ = sample.shape + if t_act > self.horizon: + raise ValueError( + f"sample length {t_act} exceeds configured horizon {self.horizon}" + ) + + hidden_states = self.input_proj(sample) + if self.action_pos_emb is not None: + hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :] + + encoder_hidden_states = self.cond_proj(cond) + if self.cond_pos_emb is not None: + t_obs = encoder_hidden_states.shape[1] + if t_obs > self.n_obs_steps: + raise ValueError( + f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}" + ) + encoder_hidden_states = ( + encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :] + ) + + timesteps = self._normalize_timesteps( + timestep, batch_size=bsz, device=sample.device + ) + dit_output = self.dit( + hidden_states=hidden_states, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + ) + return self.output_proj(dit_output)