chore: 导入gr00t

This commit is contained in:
gouhanke
2026-03-06 11:17:28 +08:00
parent 642d41dd8f
commit ca1716c67f
9 changed files with 922 additions and 166 deletions

3
gr00t/models/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .gr00t import build_gr00t_model
__all__ = ['build_gr00t_model']

168
gr00t/models/backbone.py Normal file
View File

@@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
# class DINOv2BackBone(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# self.body.eval()
# self.num_channels = 384
# @torch.no_grad()
# def forward(self, tensor):
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
# od = OrderedDict()
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
# return od
class DINOv2BackBone(nn.Module):
def __init__(self, return_interm_layers: bool = False) -> None:
super().__init__()
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.body.eval()
self.num_channels = 384
self.return_interm_layers = return_interm_layers
@torch.no_grad()
def forward(self, tensor):
features = self.body.forward_features(tensor)
if self.return_interm_layers:
layer1 = features["x_norm_patchtokens"]
layer2 = features["x_norm_patchtokens"]
layer3 = features["x_norm_patchtokens"]
layer4 = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
else:
xs = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
if args.backbone == 'dino_v2':
backbone = DINOv2BackBone()
else:
assert args.backbone in ['resnet18', 'resnet34']
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

142
gr00t/models/dit.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
import torch
from torch import nn
import torch.nn.functional as F
class TimestepEncoder(nn.Module):
def __init__(self, args):
super().__init__()
embedding_dim = args.embed_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
super().__init__()
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
super().__init__()
dim = args.embed_dim
num_heads = args.nheads
mlp_ratio = args.mlp_ratio
dropout = args.dropout
self.norm1 = AdaLayerNorm(dim)
if not use_self_attn:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
kdim=crosss_attention_dim,
vdim=crosss_attention_dim,
batch_first=True,
)
else:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mlp_ratio, dim),
nn.Dropout(dropout)
)
def forward(self, hidden_states, temb, context=None):
norm_hidden_states = self.norm1(hidden_states, temb)
attn_output = self.attn(
norm_hidden_states,
context if context is not None else norm_hidden_states,
context if context is not None else norm_hidden_states,
)[0]
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
def __init__(self, args, cross_attention_dim):
super().__init__()
inner_dim = args.embed_dim
num_layers = args.num_layers
output_dim = args.hidden_dim
self.timestep_encoder = TimestepEncoder(args)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1
if use_self_attn:
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
else:
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
all_blocks.append(block)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
def forward(self, hidden_states, timestep, encoder_hidden_states):
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(hidden_states, temb)
else:
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
def build_dit(args, cross_attention_dim):
return DiT(args, cross_attention_dim)

124
gr00t/models/gr00t.py Normal file
View File

