feat(policy): 引入gr00t(DiT)
This commit is contained in:
@@ -8,7 +8,7 @@ temporal_agg: false
|
||||
|
||||
# policy_class: "ACT"
|
||||
# backbone: 'resnet18'
|
||||
policy_class: "ACTTV"
|
||||
policy_class: "GR00T"
|
||||
backbone: 'dino_v2'
|
||||
|
||||
seed: 0
|
||||
@@ -51,6 +51,21 @@ nheads: 8
|
||||
qpos_noise_std: 0
|
||||
DT: 0.02
|
||||
|
||||
gr00t:
|
||||
action_dim: 16
|
||||
state_dim: 16
|
||||
embed_dim: 1536
|
||||
hidden_dim: 1024
|
||||
num_queries: 8
|
||||
|
||||
nheads: 32
|
||||
mlp_ratio: 4
|
||||
dropout: 0.2
|
||||
|
||||
num_layers: 16
|
||||
|
||||
|
||||
|
||||
# DO NOT CHANGE IF UNNECESSARY
|
||||
lr: 0.00001
|
||||
kl_weight: 100
|
||||
|
||||
@@ -71,11 +71,10 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
||||
qpos = pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
curr_image = get_image(env._get_image_obs(), config['camera_names'])
|
||||
if config['policy_class'] == "ACT" or "ACTTV":
|
||||
if config['policy_class'] in ["ACT", "ACTTV", "GR00T"]:
|
||||
if t % query_frequency == 0:
|
||||
all_actions = policy(qpos, curr_image)
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
# raw_action = all_actions[:, t % 1]
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
elif config['policy_class'] == "CNNMLP":
|
||||
raw_action = policy(qpos, curr_image)
|
||||
|
||||
125
roboimi/gr00t/main.py
Normal file
125
roboimi/gr00t/main.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
GR00T (diffusion-based DiT policy) model builder.
|
||||
|
||||
This module provides functions to build GR00T models and optimizers
|
||||
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from .models import build_gr00t_model
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
"""
|
||||
Create argument parser for GR00T model configuration.
|
||||
|
||||
All parameters can be overridden via args_override dictionary in
|
||||
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
|
||||
"""
|
||||
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--lr', default=1e-5, type=float,
|
||||
help='Learning rate for main parameters')
|
||||
parser.add_argument('--lr_backbone', default=1e-5, type=float,
|
||||
help='Learning rate for backbone parameters')
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float,
|
||||
help='Weight decay for optimizer')
|
||||
|
||||
# GR00T model architecture parameters
|
||||
parser.add_argument('--embed_dim', default=1536, type=int,
|
||||
help='Embedding dimension for transformer')
|
||||
parser.add_argument('--hidden_dim', default=1024, type=int,
|
||||
help='Hidden dimension for MLP layers')
|
||||
parser.add_argument('--state_dim', default=16, type=int,
|
||||
help='State (qpos) dimension')
|
||||
parser.add_argument('--action_dim', default=16, type=int,
|
||||
help='Action dimension')
|
||||
parser.add_argument('--num_queries', default=16, type=int,
|
||||
help='Number of action queries (chunk size)')
|
||||
|
||||
# DiT (Diffusion Transformer) parameters
|
||||
parser.add_argument('--num_layers', default=16, type=int,
|
||||
help='Number of transformer layers')
|
||||
parser.add_argument('--nheads', default=32, type=int,
|
||||
help='Number of attention heads')
|
||||
parser.add_argument('--mlp_ratio', default=4, type=float,
|
||||
help='MLP hidden dimension ratio')
|
||||
parser.add_argument('--dropout', default=0.2, type=float,
|
||||
help='Dropout rate')
|
||||
|
||||
# Backbone parameters
|
||||
parser.add_argument('--backbone', default='dino_v2', type=str,
|
||||
help='Backbone architecture (dino_v2, resnet18, resnet34)')
|
||||
parser.add_argument('--position_embedding', default='sine', type=str,
|
||||
choices=('sine', 'learned'),
|
||||
help='Type of positional encoding')
|
||||
|
||||
# Camera configuration
|
||||
parser.add_argument('--camera_names', default=[], nargs='+',
|
||||
help='List of camera names for observations')
|
||||
|
||||
# Other parameters (not directly used but kept for compatibility)
|
||||
parser.add_argument('--batch_size', default=15, type=int)
|
||||
parser.add_argument('--epochs', default=20000, type=int)
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help='Use intermediate layer features')
|
||||
parser.add_argument('--dilation', action='store_false',
|
||||
help='Use dilated convolution in backbone')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_gr00t_model_and_optimizer(args_override):
|
||||
"""
|
||||
Build GR00T model and optimizer from config dictionary.
|
||||
|
||||
This function is designed to work with config.yaml loading:
|
||||
1. Parse default arguments
|
||||
2. Override with values from args_override (typically from config['gr00t'])
|
||||
3. Build model and optimizer
|
||||
|
||||
Args:
|
||||
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
|
||||
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
|
||||
num_queries, nheads, mlp_ratio, dropout, num_layers,
|
||||
lr, lr_backbone, camera_names, backbone, etc.
|
||||
|
||||
Returns:
|
||||
model: GR00T model on CUDA
|
||||
optimizer: AdamW optimizer with separate learning rates for backbone and other params
|
||||
"""
|
||||
parser = argparse.ArgumentParser('GR00T training and evaluation script',
|
||||
parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Override with config values
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
# Build model
|
||||
model = build_gr00t_model(args)
|
||||
model.cuda()
|
||||
|
||||
# Create parameter groups with different learning rates
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad]
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad],
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = torch.optim.AdamW(param_dicts,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
3
roboimi/gr00t/models/__init__.py
Normal file
3
roboimi/gr00t/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gr00t import build_gr00t_model
|
||||
|
||||
__all__ = ['build_gr00t_model']
|
||||
168
roboimi/gr00t/models/backbone.py
Normal file
168
roboimi/gr00t/models/backbone.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from typing import Dict, List
|
||||
|
||||
from util.misc import NestedTensor, is_main_process
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {'layer4': "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
def __init__(self, name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
# class DINOv2BackBone(nn.Module):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
# self.body.eval()
|
||||
# self.num_channels = 384
|
||||
|
||||
# @torch.no_grad()
|
||||
# def forward(self, tensor):
|
||||
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
|
||||
# od = OrderedDict()
|
||||
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
# return od
|
||||
|
||||
class DINOv2BackBone(nn.Module):
|
||||
def __init__(self, return_interm_layers: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
self.body.eval()
|
||||
self.num_channels = 384
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, tensor):
|
||||
features = self.body.forward_features(tensor)
|
||||
|
||||
if self.return_interm_layers:
|
||||
|
||||
layer1 = features["x_norm_patchtokens"]
|
||||
layer2 = features["x_norm_patchtokens"]
|
||||
layer3 = features["x_norm_patchtokens"]
|
||||
layer4 = features["x_norm_patchtokens"]
|
||||
|
||||
od = OrderedDict()
|
||||
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
return od
|
||||
else:
|
||||
xs = features["x_norm_patchtokens"]
|
||||
od = OrderedDict()
|
||||
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||
return od
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
if args.backbone == 'dino_v2':
|
||||
backbone = DINOv2BackBone()
|
||||
else:
|
||||
assert args.backbone in ['resnet18', 'resnet34']
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
142
roboimi/gr00t/models/dit.py
Normal file
142
roboimi/gr00t/models/dit.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
embedding_dim = args.embed_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
||||
super().__init__()
|
||||
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
||||
super().__init__()
|
||||
dim = args.embed_dim
|
||||
num_heads = args.nheads
|
||||
mlp_ratio = args.mlp_ratio
|
||||
dropout = args.dropout
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
|
||||
if not use_self_attn:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
kdim=crosss_attention_dim,
|
||||
vdim=crosss_attention_dim,
|
||||
batch_first=True,
|
||||
)
|
||||
else:
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * mlp_ratio),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * mlp_ratio, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb, context=None):
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
|
||||
attn_output = self.attn(
|
||||
norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
context if context is not None else norm_hidden_states,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
ff_output = self.mlp(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(self, args, cross_attention_dim):
|
||||
super().__init__()
|
||||
inner_dim = args.embed_dim
|
||||
num_layers = args.num_layers
|
||||
output_dim = args.hidden_dim
|
||||
|
||||
self.timestep_encoder = TimestepEncoder(args)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(num_layers):
|
||||
use_self_attn = idx % 2 == 1
|
||||
if use_self_attn:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
||||
else:
|
||||
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
||||
all_blocks.append(block)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(hidden_states, temb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
def build_dit(args, cross_attention_dim):
|
||||
return DiT(args, cross_attention_dim)
|
||||
124
roboimi/gr00t/models/gr00t.py
Normal file
124
roboimi/gr00t/models/gr00t.py
Normal file
@@ -0,0 +1,124 @@
|
||||
|
||||
from .modules import (
|
||||
build_action_decoder,
|
||||
build_action_encoder,
|
||||
build_state_encoder,
|
||||
build_time_sampler,
|
||||
build_noise_scheduler,
|
||||
)
|
||||
from .backbone import build_backbone
|
||||
from .dit import build_dit
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class gr00t(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
num_queries,
|
||||
camera_names,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.dit = dit
|
||||
self.state_encoder = state_encoder
|
||||
self.action_encoder = action_encoder
|
||||
self.action_decoder = action_decoder
|
||||
self.time_sampler = time_sampler
|
||||
self.noise_scheduler = noise_scheduler
|
||||
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
B, C, H, W = features.shape
|
||||
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
||||
all_cam_features.append(features_seq)
|
||||
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
||||
|
||||
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
||||
|
||||
if is_training:
|
||||
# training logic
|
||||
|
||||
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
||||
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
||||
actions, timesteps
|
||||
)
|
||||
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
||||
action_features = self.action_encoder(noisy_actions, t_discretized)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
||||
return pred_actions, action_loss
|
||||
else:
|
||||
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
||||
k = 5
|
||||
dt = 1.0 / k
|
||||
for t in range(k):
|
||||
t_cont = t / float(k)
|
||||
t_discretized = int(t_cont * 1000)
|
||||
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
# Create tensor of shape [B] for DiT (consistent with training path)
|
||||
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
||||
pred = self.action_decoder(model_output)
|
||||
pred_velocity = pred[:, -self.num_queries :]
|
||||
actions = actions + pred_velocity * dt
|
||||
return actions, _
|
||||
def build_gr00t_model(args):
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
cross_attention_dim = backbones[0].num_channels
|
||||
|
||||
dit = build_dit(args, cross_attention_dim)
|
||||
|
||||
state_encoder = build_state_encoder(args)
|
||||
action_encoder = build_action_encoder(args)
|
||||
action_decoder = build_action_decoder(args)
|
||||
time_sampler = build_time_sampler(args)
|
||||
noise_scheduler = build_noise_scheduler(args)
|
||||
model = gr00t(
|
||||
backbones,
|
||||
dit,
|
||||
state_encoder,
|
||||
action_encoder,
|
||||
action_decoder,
|
||||
time_sampler,
|
||||
noise_scheduler,
|
||||
args.num_queries,
|
||||
args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||
return model
|
||||
|
||||
|
||||
179
roboimi/gr00t/models/modules.py
Normal file
179
roboimi/gr00t/models/modules.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ActionEncoder
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.embed_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
timesteps = timesteps.float()
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embed_dim // 2
|
||||
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
action_dim = args.action_dim
|
||||
embed_dim = args.embed_dim
|
||||
|
||||
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
||||
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
||||
|
||||
def forward(self, actions, timesteps):
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
||||
# Reshape to (B,) then expand to (B, T)
|
||||
# if timesteps.dim() == 3:
|
||||
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
||||
# timesteps = timesteps[:, 0, 0]
|
||||
# elif timesteps.dim() == 2:
|
||||
# # Shape (B, 1) or (B, T) -> take first element if needed
|
||||
# if timesteps.shape[1] == 1:
|
||||
# timesteps = timesteps[:, 0]
|
||||
# # else: already (B, T), use as is
|
||||
# elif timesteps.dim() != 1:
|
||||
# raise ValueError(
|
||||
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
||||
# )
|
||||
|
||||
# Now timesteps should be (B,), expand to (B, T)
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = F.silu(self.W2(x))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_action_encoder(args):
|
||||
return ActionEncoder(args)
|
||||
|
||||
|
||||
# StateEncoder
|
||||
class StateEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
input_dim = args.state_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.embed_dim
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, states):
|
||||
state_emb = self.mlp(states) # [B, emb_dim]
|
||||
state_emb = state_emb.unsqueeze(1)
|
||||
return state_emb # [B, 1, emb_dim]
|
||||
|
||||
|
||||
def build_state_encoder(args):
|
||||
return StateEncoder(args)
|
||||
|
||||
|
||||
# ActionDecoder
|
||||
class ActionDecoder(nn.Module):
|
||||
def __init__(self,args):
|
||||
super().__init__()
|
||||
input_dim = args.hidden_dim
|
||||
hidden_dim = args.hidden_dim
|
||||
output_dim = args.action_dim
|
||||
|
||||
self.num_queries = args.num_queries
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, model_output):
|
||||
pred_actions = self.mlp(model_output)
|
||||
return pred_actions[:, -self.num_queries:]
|
||||
|
||||
|
||||
def build_action_decoder(args):
|
||||
return ActionDecoder(args)
|
||||
|
||||
|
||||
# TimeSampler
|
||||
class TimeSampler(nn.Module):
|
||||
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
||||
super().__init__()
|
||||
self.noise_s = noise_s
|
||||
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
||||
|
||||
def forward(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
sample = (1 - sample) * self.noise_s
|
||||
return sample[:, None, None]
|
||||
|
||||
|
||||
def build_time_sampler(args):
|
||||
return TimeSampler()
|
||||
|
||||
|
||||
# NoiseScheduler
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class FlowMatchingScheduler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# --- 训练逻辑:加噪并计算目标 ---
|
||||
def add_noise(self, actions, timesteps):
|
||||
noise = torch.randn_like(actions)
|
||||
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
||||
target_velocity = actions - noise
|
||||
|
||||
return noisy_samples, target_velocity
|
||||
|
||||
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
||||
def step(self, model_output, sample, dt):
|
||||
prev_sample = sample + model_output * dt
|
||||
return prev_sample
|
||||
|
||||
def build_noise_scheduler(args):
|
||||
return FlowMatchingScheduler()
|
||||
91
roboimi/gr00t/models/position_encoding.py
Normal file
91
roboimi/gr00t/models/position_encoding.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from util.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
N_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif args.position_embedding in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
90
roboimi/gr00t/policy.py
Normal file
90
roboimi/gr00t/policy.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
GR00T Policy wrapper for imitation learning.
|
||||
|
||||
This module provides the gr00tPolicy class that wraps the GR00T model
|
||||
for training and evaluation in the imitation learning framework.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms import v2
|
||||
import torch
|
||||
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
||||
|
||||
|
||||
class gr00tPolicy(nn.Module):
|
||||
"""
|
||||
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
||||
|
||||
This policy wraps the GR00T model and handles:
|
||||
- Image resizing to match DINOv2 patch size requirements
|
||||
- Image normalization (ImageNet stats)
|
||||
- Training with action chunks and loss computation
|
||||
- Inference with diffusion sampling
|
||||
"""
|
||||
def __init__(self, args_override):
|
||||
super().__init__()
|
||||
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
||||
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
||||
self.patch_h = 16 # Number of patches vertically
|
||||
self.patch_w = 22 # Number of patches horizontally
|
||||
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
||||
|
||||
# Training transform with data augmentation
|
||||
self.train_transform = v2.Compose([
|
||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.5),
|
||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
# Inference transform (no augmentation)
|
||||
self.inference_transform = v2.Compose([
|
||||
v2.Resize(target_size),
|
||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
Forward pass for training or inference.
|
||||
|
||||
Args:
|
||||
qpos: Joint positions [B, state_dim]
|
||||
image: Camera images [B, num_cameras, C, H, W]
|
||||
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
||||
is_pad: Padding mask [B, chunk_size] (training only)
|
||||
|
||||
Returns:
|
||||
Training: dict with 'mse' loss
|
||||
Inference: predicted actions [B, num_queries, action_dim]
|
||||
"""
|
||||
# Apply transforms (resize + normalization)
|
||||
if actions is not None: # training time
|
||||
image = self.train_transform(image)
|
||||
else: # inference time
|
||||
image = self.inference_transform(image)
|
||||
|
||||
if actions is not None: # training time
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
_, action_loss = self.model(qpos, image, actions, is_pad)
|
||||
|
||||
# Mask out padded positions
|
||||
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {
|
||||
'loss': mse_loss
|
||||
}
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _ = self.model(qpos, image)
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Return the optimizer for training."""
|
||||
return self.optimizer
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import torch
|
||||
from roboimi.utils.utils import load_data, set_seed
|
||||
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy
|
||||
from roboimi.gr00t.policy import gr00tPolicy
|
||||
|
||||
class ModelInterface:
|
||||
def __init__(self, config):
|
||||
@@ -65,6 +66,26 @@ class ModelInterface:
|
||||
'num_queries': 1,
|
||||
'camera_names': self.config['camera_names'],
|
||||
}
|
||||
elif self.config['policy_class'] == 'GR00T':
|
||||
# GR00T uses its own config section from config.yaml
|
||||
gr00t_config = self.config.get('gr00t', {})
|
||||
self.config['policy_config'] = {
|
||||
'lr': gr00t_config.get('lr', self.config['lr']),
|
||||
'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']),
|
||||
'weight_decay': gr00t_config.get('weight_decay', 1e-4),
|
||||
'embed_dim': gr00t_config.get('embed_dim', 1536),
|
||||
'hidden_dim': gr00t_config.get('hidden_dim', 1024),
|
||||
'state_dim': gr00t_config.get('state_dim', 16),
|
||||
'action_dim': gr00t_config.get('action_dim', 16),
|
||||
'num_queries': gr00t_config.get('num_queries', 16),
|
||||
'num_layers': gr00t_config.get('num_layers', 16),
|
||||
'nheads': gr00t_config.get('nheads', 32),
|
||||
'mlp_ratio': gr00t_config.get('mlp_ratio', 4),
|
||||
'dropout': gr00t_config.get('dropout', 0.2),
|
||||
'backbone': gr00t_config.get('backbone', 'dino_v2'),
|
||||
'position_embedding': gr00t_config.get('position_embedding', 'sine'),
|
||||
'camera_names': self.config['camera_names'],
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -75,6 +96,8 @@ class ModelInterface:
|
||||
return ACTTVPolicy(self.config['policy_config'])
|
||||
elif self.config['policy_class'] == 'CNNMLP':
|
||||
return CNNMLPPolicy(self.config['policy_config'])
|
||||
elif self.config['policy_class'] == 'GR00T':
|
||||
return gr00tPolicy(self.config['policy_config'])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user