# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ GR00T (diffusion-based DiT policy) model builder. This module provides functions to build GR00T models and optimizers from configuration dictionaries (typically from config.yaml's 'gr00t:' section). """ import argparse from pathlib import Path import numpy as np import torch from .models import build_gr00t_model def get_args_parser(): """ Create argument parser for GR00T model configuration. All parameters can be overridden via args_override dictionary in build_gr00t_model_and_optimizer(). This allows loading from config.yaml. """ parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False) # Training parameters parser.add_argument('--lr', default=1e-5, type=float, help='Learning rate for main parameters') parser.add_argument('--lr_backbone', default=1e-5, type=float, help='Learning rate for backbone parameters') parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for optimizer') # GR00T model architecture parameters parser.add_argument('--embed_dim', default=1536, type=int, help='Embedding dimension for transformer') parser.add_argument('--hidden_dim', default=1024, type=int, help='Hidden dimension for MLP layers') parser.add_argument('--state_dim', default=16, type=int, help='State (qpos) dimension') parser.add_argument('--action_dim', default=16, type=int, help='Action dimension') parser.add_argument('--num_queries', default=16, type=int, help='Number of action queries (chunk size)') # DiT (Diffusion Transformer) parameters parser.add_argument('--num_layers', default=16, type=int, help='Number of transformer layers') parser.add_argument('--nheads', default=32, type=int, help='Number of attention heads') parser.add_argument('--mlp_ratio', default=4, type=float, help='MLP hidden dimension ratio') parser.add_argument('--dropout', default=0.2, type=float, help='Dropout rate') # Backbone parameters parser.add_argument('--backbone', default='dino_v2', type=str, help='Backbone architecture (dino_v2, resnet18, resnet34)') parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), help='Type of positional encoding') # Camera configuration parser.add_argument('--camera_names', default=[], nargs='+', help='List of camera names for observations') # Other parameters (not directly used but kept for compatibility) parser.add_argument('--batch_size', default=15, type=int) parser.add_argument('--epochs', default=20000, type=int) parser.add_argument('--masks', action='store_true', help='Use intermediate layer features') parser.add_argument('--dilation', action='store_false', help='Use dilated convolution in backbone') return parser def build_gr00t_model_and_optimizer(args_override): """ Build GR00T model and optimizer from config dictionary. This function is designed to work with config.yaml loading: 1. Parse default arguments 2. Override with values from args_override (typically from config['gr00t']) 3. Build model and optimizer Args: args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section Expected keys: embed_dim, hidden_dim, state_dim, action_dim, num_queries, nheads, mlp_ratio, dropout, num_layers, lr, lr_backbone, camera_names, backbone, etc. Returns: model: GR00T model on CUDA optimizer: AdamW optimizer with separate learning rates for backbone and other params """ parser = argparse.ArgumentParser('GR00T training and evaluation script', parents=[get_args_parser()]) args = parser.parse_args() # Override with config values for k, v in args_override.items(): setattr(args, k, v) # Build model model = build_gr00t_model(args) model.cuda() # Create parameter groups with different learning rates param_dicts = [ { "params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad] }, { "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], "lr": args.lr_backbone, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) return model, optimizer