diff --git a/roboimi/demos/config.yaml b/roboimi/demos/config.yaml index 16ab129..3b16eb1 100644 --- a/roboimi/demos/config.yaml +++ b/roboimi/demos/config.yaml @@ -8,7 +8,7 @@ temporal_agg: false # policy_class: "ACT" # backbone: 'resnet18' -policy_class: "ACTTV" +policy_class: "GR00T" backbone: 'dino_v2' seed: 0 @@ -51,6 +51,21 @@ nheads: 8 qpos_noise_std: 0 DT: 0.02 +gr00t: + action_dim: 16 + state_dim: 16 + embed_dim: 1536 + hidden_dim: 1024 + num_queries: 8 + + nheads: 32 + mlp_ratio: 4 + dropout: 0.2 + + num_layers: 16 + + + # DO NOT CHANGE IF UNNECESSARY lr: 0.00001 kl_weight: 100 diff --git a/roboimi/demos/diana_eval.py b/roboimi/demos/diana_eval.py index 1c85258..a5e71e5 100644 --- a/roboimi/demos/diana_eval.py +++ b/roboimi/demos/diana_eval.py @@ -71,11 +71,10 @@ def run_episode(config, policy, stats, save_episode,num_rollouts): qpos = pre_process(qpos_numpy) qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) curr_image = get_image(env._get_image_obs(), config['camera_names']) - if config['policy_class'] == "ACT" or "ACTTV": + if config['policy_class'] in ["ACT", "ACTTV", "GR00T"]: if t % query_frequency == 0: all_actions = policy(qpos, curr_image) raw_action = all_actions[:, t % query_frequency] - # raw_action = all_actions[:, t % 1] raw_action = raw_action.squeeze(0).cpu().numpy() elif config['policy_class'] == "CNNMLP": raw_action = policy(qpos, curr_image) diff --git a/roboimi/gr00t/main.py b/roboimi/gr00t/main.py new file mode 100644 index 0000000..c56b359 --- /dev/null +++ b/roboimi/gr00t/main.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +GR00T (diffusion-based DiT policy) model builder. + +This module provides functions to build GR00T models and optimizers +from configuration dictionaries (typically from config.yaml's 'gr00t:' section). +""" +import argparse +from pathlib import Path + +import numpy as np +import torch +from .models import build_gr00t_model + + +def get_args_parser(): + """ + Create argument parser for GR00T model configuration. + + All parameters can be overridden via args_override dictionary in + build_gr00t_model_and_optimizer(). This allows loading from config.yaml. + """ + parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False) + + # Training parameters + parser.add_argument('--lr', default=1e-5, type=float, + help='Learning rate for main parameters') + parser.add_argument('--lr_backbone', default=1e-5, type=float, + help='Learning rate for backbone parameters') + parser.add_argument('--weight_decay', default=1e-4, type=float, + help='Weight decay for optimizer') + + # GR00T model architecture parameters + parser.add_argument('--embed_dim', default=1536, type=int, + help='Embedding dimension for transformer') + parser.add_argument('--hidden_dim', default=1024, type=int, + help='Hidden dimension for MLP layers') + parser.add_argument('--state_dim', default=16, type=int, + help='State (qpos) dimension') + parser.add_argument('--action_dim', default=16, type=int, + help='Action dimension') + parser.add_argument('--num_queries', default=16, type=int, + help='Number of action queries (chunk size)') + + # DiT (Diffusion Transformer) parameters + parser.add_argument('--num_layers', default=16, type=int, + help='Number of transformer layers') + parser.add_argument('--nheads', default=32, type=int, + help='Number of attention heads') + parser.add_argument('--mlp_ratio', default=4, type=float, + help='MLP hidden dimension ratio') + parser.add_argument('--dropout', default=0.2, type=float, + help='Dropout rate') + + # Backbone parameters + parser.add_argument('--backbone', default='dino_v2', type=str, + help='Backbone architecture (dino_v2, resnet18, resnet34)') + parser.add_argument('--position_embedding', default='sine', type=str, + choices=('sine', 'learned'), + help='Type of positional encoding') + + # Camera configuration + parser.add_argument('--camera_names', default=[], nargs='+', + help='List of camera names for observations') + + # Other parameters (not directly used but kept for compatibility) + parser.add_argument('--batch_size', default=15, type=int) + parser.add_argument('--epochs', default=20000, type=int) + parser.add_argument('--masks', action='store_true', + help='Use intermediate layer features') + parser.add_argument('--dilation', action='store_false', + help='Use dilated convolution in backbone') + + return parser + + +def build_gr00t_model_and_optimizer(args_override): + """ + Build GR00T model and optimizer from config dictionary. + + This function is designed to work with config.yaml loading: + 1. Parse default arguments + 2. Override with values from args_override (typically from config['gr00t']) + 3. Build model and optimizer + + Args: + args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section + Expected keys: embed_dim, hidden_dim, state_dim, action_dim, + num_queries, nheads, mlp_ratio, dropout, num_layers, + lr, lr_backbone, camera_names, backbone, etc. + + Returns: + model: GR00T model on CUDA + optimizer: AdamW optimizer with separate learning rates for backbone and other params + """ + parser = argparse.ArgumentParser('GR00T training and evaluation script', + parents=[get_args_parser()]) + args = parser.parse_args() + + # Override with config values + for k, v in args_override.items(): + setattr(args, k, v) + + # Build model + model = build_gr00t_model(args) + model.cuda() + + # Create parameter groups with different learning rates + param_dicts = [ + { + "params": [p for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad] + }, + { + "params": [p for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + + optimizer = torch.optim.AdamW(param_dicts, + lr=args.lr, + weight_decay=args.weight_decay) + + return model, optimizer diff --git a/roboimi/gr00t/models/__init__.py b/roboimi/gr00t/models/__init__.py new file mode 100644 index 0000000..327396a --- /dev/null +++ b/roboimi/gr00t/models/__init__.py @@ -0,0 +1,3 @@ +from .gr00t import build_gr00t_model + +__all__ = ['build_gr00t_model'] diff --git a/roboimi/gr00t/models/backbone.py b/roboimi/gr00t/models/backbone.py new file mode 100644 index 0000000..759bfb5 --- /dev/null +++ b/roboimi/gr00t/models/backbone.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +# class DINOv2BackBone(nn.Module): +# def __init__(self) -> None: +# super().__init__() +# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') +# self.body.eval() +# self.num_channels = 384 + +# @torch.no_grad() +# def forward(self, tensor): +# xs = self.body.forward_features(tensor)["x_norm_patchtokens"] +# od = OrderedDict() +# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1) +# return od + +class DINOv2BackBone(nn.Module): + def __init__(self, return_interm_layers: bool = False) -> None: + super().__init__() + self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') + self.body.eval() + self.num_channels = 384 + self.return_interm_layers = return_interm_layers + + @torch.no_grad() + def forward(self, tensor): + features = self.body.forward_features(tensor) + + if self.return_interm_layers: + + layer1 = features["x_norm_patchtokens"] + layer2 = features["x_norm_patchtokens"] + layer3 = features["x_norm_patchtokens"] + layer4 = features["x_norm_patchtokens"] + + od = OrderedDict() + od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + return od + else: + xs = features["x_norm_patchtokens"] + od = OrderedDict() + od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + return od + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + if args.backbone == 'dino_v2': + backbone = DINOv2BackBone() + else: + assert args.backbone in ['resnet18', 'resnet34'] + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/roboimi/gr00t/models/dit.py b/roboimi/gr00t/models/dit.py new file mode 100644 index 0000000..ad8cede --- /dev/null +++ b/roboimi/gr00t/models/dit.py @@ -0,0 +1,142 @@ +from typing import Optional + +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps +import torch +from torch import nn +import torch.nn.functional as F + +class TimestepEncoder(nn.Module): + def __init__(self, args): + super().__init__() + embedding_dim = args.embed_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps): + dtype = next(self.parameters()).dtype + timesteps_proj = self.time_proj(timesteps).to(dtype) + timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) + return timesteps_emb + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False): + super().__init__() + + output_dim = embedding_dim * 2 + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + temb = self.linear(self.silu(temb)) + scale, shift = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] + return x + + +class BasicTransformerBlock(nn.Module): + def __init__(self, args, crosss_attention_dim, use_self_attn=False): + super().__init__() + dim = args.embed_dim + num_heads = args.nheads + mlp_ratio = args.mlp_ratio + dropout = args.dropout + self.norm1 = AdaLayerNorm(dim) + + if not use_self_attn: + self.attn = nn.MultiheadAttention( + embed_dim=dim, + num_heads=num_heads, + dropout=dropout, + kdim=crosss_attention_dim, + vdim=crosss_attention_dim, + batch_first=True, + ) + else: + self.attn = nn.MultiheadAttention( + embed_dim=dim, + num_heads=num_heads, + dropout=dropout, + batch_first=True, + ) + + self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False) + + self.mlp = nn.Sequential( + nn.Linear(dim, dim * mlp_ratio), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * mlp_ratio, dim), + nn.Dropout(dropout) + ) + + def forward(self, hidden_states, temb, context=None): + norm_hidden_states = self.norm1(hidden_states, temb) + + attn_output = self.attn( + norm_hidden_states, + context if context is not None else norm_hidden_states, + context if context is not None else norm_hidden_states, + )[0] + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + ff_output = self.mlp(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + +class DiT(nn.Module): + def __init__(self, args, cross_attention_dim): + super().__init__() + inner_dim = args.embed_dim + num_layers = args.num_layers + output_dim = args.hidden_dim + + self.timestep_encoder = TimestepEncoder(args) + + all_blocks = [] + for idx in range(num_layers): + use_self_attn = idx % 2 == 1 + if use_self_attn: + block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True) + else: + block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False) + all_blocks.append(block) + + self.transformer_blocks = nn.ModuleList(all_blocks) + + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, output_dim) + + def forward(self, hidden_states, timestep, encoder_hidden_states): + temb = self.timestep_encoder(timestep) + + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1: + hidden_states = block(hidden_states, temb) + else: + hidden_states = block(hidden_states, temb, context=encoder_hidden_states) + + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + return self.proj_out_2(hidden_states) + + +def build_dit(args, cross_attention_dim): + return DiT(args, cross_attention_dim) \ No newline at end of file diff --git a/roboimi/gr00t/models/gr00t.py b/roboimi/gr00t/models/gr00t.py new file mode 100644 index 0000000..7ed9cb4 --- /dev/null +++ b/roboimi/gr00t/models/gr00t.py @@ -0,0 +1,124 @@ + +from .modules import ( + build_action_decoder, + build_action_encoder, + build_state_encoder, + build_time_sampler, + build_noise_scheduler, +) +from .backbone import build_backbone +from .dit import build_dit +import torch +import torch.nn as nn +import torch.nn.functional as F + +class gr00t(nn.Module): + def __init__( + self, + backbones, + dit, + state_encoder, + action_encoder, + action_decoder, + time_sampler, + noise_scheduler, + num_queries, + camera_names, + ): + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.dit = dit + self.state_encoder = state_encoder + self.action_encoder = action_encoder + self.action_decoder = action_decoder + self.time_sampler = time_sampler + self.noise_scheduler = noise_scheduler + + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + else: + raise NotImplementedError + + def forward(self, qpos, image, actions=None, is_pad=None): + is_training = actions is not None # train or val + bs, _ = qpos.shape + + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + B, C, H, W = features.shape + features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C) + all_cam_features.append(features_seq) + encoder_hidden_states = torch.cat(all_cam_features, dim=1) + + state_features = self.state_encoder(qpos) # [B, 1, emb_dim] + + if is_training: + # training logic + + timesteps = self.time_sampler(bs, actions.device, actions.dtype) + noisy_actions, target_velocity = self.noise_scheduler.add_noise( + actions, timesteps + ) + t_discretized = (timesteps[:, 0, 0] * 1000).long() + action_features = self.action_encoder(noisy_actions, t_discretized) + sa_embs = torch.cat((state_features, action_features), dim=1) + model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states) + pred = self.action_decoder(model_output) + pred_actions = pred[:, -actions.shape[1] :] + action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none') + return pred_actions, action_loss + else: + actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype) + k = 5 + dt = 1.0 / k + for t in range(k): + t_cont = t / float(k) + t_discretized = int(t_cont * 1000) + timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype) + action_features = self.action_encoder(actions, timesteps) + sa_embs = torch.cat((state_features, action_features), dim=1) + # Create tensor of shape [B] for DiT (consistent with training path) + model_output = self.dit(sa_embs, timesteps, encoder_hidden_states) + pred = self.action_decoder(model_output) + pred_velocity = pred[:, -self.num_queries :] + actions = actions + pred_velocity * dt + return actions, _ +def build_gr00t_model(args): + state_dim = args.state_dim + action_dim = args.action_dim + + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + cross_attention_dim = backbones[0].num_channels + + dit = build_dit(args, cross_attention_dim) + + state_encoder = build_state_encoder(args) + action_encoder = build_action_encoder(args) + action_decoder = build_action_decoder(args) + time_sampler = build_time_sampler(args) + noise_scheduler = build_noise_scheduler(args) + model = gr00t( + backbones, + dit, + state_encoder, + action_encoder, + action_decoder, + time_sampler, + noise_scheduler, + args.num_queries, + args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters/1e6,)) + return model + + diff --git a/roboimi/gr00t/models/modules.py b/roboimi/gr00t/models/modules.py new file mode 100644 index 0000000..727cee3 --- /dev/null +++ b/roboimi/gr00t/models/modules.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ActionEncoder +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, args): + super().__init__() + self.embed_dim = args.embed_dim + + def forward(self, timesteps): + timesteps = timesteps.float() + B, T = timesteps.shape + device = timesteps.device + + half_dim = self.embed_dim // 2 + + exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + + freqs = timesteps.unsqueeze(-1) * exponent.exp() + + sin = torch.sin(freqs) + cos = torch.cos(freqs) + enc = torch.cat([sin, cos], dim=-1) # (B, T, w) + + return enc + + +class ActionEncoder(nn.Module): + def __init__(self, args): + super().__init__() + action_dim = args.action_dim + embed_dim = args.embed_dim + + self.W1 = nn.Linear(action_dim, embed_dim) + self.W2 = nn.Linear(2 * embed_dim, embed_dim) + self.W3 = nn.Linear(embed_dim, embed_dim) + + self.pos_encoder = SinusoidalPositionalEncoding(args) + + def forward(self, actions, timesteps): + B, T, _ = actions.shape + + # 1) Expand each batch's single scalar time 'tau' across all T steps + # so that shape => (B, T) + # Handle different input shapes: (B,), (B, 1), (B, 1, 1) + # Reshape to (B,) then expand to (B, T) + # if timesteps.dim() == 3: + # # Shape (B, 1, 1) or (B, T, 1) -> (B,) + # timesteps = timesteps[:, 0, 0] + # elif timesteps.dim() == 2: + # # Shape (B, 1) or (B, T) -> take first element if needed + # if timesteps.shape[1] == 1: + # timesteps = timesteps[:, 0] + # # else: already (B, T), use as is + # elif timesteps.dim() != 1: + # raise ValueError( + # f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}" + # ) + + # Now timesteps should be (B,), expand to (B, T) + if timesteps.dim() == 1 and timesteps.shape[0] == B: + timesteps = timesteps.unsqueeze(1).expand(-1, T) + else: + raise ValueError( + "Expected `timesteps` to have shape (B,) so we can replicate across T." + ) + + # 2) Standard action MLP step for shape => (B, T, w) + a_emb = self.W1(actions) + + # 3) Get the sinusoidal encoding (B, T, w) + tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype) + + # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish + x = torch.cat([a_emb, tau_emb], dim=-1) + x = F.silu(self.W2(x)) + + # 5) Finally W3 => (B, T, w) + x = self.W3(x) + + return x + + +def build_action_encoder(args): + return ActionEncoder(args) + + +# StateEncoder +class StateEncoder(nn.Module): + def __init__(self, args): + super().__init__() + input_dim = args.state_dim + hidden_dim = args.hidden_dim + output_dim = args.embed_dim + + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, states): + state_emb = self.mlp(states) # [B, emb_dim] + state_emb = state_emb.unsqueeze(1) + return state_emb # [B, 1, emb_dim] + + +def build_state_encoder(args): + return StateEncoder(args) + + +# ActionDecoder +class ActionDecoder(nn.Module): + def __init__(self,args): + super().__init__() + input_dim = args.hidden_dim + hidden_dim = args.hidden_dim + output_dim = args.action_dim + + self.num_queries = args.num_queries + + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, model_output): + pred_actions = self.mlp(model_output) + return pred_actions[:, -self.num_queries:] + + +def build_action_decoder(args): + return ActionDecoder(args) + + +# TimeSampler +class TimeSampler(nn.Module): + def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0): + super().__init__() + self.noise_s = noise_s + self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta) + + def forward(self, batch_size, device, dtype): + sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype) + sample = (1 - sample) * self.noise_s + return sample[:, None, None] + + +def build_time_sampler(args): + return TimeSampler() + + +# NoiseScheduler +import torch +import torch.nn as nn + +class FlowMatchingScheduler(nn.Module): + def __init__(self): + super().__init__() + + # --- 训练逻辑:加噪并计算目标 --- + def add_noise(self, actions, timesteps): + noise = torch.randn_like(actions) + noisy_samples = actions * timesteps + noise * (1 - timesteps) + target_velocity = actions - noise + + return noisy_samples, target_velocity + + # --- 推理逻辑:欧拉步 (Euler Step) --- + def step(self, model_output, sample, dt): + prev_sample = sample + model_output * dt + return prev_sample + +def build_noise_scheduler(args): + return FlowMatchingScheduler() diff --git a/roboimi/gr00t/models/position_encoding.py b/roboimi/gr00t/models/position_encoding.py new file mode 100644 index 0000000..c75733e --- /dev/null +++ b/roboimi/gr00t/models/position_encoding.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/roboimi/gr00t/policy.py b/roboimi/gr00t/policy.py new file mode 100644 index 0000000..83416d4 --- /dev/null +++ b/roboimi/gr00t/policy.py @@ -0,0 +1,90 @@ +""" +GR00T Policy wrapper for imitation learning. + +This module provides the gr00tPolicy class that wraps the GR00T model +for training and evaluation in the imitation learning framework. +""" +import torch.nn as nn +from torch.nn import functional as F +from torchvision.transforms import v2 +import torch +from roboimi.gr00t.main import build_gr00t_model_and_optimizer + + +class gr00tPolicy(nn.Module): + """ + GR00T Policy for action prediction using diffusion-based DiT architecture. + + This policy wraps the GR00T model and handles: + - Image resizing to match DINOv2 patch size requirements + - Image normalization (ImageNet stats) + - Training with action chunks and loss computation + - Inference with diffusion sampling + """ + def __init__(self, args_override): + super().__init__() + model, optimizer = build_gr00t_model_and_optimizer(args_override) + self.model = model + self.optimizer = optimizer + + # DINOv2 requires image dimensions to be multiples of patch size (14) + # Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336) + self.patch_h = 16 # Number of patches vertically + self.patch_w = 22 # Number of patches horizontally + target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308) + + # Training transform with data augmentation + self.train_transform = v2.Compose([ + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), + v2.RandomPerspective(distortion_scale=0.5), + v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), + v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)), + v2.Resize(target_size), + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + + # Inference transform (no augmentation) + self.inference_transform = v2.Compose([ + v2.Resize(target_size), + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + + def __call__(self, qpos, image, actions=None, is_pad=None): + """ + Forward pass for training or inference. + + Args: + qpos: Joint positions [B, state_dim] + image: Camera images [B, num_cameras, C, H, W] + actions: Ground truth actions [B, chunk_size, action_dim] (training only) + is_pad: Padding mask [B, chunk_size] (training only) + + Returns: + Training: dict with 'mse' loss + Inference: predicted actions [B, num_queries, action_dim] + """ + # Apply transforms (resize + normalization) + if actions is not None: # training time + image = self.train_transform(image) + else: # inference time + image = self.inference_transform(image) + + if actions is not None: # training time + actions = actions[:, :self.model.num_queries] + is_pad = is_pad[:, :self.model.num_queries] + _, action_loss = self.model(qpos, image, actions, is_pad) + + # Mask out padded positions + mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean() + + loss_dict = { + 'loss': mse_loss + } + return loss_dict + else: # inference time + a_hat, _ = self.model(qpos, image) + return a_hat + + def configure_optimizers(self): + """Return the optimizer for training.""" + return self.optimizer diff --git a/roboimi/utils/model_interface.py b/roboimi/utils/model_interface.py index 007b0f7..fe00697 100644 --- a/roboimi/utils/model_interface.py +++ b/roboimi/utils/model_interface.py @@ -1,7 +1,8 @@ import os import torch from roboimi.utils.utils import load_data, set_seed -from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy +from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy +from roboimi.gr00t.policy import gr00tPolicy class ModelInterface: def __init__(self, config): @@ -59,12 +60,32 @@ class ModelInterface: } elif self.config['policy_class'] == 'CNNMLP': self.config['policy_config'] = { - 'lr': self.config['lr'], - 'lr_backbone': self.config['lr_backbone'], - 'backbone': self.config['backbone'], + 'lr': self.config['lr'], + 'lr_backbone': self.config['lr_backbone'], + 'backbone': self.config['backbone'], 'num_queries': 1, 'camera_names': self.config['camera_names'], } + elif self.config['policy_class'] == 'GR00T': + # GR00T uses its own config section from config.yaml + gr00t_config = self.config.get('gr00t', {}) + self.config['policy_config'] = { + 'lr': gr00t_config.get('lr', self.config['lr']), + 'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']), + 'weight_decay': gr00t_config.get('weight_decay', 1e-4), + 'embed_dim': gr00t_config.get('embed_dim', 1536), + 'hidden_dim': gr00t_config.get('hidden_dim', 1024), + 'state_dim': gr00t_config.get('state_dim', 16), + 'action_dim': gr00t_config.get('action_dim', 16), + 'num_queries': gr00t_config.get('num_queries', 16), + 'num_layers': gr00t_config.get('num_layers', 16), + 'nheads': gr00t_config.get('nheads', 32), + 'mlp_ratio': gr00t_config.get('mlp_ratio', 4), + 'dropout': gr00t_config.get('dropout', 0.2), + 'backbone': gr00t_config.get('backbone', 'dino_v2'), + 'position_embedding': gr00t_config.get('position_embedding', 'sine'), + 'camera_names': self.config['camera_names'], + } else: raise NotImplementedError @@ -75,6 +96,8 @@ class ModelInterface: return ACTTVPolicy(self.config['policy_config']) elif self.config['policy_class'] == 'CNNMLP': return CNNMLPPolicy(self.config['policy_config']) + elif self.config['policy_class'] == 'GR00T': + return gr00tPolicy(self.config['policy_config']) else: raise NotImplementedError