@@ -0,0 +1,124 @@
from .modules import (
build_action_decoder,
build_action_encoder,
build_state_encoder,
build_time_sampler,
build_noise_scheduler,
)
from .backbone import build_backbone
from .dit import build_dit
import torch
import torch.nn as nn
import torch.nn.functional as F
class gr00t(nn.Module):
def __init__(
self,
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
num_queries,
camera_names,
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.dit = dit
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.action_decoder = action_decoder
self.time_sampler = time_sampler
self.noise_scheduler = noise_scheduler
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
else:
raise NotImplementedError
def forward(self, qpos, image, actions=None, is_pad=None):
is_training = actions is not None # train or val
bs, _ = qpos.shape
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
B, C, H, W = features.shape
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
all_cam_features.append(features_seq)
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
if is_training:
# training logic
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
actions, timesteps
)
t_discretized = (timesteps[:, 0, 0] * 1000).long()
action_features = self.action_encoder(noisy_actions, t_discretized)
sa_embs = torch.cat((state_features, action_features), dim=1)
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
return pred_actions, action_loss
else:
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
k = 5
dt = 1.0 / k
for t in range(k):
t_cont = t / float(k)
t_discretized = int(t_cont * 1000)
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
action_features = self.action_encoder(actions, timesteps)
sa_embs = torch.cat((state_features, action_features), dim=1)
# Create tensor of shape [B] for DiT (consistent with training path)
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.num_queries :]
actions = actions + pred_velocity * dt
return actions, _
def build_gr00t_model(args):
state_dim = args.state_dim
action_dim = args.action_dim
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
cross_attention_dim = backbones[0].num_channels
dit = build_dit(args, cross_attention_dim)
state_encoder = build_state_encoder(args)
action_encoder = build_action_encoder(args)
action_decoder = build_action_decoder(args)
time_sampler = build_time_sampler(args)
noise_scheduler = build_noise_scheduler(args)
model = gr00t(
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
args.num_queries,
args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model

179
gr00t/models/modules.py Normal file
View File

@@ -0,0 +1,179 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# ActionEncoder
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.embed_dim
def forward(self, timesteps):
timesteps = timesteps.float()
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embed_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class ActionEncoder(nn.Module):
def __init__(self, args):
super().__init__()
action_dim = args.action_dim
embed_dim = args.embed_dim
self.W1 = nn.Linear(action_dim, embed_dim)
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
self.W3 = nn.Linear(embed_dim, embed_dim)
self.pos_encoder = SinusoidalPositionalEncoding(args)
def forward(self, actions, timesteps):
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
# Reshape to (B,) then expand to (B, T)
# if timesteps.dim() == 3:
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
# timesteps = timesteps[:, 0, 0]
# elif timesteps.dim() == 2:
# # Shape (B, 1) or (B, T) -> take first element if needed
# if timesteps.shape[1] == 1:
# timesteps = timesteps[:, 0]
# # else: already (B, T), use as is
# elif timesteps.dim() != 1:
# raise ValueError(
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
# )
# Now timesteps should be (B,), expand to (B, T)
if timesteps.dim() == 1 and timesteps.shape[0] == B:
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.W2(x))
# 5) Finally W3 => (B, T, w)
x = self.W3(x)
return x
def build_action_encoder(args):
return ActionEncoder(args)
# StateEncoder
class StateEncoder(nn.Module):
def __init__(self, args):
super().__init__()
input_dim = args.state_dim
hidden_dim = args.hidden_dim
output_dim = args.embed_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, states):
state_emb = self.mlp(states) # [B, emb_dim]
state_emb = state_emb.unsqueeze(1)
return state_emb # [B, 1, emb_dim]
def build_state_encoder(args):
return StateEncoder(args)
# ActionDecoder
class ActionDecoder(nn.Module):
def __init__(self,args):
super().__init__()
input_dim = args.hidden_dim
hidden_dim = args.hidden_dim
output_dim = args.action_dim
self.num_queries = args.num_queries
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, model_output):
pred_actions = self.mlp(model_output)
return pred_actions[:, -self.num_queries:]
def build_action_decoder(args):
return ActionDecoder(args)
# TimeSampler
class TimeSampler(nn.Module):
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
super().__init__()
self.noise_s = noise_s
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
def forward(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = (1 - sample) * self.noise_s
return sample[:, None, None]
def build_time_sampler(args):
return TimeSampler()
# NoiseScheduler
import torch
import torch.nn as nn
class FlowMatchingScheduler(nn.Module):
def __init__(self):
super().__init__()
# --- 训练逻辑:加噪并计算目标 ---
def add_noise(self, actions, timesteps):
noise = torch.randn_like(actions)
noisy_samples = actions * timesteps + noise * (1 - timesteps)
target_velocity = actions - noise
return noisy_samples, target_velocity
# --- 推理逻辑:欧拉步 (Euler Step) ---
def step(self, model_output, sample, dt):
prev_sample = sample + model_output * dt
return prev_sample
def build_noise_scheduler(args):
return FlowMatchingScheduler()

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding