2 Commits
ddt ... main

Author SHA1 Message Date
JiajunLI
fdf4dd8bed feat(policy): 引入gr00t(DiT) 2026-02-02 17:16:28 +08:00
JiajunLI
fd1bd20c4f chore(constants): 修改参与训练和推理的相机
- 现在使用顶部相机、右手腕相机。
2026-01-28 19:32:56 +08:00
12 changed files with 967 additions and 8 deletions

View File

@@ -8,7 +8,7 @@ temporal_agg: false
# policy_class: "ACT" # policy_class: "ACT"
# backbone: 'resnet18' # backbone: 'resnet18'
policy_class: "ACTTV" policy_class: "GR00T"
backbone: 'dino_v2' backbone: 'dino_v2'
seed: 0 seed: 0
@@ -51,6 +51,21 @@ nheads: 8
qpos_noise_std: 0 qpos_noise_std: 0
DT: 0.02 DT: 0.02
gr00t:
action_dim: 16
state_dim: 16
embed_dim: 1536
hidden_dim: 1024
num_queries: 8
nheads: 32
mlp_ratio: 4
dropout: 0.2
num_layers: 16
# DO NOT CHANGE IF UNNECESSARY # DO NOT CHANGE IF UNNECESSARY
lr: 0.00001 lr: 0.00001
kl_weight: 100 kl_weight: 100

View File

@@ -71,11 +71,10 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
qpos = pre_process(qpos_numpy) qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
curr_image = get_image(env._get_image_obs(), config['camera_names']) curr_image = get_image(env._get_image_obs(), config['camera_names'])
if config['policy_class'] == "ACT" or "ACTTV": if config['policy_class'] in ["ACT", "ACTTV", "GR00T"]:
if t % query_frequency == 0: if t % query_frequency == 0:
all_actions = policy(qpos, curr_image) all_actions = policy(qpos, curr_image)
raw_action = all_actions[:, t % query_frequency] raw_action = all_actions[:, t % query_frequency]
# raw_action = all_actions[:, t % 1]
raw_action = raw_action.squeeze(0).cpu().numpy() raw_action = raw_action.squeeze(0).cpu().numpy()
elif config['policy_class'] == "CNNMLP": elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image) raw_action = policy(qpos, curr_image)

125
roboimi/gr00t/main.py Normal file
View File

@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
GR00T (diffusion-based DiT policy) model builder.
This module provides functions to build GR00T models and optimizers
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_gr00t_model
def get_args_parser():
"""
Create argument parser for GR00T model configuration.
All parameters can be overridden via args_override dictionary in
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
# Training parameters
parser.add_argument('--lr', default=1e-5, type=float,
help='Learning rate for main parameters')
parser.add_argument('--lr_backbone', default=1e-5, type=float,
help='Learning rate for backbone parameters')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='Weight decay for optimizer')
# GR00T model architecture parameters
parser.add_argument('--embed_dim', default=1536, type=int,
help='Embedding dimension for transformer')
parser.add_argument('--hidden_dim', default=1024, type=int,
help='Hidden dimension for MLP layers')
parser.add_argument('--state_dim', default=16, type=int,
help='State (qpos) dimension')
parser.add_argument('--action_dim', default=16, type=int,
help='Action dimension')
parser.add_argument('--num_queries', default=16, type=int,
help='Number of action queries (chunk size)')
# DiT (Diffusion Transformer) parameters
parser.add_argument('--num_layers', default=16, type=int,
help='Number of transformer layers')
parser.add_argument('--nheads', default=32, type=int,
help='Number of attention heads')
parser.add_argument('--mlp_ratio', default=4, type=float,
help='MLP hidden dimension ratio')
parser.add_argument('--dropout', default=0.2, type=float,
help='Dropout rate')
# Backbone parameters
parser.add_argument('--backbone', default='dino_v2', type=str,
help='Backbone architecture (dino_v2, resnet18, resnet34)')
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help='Type of positional encoding')
# Camera configuration
parser.add_argument('--camera_names', default=[], nargs='+',
help='List of camera names for observations')
# Other parameters (not directly used but kept for compatibility)
parser.add_argument('--batch_size', default=15, type=int)
parser.add_argument('--epochs', default=20000, type=int)
parser.add_argument('--masks', action='store_true',
help='Use intermediate layer features')
parser.add_argument('--dilation', action='store_false',
help='Use dilated convolution in backbone')
return parser
def build_gr00t_model_and_optimizer(args_override):
"""
Build GR00T model and optimizer from config dictionary.
This function is designed to work with config.yaml loading:
1. Parse default arguments
2. Override with values from args_override (typically from config['gr00t'])
3. Build model and optimizer
Args:
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
num_queries, nheads, mlp_ratio, dropout, num_layers,
lr, lr_backbone, camera_names, backbone, etc.
Returns:
model: GR00T model on CUDA
optimizer: AdamW optimizer with separate learning rates for backbone and other params
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script',
parents=[get_args_parser()])
args = parser.parse_args()
# Override with config values
for k, v in args_override.items():
setattr(args, k, v)
# Build model
model = build_gr00t_model(args)
model.cuda()
# Create parameter groups with different learning rates
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts,
lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer

