125 lines
4.4 KiB
Python
125 lines
4.4 KiB
Python
|
|
from .modules import (
|
|
build_action_decoder,
|
|
build_action_encoder,
|
|
build_state_encoder,
|
|
build_time_sampler,
|
|
build_noise_scheduler,
|
|
)
|
|
from .backbone import build_backbone
|
|
from .dit import build_dit
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class gr00t(nn.Module):
|
|
def __init__(
|
|
self,
|
|
backbones,
|
|
dit,
|
|
state_encoder,
|
|
action_encoder,
|
|
action_decoder,
|
|
time_sampler,
|
|
noise_scheduler,
|
|
num_queries,
|
|
camera_names,
|
|
):
|
|
super().__init__()
|
|
self.num_queries = num_queries
|
|
self.camera_names = camera_names
|
|
self.dit = dit
|
|
self.state_encoder = state_encoder
|
|
self.action_encoder = action_encoder
|
|
self.action_decoder = action_decoder
|
|
self.time_sampler = time_sampler
|
|
self.noise_scheduler = noise_scheduler
|
|
|
|
if backbones is not None:
|
|
self.backbones = nn.ModuleList(backbones)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def forward(self, qpos, image, actions=None, is_pad=None):
|
|
is_training = actions is not None # train or val
|
|
bs, _ = qpos.shape
|
|
|
|
all_cam_features = []
|
|
for cam_id, cam_name in enumerate(self.camera_names):
|
|
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
|
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
features = features[0] # take the last layer feature
|
|
B, C, H, W = features.shape
|
|
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
all_cam_features.append(features_seq)
|
|
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
|
|
|
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
|
|
|
if is_training:
|
|
# training logic
|
|
|
|
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
|
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
|
actions, timesteps
|
|
)
|
|
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
|
action_features = self.action_encoder(noisy_actions, t_discretized)
|
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
|
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
|
pred = self.action_decoder(model_output)
|
|
pred_actions = pred[:, -actions.shape[1] :]
|
|
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
|
return pred_actions, action_loss
|
|
else:
|
|
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
|
k = 5
|
|
dt = 1.0 / k
|
|
for t in range(k):
|
|
t_cont = t / float(k)
|
|
t_discretized = int(t_cont * 1000)
|
|
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
|
action_features = self.action_encoder(actions, timesteps)
|
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
|
# Create tensor of shape [B] for DiT (consistent with training path)
|
|
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
|
pred = self.action_decoder(model_output)
|
|
pred_velocity = pred[:, -self.num_queries :]
|
|
actions = actions + pred_velocity * dt
|
|
return actions, _
|
|
def build_gr00t_model(args):
|
|
state_dim = args.state_dim
|
|
action_dim = args.action_dim
|
|
|
|
backbones = []
|
|
for _ in args.camera_names:
|
|
backbone = build_backbone(args)
|
|
backbones.append(backbone)
|
|
|
|
cross_attention_dim = backbones[0].num_channels
|
|
|
|
dit = build_dit(args, cross_attention_dim)
|
|
|
|
state_encoder = build_state_encoder(args)
|
|
action_encoder = build_action_encoder(args)
|
|
action_decoder = build_action_decoder(args)
|
|
time_sampler = build_time_sampler(args)
|
|
noise_scheduler = build_noise_scheduler(args)
|
|
model = gr00t(
|
|
backbones,
|
|
dit,
|
|
state_encoder,
|
|
action_encoder,
|
|
action_decoder,
|
|
time_sampler,
|
|
noise_scheduler,
|
|
args.num_queries,
|
|
args.camera_names,
|
|
)
|
|
|
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
|
return model
|
|
|
|
|