Compare commits
10 Commits
40c40695dd
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23088e5e33 | ||
|
|
ca1716c67f | ||
|
|
642d41dd8f | ||
|
|
7d39933a5b | ||
|
|
8bcad5844e | ||
|
|
cdb887c9bf | ||
|
|
abb4f501e3 | ||
|
|
1d33db0ef0 | ||
|
|
f27e397f98 | ||
|
|
4e0add4e1d |
125
gr00t/main.py
Normal file
125
gr00t/main.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
GR00T (diffusion-based DiT policy) model builder.
|
||||||
|
|
||||||
|
This module provides functions to build GR00T models and optimizers
|
||||||
|
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from .models import build_gr00t_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser():
|
||||||
|
"""
|
||||||
|
Create argument parser for GR00T model configuration.
|
||||||
|
|
||||||
|
All parameters can be overridden via args_override dictionary in
|
||||||
|
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument('--lr', default=1e-5, type=float,
|
||||||
|
help='Learning rate for main parameters')
|
||||||
|
parser.add_argument('--lr_backbone', default=1e-5, type=float,
|
||||||
|
help='Learning rate for backbone parameters')
|
||||||
|
parser.add_argument('--weight_decay', default=1e-4, type=float,
|
||||||
|
help='Weight decay for optimizer')
|
||||||
|
|
||||||
|
# GR00T model architecture parameters
|
||||||
|
parser.add_argument('--embed_dim', default=1536, type=int,
|
||||||
|
help='Embedding dimension for transformer')
|
||||||
|
parser.add_argument('--hidden_dim', default=1024, type=int,
|
||||||
|
help='Hidden dimension for MLP layers')
|
||||||
|
parser.add_argument('--state_dim', default=16, type=int,
|
||||||
|
help='State (qpos) dimension')
|
||||||
|
parser.add_argument('--action_dim', default=16, type=int,
|
||||||
|
help='Action dimension')
|
||||||
|
parser.add_argument('--num_queries', default=16, type=int,
|
||||||
|
help='Number of action queries (chunk size)')
|
||||||
|
|
||||||
|
# DiT (Diffusion Transformer) parameters
|
||||||
|
parser.add_argument('--num_layers', default=16, type=int,
|
||||||
|
help='Number of transformer layers')
|
||||||
|
parser.add_argument('--nheads', default=32, type=int,
|
||||||
|
help='Number of attention heads')
|
||||||
|
parser.add_argument('--mlp_ratio', default=4, type=float,
|
||||||
|
help='MLP hidden dimension ratio')
|
||||||
|
parser.add_argument('--dropout', default=0.2, type=float,
|
||||||
|
help='Dropout rate')
|
||||||
|
|
||||||
|
# Backbone parameters
|
||||||
|
parser.add_argument('--backbone', default='dino_v2', type=str,
|
||||||
|
help='Backbone architecture (dino_v2, resnet18, resnet34)')
|
||||||
|
parser.add_argument('--position_embedding', default='sine', type=str,
|
||||||
|
choices=('sine', 'learned'),
|
||||||
|
help='Type of positional encoding')
|
||||||
|
|
||||||
|
# Camera configuration
|
||||||
|
parser.add_argument('--camera_names', default=[], nargs='+',
|
||||||
|
help='List of camera names for observations')
|
||||||
|
|
||||||
|
# Other parameters (not directly used but kept for compatibility)
|
||||||
|
parser.add_argument('--batch_size', default=15, type=int)
|
||||||
|
parser.add_argument('--epochs', default=20000, type=int)
|
||||||
|
parser.add_argument('--masks', action='store_true',
|
||||||
|
help='Use intermediate layer features')
|
||||||
|
parser.add_argument('--dilation', action='store_false',
|
||||||
|
help='Use dilated convolution in backbone')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def build_gr00t_model_and_optimizer(args_override):
|
||||||
|
"""
|
||||||
|
Build GR00T model and optimizer from config dictionary.
|
||||||
|
|
||||||
|
This function is designed to work with config.yaml loading:
|
||||||
|
1. Parse default arguments
|
||||||
|
2. Override with values from args_override (typically from config['gr00t'])
|
||||||
|
3. Build model and optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
|
||||||
|
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
|
||||||
|
num_queries, nheads, mlp_ratio, dropout, num_layers,
|
||||||
|
lr, lr_backbone, camera_names, backbone, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model: GR00T model on CUDA
|
||||||
|
optimizer: AdamW optimizer with separate learning rates for backbone and other params
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser('GR00T training and evaluation script',
|
||||||
|
parents=[get_args_parser()])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Override with config values
|
||||||
|
for k, v in args_override.items():
|
||||||
|
setattr(args, k, v)
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
model = build_gr00t_model(args)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
# Create parameter groups with different learning rates
|
||||||
|
param_dicts = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters()
|
||||||
|
if "backbone" not in n and p.requires_grad]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters()
|
||||||
|
if "backbone" in n and p.requires_grad],
|
||||||
|
"lr": args.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(param_dicts,
|
||||||
|
lr=args.lr,
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
return model, optimizer
|
||||||
3
gr00t/models/__init__.py
Normal file
3
gr00t/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .gr00t import build_gr00t_model
|
||||||
|
|
||||||
|
__all__ = ['build_gr00t_model']
|
||||||
168
gr00t/models/backbone.py
Normal file
168
gr00t/models/backbone.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Backbone modules.
|
||||||
|
"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from util.misc import NestedTensor, is_main_process
|
||||||
|
|
||||||
|
from .position_encoding import build_position_encoding
|
||||||
|
|
||||||
|
class FrozenBatchNorm2d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||||
|
|
||||||
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||||
|
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||||
|
produce nans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n):
|
||||||
|
super(FrozenBatchNorm2d, self).__init__()
|
||||||
|
self.register_buffer("weight", torch.ones(n))
|
||||||
|
self.register_buffer("bias", torch.zeros(n))
|
||||||
|
self.register_buffer("running_mean", torch.zeros(n))
|
||||||
|
self.register_buffer("running_var", torch.ones(n))
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
|
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
||||||
|
if num_batches_tracked_key in state_dict:
|
||||||
|
del state_dict[num_batches_tracked_key]
|
||||||
|
|
||||||
|
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# move reshapes to the beginning
|
||||||
|
# to make it fuser-friendly
|
||||||
|
w = self.weight.reshape(1, -1, 1, 1)
|
||||||
|
b = self.bias.reshape(1, -1, 1, 1)
|
||||||
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||||
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||||
|
eps = 1e-5
|
||||||
|
scale = w * (rv + eps).rsqrt()
|
||||||
|
bias = b - rm * scale
|
||||||
|
return x * scale + bias
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneBase(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
||||||
|
super().__init__()
|
||||||
|
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||||
|
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||||
|
# parameter.requires_grad_(False)
|
||||||
|
if return_interm_layers:
|
||||||
|
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||||
|
else:
|
||||||
|
return_layers = {'layer4': "0"}
|
||||||
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
xs = self.body(tensor)
|
||||||
|
return xs
|
||||||
|
# out: Dict[str, NestedTensor] = {}
|
||||||
|
# for name, x in xs.items():
|
||||||
|
# m = tensor_list.mask
|
||||||
|
# assert m is not None
|
||||||
|
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||||
|
# out[name] = NestedTensor(x, mask)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(BackboneBase):
|
||||||
|
"""ResNet backbone with frozen BatchNorm."""
|
||||||
|
def __init__(self, name: str,
|
||||||
|
train_backbone: bool,
|
||||||
|
return_interm_layers: bool,
|
||||||
|
dilation: bool):
|
||||||
|
backbone = getattr(torchvision.models, name)(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||||
|
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||||
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||||
|
|
||||||
|
|
||||||
|
# class DINOv2BackBone(nn.Module):
|
||||||
|
# def __init__(self) -> None:
|
||||||
|
# super().__init__()
|
||||||
|
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||||
|
# self.body.eval()
|
||||||
|
# self.num_channels = 384
|
||||||
|
|
||||||
|
# @torch.no_grad()
|
||||||
|
# def forward(self, tensor):
|
||||||
|
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
|
||||||
|
# od = OrderedDict()
|
||||||
|
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
# return od
|
||||||
|
|
||||||
|
class DINOv2BackBone(nn.Module):
|
||||||
|
def __init__(self, return_interm_layers: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||||
|
self.body.eval()
|
||||||
|
self.num_channels = 384
|
||||||
|
self.return_interm_layers = return_interm_layers
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, tensor):
|
||||||
|
features = self.body.forward_features(tensor)
|
||||||
|
|
||||||
|
if self.return_interm_layers:
|
||||||
|
|
||||||
|
layer1 = features["x_norm_patchtokens"]
|
||||||
|
layer2 = features["x_norm_patchtokens"]
|
||||||
|
layer3 = features["x_norm_patchtokens"]
|
||||||
|
layer4 = features["x_norm_patchtokens"]
|
||||||
|
|
||||||
|
od = OrderedDict()
|
||||||
|
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
return od
|
||||||
|
else:
|
||||||
|
xs = features["x_norm_patchtokens"]
|
||||||
|
od = OrderedDict()
|
||||||
|
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
|
||||||
|
return od
|
||||||
|
|
||||||
|
class Joiner(nn.Sequential):
|
||||||
|
def __init__(self, backbone, position_embedding):
|
||||||
|
super().__init__(backbone, position_embedding)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
xs = self[0](tensor_list)
|
||||||
|
out: List[NestedTensor] = []
|
||||||
|
pos = []
|
||||||
|
for name, x in xs.items():
|
||||||
|
out.append(x)
|
||||||
|
# position encoding
|
||||||
|
pos.append(self[1](x).to(x.dtype))
|
||||||
|
|
||||||
|
return out, pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(args):
|
||||||
|
position_embedding = build_position_encoding(args)
|
||||||
|
train_backbone = args.lr_backbone > 0
|
||||||
|
return_interm_layers = args.masks
|
||||||
|
if args.backbone == 'dino_v2':
|
||||||
|
backbone = DINOv2BackBone()
|
||||||
|
else:
|
||||||
|
assert args.backbone in ['resnet18', 'resnet34']
|
||||||
|
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||||
|
model = Joiner(backbone, position_embedding)
|
||||||
|
model.num_channels = backbone.num_channels
|
||||||
|
return model
|
||||||
142
gr00t/models/dit.py
Normal file
142
gr00t/models/dit.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class TimestepEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
embedding_dim = args.embed_dim
|
||||||
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
output_dim = embedding_dim * 2
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
temb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
scale, shift = temb.chunk(2, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.embed_dim
|
||||||
|
num_heads = args.nheads
|
||||||
|
mlp_ratio = args.mlp_ratio
|
||||||
|
dropout = args.dropout
|
||||||
|
self.norm1 = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
if not use_self_attn:
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
embed_dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
kdim=crosss_attention_dim,
|
||||||
|
vdim=crosss_attention_dim,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
embed_dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * mlp_ratio),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * mlp_ratio, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb, context=None):
|
||||||
|
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||||
|
|
||||||
|
attn_output = self.attn(
|
||||||
|
norm_hidden_states,
|
||||||
|
context if context is not None else norm_hidden_states,
|
||||||
|
context if context is not None else norm_hidden_states,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
ff_output = self.mlp(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class DiT(nn.Module):
|
||||||
|
def __init__(self, args, cross_attention_dim):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = args.embed_dim
|
||||||
|
num_layers = args.num_layers
|
||||||
|
output_dim = args.hidden_dim
|
||||||
|
|
||||||
|
self.timestep_encoder = TimestepEncoder(args)
|
||||||
|
|
||||||
|
all_blocks = []
|
||||||
|
for idx in range(num_layers):
|
||||||
|
use_self_attn = idx % 2 == 1
|
||||||
|
if use_self_attn:
|
||||||
|
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
||||||
|
else:
|
||||||
|
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
||||||
|
all_blocks.append(block)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||||
|
|
||||||
|
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||||
|
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
||||||
|
temb = self.timestep_encoder(timestep)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.contiguous()
|
||||||
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||||
|
|
||||||
|
for idx, block in enumerate(self.transformer_blocks):
|
||||||
|
if idx % 2 == 1:
|
||||||
|
hidden_states = block(hidden_states, temb)
|
||||||
|
else:
|
||||||
|
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
||||||
|
|
||||||
|
conditioning = temb
|
||||||
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||||
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
return self.proj_out_2(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dit(args, cross_attention_dim):
|
||||||
|
return DiT(args, cross_attention_dim)
|
||||||
124
gr00t/models/gr00t.py
Normal file
124
gr00t/models/gr00t.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
|
||||||
|
from .modules import (
|
||||||
|
build_action_decoder,
|
||||||
|
build_action_encoder,
|
||||||
|
build_state_encoder,
|
||||||
|
build_time_sampler,
|
||||||
|
build_noise_scheduler,
|
||||||
|
)
|
||||||
|
from .backbone import build_backbone
|
||||||
|
from .dit import build_dit
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class gr00t(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbones,
|
||||||
|
dit,
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
action_decoder,
|
||||||
|
time_sampler,
|
||||||
|
noise_scheduler,
|
||||||
|
num_queries,
|
||||||
|
camera_names,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.dit = dit
|
||||||
|
self.state_encoder = state_encoder
|
||||||
|
self.action_encoder = action_encoder
|
||||||
|
self.action_decoder = action_decoder
|
||||||
|
self.time_sampler = time_sampler
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
|
||||||
|
if backbones is not None:
|
||||||
|
self.backbones = nn.ModuleList(backbones)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
is_training = actions is not None # train or val
|
||||||
|
bs, _ = qpos.shape
|
||||||
|
|
||||||
|
all_cam_features = []
|
||||||
|
for cam_id, cam_name in enumerate(self.camera_names):
|
||||||
|
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||||
|
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||||
|
features = features[0] # take the last layer feature
|
||||||
|
B, C, H, W = features.shape
|
||||||
|
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
||||||
|
all_cam_features.append(features_seq)
|
||||||
|
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
||||||
|
|
||||||
|
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
# training logic
|
||||||
|
|
||||||
|
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
||||||
|
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
||||||
|
actions, timesteps
|
||||||
|
)
|
||||||
|
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
||||||
|
action_features = self.action_encoder(noisy_actions, t_discretized)
|
||||||
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||||
|
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
||||||
|
pred = self.action_decoder(model_output)
|
||||||
|
pred_actions = pred[:, -actions.shape[1] :]
|
||||||
|
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
||||||
|
return pred_actions, action_loss
|
||||||
|
else:
|
||||||
|
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
||||||
|
k = 5
|
||||||
|
dt = 1.0 / k
|
||||||
|
for t in range(k):
|
||||||
|
t_cont = t / float(k)
|
||||||
|
t_discretized = int(t_cont * 1000)
|
||||||
|
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
||||||
|
action_features = self.action_encoder(actions, timesteps)
|
||||||
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||||
|
# Create tensor of shape [B] for DiT (consistent with training path)
|
||||||
|
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
||||||
|
pred = self.action_decoder(model_output)
|
||||||
|
pred_velocity = pred[:, -self.num_queries :]
|
||||||
|
actions = actions + pred_velocity * dt
|
||||||
|
return actions, _
|
||||||
|
def build_gr00t_model(args):
|
||||||
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
|
backbones = []
|
||||||
|
for _ in args.camera_names:
|
||||||
|
backbone = build_backbone(args)
|
||||||
|
backbones.append(backbone)
|
||||||
|
|
||||||
|
cross_attention_dim = backbones[0].num_channels
|
||||||
|
|
||||||
|
dit = build_dit(args, cross_attention_dim)
|
||||||
|
|
||||||
|
state_encoder = build_state_encoder(args)
|
||||||
|
action_encoder = build_action_encoder(args)
|
||||||
|
action_decoder = build_action_decoder(args)
|
||||||
|
time_sampler = build_time_sampler(args)
|
||||||
|
noise_scheduler = build_noise_scheduler(args)
|
||||||
|
model = gr00t(
|
||||||
|
backbones,
|
||||||
|
dit,
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
action_decoder,
|
||||||
|
time_sampler,
|
||||||
|
noise_scheduler,
|
||||||
|
args.num_queries,
|
||||||
|
args.camera_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
179
gr00t/models/modules.py
Normal file
179
gr00t/models/modules.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# ActionEncoder
|
||||||
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = args.embed_dim
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
B, T = timesteps.shape
|
||||||
|
device = timesteps.device
|
||||||
|
|
||||||
|
half_dim = self.embed_dim // 2
|
||||||
|
|
||||||
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||||
|
torch.log(torch.tensor(10000.0)) / half_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||||
|
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||||
|
|
||||||
|
return enc
|
||||||
|
|
||||||
|
|
||||||
|
class ActionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
action_dim = args.action_dim
|
||||||
|
embed_dim = args.embed_dim
|
||||||
|
|
||||||
|
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||||
|
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
||||||
|
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
||||||
|
|
||||||
|
def forward(self, actions, timesteps):
|
||||||
|
B, T, _ = actions.shape
|
||||||
|
|
||||||
|
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||||
|
# so that shape => (B, T)
|
||||||
|
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
||||||
|
# Reshape to (B,) then expand to (B, T)
|
||||||
|
# if timesteps.dim() == 3:
|
||||||
|
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
||||||
|
# timesteps = timesteps[:, 0, 0]
|
||||||
|
# elif timesteps.dim() == 2:
|
||||||
|
# # Shape (B, 1) or (B, T) -> take first element if needed
|
||||||
|
# if timesteps.shape[1] == 1:
|
||||||
|
# timesteps = timesteps[:, 0]
|
||||||
|
# # else: already (B, T), use as is
|
||||||
|
# elif timesteps.dim() != 1:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Now timesteps should be (B,), expand to (B, T)
|
||||||
|
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||||
|
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Standard action MLP step for shape => (B, T, w)
|
||||||
|
a_emb = self.W1(actions)
|
||||||
|
|
||||||
|
# 3) Get the sinusoidal encoding (B, T, w)
|
||||||
|
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||||
|
|
||||||
|
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||||
|
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||||
|
x = F.silu(self.W2(x))
|
||||||
|
|
||||||
|
# 5) Finally W3 => (B, T, w)
|
||||||
|
x = self.W3(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_encoder(args):
|
||||||
|
return ActionEncoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# StateEncoder
|
||||||
|
class StateEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = args.state_dim
|
||||||
|
hidden_dim = args.hidden_dim
|
||||||
|
output_dim = args.embed_dim
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, states):
|
||||||
|
state_emb = self.mlp(states) # [B, emb_dim]
|
||||||
|
state_emb = state_emb.unsqueeze(1)
|
||||||
|
return state_emb # [B, 1, emb_dim]
|
||||||
|
|
||||||
|
|
||||||
|
def build_state_encoder(args):
|
||||||
|
return StateEncoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# ActionDecoder
|
||||||
|
class ActionDecoder(nn.Module):
|
||||||
|
def __init__(self,args):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = args.hidden_dim
|
||||||
|
hidden_dim = args.hidden_dim
|
||||||
|
output_dim = args.action_dim
|
||||||
|
|
||||||
|
self.num_queries = args.num_queries
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, model_output):
|
||||||
|
pred_actions = self.mlp(model_output)
|
||||||
|
return pred_actions[:, -self.num_queries:]
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_decoder(args):
|
||||||
|
return ActionDecoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# TimeSampler
|
||||||
|
class TimeSampler(nn.Module):
|
||||||
|
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.noise_s = noise_s
|
||||||
|
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
||||||
|
|
||||||
|
def forward(self, batch_size, device, dtype):
|
||||||
|
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||||
|
sample = (1 - sample) * self.noise_s
|
||||||
|
return sample[:, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
def build_time_sampler(args):
|
||||||
|
return TimeSampler()
|
||||||
|
|
||||||
|
|
||||||
|
# NoiseScheduler
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class FlowMatchingScheduler(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# --- 训练逻辑:加噪并计算目标 ---
|
||||||
|
def add_noise(self, actions, timesteps):
|
||||||
|
noise = torch.randn_like(actions)
|
||||||
|
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
||||||
|
target_velocity = actions - noise
|
||||||
|
|
||||||
|
return noisy_samples, target_velocity
|
||||||
|
|
||||||
|
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
||||||
|
def step(self, model_output, sample, dt):
|
||||||
|
prev_sample = sample + model_output * dt
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def build_noise_scheduler(args):
|
||||||
|
return FlowMatchingScheduler()
|
||||||
91
gr00t/models/position_encoding.py
Normal file
91
gr00t/models/position_encoding.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Various positional encodings for the transformer.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from util.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
x = tensor
|
||||||
|
# mask = tensor_list.mask
|
||||||
|
# assert mask is not None
|
||||||
|
# not_mask = ~mask
|
||||||
|
|
||||||
|
not_mask = torch.ones_like(x[0, [0]])
|
||||||
|
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingLearned(nn.Module):
|
||||||
|
"""
|
||||||
|
Absolute pos embedding, learned.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_pos_feats=256):
|
||||||
|
super().__init__()
|
||||||
|
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.uniform_(self.row_embed.weight)
|
||||||
|
nn.init.uniform_(self.col_embed.weight)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
i = torch.arange(w, device=x.device)
|
||||||
|
j = torch.arange(h, device=x.device)
|
||||||
|
x_emb = self.col_embed(i)
|
||||||
|
y_emb = self.row_embed(j)
|
||||||
|
pos = torch.cat([
|
||||||
|
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||||
|
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||||
|
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(args):
|
||||||
|
N_steps = args.hidden_dim // 2
|
||||||
|
if args.position_embedding in ('v2', 'sine'):
|
||||||
|
# TODO find a better way of exposing other arguments
|
||||||
|
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||||
|
elif args.position_embedding in ('v3', 'learned'):
|
||||||
|
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not supported {args.position_embedding}")
|
||||||
|
|
||||||
|
return position_embedding
|
||||||
90
gr00t/policy.py
Normal file
90
gr00t/policy.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
GR00T Policy wrapper for imitation learning.
|
||||||
|
|
||||||
|
This module provides the gr00tPolicy class that wraps the GR00T model
|
||||||
|
for training and evaluation in the imitation learning framework.
|
||||||
|
"""
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
import torch
|
||||||
|
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class gr00tPolicy(nn.Module):
|
||||||
|
"""
|
||||||
|
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
||||||
|
|
||||||
|
This policy wraps the GR00T model and handles:
|
||||||
|
- Image resizing to match DINOv2 patch size requirements
|
||||||
|
- Image normalization (ImageNet stats)
|
||||||
|
- Training with action chunks and loss computation
|
||||||
|
- Inference with diffusion sampling
|
||||||
|
"""
|
||||||
|
def __init__(self, args_override):
|
||||||
|
super().__init__()
|
||||||
|
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
||||||
|
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
||||||
|
self.patch_h = 16 # Number of patches vertically
|
||||||
|
self.patch_w = 22 # Number of patches horizontally
|
||||||
|
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
||||||
|
|
||||||
|
# Training transform with data augmentation
|
||||||
|
self.train_transform = v2.Compose([
|
||||||
|
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||||
|
v2.RandomPerspective(distortion_scale=0.5),
|
||||||
|
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||||
|
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||||
|
v2.Resize(target_size),
|
||||||
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Inference transform (no augmentation)
|
||||||
|
self.inference_transform = v2.Compose([
|
||||||
|
v2.Resize(target_size),
|
||||||
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
"""
|
||||||
|
Forward pass for training or inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos: Joint positions [B, state_dim]
|
||||||
|
image: Camera images [B, num_cameras, C, H, W]
|
||||||
|
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
||||||
|
is_pad: Padding mask [B, chunk_size] (training only)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training: dict with 'mse' loss
|
||||||
|
Inference: predicted actions [B, num_queries, action_dim]
|
||||||
|
"""
|
||||||
|
# Apply transforms (resize + normalization)
|
||||||
|
if actions is not None: # training time
|
||||||
|
image = self.train_transform(image)
|
||||||
|
else: # inference time
|
||||||
|
image = self.inference_transform(image)
|
||||||
|
|
||||||
|
if actions is not None: # training time
|
||||||
|
actions = actions[:, :self.model.num_queries]
|
||||||
|
is_pad = is_pad[:, :self.model.num_queries]
|
||||||
|
_, action_loss = self.model(qpos, image, actions, is_pad)
|
||||||
|
|
||||||
|
# Mask out padded positions
|
||||||
|
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
|
||||||
|
loss_dict = {
|
||||||
|
'loss': mse_loss
|
||||||
|
}
|
||||||
|
return loss_dict
|
||||||
|
else: # inference time
|
||||||
|
a_hat, _ = self.model(qpos, image)
|
||||||
|
return a_hat
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""Return the optimizer for training."""
|
||||||
|
return self.optimizer
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
<body name="box" pos="0.2 1.0 0.47">
|
<body name="box" pos="0.2 1.0 0.47">
|
||||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.018 0.018 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||||
</body>
|
</body>
|
||||||
</worldbody>
|
</worldbody>
|
||||||
</mujoco>
|
</mujoco>
|
||||||
|
|||||||
@@ -8,6 +8,6 @@
|
|||||||
</body>
|
</body>
|
||||||
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
|
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
|
||||||
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
|
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
|
||||||
<camera name="front" pos="0 0 0.9" fovy="55" mode="fixed" quat="0.7071 0.7071 0 0"/>
|
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
|
||||||
</worldbody>
|
</worldbody>
|
||||||
</mujoco>
|
</mujoco>
|
||||||
|
|||||||
@@ -104,8 +104,8 @@ class TestPickAndTransferPolicy(PolicyBase):
|
|||||||
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep
|
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep
|
||||||
{"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100},
|
{"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100},
|
||||||
{"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube
|
{"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube
|
||||||
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
|
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
|
||||||
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
|
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
|
||||||
{"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position
|
{"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position
|
||||||
{"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position
|
{"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position
|
||||||
{"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper
|
{"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper
|
||||||
@@ -116,8 +116,8 @@ class TestPickAndTransferPolicy(PolicyBase):
|
|||||||
self.left_trajectory = [
|
self.left_trajectory = [
|
||||||
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep
|
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep
|
||||||
{"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position
|
{"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position
|
||||||
{"t": 500, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
|
{"t": 500, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
|
||||||
{"t": 505, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
|
{"t": 505, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
|
||||||
{"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left
|
{"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left
|
||||||
{"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay
|
{"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -32,6 +32,12 @@ def main():
|
|||||||
|
|
||||||
env = make_sim_env(task_name)
|
env = make_sim_env(task_name)
|
||||||
policy = TestPickAndTransferPolicy(inject_noise)
|
policy = TestPickAndTransferPolicy(inject_noise)
|
||||||
|
|
||||||
|
# 等待osmesa完全启动后再开始收集数据
|
||||||
|
print("等待osmesa线程启动...")
|
||||||
|
time.sleep(60)
|
||||||
|
print("osmesa已就绪,开始收集数据...")
|
||||||
|
|
||||||
for episode_idx in range(num_episodes):
|
for episode_idx in range(num_episodes):
|
||||||
obs = []
|
obs = []
|
||||||
reward_ee = []
|
reward_ee = []
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import json
|
|||||||
import pickle
|
import pickle
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
|
||||||
|
"""
|
||||||
|
解析恢复训练用的 checkpoint 路径。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resume_ckpt: 配置中的 resume_ckpt,支持路径或 "auto"
|
||||||
|
checkpoint_dir: 默认检查点目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path 或 None
|
||||||
|
"""
|
||||||
|
if resume_ckpt is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if str(resume_ckpt).lower() != "auto":
|
||||||
|
return Path(resume_ckpt)
|
||||||
|
|
||||||
|
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
|
||||||
|
candidates = []
|
||||||
|
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
|
||||||
|
match = pattern.search(ckpt_path.name)
|
||||||
|
if match:
|
||||||
|
candidates.append((int(match.group(1)), ckpt_path))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
return max(candidates, key=lambda x: x[0])[1]
|
||||||
|
|
||||||
|
|
||||||
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
|
||||||
"""
|
"""
|
||||||
创建带预热的学习率调度器。
|
创建带预热的学习率调度器。
|
||||||
@@ -139,6 +169,7 @@ def main(cfg: DictConfig):
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
persistent_workers=(cfg.train.num_workers > 0),
|
||||||
drop_last=True # 丢弃不完整批次以稳定训练
|
drop_last=True # 丢弃不完整批次以稳定训练
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,6 +181,7 @@ def main(cfg: DictConfig):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=cfg.train.num_workers,
|
num_workers=cfg.train.num_workers,
|
||||||
pin_memory=(cfg.train.device != "cpu"),
|
pin_memory=(cfg.train.device != "cpu"),
|
||||||
|
persistent_workers=(cfg.train.num_workers > 0),
|
||||||
drop_last=False
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -248,8 +280,11 @@ def main(cfg: DictConfig):
|
|||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 4. 设置优化器与学习率调度器
|
# 4. 设置优化器与学习率调度器
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
|
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
|
||||||
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
|
grad_clip = float(cfg.train.get('grad_clip', 1.0))
|
||||||
|
|
||||||
|
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
|
||||||
|
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
|
||||||
|
|
||||||
# 设置带预热的学習率调度器
|
# 设置带预热的学習率调度器
|
||||||
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
warmup_steps = int(cfg.train.get('warmup_steps', 500))
|
||||||
@@ -265,6 +300,52 @@ def main(cfg: DictConfig):
|
|||||||
)
|
)
|
||||||
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
log.info(f"📈 学习率调度器: {scheduler_type},{warmup_steps} 步预热 (最小学习率={min_lr})")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
|
||||||
|
# =========================================================================
|
||||||
|
start_step = 0
|
||||||
|
resume_loss = None
|
||||||
|
resume_best_loss = float('inf')
|
||||||
|
|
||||||
|
resume_ckpt = cfg.train.get('resume_ckpt', None)
|
||||||
|
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
|
||||||
|
if resume_ckpt is not None:
|
||||||
|
if pretrained_ckpt is not None:
|
||||||
|
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt,将优先使用 resume_ckpt 进行断点续训")
|
||||||
|
if resume_path is None:
|
||||||
|
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint,将从头开始训练")
|
||||||
|
elif not resume_path.exists():
|
||||||
|
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
else:
|
||||||
|
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
|
||||||
|
|
||||||
|
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
resume_step = int(checkpoint['step'])
|
||||||
|
start_step = resume_step + 1
|
||||||
|
|
||||||
|
loaded_loss = checkpoint.get('loss', None)
|
||||||
|
loaded_val_loss = checkpoint.get('val_loss', None)
|
||||||
|
resume_loss = float(loaded_loss) if loaded_loss is not None else None
|
||||||
|
if loaded_val_loss is not None:
|
||||||
|
resume_best_loss = float(loaded_val_loss)
|
||||||
|
elif loaded_loss is not None:
|
||||||
|
resume_best_loss = float(loaded_loss)
|
||||||
|
|
||||||
|
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
|
||||||
|
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"❌ [Resume] 恢复失败: {e}")
|
||||||
|
log.warning("⚠️ 将从头开始训练")
|
||||||
|
start_step = 0
|
||||||
|
resume_loss = None
|
||||||
|
resume_best_loss = float('inf')
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 5. 训练循环
|
# 5. 训练循环
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
@@ -311,9 +392,15 @@ def main(cfg: DictConfig):
|
|||||||
return total_loss / max(num_batches, 1)
|
return total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
data_iter = iter(train_loader)
|
data_iter = iter(train_loader)
|
||||||
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
|
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
|
||||||
|
|
||||||
best_loss = float('inf')
|
best_loss = resume_best_loss
|
||||||
|
last_loss = resume_loss
|
||||||
|
|
||||||
|
if start_step >= cfg.train.max_steps:
|
||||||
|
log.warning(
|
||||||
|
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
|
||||||
|
)
|
||||||
|
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
try:
|
try:
|
||||||
@@ -346,6 +433,8 @@ def main(cfg: DictConfig):
|
|||||||
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
last_loss = loss.item()
|
||||||
|
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
# 反向传播与优化
|
# 反向传播与优化
|
||||||
# =====================================================================
|
# =====================================================================
|
||||||
@@ -353,7 +442,7 @@ def main(cfg: DictConfig):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# 梯度裁剪以稳定训练
|
# 梯度裁剪以稳定训练
|
||||||
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
@@ -422,15 +511,21 @@ def main(cfg: DictConfig):
|
|||||||
'model_state_dict': agent.state_dict(),
|
'model_state_dict': agent.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'scheduler_state_dict': scheduler.state_dict(),
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
'loss': loss.item(),
|
'loss': last_loss,
|
||||||
'dataset_stats': agent_stats, # 保存agent的统计信息
|
'dataset_stats': agent_stats, # 保存agent的统计信息
|
||||||
'current_lr': optimizer.param_groups[0]['lr'],
|
'current_lr': optimizer.param_groups[0]['lr'],
|
||||||
}, final_model_path)
|
}, final_model_path)
|
||||||
log.info(f"💾 最终模型已保存: {final_model_path}")
|
log.info(f"💾 最终模型已保存: {final_model_path}")
|
||||||
|
|
||||||
log.info("✅ 训练成功完成!")
|
log.info("✅ 训练成功完成!")
|
||||||
log.info(f"📊 最终损失: {loss.item():.4f}")
|
if last_loss is not None:
|
||||||
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
log.info(f"📊 最终损失: {last_loss:.4f}")
|
||||||
|
else:
|
||||||
|
log.info("📊 最终损失: N/A(未执行训练步)")
|
||||||
|
if best_loss != float('inf'):
|
||||||
|
log.info(f"📊 最佳损失: {best_loss:.4f}")
|
||||||
|
else:
|
||||||
|
log.info("📊 最佳损失: N/A(无有效验证/训练损失)")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -230,7 +230,8 @@ class DualDianaMed(MujocoEnv):
|
|||||||
img_renderer.update_scene(self.mj_data,camera="front")
|
img_renderer.update_scene(self.mj_data,camera="front")
|
||||||
self.front = img_renderer.render()
|
self.front = img_renderer.render()
|
||||||
self.front = self.front[:, :, ::-1]
|
self.front = self.front[:, :, ::-1]
|
||||||
cv2.imshow('Cam view', self.cam_view)
|
if self.cam_view is not None:
|
||||||
|
cv2.imshow('Cam view', self.cam_view)
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,10 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
|||||||
self.mj_data.joint('red_box_joint').qpos[5] = 0.0
|
self.mj_data.joint('red_box_joint').qpos[5] = 0.0
|
||||||
self.mj_data.joint('red_box_joint').qpos[6] = 0.0
|
self.mj_data.joint('red_box_joint').qpos[6] = 0.0
|
||||||
super().reset()
|
super().reset()
|
||||||
|
self.top = None
|
||||||
|
self.angle = None
|
||||||
|
self.r_vis = None
|
||||||
|
self.front = None
|
||||||
self.cam_flage = True
|
self.cam_flage = True
|
||||||
t=0
|
t=0
|
||||||
while self.cam_flage:
|
while self.cam_flage:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
|
|||||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||||
|
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 保存参数
|
# 保存参数
|
||||||
@@ -37,6 +38,7 @@ class VLAAgent(nn.Module):
|
|||||||
self.num_cams = num_cams
|
self.num_cams = num_cams
|
||||||
self.num_action_steps = num_action_steps
|
self.num_action_steps = num_action_steps
|
||||||
self.inference_steps = inference_steps
|
self.inference_steps = inference_steps
|
||||||
|
self.head_type = head_type # 'unet' 或 'transformer'
|
||||||
|
|
||||||
|
|
||||||
# 归一化模块 - 统一训练和推理的归一化逻辑
|
# 归一化模块 - 统一训练和推理的归一化逻辑
|
||||||
@@ -47,10 +49,15 @@ class VLAAgent(nn.Module):
|
|||||||
|
|
||||||
self.vision_encoder = vision_backbone
|
self.vision_encoder = vision_backbone
|
||||||
single_cam_feat_dim = self.vision_encoder.output_dim
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
|
# global_cond_dim: 展平后的总维度(用于UNet)
|
||||||
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
|
||||||
total_prop_dim = obs_dim * obs_horizon
|
total_prop_dim = obs_dim * obs_horizon
|
||||||
self.global_cond_dim = total_vision_dim + total_prop_dim
|
self.global_cond_dim = total_vision_dim + total_prop_dim
|
||||||
|
|
||||||
|
# per_step_cond_dim: 每步的条件维度(用于Transformer)
|
||||||
|
# 注意:这里不乘以obs_horizon,因为Transformer的输入是序列形式
|
||||||
|
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||||
|
|
||||||
self.noise_scheduler = DDPMScheduler(
|
self.noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=diffusion_steps,
|
num_train_timesteps=diffusion_steps,
|
||||||
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
|
||||||
@@ -66,11 +73,27 @@ class VLAAgent(nn.Module):
|
|||||||
prediction_type='epsilon'
|
prediction_type='epsilon'
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_pred_net = head(
|
# 根据head类型初始化不同的参数
|
||||||
input_dim=action_dim,
|
if head_type == 'transformer':
|
||||||
# input_dim = action_dim + obs_dim, # 备选:包含观测维度
|
# 如果head已经是nn.Module实例,直接使用;否则需要初始化
|
||||||
global_cond_dim=self.global_cond_dim
|
if isinstance(head, nn.Module):
|
||||||
)
|
# 已经是实例化的模块(测试时直接传入<E4BCA0><E585A5>
|
||||||
|
self.noise_pred_net = head
|
||||||
|
else:
|
||||||
|
# Hydra部分初始化的对象,调用时传入参数
|
||||||
|
self.noise_pred_net = head(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=pred_horizon,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=self.per_step_cond_dim # 每步的条件维度
|
||||||
|
)
|
||||||
|
else: # 'unet' (default)
|
||||||
|
# UNet接口: input_dim, global_cond_dim
|
||||||
|
self.noise_pred_net = head(
|
||||||
|
input_dim=action_dim,
|
||||||
|
global_cond_dim=self.global_cond_dim
|
||||||
|
)
|
||||||
|
|
||||||
self.state_encoder = state_encoder
|
self.state_encoder = state_encoder
|
||||||
self.action_encoder = action_encoder
|
self.action_encoder = action_encoder
|
||||||
@@ -78,6 +101,22 @@ class VLAAgent(nn.Module):
|
|||||||
# 初始化队列(用于在线推理)
|
# 初始化队列(用于在线推理)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def _get_model_device(self) -> torch.device:
|
||||||
|
"""获取模型当前所在设备。"""
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def _move_to_device(self, data, device: torch.device):
|
||||||
|
"""递归地将张量数据移动到指定设备。"""
|
||||||
|
if torch.is_tensor(data):
|
||||||
|
return data.to(device)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||||
|
if isinstance(data, list):
|
||||||
|
return [self._move_to_device(v, device) for v in data]
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
return tuple(self._move_to_device(v, device) for v in data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# 训练阶段 (Training)
|
# 训练阶段 (Training)
|
||||||
@@ -124,13 +163,22 @@ class VLAAgent(nn.Module):
|
|||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
global_cond = global_cond.flatten(start_dim=1)
|
global_cond = global_cond.flatten(start_dim=1)
|
||||||
|
|
||||||
|
# 5. 网络预测噪声(根据head类型选择接口)
|
||||||
# 5. 网络预测噪声
|
if self.head_type == 'transformer':
|
||||||
pred_noise = self.noise_pred_net(
|
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
|
||||||
sample=noisy_actions,
|
# 将展平的global_cond reshape回序列格式
|
||||||
timestep=timesteps,
|
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||||
global_cond=global_cond
|
pred_noise = self.noise_pred_net(
|
||||||
)
|
sample=noisy_actions,
|
||||||
|
timestep=timesteps,
|
||||||
|
cond=cond
|
||||||
|
)
|
||||||
|
else: # 'unet'
|
||||||
|
pred_noise = self.noise_pred_net(
|
||||||
|
sample=noisy_actions,
|
||||||
|
timestep=timesteps,
|
||||||
|
global_cond=global_cond
|
||||||
|
)
|
||||||
|
|
||||||
# 6. 计算 Loss (MSE),支持 padding mask
|
# 6. 计算 Loss (MSE),支持 padding mask
|
||||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||||||
@@ -138,8 +186,9 @@ class VLAAgent(nn.Module):
|
|||||||
# 如果提供了 action_is_pad,对padding位置进行mask
|
# 如果提供了 action_is_pad,对padding位置进行mask
|
||||||
if action_is_pad is not None:
|
if action_is_pad is not None:
|
||||||
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||||||
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
|
||||||
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
|
valid_count = mask.sum() * loss.shape[-1]
|
||||||
|
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||||
else:
|
else:
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
@@ -230,33 +279,10 @@ class VLAAgent(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
action: (action_dim,) 单个动作
|
action: (action_dim,) 单个动作
|
||||||
"""
|
"""
|
||||||
# 检测设备并确保所有组件在同一设备上
|
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
|
||||||
# 尝试从观测中获取设备
|
# 避免根据CPU观测把模型错误搬回CPU。
|
||||||
device = None
|
device = self._get_model_device()
|
||||||
for v in observation.values():
|
observation = self._move_to_device(observation, device)
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
device = v.device
|
|
||||||
break
|
|
||||||
|
|
||||||
if device is not None and self.normalization.enabled:
|
|
||||||
# 确保归一化参数在同一设备上
|
|
||||||
# 根据归一化类型获取正确的属性
|
|
||||||
if self.normalization.normalization_type == 'gaussian':
|
|
||||||
norm_device = self.normalization.qpos_mean.device
|
|
||||||
else: # min_max
|
|
||||||
norm_device = self.normalization.qpos_min.device
|
|
||||||
|
|
||||||
if device != norm_device:
|
|
||||||
self.normalization.to(device)
|
|
||||||
# 同时确保其他模块也在正确设备
|
|
||||||
self.vision_encoder.to(device)
|
|
||||||
self.state_encoder.to(device)
|
|
||||||
self.action_encoder.to(device)
|
|
||||||
self.noise_pred_net.to(device)
|
|
||||||
|
|
||||||
# 将所有 observation 移到正确设备
|
|
||||||
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
|
||||||
for k, v in observation.items()}
|
|
||||||
|
|
||||||
# 将新观测添加到队列
|
# 将新观测添加到队列
|
||||||
self._populate_queues(observation)
|
self._populate_queues(observation)
|
||||||
@@ -323,6 +349,16 @@ class VLAAgent(nn.Module):
|
|||||||
visual_features = self.vision_encoder(images)
|
visual_features = self.vision_encoder(images)
|
||||||
state_features = self.state_encoder(proprioception)
|
state_features = self.state_encoder(proprioception)
|
||||||
|
|
||||||
|
# 拼接条件(只计算一次)
|
||||||
|
# visual_features: (B, obs_horizon, vision_dim)
|
||||||
|
# state_features: (B, obs_horizon, obs_dim)
|
||||||
|
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
global_cond_flat = global_cond.flatten(start_dim=1)
|
||||||
|
if self.head_type == 'transformer':
|
||||||
|
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
|
||||||
|
else:
|
||||||
|
cond = None
|
||||||
|
|
||||||
# 2. 初始化纯高斯噪声动作
|
# 2. 初始化纯高斯噪声动作
|
||||||
# 形状: (B, pred_horizon, action_dim)
|
# 形状: (B, pred_horizon, action_dim)
|
||||||
device = visual_features.device
|
device = visual_features.device
|
||||||
@@ -336,19 +372,19 @@ class VLAAgent(nn.Module):
|
|||||||
for t in self.infer_scheduler.timesteps:
|
for t in self.infer_scheduler.timesteps:
|
||||||
model_input = current_actions
|
model_input = current_actions
|
||||||
|
|
||||||
# 拼接全局条件并展平
|
# 预测噪声(根据head类型选择接口)
|
||||||
# visual_features: (B, obs_horizon, vision_dim)
|
if self.head_type == 'transformer':
|
||||||
# state_features: (B, obs_horizon, obs_dim)
|
noise_pred = self.noise_pred_net(
|
||||||
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
|
sample=model_input,
|
||||||
global_cond = torch.cat([visual_features, state_features], dim=-1)
|
timestep=t,
|
||||||
global_cond = global_cond.flatten(start_dim=1)
|
cond=cond
|
||||||
|
)
|
||||||
# 预测噪声
|
else: # 'unet'
|
||||||
noise_pred = self.noise_pred_net(
|
noise_pred = self.noise_pred_net(
|
||||||
sample=model_input,
|
sample=model_input,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
global_cond=global_cond
|
global_cond=global_cond_flat
|
||||||
)
|
)
|
||||||
|
|
||||||
# 移除噪声,更新 current_actions
|
# 移除噪声,更新 current_actions
|
||||||
current_actions = self.infer_scheduler.step(
|
current_actions = self.infer_scheduler.step(
|
||||||
|
|||||||
217
roboimi/vla/agent_gr00t_dit.py
Normal file
217
roboimi/vla/agent_gr00t_dit.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
|
||||||
|
from roboimi.vla.models.normalization import NormalizationModule
|
||||||
|
|
||||||
|
|
||||||
|
class VLAAgentGr00tDiT(nn.Module):
|
||||||
|
"""
|
||||||
|
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
|
||||||
|
Other components (backbone/encoders/scheduler/queue logic) stay aligned
|
||||||
|
with the existing VLAAgent implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_backbone,
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
head,
|
||||||
|
action_dim,
|
||||||
|
obs_dim,
|
||||||
|
pred_horizon=16,
|
||||||
|
obs_horizon=4,
|
||||||
|
diffusion_steps=100,
|
||||||
|
inference_steps=10,
|
||||||
|
num_cams=3,
|
||||||
|
dataset_stats=None,
|
||||||
|
normalization_type="min_max",
|
||||||
|
num_action_steps=8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.obs_dim = obs_dim
|
||||||
|
self.pred_horizon = pred_horizon
|
||||||
|
self.obs_horizon = obs_horizon
|
||||||
|
self.num_cams = num_cams
|
||||||
|
self.num_action_steps = num_action_steps
|
||||||
|
self.inference_steps = inference_steps
|
||||||
|
|
||||||
|
self.normalization = NormalizationModule(
|
||||||
|
stats=dataset_stats,
|
||||||
|
normalization_type=normalization_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_encoder = vision_backbone
|
||||||
|
single_cam_feat_dim = self.vision_encoder.output_dim
|
||||||
|
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
|
||||||
|
|
||||||
|
self.noise_scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=diffusion_steps,
|
||||||
|
beta_schedule="squaredcos_cap_v2",
|
||||||
|
clip_sample=True,
|
||||||
|
prediction_type="epsilon",
|
||||||
|
)
|
||||||
|
self.infer_scheduler = DDIMScheduler(
|
||||||
|
num_train_timesteps=diffusion_steps,
|
||||||
|
beta_schedule="squaredcos_cap_v2",
|
||||||
|
clip_sample=True,
|
||||||
|
prediction_type="epsilon",
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(head, nn.Module):
|
||||||
|
self.noise_pred_net = head
|
||||||
|
else:
|
||||||
|
self.noise_pred_net = head(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=pred_horizon,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=self.per_step_cond_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.state_encoder = state_encoder
|
||||||
|
self.action_encoder = action_encoder
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def _get_model_device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def _move_to_device(self, data, device: torch.device):
|
||||||
|
if torch.is_tensor(data):
|
||||||
|
return data.to(device)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {k: self._move_to_device(v, device) for k, v in data.items()}
|
||||||
|
if isinstance(data, list):
|
||||||
|
return [self._move_to_device(v, device) for v in data]
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
return tuple(self._move_to_device(v, device) for v in data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
|
||||||
|
visual_features = self.vision_encoder(images)
|
||||||
|
state_features = self.state_encoder(states)
|
||||||
|
return torch.cat([visual_features, state_features], dim=-1)
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
actions, states, images = batch["action"], batch["qpos"], batch["images"]
|
||||||
|
action_is_pad = batch.get("action_is_pad", None)
|
||||||
|
bsz = actions.shape[0]
|
||||||
|
|
||||||
|
states = self.normalization.normalize_qpos(states)
|
||||||
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
|
||||||
|
action_features = self.action_encoder(actions)
|
||||||
|
cond = self._build_cond(images, states)
|
||||||
|
|
||||||
|
noise = torch.randn_like(action_features)
|
||||||
|
timesteps = torch.randint(
|
||||||
|
0,
|
||||||
|
self.noise_scheduler.config.num_train_timesteps,
|
||||||
|
(bsz,),
|
||||||
|
device=action_features.device,
|
||||||
|
).long()
|
||||||
|
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
|
||||||
|
|
||||||
|
pred_noise = self.noise_pred_net(
|
||||||
|
sample=noisy_actions,
|
||||||
|
timestep=timesteps,
|
||||||
|
cond=cond,
|
||||||
|
)
|
||||||
|
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
|
||||||
|
|
||||||
|
if action_is_pad is not None:
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||||
|
valid_count = mask.sum() * loss.shape[-1]
|
||||||
|
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||||
|
else:
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._queues = {
|
||||||
|
"qpos": deque(maxlen=self.obs_horizon),
|
||||||
|
"images": deque(maxlen=self.obs_horizon),
|
||||||
|
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
|
||||||
|
if "qpos" in observation:
|
||||||
|
self._queues["qpos"].append(observation["qpos"].clone())
|
||||||
|
if "images" in observation:
|
||||||
|
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
|
||||||
|
|
||||||
|
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
|
||||||
|
qpos_list = list(self._queues["qpos"])
|
||||||
|
if len(qpos_list) == 0:
|
||||||
|
raise ValueError("observation queue is empty.")
|
||||||
|
while len(qpos_list) < self.obs_horizon:
|
||||||
|
qpos_list.append(qpos_list[-1])
|
||||||
|
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
|
||||||
|
|
||||||
|
images_list = list(self._queues["images"])
|
||||||
|
if len(images_list) == 0:
|
||||||
|
raise ValueError("image queue is empty.")
|
||||||
|
while len(images_list) < self.obs_horizon:
|
||||||
|
images_list.append(images_list[-1])
|
||||||
|
|
||||||
|
batch_images = {}
|
||||||
|
for cam_name in images_list[0].keys():
|
||||||
|
batch_images[cam_name] = torch.stack(
|
||||||
|
[img[cam_name] for img in images_list], dim=0
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
return {"qpos": batch_qpos, "images": batch_images}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
device = self._get_model_device()
|
||||||
|
observation = self._move_to_device(observation, device)
|
||||||
|
self._populate_queues(observation)
|
||||||
|
|
||||||
|
if len(self._queues["action"]) == 0:
|
||||||
|
batch = self._prepare_observation_batch()
|
||||||
|
actions = self.predict_action_chunk(batch)
|
||||||
|
start = self.obs_horizon - 1
|
||||||
|
end = start + self.num_action_steps
|
||||||
|
executable_actions = actions[:, start:end]
|
||||||
|
for i in range(executable_actions.shape[1]):
|
||||||
|
self._queues["action"].append(executable_actions[:, i].squeeze(0))
|
||||||
|
|
||||||
|
return self._queues["action"].popleft()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.predict_action(batch["images"], batch["qpos"])
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(self, images, proprioception):
|
||||||
|
bsz = proprioception.shape[0]
|
||||||
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
cond = self._build_cond(images, proprioception)
|
||||||
|
|
||||||
|
device = cond.device
|
||||||
|
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
|
||||||
|
self.infer_scheduler.set_timesteps(self.inference_steps)
|
||||||
|
|
||||||
|
for t in self.infer_scheduler.timesteps:
|
||||||
|
noise_pred = self.noise_pred_net(
|
||||||
|
sample=current_actions,
|
||||||
|
timestep=t,
|
||||||
|
cond=cond,
|
||||||
|
)
|
||||||
|
current_actions = self.infer_scheduler.step(
|
||||||
|
noise_pred, t, current_actions
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
return self.normalization.denormalize_action(current_actions)
|
||||||
|
|
||||||
|
def get_normalization_stats(self):
|
||||||
|
return self.normalization.get_stats()
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ normalization_type: "min_max" # "min_max" or "gaussian"
|
|||||||
# ====================
|
# ====================
|
||||||
pred_horizon: 16 # 预测未来多少步动作
|
pred_horizon: 16 # 预测未来多少步动作
|
||||||
obs_horizon: 2 # 使用多少步历史观测
|
obs_horizon: 2 # 使用多少步历史观测
|
||||||
num_action_steps: 16 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||||
|
|
||||||
# ====================
|
# ====================
|
||||||
# 相机配置
|
# 相机配置
|
||||||
|
|||||||
37
roboimi/vla/conf/agent/resnet_gr00t_dit.yaml
Normal file
37
roboimi/vla/conf/agent/resnet_gr00t_dit.yaml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: gr00t_dit1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
|
||||||
|
|
||||||
|
# Model dimensions
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
normalization_type: "min_max"
|
||||||
|
|
||||||
|
# Horizons
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
|
||||||
|
# Cameras
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
# Diffusion
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 10
|
||||||
|
|
||||||
|
# Head overrides
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
|
||||||
54
roboimi/vla/conf/agent/resnet_transformer.yaml
Normal file
54
roboimi/vla/conf/agent/resnet_transformer.yaml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent.VLAAgent
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 模型维度配置
|
||||||
|
# ====================
|
||||||
|
action_dim: 16 # 动作维度(机器人关节数)
|
||||||
|
obs_dim: 16 # 本体感知维度(关节位置)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 归一化配置
|
||||||
|
# ====================
|
||||||
|
normalization_type: "min_max" # "min_max" or "gaussian"
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 时间步配置
|
||||||
|
# ====================
|
||||||
|
pred_horizon: 16 # 预测未来多少步动作
|
||||||
|
obs_horizon: 2 # 使用多少步历史观测
|
||||||
|
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 相机配置
|
||||||
|
# ====================
|
||||||
|
num_cams: 3 # 摄像头数量 (r_vis, top, front)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 扩散过程配置
|
||||||
|
# ====================
|
||||||
|
diffusion_steps: 100 # 扩散训练步数(DDPM)
|
||||||
|
inference_steps: 10 # 推理时的去噪步数(DDIM,<4D><EFBC8C>定为 10)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Head 类型标识(用于VLAAgent选择调用方式)
|
||||||
|
# ====================
|
||||||
|
head_type: "transformer" # "unet" 或 "transformer"
|
||||||
|
|
||||||
|
# Head 参数覆盖
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
# Transformer的cond_dim是每步的维度
|
||||||
|
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
|
||||||
|
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
|
||||||
|
cond_dim: 208
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- agent: resnet_diffusion
|
- agent: resnet_transformer
|
||||||
- data: simpe_robot_dataset
|
- data: simpe_robot_dataset
|
||||||
- eval: eval
|
- eval: eval
|
||||||
- _self_
|
- _self_
|
||||||
@@ -10,7 +10,7 @@ defaults:
|
|||||||
train:
|
train:
|
||||||
# 基础训练参数
|
# 基础训练参数
|
||||||
batch_size: 8 # 批次大小
|
batch_size: 8 # 批次大小
|
||||||
lr: 1e-4 # 学习率
|
lr: 5e-5 # 学习率(Transformer建议更小)
|
||||||
max_steps: 100000 # 最大训练步数
|
max_steps: 100000 # 最大训练步数
|
||||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ train:
|
|||||||
save_freq: 2000 # 保存检查点频率(步数)
|
save_freq: 2000 # 保存检查点频率(步数)
|
||||||
|
|
||||||
# 学习率调度器(带预热)
|
# 学习率调度器(带预热)
|
||||||
warmup_steps: 500 # 预热步数
|
warmup_steps: 2000 # 预热步数(Transformer建议更长)
|
||||||
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
|
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
|
||||||
min_lr: 1e-6 # 最小学习率(用于余弦退火)
|
min_lr: 1e-6 # 最小学习率(用于余弦退火)
|
||||||
|
|
||||||
|
|||||||
22
roboimi/vla/conf/head/gr00t_dit1d.yaml
Normal file
22
roboimi/vla/conf/head/gr00t_dit1d.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
# DiT architecture
|
||||||
|
n_layer: 6
|
||||||
|
n_head: 8
|
||||||
|
n_emb: 256
|
||||||
|
hidden_dim: 256
|
||||||
|
mlp_ratio: 4
|
||||||
|
dropout: 0.1
|
||||||
|
|
||||||
|
# Positional embeddings
|
||||||
|
add_action_pos_emb: true
|
||||||
|
add_cond_pos_emb: true
|
||||||
|
|
||||||
|
# Supplied by agent interpolation:
|
||||||
|
# - input_dim
|
||||||
|
# - output_dim
|
||||||
|
# - horizon
|
||||||
|
# - n_obs_steps
|
||||||
|
# - cond_dim
|
||||||
|
|
||||||
29
roboimi/vla/conf/head/transformer1d.yaml
Normal file
29
roboimi/vla/conf/head/transformer1d.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Transformer-based Diffusion Policy Head
|
||||||
|
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# Transformer 架构配置
|
||||||
|
# ====================
|
||||||
|
n_layer: 4 # Transformer层数(先用小模型提高收敛稳定性)
|
||||||
|
n_head: 4 # 注意力头数
|
||||||
|
n_emb: 128 # 嵌入维度
|
||||||
|
p_drop_emb: 0.05 # Embedding dropout
|
||||||
|
p_drop_attn: 0.05 # Attention dropout
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 条件配置
|
||||||
|
# ====================
|
||||||
|
causal_attn: false # 是否使用因果注意力(自回归生成)
|
||||||
|
obs_as_cond: true # 观测作为条件(由cond_dim > 0决定)
|
||||||
|
n_cond_layers: 1 # 条件编码器层数(1层先做稳定融合)
|
||||||
|
|
||||||
|
# ====================
|
||||||
|
# 注意事项
|
||||||
|
# ====================
|
||||||
|
# 以下参数将在agent配置中通过interpolation提供:
|
||||||
|
# - input_dim: ${agent.action_dim}
|
||||||
|
# - output_dim: ${agent.action_dim}
|
||||||
|
# - horizon: ${agent.pred_horizon}
|
||||||
|
# - n_obs_steps: ${agent.obs_horizon}
|
||||||
|
# - cond_dim: 通过agent中的global_cond_dim计算
|
||||||
@@ -3,6 +3,7 @@ import h5py
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import List, Dict, Union
|
from typing import List, Dict, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class SimpleRobotDataset(Dataset):
|
class SimpleRobotDataset(Dataset):
|
||||||
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: int = 2,
|
obs_horizon: int = 2,
|
||||||
pred_horizon: int = 8,
|
pred_horizon: int = 8,
|
||||||
camera_names: List[str] = None,
|
camera_names: List[str] = None,
|
||||||
|
max_open_files: int = 64,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
|
|||||||
obs_horizon: 观察过去多少帧
|
obs_horizon: 观察过去多少帧
|
||||||
pred_horizon: 预测未来多少帧动作
|
pred_horizon: 预测未来多少帧动作
|
||||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||||
|
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
|
||||||
|
|
||||||
HDF5 文件格式:
|
HDF5 文件格式:
|
||||||
- action: [T, action_dim]
|
- action: [T, action_dim]
|
||||||
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
|
|||||||
self.obs_horizon = obs_horizon
|
self.obs_horizon = obs_horizon
|
||||||
self.pred_horizon = pred_horizon
|
self.pred_horizon = pred_horizon
|
||||||
self.camera_names = camera_names or []
|
self.camera_names = camera_names or []
|
||||||
|
self.max_open_files = max(1, int(max_open_files))
|
||||||
|
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
|
||||||
|
|
||||||
self.dataset_dir = Path(dataset_dir)
|
self.dataset_dir = Path(dataset_dir)
|
||||||
if not self.dataset_dir.exists():
|
if not self.dataset_dir.exists():
|
||||||
@@ -69,29 +74,60 @@ class SimpleRobotDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.frame_meta)
|
return len(self.frame_meta)
|
||||||
|
|
||||||
|
def _close_all_files(self) -> None:
|
||||||
|
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
|
||||||
|
for f in self._file_cache.values():
|
||||||
|
try:
|
||||||
|
f.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._file_cache.clear()
|
||||||
|
|
||||||
|
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
|
||||||
|
"""
|
||||||
|
获取 HDF5 文件句柄(worker 内 LRU 缓存)。
|
||||||
|
注意:缓存的是文件句柄,不是帧数据本身。
|
||||||
|
"""
|
||||||
|
key = str(hdf5_path)
|
||||||
|
if key in self._file_cache:
|
||||||
|
self._file_cache.move_to_end(key)
|
||||||
|
return self._file_cache[key]
|
||||||
|
|
||||||
|
# 超过上限时淘汰最久未使用的句柄
|
||||||
|
if len(self._file_cache) >= self.max_open_files:
|
||||||
|
_, old_file = self._file_cache.popitem(last=False)
|
||||||
|
try:
|
||||||
|
old_file.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
f = h5py.File(key, 'r')
|
||||||
|
self._file_cache[key] = f
|
||||||
|
return f
|
||||||
|
|
||||||
def _load_frame(self, idx: int) -> Dict:
|
def _load_frame(self, idx: int) -> Dict:
|
||||||
"""从 HDF5 文件懒加载单帧数据"""
|
"""从 HDF5 文件懒加载单帧数据"""
|
||||||
meta = self.frame_meta[idx]
|
meta = self.frame_meta[idx]
|
||||||
with h5py.File(meta["hdf5_path"], 'r') as f:
|
f = self._get_h5_file(meta["hdf5_path"])
|
||||||
frame = {
|
frame = {
|
||||||
"episode_index": meta["ep_idx"],
|
"episode_index": meta["ep_idx"],
|
||||||
"frame_index": meta["frame_idx"],
|
"frame_index": meta["frame_idx"],
|
||||||
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
||||||
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
||||||
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
h5_path = f'observations/images/{cam_name}'
|
h5_path = f'observations/images/{cam_name}'
|
||||||
if h5_path in f:
|
if h5_path in f:
|
||||||
img = f[h5_path][meta["frame_idx"]]
|
img = f[h5_path][meta["frame_idx"]]
|
||||||
# Resize图像到224x224(减少内存和I/O负担)
|
# Resize图像到224x224(减少内存和I/O负担)
|
||||||
import cv2
|
import cv2
|
||||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||||
# 转换为float并归一化到 [0, 1]
|
# 转换为float并归一化到 [0, 1]
|
||||||
img = torch.from_numpy(img).float() / 255.0
|
img = torch.from_numpy(img).float() / 255.0
|
||||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
|
|||||||
"dtype": str(sample[key].dtype),
|
"dtype": str(sample[key].dtype),
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._close_all_files()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# # Action Head models
|
# Action Head models
|
||||||
from .conditional_unet1d import ConditionalUnet1D
|
from .conditional_unet1d import ConditionalUnet1D
|
||||||
|
from .transformer1d import Transformer1D
|
||||||
|
|
||||||
__all__ = ["ConditionalUnet1D"]
|
__all__ = ["ConditionalUnet1D", "Transformer1D"]
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class ConditionalResidualBlock1D(nn.Module):
|
|||||||
class ConditionalUnet1D(nn.Module):
|
class ConditionalUnet1D(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_dim,
|
input_dim,
|
||||||
local_cond_dim=None,
|
|
||||||
global_cond_dim=None,
|
global_cond_dim=None,
|
||||||
diffusion_step_embed_dim=256,
|
diffusion_step_embed_dim=256,
|
||||||
down_dims=[256,512,1024],
|
down_dims=[256,512,1024],
|
||||||
@@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||||
|
|
||||||
local_cond_encoder = None
|
|
||||||
if local_cond_dim is not None:
|
|
||||||
_, dim_out = in_out[0]
|
|
||||||
dim_in = local_cond_dim
|
|
||||||
local_cond_encoder = nn.ModuleList([
|
|
||||||
# down encoder
|
|
||||||
ConditionalResidualBlock1D(
|
|
||||||
dim_in, dim_out, cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size, n_groups=n_groups,
|
|
||||||
cond_predict_scale=cond_predict_scale),
|
|
||||||
# up encoder
|
|
||||||
ConditionalResidualBlock1D(
|
|
||||||
dim_in, dim_out, cond_dim=cond_dim,
|
|
||||||
kernel_size=kernel_size, n_groups=n_groups,
|
|
||||||
cond_predict_scale=cond_predict_scale)
|
|
||||||
])
|
|
||||||
|
|
||||||
mid_dim = all_dims[-1]
|
mid_dim = all_dims[-1]
|
||||||
self.mid_modules = nn.ModuleList([
|
self.mid_modules = nn.ModuleList([
|
||||||
ConditionalResidualBlock1D(
|
ConditionalResidualBlock1D(
|
||||||
@@ -216,7 +198,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.diffusion_step_encoder = diffusion_step_encoder
|
self.diffusion_step_encoder = diffusion_step_encoder
|
||||||
self.local_cond_encoder = local_cond_encoder
|
|
||||||
self.up_modules = up_modules
|
self.up_modules = up_modules
|
||||||
self.down_modules = down_modules
|
self.down_modules = down_modules
|
||||||
self.final_conv = final_conv
|
self.final_conv = final_conv
|
||||||
@@ -225,12 +206,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
local_cond=None, global_cond=None,
|
global_cond=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
x: (B,T,input_dim)
|
x: (B,T,input_dim)
|
||||||
timestep: (B,) or int, diffusion step
|
timestep: (B,) or int, diffusion step
|
||||||
local_cond: (B,T,local_cond_dim)
|
|
||||||
global_cond: (B,global_cond_dim)
|
global_cond: (B,global_cond_dim)
|
||||||
output: (B,T,input_dim)
|
output: (B,T,input_dim)
|
||||||
"""
|
"""
|
||||||
@@ -253,22 +233,10 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
global_feature, global_cond
|
global_feature, global_cond
|
||||||
], axis=-1)
|
], axis=-1)
|
||||||
|
|
||||||
# encode local features
|
|
||||||
h_local = list()
|
|
||||||
if local_cond is not None:
|
|
||||||
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
|
|
||||||
resnet, resnet2 = self.local_cond_encoder
|
|
||||||
x = resnet(local_cond, global_feature)
|
|
||||||
h_local.append(x)
|
|
||||||
x = resnet2(local_cond, global_feature)
|
|
||||||
h_local.append(x)
|
|
||||||
|
|
||||||
x = sample
|
x = sample
|
||||||
h = []
|
h = []
|
||||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
if idx == 0 and len(h_local) > 0:
|
|
||||||
x = x + h_local[0]
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
@@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, global_feature)
|
x = resnet(x, global_feature)
|
||||||
# The correct condition should be:
|
|
||||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
|
||||||
# However this change will break compatibility with published checkpoints.
|
|
||||||
# Therefore it is left as a comment.
|
|
||||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
|
||||||
x = x + h_local[1]
|
|
||||||
x = resnet2(x, global_feature)
|
x = resnet2(x, global_feature)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
|||||||
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
146
roboimi/vla/models/heads/gr00t_dit1d.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Optional, Union
|
||||||
|
from pathlib import Path
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gr00t_dit():
|
||||||
|
repo_root = Path(__file__).resolve().parents[4]
|
||||||
|
dit_path = repo_root / "gr00t" / "models" / "dit.py"
|
||||||
|
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise ImportError(f"Unable to load DiT from {dit_path}")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module.DiT
|
||||||
|
|
||||||
|
|
||||||
|
DiT = _load_gr00t_dit()
|
||||||
|
|
||||||
|
|
||||||
|
class Gr00tDiT1D(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
|
||||||
|
|
||||||
|
Expected forward interface:
|
||||||
|
- sample: (B, T_action, input_dim)
|
||||||
|
- timestep: (B,) or scalar diffusion timestep
|
||||||
|
- cond: (B, T_obs, cond_dim)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int,
|
||||||
|
cond_dim: int,
|
||||||
|
n_layer: int = 8,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_emb: int = 256,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
add_action_pos_emb: bool = True,
|
||||||
|
add_cond_pos_emb: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if cond_dim <= 0:
|
||||||
|
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
|
||||||
|
|
||||||
|
self.horizon = horizon
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
|
||||||
|
self.input_proj = nn.Linear(input_dim, n_emb)
|
||||||
|
self.cond_proj = nn.Linear(cond_dim, n_emb)
|
||||||
|
self.output_proj = nn.Linear(hidden_dim, output_dim)
|
||||||
|
|
||||||
|
self.action_pos_emb = (
|
||||||
|
nn.Parameter(torch.zeros(1, horizon, n_emb))
|
||||||
|
if add_action_pos_emb
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.cond_pos_emb = (
|
||||||
|
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
|
||||||
|
if add_cond_pos_emb
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
embed_dim=n_emb,
|
||||||
|
nheads=n_head,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
dropout=dropout,
|
||||||
|
num_layers=n_layer,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
)
|
||||||
|
self.dit = DiT(args, cross_attention_dim=n_emb)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
if self.action_pos_emb is not None:
|
||||||
|
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def _normalize_timesteps(
|
||||||
|
self,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
batch_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not torch.is_tensor(timestep):
|
||||||
|
timesteps = torch.tensor([timestep], device=device)
|
||||||
|
else:
|
||||||
|
timesteps = timestep.to(device)
|
||||||
|
|
||||||
|
if timesteps.ndim == 0:
|
||||||
|
timesteps = timesteps[None]
|
||||||
|
if timesteps.shape[0] != batch_size:
|
||||||
|
timesteps = timesteps.expand(batch_size)
|
||||||
|
|
||||||
|
return timesteps.long()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if cond is None:
|
||||||
|
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
|
||||||
|
|
||||||
|
bsz, t_act, _ = sample.shape
|
||||||
|
if t_act > self.horizon:
|
||||||
|
raise ValueError(
|
||||||
|
f"sample length {t_act} exceeds configured horizon {self.horizon}"
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.input_proj(sample)
|
||||||
|
if self.action_pos_emb is not None:
|
||||||
|
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
|
||||||
|
|
||||||
|
encoder_hidden_states = self.cond_proj(cond)
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
t_obs = encoder_hidden_states.shape[1]
|
||||||
|
if t_obs > self.n_obs_steps:
|
||||||
|
raise ValueError(
|
||||||
|
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
|
||||||
|
)
|
||||||
|
encoder_hidden_states = (
|
||||||
|
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps = self._normalize_timesteps(
|
||||||
|
timestep, batch_size=bsz, device=sample.device
|
||||||
|
)
|
||||||
|
dit_output = self.dit(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
timestep=timesteps,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
return self.output_proj(dit_output)
|
||||||
396
roboimi/vla/models/heads/transformer1d.py
Normal file
396
roboimi/vla/models/heads/transformer1d.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
Transformer-based Diffusion Policy Head
|
||||||
|
|
||||||
|
使用Transformer架构(Encoder-Decoder)替代UNet进行噪声预测。
|
||||||
|
支持通过Cross-Attention注入全局条件(观测特征)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(nn.Module):
|
||||||
|
"""正弦位置编码(用于时间步嵌入)"""
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
|
emb = x[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer1D(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer-based 1D Diffusion Model
|
||||||
|
|
||||||
|
使用Encoder-Decoder架构:
|
||||||
|
- Encoder: 处理条件(观测 + 时间步)
|
||||||
|
- Decoder: 通过Cross-Attention预测噪声
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: 输入动作维度
|
||||||
|
output_dim: 输出动作维度
|
||||||
|
horizon: 预测horizon长度
|
||||||
|
n_obs_steps: 观测步数
|
||||||
|
cond_dim: 条件维度
|
||||||
|
n_layer: Transformer层数
|
||||||
|
n_head: 注意力头数
|
||||||
|
n_emb: 嵌入维度
|
||||||
|
p_drop_emb: Embedding dropout
|
||||||
|
p_drop_attn: Attention dropout
|
||||||
|
causal_attn: 是否使用因果注意力(自回归)
|
||||||
|
n_cond_layers: Encoder层数(0表示使用MLP)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 8,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_emb: int = 256,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# 计算序列长度
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
T = horizon
|
||||||
|
T_cond = 1 # 时间步token数量
|
||||||
|
|
||||||
|
# 确定是否使用观测作为条件
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
if obs_as_cond:
|
||||||
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# ==================== 输入嵌入 ====================
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
|
# ==================== 条件编码 ====================
|
||||||
|
# 时间步嵌入
|
||||||
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
|
||||||
|
# 观测条件嵌入(可选)
|
||||||
|
self.cond_obs_emb = None
|
||||||
|
if obs_as_cond:
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||||
|
|
||||||
|
# 条件位置编码
|
||||||
|
self.cond_pos_emb = None
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
|
||||||
|
# ==================== Encoder ====================
|
||||||
|
self.encoder = None
|
||||||
|
self.encoder_only = False
|
||||||
|
|
||||||
|
if T_cond > 0:
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
# 使用Transformer Encoder
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True # Pre-LN更稳定
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用简单的MLP
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Encoder-only模式(BERT风格)
|
||||||
|
self.encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Attention Mask ====================
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
if causal_attn:
|
||||||
|
# 因果mask:确保只关注左侧
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer("mask", mask)
|
||||||
|
|
||||||
|
if obs_as_cond:
|
||||||
|
# 交叉注意力mask
|
||||||
|
S = T_cond
|
||||||
|
t, s = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(S),
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
mask = t >= (s - 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
|
||||||
|
# ==================== Decoder ====================
|
||||||
|
if not self.encoder_only:
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 输出头 ====================
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
# ==================== 初始化 ====================
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
# 打印参数量
|
||||||
|
total_params = sum(p.numel() for p in self.parameters())
|
||||||
|
print(f"Transformer1D parameters: {total_params:,}")
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""初始化权重"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
# MultiheadAttention的权重初始化
|
||||||
|
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
|
||||||
|
weight = getattr(module, name, None)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
|
||||||
|
bias = getattr(module, name, None)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, Transformer1D):
|
||||||
|
# 位置编码初始化
|
||||||
|
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
前向传播
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: (B, T, input_dim) 输入序列(加噪动作)
|
||||||
|
timestep: (B,) 时间步
|
||||||
|
cond: (B, T', cond_dim) 条件序列(观测特征)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, T, output_dim) 预测的噪声
|
||||||
|
"""
|
||||||
|
# ==================== 处理时间步 ====================
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||||
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# 扩展到batch维度
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
|
||||||
|
|
||||||
|
# ==================== 处理输入 ====================
|
||||||
|
input_emb = self.input_emb(sample) # (B, T, n_emb)
|
||||||
|
|
||||||
|
# ==================== Encoder-Decoder模式 ====================
|
||||||
|
if not self.encoder_only:
|
||||||
|
# --- Encoder: 处理条件 ---
|
||||||
|
cond_embeddings = time_emb
|
||||||
|
|
||||||
|
if self.obs_as_cond and cond is not None:
|
||||||
|
# 添加观测条件
|
||||||
|
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
|
||||||
|
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||||
|
|
||||||
|
# 添加位置编码
|
||||||
|
tc = cond_embeddings.shape[1]
|
||||||
|
pos_emb = self.cond_pos_emb[:, :tc, :]
|
||||||
|
x = self.drop(cond_embeddings + pos_emb)
|
||||||
|
|
||||||
|
# 通过encoder
|
||||||
|
memory = self.encoder(x) # (B, T_cond, n_emb)
|
||||||
|
|
||||||
|
# --- Decoder: 预测噪声 ---
|
||||||
|
# 添加位置编码到输入
|
||||||
|
token_embeddings = input_emb
|
||||||
|
t = token_embeddings.shape[1]
|
||||||
|
pos_emb = self.pos_emb[:, :t, :]
|
||||||
|
x = self.drop(token_embeddings + pos_emb)
|
||||||
|
|
||||||
|
# Cross-Attention: Query来自输入,Key/Value来自memory
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Encoder-Only模式 ====================
|
||||||
|
else:
|
||||||
|
# BERT风格:时间步作为特殊token
|
||||||
|
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
|
||||||
|
t = token_embeddings.shape[1]
|
||||||
|
pos_emb = self.pos_emb[:, :t, :]
|
||||||
|
x = self.drop(token_embeddings + pos_emb)
|
||||||
|
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 1:, :] # 移除时间步token
|
||||||
|
|
||||||
|
# ==================== 输出头 ====================
|
||||||
|
x = self.ln_f(x)
|
||||||
|
x = self.head(x) # (B, T, output_dim)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 便捷函数:创建Transformer1D模型
|
||||||
|
# ============================================================================
|
||||||
|
def create_transformer1d(
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: int,
|
||||||
|
cond_dim: int,
|
||||||
|
n_layer: int = 8,
|
||||||
|
n_head: int = 8,
|
||||||
|
n_emb: int = 256,
|
||||||
|
**kwargs
|
||||||
|
) -> Transformer1D:
|
||||||
|
"""
|
||||||
|
创建Transformer1D模型的便捷函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: 输入动作维度
|
||||||
|
output_dim: 输出动作维度
|
||||||
|
horizon: 预测horizon
|
||||||
|
n_obs_steps: 观测步数
|
||||||
|
cond_dim: 条件维度
|
||||||
|
n_layer: Transformer层数
|
||||||
|
n_head: 注意力头数
|
||||||
|
n_emb: 嵌入维度
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformer1D模型
|
||||||
|
"""
|
||||||
|
model = Transformer1D(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
horizon=horizon,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 80)
|
||||||
|
print("Testing Transformer1D")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
B = 4
|
||||||
|
T = 16
|
||||||
|
action_dim = 16
|
||||||
|
obs_horizon = 2
|
||||||
|
cond_dim = 416 # vision + state特征维度
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = Transformer1D(
|
||||||
|
input_dim=action_dim,
|
||||||
|
output_dim=action_dim,
|
||||||
|
horizon=T,
|
||||||
|
n_obs_steps=obs_horizon,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=4,
|
||||||
|
n_head=8,
|
||||||
|
n_emb=256,
|
||||||
|
causal_attn=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试前向传播
|
||||||
|
sample = torch.randn(B, T, action_dim)
|
||||||
|
timestep = torch.randint(0, 100, (B,))
|
||||||
|
cond = torch.randn(B, obs_horizon, cond_dim)
|
||||||
|
|
||||||
|
output = model(sample, timestep, cond)
|
||||||
|
|
||||||
|
print(f"\n输入:")
|
||||||
|
print(f" sample: {sample.shape}")
|
||||||
|
print(f" timestep: {timestep.shape}")
|
||||||
|
print(f" cond: {cond.shape}")
|
||||||
|
print(f"\n输出:")
|
||||||
|
print(f" output: {output.shape}")
|
||||||
|
print(f"\n✅ 测试通过!")
|
||||||
Reference in New Issue
Block a user