View File

@@ -0,0 +1,3 @@
from .gr00t import build_gr00t_model
__all__ = ['build_gr00t_model']

View File

@@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
# class DINOv2BackBone(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# self.body.eval()
# self.num_channels = 384
# @torch.no_grad()
# def forward(self, tensor):
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
# od = OrderedDict()
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
# return od
class DINOv2BackBone(nn.Module):
def __init__(self, return_interm_layers: bool = False) -> None:
super().__init__()
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.body.eval()
self.num_channels = 384
self.return_interm_layers = return_interm_layers
@torch.no_grad()
def forward(self, tensor):
features = self.body.forward_features(tensor)
if self.return_interm_layers:
layer1 = features["x_norm_patchtokens"]
layer2 = features["x_norm_patchtokens"]
layer3 = features["x_norm_patchtokens"]
layer4 = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
else:
xs = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
if args.backbone == 'dino_v2':
backbone = DINOv2BackBone()
else:
assert args.backbone in ['resnet18', 'resnet34']
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

142
roboimi/gr00t/models/dit.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
import torch
from torch import nn
import torch.nn.functional as F
class TimestepEncoder(nn.Module):
def __init__(self, args):
super().__init__()
embedding_dim = args.embed_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
super().__init__()
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
super().__init__()
dim = args.embed_dim
num_heads = args.nheads
mlp_ratio = args.mlp_ratio
dropout = args.dropout
self.norm1 = AdaLayerNorm(dim)
if not use_self_attn:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
kdim=crosss_attention_dim,
vdim=crosss_attention_dim,
batch_first=True,
)
else:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mlp_ratio, dim),
nn.Dropout(dropout)
)
def forward(self, hidden_states, temb, context=None):
norm_hidden_states = self.norm1(hidden_states, temb)
attn_output = self.attn(
norm_hidden_states,
context if context is not None else norm_hidden_states,
context if context is not None else norm_hidden_states,
)[0]
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
def __init__(self, args, cross_attention_dim):
super().__init__()
inner_dim = args.embed_dim
num_layers = args.num_layers
output_dim = args.hidden_dim
self.timestep_encoder = TimestepEncoder(args)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1
if use_self_attn:
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
else:
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
all_blocks.append(block)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
def forward(self, hidden_states, timestep, encoder_hidden_states):
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(hidden_states, temb)
else:
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
def build_dit(args, cross_attention_dim):
return DiT(args, cross_attention_dim)

View File

