Files
roboimi/roboimi/gr00t/main.py
2026-02-02 17:16:28 +08:00

126 lines
5.0 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
GR00T (diffusion-based DiT policy) model builder.
This module provides functions to build GR00T models and optimizers
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_gr00t_model
def get_args_parser():
"""
Create argument parser for GR00T model configuration.
All parameters can be overridden via args_override dictionary in
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
# Training parameters
parser.add_argument('--lr', default=1e-5, type=float,
help='Learning rate for main parameters')
parser.add_argument('--lr_backbone', default=1e-5, type=float,
help='Learning rate for backbone parameters')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='Weight decay for optimizer')
# GR00T model architecture parameters
parser.add_argument('--embed_dim', default=1536, type=int,
help='Embedding dimension for transformer')
parser.add_argument('--hidden_dim', default=1024, type=int,
help='Hidden dimension for MLP layers')
parser.add_argument('--state_dim', default=16, type=int,
help='State (qpos) dimension')
parser.add_argument('--action_dim', default=16, type=int,
help='Action dimension')
parser.add_argument('--num_queries', default=16, type=int,
help='Number of action queries (chunk size)')
# DiT (Diffusion Transformer) parameters
parser.add_argument('--num_layers', default=16, type=int,
help='Number of transformer layers')
parser.add_argument('--nheads', default=32, type=int,
help='Number of attention heads')
parser.add_argument('--mlp_ratio', default=4, type=float,
help='MLP hidden dimension ratio')
parser.add_argument('--dropout', default=0.2, type=float,
help='Dropout rate')
# Backbone parameters
parser.add_argument('--backbone', default='dino_v2', type=str,
help='Backbone architecture (dino_v2, resnet18, resnet34)')
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help='Type of positional encoding')
# Camera configuration
parser.add_argument('--camera_names', default=[], nargs='+',
help='List of camera names for observations')
# Other parameters (not directly used but kept for compatibility)
parser.add_argument('--batch_size', default=15, type=int)
parser.add_argument('--epochs', default=20000, type=int)
parser.add_argument('--masks', action='store_true',
help='Use intermediate layer features')
parser.add_argument('--dilation', action='store_false',
help='Use dilated convolution in backbone')
return parser
def build_gr00t_model_and_optimizer(args_override):
"""
Build GR00T model and optimizer from config dictionary.
This function is designed to work with config.yaml loading:
1. Parse default arguments
2. Override with values from args_override (typically from config['gr00t'])
3. Build model and optimizer
Args:
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
num_queries, nheads, mlp_ratio, dropout, num_layers,
lr, lr_backbone, camera_names, backbone, etc.
Returns:
model: GR00T model on CUDA
optimizer: AdamW optimizer with separate learning rates for backbone and other params
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script',
parents=[get_args_parser()])
args = parser.parse_args()
# Override with config values
for k, v in args_override.items():
setattr(args, k, v)
# Build model
model = build_gr00t_model(args)
model.cuda()
# Create parameter groups with different learning rates
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts,
lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer