chore: 导入gr00t

This commit is contained in:
gouhanke
2026-03-06 11:17:28 +08:00
parent 642d41dd8f
commit ca1716c67f
9 changed files with 922 additions and 166 deletions

90
gr00t/policy.py Normal file
View File

@@ -0,0 +1,90 @@
"""
GR00T Policy wrapper for imitation learning.
This module provides the gr00tPolicy class that wraps the GR00T model
for training and evaluation in the imitation learning framework.
"""
import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import v2
import torch
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
class gr00tPolicy(nn.Module):
"""
GR00T Policy for action prediction using diffusion-based DiT architecture.
This policy wraps the GR00T model and handles:
- Image resizing to match DINOv2 patch size requirements
- Image normalization (ImageNet stats)
- Training with action chunks and loss computation
- Inference with diffusion sampling
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_gr00t_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
# DINOv2 requires image dimensions to be multiples of patch size (14)
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
self.patch_h = 16 # Number of patches vertically
self.patch_w = 22 # Number of patches horizontally
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
# Training transform with data augmentation
self.train_transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Inference transform (no augmentation)
self.inference_transform = v2.Compose([
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
Forward pass for training or inference.
Args:
qpos: Joint positions [B, state_dim]
image: Camera images [B, num_cameras, C, H, W]
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
is_pad: Padding mask [B, chunk_size] (training only)
Returns:
Training: dict with 'mse' loss
Inference: predicted actions [B, num_queries, action_dim]
"""
# Apply transforms (resize + normalization)
if actions is not None: # training time
image = self.train_transform(image)
else: # inference time
image = self.inference_transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
_, action_loss = self.model(qpos, image, actions, is_pad)
# Mask out padded positions
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {
'loss': mse_loss
}
return loss_dict
else: # inference time
a_hat, _ = self.model(qpos, image)
return a_hat
def configure_optimizers(self):
"""Return the optimizer for training."""
return self.optimizer