from .modules import ( build_action_decoder, build_action_encoder, build_state_encoder, build_time_sampler, build_noise_scheduler, ) from .backbone import build_backbone from .dit import build_dit import torch import torch.nn as nn import torch.nn.functional as F class gr00t(nn.Module): def __init__( self, backbones, dit, state_encoder, action_encoder, action_decoder, time_sampler, noise_scheduler, num_queries, camera_names, ): super().__init__() self.num_queries = num_queries self.camera_names = camera_names self.dit = dit self.state_encoder = state_encoder self.action_encoder = action_encoder self.action_decoder = action_decoder self.time_sampler = time_sampler self.noise_scheduler = noise_scheduler if backbones is not None: self.backbones = nn.ModuleList(backbones) else: raise NotImplementedError def forward(self, qpos, image, actions=None, is_pad=None): is_training = actions is not None # train or val bs, _ = qpos.shape all_cam_features = [] for cam_id, cam_name in enumerate(self.camera_names): # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED features, pos = self.backbones[cam_id](image[:, cam_id]) features = features[0] # take the last layer feature B, C, H, W = features.shape features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C) all_cam_features.append(features_seq) encoder_hidden_states = torch.cat(all_cam_features, dim=1) state_features = self.state_encoder(qpos) # [B, 1, emb_dim] if is_training: # training logic timesteps = self.time_sampler(bs, actions.device, actions.dtype) noisy_actions, target_velocity = self.noise_scheduler.add_noise( actions, timesteps ) t_discretized = (timesteps[:, 0, 0] * 1000).long() action_features = self.action_encoder(noisy_actions, t_discretized) sa_embs = torch.cat((state_features, action_features), dim=1) model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states) pred = self.action_decoder(model_output) pred_actions = pred[:, -actions.shape[1] :] action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none') return pred_actions, action_loss else: actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype) k = 5 dt = 1.0 / k for t in range(k): t_cont = t / float(k) t_discretized = int(t_cont * 1000) timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype) action_features = self.action_encoder(actions, timesteps) sa_embs = torch.cat((state_features, action_features), dim=1) # Create tensor of shape [B] for DiT (consistent with training path) model_output = self.dit(sa_embs, timesteps, encoder_hidden_states) pred = self.action_decoder(model_output) pred_velocity = pred[:, -self.num_queries :] actions = actions + pred_velocity * dt return actions, _ def build_gr00t_model(args): state_dim = args.state_dim action_dim = args.action_dim backbones = [] for _ in args.camera_names: backbone = build_backbone(args) backbones.append(backbone) cross_attention_dim = backbones[0].num_channels dit = build_dit(args, cross_attention_dim) state_encoder = build_state_encoder(args) action_encoder = build_action_encoder(args) action_decoder = build_action_decoder(args) time_sampler = build_time_sampler(args) noise_scheduler = build_noise_scheduler(args) model = gr00t( backbones, dit, state_encoder, action_encoder, action_decoder, time_sampler, noise_scheduler, args.num_queries, args.camera_names, ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("number of parameters: %.2fM" % (n_parameters/1e6,)) return model