@@ -0,0 +1,124 @@
from .modules import (
build_action_decoder,
build_action_encoder,
build_state_encoder,
build_time_sampler,
build_noise_scheduler,
)
from .backbone import build_backbone
from .dit import build_dit
import torch
import torch.nn as nn
import torch.nn.functional as F
class gr00t(nn.Module):
def __init__(
self,
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
num_queries,
camera_names,
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.dit = dit
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.action_decoder = action_decoder
self.time_sampler = time_sampler
self.noise_scheduler = noise_scheduler
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
else:
raise NotImplementedError
def forward(self, qpos, image, actions=None, is_pad=None):
is_training = actions is not None # train or val
bs, _ = qpos.shape
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
B, C, H, W = features.shape
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
all_cam_features.append(features_seq)
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
if is_training:
# training logic
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
actions, timesteps
)
t_discretized = (timesteps[:, 0, 0] * 1000).long()
action_features = self.action_encoder(noisy_actions, t_discretized)
sa_embs = torch.cat((state_features, action_features), dim=1)
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
return pred_actions, action_loss
else:
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
k = 5
dt = 1.0 / k
for t in range(k):
t_cont = t / float(k)
t_discretized = int(t_cont * 1000)
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
action_features = self.action_encoder(actions, timesteps)
sa_embs = torch.cat((state_features, action_features), dim=1)
# Create tensor of shape [B] for DiT (consistent with training path)
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.num_queries :]
actions = actions + pred_velocity * dt
return actions, _
def build_gr00t_model(args):
state_dim = args.state_dim
action_dim = args.action_dim
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
cross_attention_dim = backbones[0].num_channels
dit = build_dit(args, cross_attention_dim)
state_encoder = build_state_encoder(args)
action_encoder = build_action_encoder(args)
action_decoder = build_action_decoder(args)
time_sampler = build_time_sampler(args)
noise_scheduler = build_noise_scheduler(args)
model = gr00t(
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
args.num_queries,
args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model

View File

@@ -0,0 +1,179 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# ActionEncoder
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.embed_dim
def forward(self, timesteps):
timesteps = timesteps.float()
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embed_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class ActionEncoder(nn.Module):
def __init__(self, args):
super().__init__()
action_dim = args.action_dim
embed_dim = args.embed_dim
self.W1 = nn.Linear(action_dim, embed_dim)
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
self.W3 = nn.Linear(embed_dim, embed_dim)
self.pos_encoder = SinusoidalPositionalEncoding(args)
def forward(self, actions, timesteps):
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
# Reshape to (B,) then expand to (B, T)
# if timesteps.dim() == 3:
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
# timesteps = timesteps[:, 0, 0]
# elif timesteps.dim() == 2:
# # Shape (B, 1) or (B, T) -> take first element if needed
# if timesteps.shape[1] == 1:
# timesteps = timesteps[:, 0]
# # else: already (B, T), use as is
# elif timesteps.dim() != 1:
# raise ValueError(
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
# )
# Now timesteps should be (B,), expand to (B, T)
if timesteps.dim() == 1 and timesteps.shape[0] == B:
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.W2(x))
# 5) Finally W3 => (B, T, w)
x = self.W3(x)
return x
def build_action_encoder(args):
return ActionEncoder(args)
# StateEncoder
class StateEncoder(nn.Module):
def __init__(self, args):
super().__init__()
input_dim = args.state_dim
hidden_dim = args.hidden_dim
output_dim = args.embed_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, states):
state_emb = self.mlp(states) # [B, emb_dim]
state_emb = state_emb.unsqueeze(1)
return state_emb # [B, 1, emb_dim]
def build_state_encoder(args):
return StateEncoder(args)
# ActionDecoder
class ActionDecoder(nn.Module):
def __init__(self,args):
super().__init__()
input_dim = args.hidden_dim
hidden_dim = args.hidden_dim
output_dim = args.action_dim
self.num_queries = args.num_queries
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, model_output):
pred_actions = self.mlp(model_output)
return pred_actions[:, -self.num_queries:]
def build_action_decoder(args):
return ActionDecoder(args)
# TimeSampler
class TimeSampler(nn.Module):
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
super().__init__()
self.noise_s = noise_s
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
def forward(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = (1 - sample) * self.noise_s
return sample[:, None, None]
def build_time_sampler(args):
return TimeSampler()
# NoiseScheduler
import torch
import torch.nn as nn
class FlowMatchingScheduler(nn.Module):
def __init__(self):
super().__init__()
# --- 训练逻辑:加噪并计算目标 ---
def add_noise(self, actions, timesteps):
noise = torch.randn_like(actions)
noisy_samples = actions * timesteps + noise * (1 - timesteps)
target_velocity = actions - noise
return noisy_samples, target_velocity
# --- 推理逻辑:欧拉步 (Euler Step) ---
def step(self, model_output, sample, dt):
prev_sample = sample + model_output * dt
return prev_sample
def build_noise_scheduler(args):
return FlowMatchingScheduler()

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

90
roboimi/gr00t/policy.py Normal file
View File

@@ -0,0 +1,90 @@
"""
GR00T Policy wrapper for imitation learning.
This module provides the gr00tPolicy class that wraps the GR00T model
for training and evaluation in the imitation learning framework.
"""
import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import v2
import torch
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
class gr00tPolicy(nn.Module):
"""
GR00T Policy for action prediction using diffusion-based DiT architecture.
This policy wraps the GR00T model and handles:
- Image resizing to match DINOv2 patch size requirements
- Image normalization (ImageNet stats)
- Training with action chunks and loss computation
- Inference with diffusion sampling
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_gr00t_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
# DINOv2 requires image dimensions to be multiples of patch size (14)
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
self.patch_h = 16 # Number of patches vertically
self.patch_w = 22 # Number of patches horizontally
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
# Training transform with data augmentation
self.train_transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Inference transform (no augmentation)
self.inference_transform = v2.Compose([
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
Forward pass for training or inference.
Args:
qpos: Joint positions [B, state_dim]
image: Camera images [B, num_cameras, C, H, W]
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
is_pad: Padding mask [B, chunk_size] (training only)
Returns:
Training: dict with 'mse' loss
Inference: predicted actions [B, num_queries, action_dim]
"""
# Apply transforms (resize + normalization)
if actions is not None: # training time
image = self.train_transform(image)
else: # inference time
image = self.inference_transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
_, action_loss = self.model(qpos, image, actions, is_pad)
# Mask out padded positions
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {
'loss': mse_loss
}
return loss_dict
else: # inference time
a_hat, _ = self.model(qpos, image)
return a_hat
def configure_optimizers(self):
"""Return the optimizer for training."""
return self.optimizer

View File

@@ -20,7 +20,7 @@ SIM_TASK_CONFIGS = {
'dataset_dir': DATASET_DIR + '/sim_transfer', 'dataset_dir': DATASET_DIR + '/sim_transfer',
'num_episodes': 7, 'num_episodes': 7,
'episode_len': 700, 'episode_len': 700,
'camera_names': ['angle','r_vis'], 'camera_names': ['top','r_vis'],
'xml_dir': HOME_PATH + '/assets' 'xml_dir': HOME_PATH + '/assets'
}, },

