Files
roboimi/roboimi/vla/agent_gr00t_dit.py
2026-03-06 11:31:37 +08:00

218 lines
7.7 KiB
Python

import torch
import torch.nn as nn
from collections import deque
from typing import Dict
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgentGr00tDiT(nn.Module):
"""
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
Other components (backbone/encoders/scheduler/queue logic) stay aligned
with the existing VLAAgent implementation.
"""
def __init__(
self,
vision_backbone,
state_encoder,
action_encoder,
head,
action_dim,
obs_dim,
pred_horizon=16,
obs_horizon=4,
diffusion_steps=100,
inference_steps=10,
num_cams=3,
dataset_stats=None,
normalization_type="min_max",
num_action_steps=8,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type,
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
if isinstance(head, nn.Module):
self.noise_pred_net = head
else:
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim,
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.reset()
def _get_model_device(self) -> torch.device:
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(states)
return torch.cat([visual_features, state_features], dim=-1)
def compute_loss(self, batch):
actions, states, images = batch["action"], batch["qpos"], batch["images"]
action_is_pad = batch.get("action_is_pad", None)
bsz = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
action_features = self.action_encoder(actions)
cond = self._build_cond(images, states)
noise = torch.randn_like(action_features)
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=action_features.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond,
)
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
def reset(self):
self._queues = {
"qpos": deque(maxlen=self.obs_horizon),
"images": deque(maxlen=self.obs_horizon),
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
if "qpos" in observation:
self._queues["qpos"].append(observation["qpos"].clone())
if "images" in observation:
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
qpos_list = list(self._queues["qpos"])
if len(qpos_list) == 0:
raise ValueError("observation queue is empty.")
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
images_list = list(self._queues["images"])
if len(images_list) == 0:
raise ValueError("image queue is empty.")
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack(
[img[cam_name] for img in images_list], dim=0
).unsqueeze(0)
return {"qpos": batch_qpos, "images": batch_images}
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
device = self._get_model_device()
observation = self._move_to_device(observation, device)
self._populate_queues(observation)
if len(self._queues["action"]) == 0:
batch = self._prepare_observation_batch()
actions = self.predict_action_chunk(batch)
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end]
for i in range(executable_actions.shape[1]):
self._queues["action"].append(executable_actions[:, i].squeeze(0))
return self._queues["action"].popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(batch["images"], batch["qpos"])
@torch.no_grad()
def predict_action(self, images, proprioception):
bsz = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
device = cond.device
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
self.infer_scheduler.set_timesteps(self.inference_steps)
for t in self.infer_scheduler.timesteps:
noise_pred = self.noise_pred_net(
sample=current_actions,
timestep=t,
cond=cond,
)
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
return self.normalization.denormalize_action(current_actions)
def get_normalization_stats(self):
return self.normalization.get_stats()