chore: 导入gr00t
This commit is contained in:
90
gr00t/policy.py
Normal file
90
gr00t/policy.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
GR00T Policy wrapper for imitation learning.
|
||||
|
||||
This module provides the gr00tPolicy class that wraps the GR00T model
|
||||
for training and evaluation in the imitation learning framework.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms import v2
|
||||
import torch
|
||||
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
||||
|
||||
|
||||
class gr00tPolicy(nn.Module):
|
||||
"""
|
||||
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
||||
|
||||
This policy wraps the GR00T model and handles:
|
||||
- Image resizing to match DINOv2 patch size requirements
|
||||
- Image normalization (ImageNet stats)
|
||||
- Training with action chunks and loss computation
|
||||
- Inference with diffusion sampling
|
||||
"""
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
||||
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
||||
self.patch_h = 16 # Number of patches vertically
|
||||
self.patch_w = 22 # Number of patches horizontally
|
||||
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
||||
|
||||
# Training transform with data augmentation
|
||||
self.train_transform = v2.Compose([
|
||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.5),
|
||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
# Inference transform (no augmentation)
|
||||
self.inference_transform = v2.Compose([
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
Forward pass for training or inference.
|
||||
|
||||
Args:
|
||||
qpos: Joint positions [B, state_dim]
|
||||
image: Camera images [B, num_cameras, C, H, W]
|
||||
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
||||
is_pad: Padding mask [B, chunk_size] (training only)
|
||||
|
||||
Returns:
|
||||
Training: dict with 'mse' loss
|
||||
Inference: predicted actions [B, num_queries, action_dim]
|
||||
"""
|
||||
# Apply transforms (resize + normalization)
|
||||
if actions is not None: # training time
|
||||
image = self.train_transform(image)
|
||||
else: # inference time
|
||||
image = self.inference_transform(image)
|
||||
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
_, action_loss = self.model(qpos, image, actions, is_pad)
|
||||
|
||||
# Mask out padded positions
|
||||
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {
|
||||
'loss': mse_loss
|
||||
}
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _ = self.model(qpos, image)
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Return the optimizer for training."""
|
||||
return self.optimizer
|
||||
Reference in New Issue
Block a user