View File

@@ -1,7 +1,8 @@
import os import os
import torch import torch
from roboimi.utils.utils import load_data, set_seed from roboimi.utils.utils import load_data, set_seed
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy
from roboimi.gr00t.policy import gr00tPolicy
class ModelInterface: class ModelInterface:
def __init__(self, config): def __init__(self, config):
@@ -59,12 +60,32 @@ class ModelInterface:
} }
elif self.config['policy_class'] == 'CNNMLP': elif self.config['policy_class'] == 'CNNMLP':
self.config['policy_config'] = { self.config['policy_config'] = {
'lr': self.config['lr'], 'lr': self.config['lr'],
'lr_backbone': self.config['lr_backbone'], 'lr_backbone': self.config['lr_backbone'],
'backbone': self.config['backbone'], 'backbone': self.config['backbone'],
'num_queries': 1, 'num_queries': 1,
'camera_names': self.config['camera_names'], 'camera_names': self.config['camera_names'],
} }
elif self.config['policy_class'] == 'GR00T':
# GR00T uses its own config section from config.yaml
gr00t_config = self.config.get('gr00t', {})
self.config['policy_config'] = {
'lr': gr00t_config.get('lr', self.config['lr']),
'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']),
'weight_decay': gr00t_config.get('weight_decay', 1e-4),
'embed_dim': gr00t_config.get('embed_dim', 1536),
'hidden_dim': gr00t_config.get('hidden_dim', 1024),
'state_dim': gr00t_config.get('state_dim', 16),
'action_dim': gr00t_config.get('action_dim', 16),
'num_queries': gr00t_config.get('num_queries', 16),
'num_layers': gr00t_config.get('num_layers', 16),
'nheads': gr00t_config.get('nheads', 32),
'mlp_ratio': gr00t_config.get('mlp_ratio', 4),
'dropout': gr00t_config.get('dropout', 0.2),
'backbone': gr00t_config.get('backbone', 'dino_v2'),
'position_embedding': gr00t_config.get('position_embedding', 'sine'),
'camera_names': self.config['camera_names'],
}
else: else:
raise NotImplementedError raise NotImplementedError
@@ -75,6 +96,8 @@ class ModelInterface:
return ACTTVPolicy(self.config['policy_config']) return ACTTVPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'CNNMLP': elif self.config['policy_class'] == 'CNNMLP':
return CNNMLPPolicy(self.config['policy_config']) return CNNMLPPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'GR00T':
return gr00tPolicy(self.config['policy_config'])
else: else:
raise NotImplementedError raise NotImplementedError