Files
roboimi/gr00t/models/gr00t.py
2026-03-06 11:31:37 +08:00

125 lines
4.4 KiB
Python

from .modules import (
build_action_decoder,
build_action_encoder,
build_state_encoder,
build_time_sampler,
build_noise_scheduler,
)
from .backbone import build_backbone
from .dit import build_dit
import torch
import torch.nn as nn
import torch.nn.functional as F
class gr00t(nn.Module):
def __init__(
self,
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
num_queries,
camera_names,
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.dit = dit
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.action_decoder = action_decoder
self.time_sampler = time_sampler
self.noise_scheduler = noise_scheduler
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
else:
raise NotImplementedError
def forward(self, qpos, image, actions=None, is_pad=None):
is_training = actions is not None # train or val
bs, _ = qpos.shape
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
B, C, H, W = features.shape
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
all_cam_features.append(features_seq)
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
if is_training:
# training logic
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
actions, timesteps
)
t_discretized = (timesteps[:, 0, 0] * 1000).long()
action_features = self.action_encoder(noisy_actions, t_discretized)
sa_embs = torch.cat((state_features, action_features), dim=1)
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
return pred_actions, action_loss
else:
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
k = 5
dt = 1.0 / k
for t in range(k):
t_cont = t / float(k)
t_discretized = int(t_cont * 1000)
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
action_features = self.action_encoder(actions, timesteps)
sa_embs = torch.cat((state_features, action_features), dim=1)
# Create tensor of shape [B] for DiT (consistent with training path)
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.num_queries :]
actions = actions + pred_velocity * dt
return actions, _
def build_gr00t_model(args):
state_dim = args.state_dim
action_dim = args.action_dim
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
cross_attention_dim = backbones[0].num_channels
dit = build_dit(args, cross_attention_dim)
state_encoder = build_state_encoder(args)
action_encoder = build_action_encoder(args)
action_decoder = build_action_decoder(args)
time_sampler = build_time_sampler(args)
noise_scheduler = build_noise_scheduler(args)
model = gr00t(
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
args.num_queries,
args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model