""" GR00T Policy wrapper for imitation learning. This module provides the gr00tPolicy class that wraps the GR00T model for training and evaluation in the imitation learning framework. """ import torch.nn as nn from torch.nn import functional as F from torchvision.transforms import v2 import torch from roboimi.gr00t.main import build_gr00t_model_and_optimizer class gr00tPolicy(nn.Module): """ GR00T Policy for action prediction using diffusion-based DiT architecture. This policy wraps the GR00T model and handles: - Image resizing to match DINOv2 patch size requirements - Image normalization (ImageNet stats) - Training with action chunks and loss computation - Inference with diffusion sampling """ def __init__(self, args_override): super().__init__() model, optimizer = build_gr00t_model_and_optimizer(args_override) self.model = model self.optimizer = optimizer # DINOv2 requires image dimensions to be multiples of patch size (14) # Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336) self.patch_h = 16 # Number of patches vertically self.patch_w = 22 # Number of patches horizontally target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308) # Training transform with data augmentation self.train_transform = v2.Compose([ v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), v2.RandomPerspective(distortion_scale=0.5), v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)), v2.Resize(target_size), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) # Inference transform (no augmentation) self.inference_transform = v2.Compose([ v2.Resize(target_size), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def __call__(self, qpos, image, actions=None, is_pad=None): """ Forward pass for training or inference. Args: qpos: Joint positions [B, state_dim] image: Camera images [B, num_cameras, C, H, W] actions: Ground truth actions [B, chunk_size, action_dim] (training only) is_pad: Padding mask [B, chunk_size] (training only) Returns: Training: dict with 'mse' loss Inference: predicted actions [B, num_queries, action_dim] """ # Apply transforms (resize + normalization) if actions is not None: # training time image = self.train_transform(image) else: # inference time image = self.inference_transform(image) if actions is not None: # training time actions = actions[:, :self.model.num_queries] is_pad = is_pad[:, :self.model.num_queries] _, action_loss = self.model(qpos, image, actions, is_pad) # Mask out padded positions mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean() loss_dict = { 'loss': mse_loss } return loss_dict else: # inference time a_hat, _ = self.model(qpos, image) return a_hat def configure_optimizers(self): """Return the optimizer for training.""" return self.optimizer