91 lines
3.4 KiB
Python
91 lines
3.4 KiB
Python
"""
|
|
GR00T Policy wrapper for imitation learning.
|
|
|
|
This module provides the gr00tPolicy class that wraps the GR00T model
|
|
for training and evaluation in the imitation learning framework.
|
|
"""
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torchvision.transforms import v2
|
|
import torch
|
|
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
|
|
|
|
|
class gr00tPolicy(nn.Module):
|
|
"""
|
|
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
|
|
|
This policy wraps the GR00T model and handles:
|
|
- Image resizing to match DINOv2 patch size requirements
|
|
- Image normalization (ImageNet stats)
|
|
- Training with action chunks and loss computation
|
|
- Inference with diffusion sampling
|
|
"""
|
|
def __init__(self, args_override):
|
|
super().__init__()
|
|
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
|
|
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
|
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
|
self.patch_h = 16 # Number of patches vertically
|
|
self.patch_w = 22 # Number of patches horizontally
|
|
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
|
|
|
# Training transform with data augmentation
|
|
self.train_transform = v2.Compose([
|
|
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
|
v2.RandomPerspective(distortion_scale=0.5),
|
|
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
|
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
|
v2.Resize(target_size),
|
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
])
|
|
|
|
# Inference transform (no augmentation)
|
|
self.inference_transform = v2.Compose([
|
|
v2.Resize(target_size),
|
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
])
|
|
|
|
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
"""
|
|
Forward pass for training or inference.
|
|
|
|
Args:
|
|
qpos: Joint positions [B, state_dim]
|
|
image: Camera images [B, num_cameras, C, H, W]
|
|
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
|
is_pad: Padding mask [B, chunk_size] (training only)
|
|
|
|
Returns:
|
|
Training: dict with 'mse' loss
|
|
Inference: predicted actions [B, num_queries, action_dim]
|
|
"""
|
|
# Apply transforms (resize + normalization)
|
|
if actions is not None: # training time
|
|
image = self.train_transform(image)
|
|
else: # inference time
|
|
image = self.inference_transform(image)
|
|
|
|
if actions is not None: # training time
|
|
actions = actions[:, :self.model.num_queries]
|
|
is_pad = is_pad[:, :self.model.num_queries]
|
|
_, action_loss = self.model(qpos, image, actions, is_pad)
|
|
|
|
# Mask out padded positions
|
|
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
|
|
|
loss_dict = {
|
|
'loss': mse_loss
|
|
}
|
|
return loss_dict
|
|
else: # inference time
|
|
a_hat, _ = self.model(qpos, image)
|
|
return a_hat
|
|
|
|
def configure_optimizers(self):
|
|
"""Return the optimizer for training."""
|
|
return self.optimizer
|