218 lines
7.7 KiB
Python
218 lines
7.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from collections import deque
|
|
from typing import Dict
|
|
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
|
|
|
from roboimi.vla.models.normalization import NormalizationModule
|
|
|
|
|
|
class VLAAgentGr00tDiT(nn.Module):
|
|
"""
|
|
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
|
|
Other components (backbone/encoders/scheduler/queue logic) stay aligned
|
|
with the existing VLAAgent implementation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_backbone,
|
|
state_encoder,
|
|
action_encoder,
|
|
head,
|
|
action_dim,
|
|
obs_dim,
|
|
pred_horizon=16,
|
|
obs_horizon=4,
|
|
diffusion_steps=100,
|
|
inference_steps=10,
|
|
num_cams=3,
|
|
dataset_stats=None,
|
|
normalization_type="min_max",
|
|
num_action_steps=8,
|
|
):
|
|
super().__init__()
|
|
self.action_dim = action_dim
|
|
self.obs_dim = obs_dim
|
|
self.pred_horizon = pred_horizon
|
|
self.obs_horizon = obs_horizon
|
|
self.num_cams = num_cams
|
|
self.num_action_steps = num_action_steps
|
|
self.inference_steps = inference_steps
|
|
|
|
self.normalization = NormalizationModule(
|
|
stats=dataset_stats,
|
|
normalization_type=normalization_type,
|
|
)
|
|
|
|
self.vision_encoder = vision_backbone
|
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
|
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
|
|
|
self.noise_scheduler = DDPMScheduler(
|
|
num_train_timesteps=diffusion_steps,
|
|
beta_schedule="squaredcos_cap_v2",
|
|
clip_sample=True,
|
|
prediction_type="epsilon",
|
|
)
|
|
self.infer_scheduler = DDIMScheduler(
|
|
num_train_timesteps=diffusion_steps,
|
|
beta_schedule="squaredcos_cap_v2",
|
|
clip_sample=True,
|
|
prediction_type="epsilon",
|
|
)
|
|
|
|
if isinstance(head, nn.Module):
|
|
self.noise_pred_net = head
|
|
else:
|
|
self.noise_pred_net = head(
|
|
input_dim=action_dim,
|
|
output_dim=action_dim,
|
|
horizon=pred_horizon,
|
|
n_obs_steps=obs_horizon,
|
|
cond_dim=self.per_step_cond_dim,
|
|
)
|
|
|
|
self.state_encoder = state_encoder
|
|
self.action_encoder = action_encoder
|
|
self.reset()
|
|
|
|
def _get_model_device(self) -> torch.device:
|
|
return next(self.parameters()).device
|
|
|
|
def _move_to_device(self, data, device: torch.device):
|
|
if torch.is_tensor(data):
|
|
return data.to(device)
|
|
if isinstance(data, dict):
|
|
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
|
if isinstance(data, list):
|
|
return [self._move_to_device(v, device) for v in data]
|
|
if isinstance(data, tuple):
|
|
return tuple(self._move_to_device(v, device) for v in data)
|
|
return data
|
|
|
|
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
|
visual_features = self.vision_encoder(images)
|
|
state_features = self.state_encoder(states)
|
|
return torch.cat([visual_features, state_features], dim=-1)
|
|
|
|
def compute_loss(self, batch):
|
|
actions, states, images = batch["action"], batch["qpos"], batch["images"]
|
|
action_is_pad = batch.get("action_is_pad", None)
|
|
bsz = actions.shape[0]
|
|
|
|
states = self.normalization.normalize_qpos(states)
|
|
actions = self.normalization.normalize_action(actions)
|
|
|
|
action_features = self.action_encoder(actions)
|
|
cond = self._build_cond(images, states)
|
|
|
|
noise = torch.randn_like(action_features)
|
|
timesteps = torch.randint(
|
|
0,
|
|
self.noise_scheduler.config.num_train_timesteps,
|
|
(bsz,),
|
|
device=action_features.device,
|
|
).long()
|
|
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
|
|
|
|
pred_noise = self.noise_pred_net(
|
|
sample=noisy_actions,
|
|
timestep=timesteps,
|
|
cond=cond,
|
|
)
|
|
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
|
|
|
|
if action_is_pad is not None:
|
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
|
valid_count = mask.sum() * loss.shape[-1]
|
|
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
|
else:
|
|
loss = loss.mean()
|
|
|
|
return loss
|
|
|
|
def reset(self):
|
|
self._queues = {
|
|
"qpos": deque(maxlen=self.obs_horizon),
|
|
"images": deque(maxlen=self.obs_horizon),
|
|
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
|
|
}
|
|
|
|
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
|
if "qpos" in observation:
|
|
self._queues["qpos"].append(observation["qpos"].clone())
|
|
if "images" in observation:
|
|
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
|
|
|
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
|
qpos_list = list(self._queues["qpos"])
|
|
if len(qpos_list) == 0:
|
|
raise ValueError("observation queue is empty.")
|
|
while len(qpos_list) < self.obs_horizon:
|
|
qpos_list.append(qpos_list[-1])
|
|
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
|
|
|
images_list = list(self._queues["images"])
|
|
if len(images_list) == 0:
|
|
raise ValueError("image queue is empty.")
|
|
while len(images_list) < self.obs_horizon:
|
|
images_list.append(images_list[-1])
|
|
|
|
batch_images = {}
|
|
for cam_name in images_list[0].keys():
|
|
batch_images[cam_name] = torch.stack(
|
|
[img[cam_name] for img in images_list], dim=0
|
|
).unsqueeze(0)
|
|
|
|
return {"qpos": batch_qpos, "images": batch_images}
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
device = self._get_model_device()
|
|
observation = self._move_to_device(observation, device)
|
|
self._populate_queues(observation)
|
|
|
|
if len(self._queues["action"]) == 0:
|
|
batch = self._prepare_observation_batch()
|
|
actions = self.predict_action_chunk(batch)
|
|
start = self.obs_horizon - 1
|
|
end = start + self.num_action_steps
|
|
executable_actions = actions[:, start:end]
|
|
for i in range(executable_actions.shape[1]):
|
|
self._queues["action"].append(executable_actions[:, i].squeeze(0))
|
|
|
|
return self._queues["action"].popleft()
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
return self.predict_action(batch["images"], batch["qpos"])
|
|
|
|
@torch.no_grad()
|
|
def predict_action(self, images, proprioception):
|
|
bsz = proprioception.shape[0]
|
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
|
cond = self._build_cond(images, proprioception)
|
|
|
|
device = cond.device
|
|
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
|
|
self.infer_scheduler.set_timesteps(self.inference_steps)
|
|
|
|
for t in self.infer_scheduler.timesteps:
|
|
noise_pred = self.noise_pred_net(
|
|
sample=current_actions,
|
|
timestep=t,
|
|
cond=cond,
|
|
)
|
|
current_actions = self.infer_scheduler.step(
|
|
noise_pred, t, current_actions
|
|
).prev_sample
|
|
|
|
return self.normalization.denormalize_action(current_actions)
|
|
|
|
def get_normalization_stats(self):
|
|
return self.normalization.get_stats()
|
|
|