10 Commits

Author SHA1 Message Date
gouhanke
23088e5e33 feat: 架构引入DiT 2026-03-06 11:31:37 +08:00
gouhanke
ca1716c67f chore: 导入gr00t 2026-03-06 11:31:37 +08:00
JiajunLI
642d41dd8f feat: 添加resume机制 2026-03-06 11:19:30 +08:00
gouhanke
7d39933a5b feat: 缓存worker内的句柄 2026-03-04 10:49:41 +08:00
gouhanke
8bcad5844e fix: 修复VLA设备与损失计算逻辑,并优化Transformer默认训练参数 2026-03-03 17:56:12 +08:00
gouhanke
cdb887c9bf feat: 添加transformer头 2026-02-28 19:07:27 +08:00
gouhanke
abb4f501e3 chore: 删除unet里的local_cond(未使用) 2026-02-28 10:42:16 +08:00
gouhanke
1d33db0ef0 chore: 缩小物块的大小 2026-02-27 18:23:30 +08:00
gouhanke
f27e397f98 chore: 修改了采数时的一些参数 2026-02-26 17:09:40 +08:00
gouhanke
4e0add4e1d debug: 修复episode首帧图像不正确的问题;修复前2个episode帧重复的问题 2026-02-26 16:17:54 +08:00
28 changed files with 2106 additions and 139 deletions

125
gr00t/main.py Normal file
View File

@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
GR00T (diffusion-based DiT policy) model builder.
This module provides functions to build GR00T models and optimizers
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_gr00t_model
def get_args_parser():
"""
Create argument parser for GR00T model configuration.
All parameters can be overridden via args_override dictionary in
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
# Training parameters
parser.add_argument('--lr', default=1e-5, type=float,
help='Learning rate for main parameters')
parser.add_argument('--lr_backbone', default=1e-5, type=float,
help='Learning rate for backbone parameters')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='Weight decay for optimizer')
# GR00T model architecture parameters
parser.add_argument('--embed_dim', default=1536, type=int,
help='Embedding dimension for transformer')
parser.add_argument('--hidden_dim', default=1024, type=int,
help='Hidden dimension for MLP layers')
parser.add_argument('--state_dim', default=16, type=int,
help='State (qpos) dimension')
parser.add_argument('--action_dim', default=16, type=int,
help='Action dimension')
parser.add_argument('--num_queries', default=16, type=int,
help='Number of action queries (chunk size)')
# DiT (Diffusion Transformer) parameters
parser.add_argument('--num_layers', default=16, type=int,
help='Number of transformer layers')
parser.add_argument('--nheads', default=32, type=int,
help='Number of attention heads')
parser.add_argument('--mlp_ratio', default=4, type=float,
help='MLP hidden dimension ratio')
parser.add_argument('--dropout', default=0.2, type=float,
help='Dropout rate')
# Backbone parameters
parser.add_argument('--backbone', default='dino_v2', type=str,
help='Backbone architecture (dino_v2, resnet18, resnet34)')
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help='Type of positional encoding')
# Camera configuration
parser.add_argument('--camera_names', default=[], nargs='+',
help='List of camera names for observations')
# Other parameters (not directly used but kept for compatibility)
parser.add_argument('--batch_size', default=15, type=int)
parser.add_argument('--epochs', default=20000, type=int)
parser.add_argument('--masks', action='store_true',
help='Use intermediate layer features')
parser.add_argument('--dilation', action='store_false',
help='Use dilated convolution in backbone')
return parser
def build_gr00t_model_and_optimizer(args_override):
"""
Build GR00T model and optimizer from config dictionary.
This function is designed to work with config.yaml loading:
1. Parse default arguments
2. Override with values from args_override (typically from config['gr00t'])
3. Build model and optimizer
Args:
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
num_queries, nheads, mlp_ratio, dropout, num_layers,
lr, lr_backbone, camera_names, backbone, etc.
Returns:
model: GR00T model on CUDA
optimizer: AdamW optimizer with separate learning rates for backbone and other params
"""
parser = argparse.ArgumentParser('GR00T training and evaluation script',
parents=[get_args_parser()])
args = parser.parse_args()
# Override with config values
for k, v in args_override.items():
setattr(args, k, v)
# Build model
model = build_gr00t_model(args)
model.cuda()
# Create parameter groups with different learning rates
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts,
lr=args.lr,
weight_decay=args.weight_decay)
return model, optimizer

3
gr00t/models/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .gr00t import build_gr00t_model
__all__ = ['build_gr00t_model']

168
gr00t/models/backbone.py Normal file
View File

@@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
# class DINOv2BackBone(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# self.body.eval()
# self.num_channels = 384
# @torch.no_grad()
# def forward(self, tensor):
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
# od = OrderedDict()
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
# return od
class DINOv2BackBone(nn.Module):
def __init__(self, return_interm_layers: bool = False) -> None:
super().__init__()
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.body.eval()
self.num_channels = 384
self.return_interm_layers = return_interm_layers
@torch.no_grad()
def forward(self, tensor):
features = self.body.forward_features(tensor)
if self.return_interm_layers:
layer1 = features["x_norm_patchtokens"]
layer2 = features["x_norm_patchtokens"]
layer3 = features["x_norm_patchtokens"]
layer4 = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
else:
xs = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
if args.backbone == 'dino_v2':
backbone = DINOv2BackBone()
else:
assert args.backbone in ['resnet18', 'resnet34']
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

142
gr00t/models/dit.py Normal file
View File

@@ -0,0 +1,142 @@
from typing import Optional
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
import torch
from torch import nn
import torch.nn.functional as F
class TimestepEncoder(nn.Module):
def __init__(self, args):
super().__init__()
embedding_dim = args.embed_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
super().__init__()
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
super().__init__()
dim = args.embed_dim
num_heads = args.nheads
mlp_ratio = args.mlp_ratio
dropout = args.dropout
self.norm1 = AdaLayerNorm(dim)
if not use_self_attn:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
kdim=crosss_attention_dim,
vdim=crosss_attention_dim,
batch_first=True,
)
else:
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mlp_ratio, dim),
nn.Dropout(dropout)
)
def forward(self, hidden_states, temb, context=None):
norm_hidden_states = self.norm1(hidden_states, temb)
attn_output = self.attn(
norm_hidden_states,
context if context is not None else norm_hidden_states,
context if context is not None else norm_hidden_states,
)[0]
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class DiT(nn.Module):
def __init__(self, args, cross_attention_dim):
super().__init__()
inner_dim = args.embed_dim
num_layers = args.num_layers
output_dim = args.hidden_dim
self.timestep_encoder = TimestepEncoder(args)
all_blocks = []
for idx in range(num_layers):
use_self_attn = idx % 2 == 1
if use_self_attn:
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
else:
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
all_blocks.append(block)
self.transformer_blocks = nn.ModuleList(all_blocks)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
def forward(self, hidden_states, timestep, encoder_hidden_states):
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(hidden_states, temb)
else:
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(hidden_states)
def build_dit(args, cross_attention_dim):
return DiT(args, cross_attention_dim)

124
gr00t/models/gr00t.py Normal file
View File

@@ -0,0 +1,124 @@
from .modules import (
build_action_decoder,
build_action_encoder,
build_state_encoder,
build_time_sampler,
build_noise_scheduler,
)
from .backbone import build_backbone
from .dit import build_dit
import torch
import torch.nn as nn
import torch.nn.functional as F
class gr00t(nn.Module):
def __init__(
self,
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
num_queries,
camera_names,
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.dit = dit
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.action_decoder = action_decoder
self.time_sampler = time_sampler
self.noise_scheduler = noise_scheduler
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
else:
raise NotImplementedError
def forward(self, qpos, image, actions=None, is_pad=None):
is_training = actions is not None # train or val
bs, _ = qpos.shape
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
B, C, H, W = features.shape
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
all_cam_features.append(features_seq)
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
if is_training:
# training logic
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
actions, timesteps
)
t_discretized = (timesteps[:, 0, 0] * 1000).long()
action_features = self.action_encoder(noisy_actions, t_discretized)
sa_embs = torch.cat((state_features, action_features), dim=1)
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_actions = pred[:, -actions.shape[1] :]
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
return pred_actions, action_loss
else:
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
k = 5
dt = 1.0 / k
for t in range(k):
t_cont = t / float(k)
t_discretized = int(t_cont * 1000)
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
action_features = self.action_encoder(actions, timesteps)
sa_embs = torch.cat((state_features, action_features), dim=1)
# Create tensor of shape [B] for DiT (consistent with training path)
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
pred = self.action_decoder(model_output)
pred_velocity = pred[:, -self.num_queries :]
actions = actions + pred_velocity * dt
return actions, _
def build_gr00t_model(args):
state_dim = args.state_dim
action_dim = args.action_dim
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
cross_attention_dim = backbones[0].num_channels
dit = build_dit(args, cross_attention_dim)
state_encoder = build_state_encoder(args)
action_encoder = build_action_encoder(args)
action_decoder = build_action_decoder(args)
time_sampler = build_time_sampler(args)
noise_scheduler = build_noise_scheduler(args)
model = gr00t(
backbones,
dit,
state_encoder,
action_encoder,
action_decoder,
time_sampler,
noise_scheduler,
args.num_queries,
args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model

179
gr00t/models/modules.py Normal file
View File

@@ -0,0 +1,179 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# ActionEncoder
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.embed_dim
def forward(self, timesteps):
timesteps = timesteps.float()
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embed_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class ActionEncoder(nn.Module):
def __init__(self, args):
super().__init__()
action_dim = args.action_dim
embed_dim = args.embed_dim
self.W1 = nn.Linear(action_dim, embed_dim)
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
self.W3 = nn.Linear(embed_dim, embed_dim)
self.pos_encoder = SinusoidalPositionalEncoding(args)
def forward(self, actions, timesteps):
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
# Reshape to (B,) then expand to (B, T)
# if timesteps.dim() == 3:
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
# timesteps = timesteps[:, 0, 0]
# elif timesteps.dim() == 2:
# # Shape (B, 1) or (B, T) -> take first element if needed
# if timesteps.shape[1] == 1:
# timesteps = timesteps[:, 0]
# # else: already (B, T), use as is
# elif timesteps.dim() != 1:
# raise ValueError(
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
# )
# Now timesteps should be (B,), expand to (B, T)
if timesteps.dim() == 1 and timesteps.shape[0] == B:
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = F.silu(self.W2(x))
# 5) Finally W3 => (B, T, w)
x = self.W3(x)
return x
def build_action_encoder(args):
return ActionEncoder(args)
# StateEncoder
class StateEncoder(nn.Module):
def __init__(self, args):
super().__init__()
input_dim = args.state_dim
hidden_dim = args.hidden_dim
output_dim = args.embed_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, states):
state_emb = self.mlp(states) # [B, emb_dim]
state_emb = state_emb.unsqueeze(1)
return state_emb # [B, 1, emb_dim]
def build_state_encoder(args):
return StateEncoder(args)
# ActionDecoder
class ActionDecoder(nn.Module):
def __init__(self,args):
super().__init__()
input_dim = args.hidden_dim
hidden_dim = args.hidden_dim
output_dim = args.action_dim
self.num_queries = args.num_queries
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, model_output):
pred_actions = self.mlp(model_output)
return pred_actions[:, -self.num_queries:]
def build_action_decoder(args):
return ActionDecoder(args)
# TimeSampler
class TimeSampler(nn.Module):
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
super().__init__()
self.noise_s = noise_s
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
def forward(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
sample = (1 - sample) * self.noise_s
return sample[:, None, None]
def build_time_sampler(args):
return TimeSampler()
# NoiseScheduler
import torch
import torch.nn as nn
class FlowMatchingScheduler(nn.Module):
def __init__(self):
super().__init__()
# --- 训练逻辑:加噪并计算目标 ---
def add_noise(self, actions, timesteps):
noise = torch.randn_like(actions)
noisy_samples = actions * timesteps + noise * (1 - timesteps)
target_velocity = actions - noise
return noisy_samples, target_velocity
# --- 推理逻辑:欧拉步 (Euler Step) ---
def step(self, model_output, sample, dt):
prev_sample = sample + model_output * dt
return prev_sample
def build_noise_scheduler(args):
return FlowMatchingScheduler()

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

90
gr00t/policy.py Normal file
View File

@@ -0,0 +1,90 @@
"""
GR00T Policy wrapper for imitation learning.
This module provides the gr00tPolicy class that wraps the GR00T model
for training and evaluation in the imitation learning framework.
"""
import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import v2
import torch
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
class gr00tPolicy(nn.Module):
"""
GR00T Policy for action prediction using diffusion-based DiT architecture.
This policy wraps the GR00T model and handles:
- Image resizing to match DINOv2 patch size requirements
- Image normalization (ImageNet stats)
- Training with action chunks and loss computation
- Inference with diffusion sampling
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_gr00t_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
# DINOv2 requires image dimensions to be multiples of patch size (14)
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
self.patch_h = 16 # Number of patches vertically
self.patch_w = 22 # Number of patches horizontally
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
# Training transform with data augmentation
self.train_transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Inference transform (no augmentation)
self.inference_transform = v2.Compose([
v2.Resize(target_size),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
Forward pass for training or inference.
Args:
qpos: Joint positions [B, state_dim]
image: Camera images [B, num_cameras, C, H, W]
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
is_pad: Padding mask [B, chunk_size] (training only)
Returns:
Training: dict with 'mse' loss
Inference: predicted actions [B, num_queries, action_dim]
"""
# Apply transforms (resize + normalization)
if actions is not None: # training time
image = self.train_transform(image)
else: # inference time
image = self.inference_transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
_, action_loss = self.model(qpos, image, actions, is_pad)
# Mask out padded positions
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {
'loss': mse_loss
}
return loss_dict
else: # inference time
a_hat, _ = self.model(qpos, image)
return a_hat
def configure_optimizers(self):
"""Return the optimizer for training."""
return self.optimizer

View File

@@ -3,7 +3,7 @@
<body name="box" pos="0.2 1.0 0.47">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
<geom contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.018 0.018 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>
</worldbody>
</mujoco>

View File

@@ -8,6 +8,6 @@
</body>
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
<camera name="front" pos="0 0 0.9" fovy="55" mode="fixed" quat="0.7071 0.7071 0 0"/>
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
</worldbody>
</mujoco>

View File

@@ -104,8 +104,8 @@ class TestPickAndTransferPolicy(PolicyBase):
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100}, # sleep
{"t": 75, "xyz": np.array([(0.8+box_xyz[0])*0.5,(1.0+box_xyz[1])*0.5,init_mocap_pose_right[2]]), "quat": gripper_approach_quat.elements, "gripper": 100},
{"t": 225, "xyz": box_xyz + np.array([0, 0, 0.3]), "quat": gripper_pick_quat.elements, "gripper": 100}, # approach the cube
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.12]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
{"t": 275, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": 100}, # go down
{"t": 280, "xyz": box_xyz + np.array([0, 0, 0.11]), "quat": gripper_pick_quat.elements, "gripper": -100}, # close gripper
{"t": 450, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": -100},# approach wait position
{"t": 500, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": -100},# approach meet position
{"t": 510, "xyz": meet_xyz + np.array([0.1, 0, 0.0]), "quat": meet_right_quat.elements, "gripper": 100}, # open gripper
@@ -116,8 +116,8 @@ class TestPickAndTransferPolicy(PolicyBase):
self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": -100},# sleep
{"t": 250, "xyz": meet_xyz + np.array([-0.5, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # approach meet position
{"t": 500, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
{"t": 505, "xyz": meet_xyz + np.array([-0.15, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
{"t": 500, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": 100}, # move to meet position
{"t": 505, "xyz": meet_xyz + np.array([-0.14, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # close gripper
{"t": 675, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # move left
{"t": 700, "xyz": meet_xyz + np.array([-0.3, 0, 0.0]), "quat": meet_left_quat.elements, "gripper": -100}, # stay
]

View File

@@ -32,6 +32,12 @@ def main():
env = make_sim_env(task_name)
policy = TestPickAndTransferPolicy(inject_noise)
# 等待osmesa完全启动后再开始收集数据
print("等待osmesa线程启动...")
time.sleep(60)
print("osmesa已就绪开始收集数据...")
for episode_idx in range(num_episodes):
obs = []
reward_ee = []

View File

@@ -5,6 +5,7 @@ import json
import pickle
import hydra
import torch
import re
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split
@@ -44,6 +45,35 @@ def recursive_to_device(data, device):
return data
def resolve_resume_checkpoint(resume_ckpt, checkpoint_dir):
"""
解析恢复训练用的 checkpoint 路径。
Args:
resume_ckpt: 配置中的 resume_ckpt支持路径或 "auto"
checkpoint_dir: 默认检查点目录
Returns:
Path 或 None
"""
if resume_ckpt is None:
return None
if str(resume_ckpt).lower() != "auto":
return Path(resume_ckpt)
pattern = re.compile(r"vla_model_step_(\d+)\.pt$")
candidates = []
for ckpt_path in checkpoint_dir.glob("vla_model_step_*.pt"):
match = pattern.search(ckpt_path.name)
if match:
candidates.append((int(match.group(1)), ckpt_path))
if not candidates:
return None
return max(candidates, key=lambda x: x[0])[1]
def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_type='cosine', min_lr=0):
"""
创建带预热的学习率调度器。
@@ -139,6 +169,7 @@ def main(cfg: DictConfig):
shuffle=True,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=True # 丢弃不完整批次以稳定训练
)
@@ -150,6 +181,7 @@ def main(cfg: DictConfig):
shuffle=False,
num_workers=cfg.train.num_workers,
pin_memory=(cfg.train.device != "cpu"),
persistent_workers=(cfg.train.num_workers > 0),
drop_last=False
)
@@ -248,8 +280,11 @@ def main(cfg: DictConfig):
# =========================================================================
# 4. 设置优化器与学习率调度器
# =========================================================================
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=1e-5)
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr})")
weight_decay = float(cfg.train.get('weight_decay', 1e-5))
grad_clip = float(cfg.train.get('grad_clip', 1.0))
optimizer = AdamW(agent.parameters(), lr=cfg.train.lr, weight_decay=weight_decay)
log.info(f"🔧 优化器: AdamW (学习率={cfg.train.lr}, weight_decay={weight_decay})")
# 设置带预热的学習率调度器
warmup_steps = int(cfg.train.get('warmup_steps', 500))
@@ -265,6 +300,52 @@ def main(cfg: DictConfig):
)
log.info(f"📈 学习率调度器: {scheduler_type}{warmup_steps} 步预热 (最小学习率={min_lr})")
# =========================================================================
# 4.1 断点续训(恢复模型、优化器、调度器、步数)
# =========================================================================
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
resume_ckpt = cfg.train.get('resume_ckpt', None)
resume_path = resolve_resume_checkpoint(resume_ckpt, checkpoint_dir)
if resume_ckpt is not None:
if pretrained_ckpt is not None:
log.warning("⚠️ [Resume] 同时设置了 pretrained_ckpt 与 resume_ckpt将优先使用 resume_ckpt 进行断点续训")
if resume_path is None:
log.warning("⚠️ [Resume] 未找到可恢复的 checkpoint将从头开始训练")
elif not resume_path.exists():
log.error(f"❌ [Resume] Checkpoint 文件不存在: {resume_path}")
log.warning("⚠️ 将从头开始训练")
else:
log.info(f"🔄 [Resume] 从 checkpoint 恢复训练: {resume_path}")
try:
checkpoint = torch.load(resume_path, map_location=cfg.train.device)
agent.load_state_dict(checkpoint['model_state_dict'], strict=True)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
resume_step = int(checkpoint['step'])
start_step = resume_step + 1
loaded_loss = checkpoint.get('loss', None)
loaded_val_loss = checkpoint.get('val_loss', None)
resume_loss = float(loaded_loss) if loaded_loss is not None else None
if loaded_val_loss is not None:
resume_best_loss = float(loaded_val_loss)
elif loaded_loss is not None:
resume_best_loss = float(loaded_loss)
log.info(f"✅ [Resume] 恢复成功: 上次步骤={resume_step}, 本次从步骤 {start_step} 开始")
log.info(f"📈 [Resume] 当前学习率: {optimizer.param_groups[0]['lr']:.2e}")
except Exception as e:
log.error(f"❌ [Resume] 恢复失败: {e}")
log.warning("⚠️ 将从头开始训练")
start_step = 0
resume_loss = None
resume_best_loss = float('inf')
# =========================================================================
# 5. 训练循环
# =========================================================================
@@ -311,9 +392,15 @@ def main(cfg: DictConfig):
return total_loss / max(num_batches, 1)
data_iter = iter(train_loader)
pbar = tqdm(range(cfg.train.max_steps), desc="训练中", ncols=100)
pbar = tqdm(range(start_step, cfg.train.max_steps), desc="训练中", ncols=100)
best_loss = float('inf')
best_loss = resume_best_loss
last_loss = resume_loss
if start_step >= cfg.train.max_steps:
log.warning(
f"⚠️ [Resume] start_step={start_step} 已达到/超过 max_steps={cfg.train.max_steps},跳过训练循环"
)
for step in pbar:
try:
@@ -346,6 +433,8 @@ def main(cfg: DictConfig):
log.error(f"❌ 步骤 {step} 前向传播失败: {e}")
raise
last_loss = loss.item()
# =====================================================================
# 反向传播与优化
# =====================================================================
@@ -353,7 +442,7 @@ def main(cfg: DictConfig):
loss.backward()
# 梯度裁剪以稳定训练
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=grad_clip)
optimizer.step()
scheduler.step()
@@ -422,15 +511,21 @@ def main(cfg: DictConfig):
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'loss': last_loss,
'dataset_stats': agent_stats, # 保存agent的统计信息
'current_lr': optimizer.param_groups[0]['lr'],
}, final_model_path)
log.info(f"💾 最终模型已保存: {final_model_path}")
log.info("✅ 训练成功完成!")
log.info(f"📊 最终损失: {loss.item():.4f}")
if last_loss is not None:
log.info(f"📊 最终损失: {last_loss:.4f}")
else:
log.info("📊 最终损失: N/A未执行训练步")
if best_loss != float('inf'):
log.info(f"📊 最佳损失: {best_loss:.4f}")
else:
log.info("📊 最佳损失: N/A无有效验证/训练损失)")
if __name__ == "__main__":

View File

@@ -230,6 +230,7 @@ class DualDianaMed(MujocoEnv):
img_renderer.update_scene(self.mj_data,camera="front")
self.front = img_renderer.render()
self.front = self.front[:, :, ::-1]
if self.cam_view is not None:
cv2.imshow('Cam view', self.cam_view)
cv2.waitKey(1)

View File

@@ -72,6 +72,10 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
self.mj_data.joint('red_box_joint').qpos[5] = 0.0
self.mj_data.joint('red_box_joint').qpos[6] = 0.0
super().reset()
self.top = None
self.angle = None
self.r_vis = None
self.front = None
self.cam_flage = True
t=0
while self.cam_flage:

View File

@@ -27,6 +27,7 @@ class VLAAgent(nn.Module):
dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=8, # 每次推理实际执行多少步动作
head_type='unet', # Policy head类型: 'unet' 或 'transformer'
):
super().__init__()
# 保存参数
@@ -37,6 +38,7 @@ class VLAAgent(nn.Module):
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.head_type = head_type # 'unet' 或 'transformer'
# 归一化模块 - 统一训练和推理的归一化逻辑
@@ -47,10 +49,15 @@ class VLAAgent(nn.Module):
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
# global_cond_dim: 展平后的总维度用于UNet
total_vision_dim = single_cam_feat_dim * num_cams * obs_horizon
total_prop_dim = obs_dim * obs_horizon
self.global_cond_dim = total_vision_dim + total_prop_dim
# per_step_cond_dim: 每步的条件维度用于Transformer
# 注意这里不乘以obs_horizon因为Transformer的输入是序列形式
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule='squaredcos_cap_v2', # 机器人任务常用的 schedule
@@ -66,9 +73,25 @@ class VLAAgent(nn.Module):
prediction_type='epsilon'
)
# 根据head类型初始化不同的参数
if head_type == 'transformer':
# 如果head已经是nn.Module实例直接使用否则需要初始化
if isinstance(head, nn.Module):
# 已经是实例化的模块测试时直接传入<E4BCA0><E585A5>
self.noise_pred_net = head
else:
# Hydra部分初始化的对象调用时传入参数
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim # 每步的条件维度
)
else: # 'unet' (default)
# UNet接口: input_dim, global_cond_dim
self.noise_pred_net = head(
input_dim=action_dim,
# input_dim = action_dim + obs_dim, # 备选:包含观测维度
global_cond_dim=self.global_cond_dim
)
@@ -78,6 +101,22 @@ class VLAAgent(nn.Module):
# 初始化队列(用于在线推理)
self.reset()
def _get_model_device(self) -> torch.device:
"""获取模型当前所在设备。"""
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
"""递归地将张量数据移动到指定设备。"""
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
# ==========================
# 训练阶段 (Training)
@@ -124,8 +163,17 @@ class VLAAgent(nn.Module):
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1)
# 5. 网络预测噪声
# 5. 网络预测噪声根据head类型选择接口
if self.head_type == 'transformer':
# Transformer需要序列格式的条件: (B, obs_horizon, cond_dim_per_step)
# 将展平的global_cond reshape回序列格式
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond
)
else: # 'unet'
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
@@ -138,8 +186,9 @@ class VLAAgent(nn.Module):
# 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) # 1.0表示有效数据
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
@@ -230,33 +279,10 @@ class VLAAgent(nn.Module):
Returns:
action: (action_dim,) 单个动作
"""
# 检测设备并确保所有组件在同一设备
# 尝试从观测中获取设备
device = None
for v in observation.values():
if isinstance(v, torch.Tensor):
device = v.device
break
if device is not None and self.normalization.enabled:
# 确保归一化参数在同一设备上
# 根据归一化类型获取正确的属性
if self.normalization.normalization_type == 'gaussian':
norm_device = self.normalization.qpos_mean.device
else: # min_max
norm_device = self.normalization.qpos_min.device
if device != norm_device:
self.normalization.to(device)
# 同时确保其他模块也在正确设备
self.vision_encoder.to(device)
self.state_encoder.to(device)
self.action_encoder.to(device)
self.noise_pred_net.to(device)
# 将所有 observation 移到正确设备
observation = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()}
# 使用模型当前设备作为唯一真值,将输入移动到模型设备
# 避免根据CPU观测把模型错误搬回CPU。
device = self._get_model_device()
observation = self._move_to_device(observation, device)
# 将新观测添加到队列
self._populate_queues(observation)
@@ -323,6 +349,16 @@ class VLAAgent(nn.Module):
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(proprioception)
# 拼接条件(只计算一次)
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond_flat = global_cond.flatten(start_dim=1)
if self.head_type == 'transformer':
cond = global_cond.reshape(B, self.obs_horizon, self.per_step_cond_dim)
else:
cond = None
# 2. 初始化纯高斯噪声动作
# 形状: (B, pred_horizon, action_dim)
device = visual_features.device
@@ -336,18 +372,18 @@ class VLAAgent(nn.Module):
for t in self.infer_scheduler.timesteps:
model_input = current_actions
# 拼接全局条件并展平
# visual_features: (B, obs_horizon, vision_dim)
# state_features: (B, obs_horizon, obs_dim)
# 拼接后展平为 (B, obs_horizon * (vision_dim + obs_dim))
global_cond = torch.cat([visual_features, state_features], dim=-1)
global_cond = global_cond.flatten(start_dim=1)
# 预测噪声
# 预测噪声根据head类型选择接口
if self.head_type == 'transformer':
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond
cond=cond
)
else: # 'unet'
noise_pred = self.noise_pred_net(
sample=model_input,
timestep=t,
global_cond=global_cond_flat
)
# 移除噪声,更新 current_actions

View File

@@ -0,0 +1,217 @@
import torch
import torch.nn as nn
from collections import deque
from typing import Dict
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from roboimi.vla.models.normalization import NormalizationModule
class VLAAgentGr00tDiT(nn.Module):
"""
VLA Agent variant that swaps Transformer1D head with gr00t DiT head.
Other components (backbone/encoders/scheduler/queue logic) stay aligned
with the existing VLAAgent implementation.
"""
def __init__(
self,
vision_backbone,
state_encoder,
action_encoder,
head,
action_dim,
obs_dim,
pred_horizon=16,
obs_horizon=4,
diffusion_steps=100,
inference_steps=10,
num_cams=3,
dataset_stats=None,
normalization_type="min_max",
num_action_steps=8,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.num_cams = num_cams
self.num_action_steps = num_action_steps
self.inference_steps = inference_steps
self.normalization = NormalizationModule(
stats=dataset_stats,
normalization_type=normalization_type,
)
self.vision_encoder = vision_backbone
single_cam_feat_dim = self.vision_encoder.output_dim
self.per_step_cond_dim = single_cam_feat_dim * num_cams + obs_dim
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
self.infer_scheduler = DDIMScheduler(
num_train_timesteps=diffusion_steps,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
prediction_type="epsilon",
)
if isinstance(head, nn.Module):
self.noise_pred_net = head
else:
self.noise_pred_net = head(
input_dim=action_dim,
output_dim=action_dim,
horizon=pred_horizon,
n_obs_steps=obs_horizon,
cond_dim=self.per_step_cond_dim,
)
self.state_encoder = state_encoder
self.action_encoder = action_encoder
self.reset()
def _get_model_device(self) -> torch.device:
return next(self.parameters()).device
def _move_to_device(self, data, device: torch.device):
if torch.is_tensor(data):
return data.to(device)
if isinstance(data, dict):
return {k: self._move_to_device(v, device) for k, v in data.items()}
if isinstance(data, list):
return [self._move_to_device(v, device) for v in data]
if isinstance(data, tuple):
return tuple(self._move_to_device(v, device) for v in data)
return data
def _build_cond(self, images: Dict[str, torch.Tensor], states: torch.Tensor) -> torch.Tensor:
visual_features = self.vision_encoder(images)
state_features = self.state_encoder(states)
return torch.cat([visual_features, state_features], dim=-1)
def compute_loss(self, batch):
actions, states, images = batch["action"], batch["qpos"], batch["images"]
action_is_pad = batch.get("action_is_pad", None)
bsz = actions.shape[0]
states = self.normalization.normalize_qpos(states)
actions = self.normalization.normalize_action(actions)
action_features = self.action_encoder(actions)
cond = self._build_cond(images, states)
noise = torch.randn_like(action_features)
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=action_features.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(action_features, noise, timesteps)
pred_noise = self.noise_pred_net(
sample=noisy_actions,
timestep=timesteps,
cond=cond,
)
loss = nn.functional.mse_loss(pred_noise, noise, reduction="none")
if action_is_pad is not None:
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
valid_count = mask.sum() * loss.shape[-1]
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
else:
loss = loss.mean()
return loss
def reset(self):
self._queues = {
"qpos": deque(maxlen=self.obs_horizon),
"images": deque(maxlen=self.obs_horizon),
"action": deque(maxlen=self.pred_horizon - self.obs_horizon + 1),
}
def _populate_queues(self, observation: Dict[str, torch.Tensor]) -> None:
if "qpos" in observation:
self._queues["qpos"].append(observation["qpos"].clone())
if "images" in observation:
self._queues["images"].append({k: v.clone() for k, v in observation["images"].items()})
def _prepare_observation_batch(self) -> Dict[str, torch.Tensor]:
qpos_list = list(self._queues["qpos"])
if len(qpos_list) == 0:
raise ValueError("observation queue is empty.")
while len(qpos_list) < self.obs_horizon:
qpos_list.append(qpos_list[-1])
batch_qpos = torch.stack(qpos_list, dim=0).unsqueeze(0)
images_list = list(self._queues["images"])
if len(images_list) == 0:
raise ValueError("image queue is empty.")
while len(images_list) < self.obs_horizon:
images_list.append(images_list[-1])
batch_images = {}
for cam_name in images_list[0].keys():
batch_images[cam_name] = torch.stack(
[img[cam_name] for img in images_list], dim=0
).unsqueeze(0)
return {"qpos": batch_qpos, "images": batch_images}
@torch.no_grad()
def select_action(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
device = self._get_model_device()
observation = self._move_to_device(observation, device)
self._populate_queues(observation)
if len(self._queues["action"]) == 0:
batch = self._prepare_observation_batch()
actions = self.predict_action_chunk(batch)
start = self.obs_horizon - 1
end = start + self.num_action_steps
executable_actions = actions[:, start:end]
for i in range(executable_actions.shape[1]):
self._queues["action"].append(executable_actions[:, i].squeeze(0))
return self._queues["action"].popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
return self.predict_action(batch["images"], batch["qpos"])
@torch.no_grad()
def predict_action(self, images, proprioception):
bsz = proprioception.shape[0]
proprioception = self.normalization.normalize_qpos(proprioception)
cond = self._build_cond(images, proprioception)
device = cond.device
current_actions = torch.randn((bsz, self.pred_horizon, self.action_dim), device=device)
self.infer_scheduler.set_timesteps(self.inference_steps)
for t in self.infer_scheduler.timesteps:
noise_pred = self.noise_pred_net(
sample=current_actions,
timestep=t,
cond=cond,
)
current_actions = self.infer_scheduler.step(
noise_pred, t, current_actions
).prev_sample
return self.normalization.denormalize_action(current_actions)
def get_normalization_stats(self):
return self.normalization.get_stats()

View File

@@ -25,7 +25,7 @@ normalization_type: "min_max" # "min_max" or "gaussian"
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 16 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置

View File

@@ -0,0 +1,37 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: gr00t_dit1d
- _self_
_target_: roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT
# Model dimensions
action_dim: 16
obs_dim: 16
# Normalization
normalization_type: "min_max"
# Horizons
pred_horizon: 16
obs_horizon: 2
num_action_steps: 8
# Cameras
num_cams: 3
# Diffusion
diffusion_steps: 100
inference_steps: 10
# Head overrides
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
cond_dim: 208

View File

@@ -0,0 +1,54 @@
# @package agent
defaults:
- /backbone@vision_backbone: resnet_diffusion
- /modules@state_encoder: identity_state_encoder
- /modules@action_encoder: identity_action_encoder
- /head: transformer1d
- _self_
_target_: roboimi.vla.agent.VLAAgent
# ====================
# 模型维度配置
# ====================
action_dim: 16 # 动作维度(机器人关节数)
obs_dim: 16 # 本体感知维度(关节位置)
# ====================
# 归一化配置
# ====================
normalization_type: "min_max" # "min_max" or "gaussian"
# ====================
# 时间步配置
# ====================
pred_horizon: 16 # 预测未来多少步动作
obs_horizon: 2 # 使用多少步历史观测
num_action_steps: 8 # 每次推理实际执行多少步动作(应 <= pred_horizon - obs_horizon + 1
# ====================
# 相机配置
# ====================
num_cams: 3 # 摄像头数量 (r_vis, top, front)
# ====================
# 扩散过程配置
# ====================
diffusion_steps: 100 # 扩散训练步数DDPM
inference_steps: 10 # 推理时的去噪步数DDIM<4D><EFBC8C>定为 10
# ====================
# Head 类型标识用于VLAAgent选择调用方式
# ====================
head_type: "transformer" # "unet" 或 "transformer"
# Head 参数覆盖
head:
input_dim: ${agent.action_dim}
output_dim: ${agent.action_dim}
horizon: ${agent.pred_horizon}
n_obs_steps: ${agent.obs_horizon}
# Transformer的cond_dim是每步的维度
# ResNet18 + SpatialSoftmax(32 keypoints) = 64维/相机
# 计算方式:单相机特征(64) * 相机数(3) + obs_dim(16) = 208
cond_dim: 208

View File

@@ -1,5 +1,5 @@
defaults:
- agent: resnet_diffusion
- agent: resnet_transformer
- data: simpe_robot_dataset
- eval: eval
- _self_
@@ -10,7 +10,7 @@ defaults:
train:
# 基础训练参数
batch_size: 8 # 批次大小
lr: 1e-4 # 学习率
lr: 5e-5 # 学习率Transformer建议更小
max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu"
@@ -24,7 +24,7 @@ train:
save_freq: 2000 # 保存检查点频率(步数)
# 学习率调度器(带预热)
warmup_steps: 500 # 预热步数
warmup_steps: 2000 # 预热步数Transformer建议更长
scheduler_type: "cosine" # 预热后的调度器: "constant" 或 "cosine"
min_lr: 1e-6 # 最小学习率(用于余弦退火)

View File

@@ -0,0 +1,22 @@
_target_: roboimi.vla.models.heads.gr00t_dit1d.Gr00tDiT1D
_partial_: true
# DiT architecture
n_layer: 6
n_head: 8
n_emb: 256
hidden_dim: 256
mlp_ratio: 4
dropout: 0.1
# Positional embeddings
add_action_pos_emb: true
add_cond_pos_emb: true
# Supplied by agent interpolation:
# - input_dim
# - output_dim
# - horizon
# - n_obs_steps
# - cond_dim

View File

@@ -0,0 +1,29 @@
# Transformer-based Diffusion Policy Head
_target_: roboimi.vla.models.heads.transformer1d.Transformer1D
_partial_: true
# ====================
# Transformer 架构配置
# ====================
n_layer: 4 # Transformer层数先用小模型提高收敛稳定性
n_head: 4 # 注意力头数
n_emb: 128 # 嵌入维度
p_drop_emb: 0.05 # Embedding dropout
p_drop_attn: 0.05 # Attention dropout
# ====================
# 条件配置
# ====================
causal_attn: false # 是否使用因果注意力(自回归生成)
obs_as_cond: true # 观测作为条件由cond_dim > 0决定
n_cond_layers: 1 # 条件编码器层数1层先做稳定融合
# ====================
# 注意事项
# ====================
# 以下参数将在agent配置中通过interpolation提供
# - input_dim: ${agent.action_dim}
# - output_dim: ${agent.action_dim}
# - horizon: ${agent.pred_horizon}
# - n_obs_steps: ${agent.obs_horizon}
# - cond_dim: 通过agent中的global_cond_dim计算

View File

@@ -3,6 +3,7 @@ import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Union
from pathlib import Path
from collections import OrderedDict
class SimpleRobotDataset(Dataset):
@@ -21,6 +22,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: int = 2,
pred_horizon: int = 8,
camera_names: List[str] = None,
max_open_files: int = 64,
):
"""
Args:
@@ -28,6 +30,7 @@ class SimpleRobotDataset(Dataset):
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
max_open_files: 每个 worker 最多缓存的 HDF5 文件句柄数
HDF5 文件格式:
- action: [T, action_dim]
@@ -37,6 +40,8 @@ class SimpleRobotDataset(Dataset):
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.max_open_files = max(1, int(max_open_files))
self._file_cache: "OrderedDict[str, h5py.File]" = OrderedDict()
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
@@ -69,10 +74,41 @@ class SimpleRobotDataset(Dataset):
def __len__(self):
return len(self.frame_meta)
def _close_all_files(self) -> None:
"""关闭当前 worker 内缓存的所有 HDF5 文件句柄。"""
for f in self._file_cache.values():
try:
f.close()
except Exception:
pass
self._file_cache.clear()
def _get_h5_file(self, hdf5_path: Union[str, Path]) -> h5py.File:
"""
获取 HDF5 文件句柄worker 内 LRU 缓存)。
注意:缓存的是文件句柄,不是帧数据本身。
"""
key = str(hdf5_path)
if key in self._file_cache:
self._file_cache.move_to_end(key)
return self._file_cache[key]
# 超过上限时淘汰最久未使用的句柄
if len(self._file_cache) >= self.max_open_files:
_, old_file = self._file_cache.popitem(last=False)
try:
old_file.close()
except Exception:
pass
f = h5py.File(key, 'r')
self._file_cache[key] = f
return f
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
f = self._get_h5_file(meta["hdf5_path"])
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
@@ -201,3 +237,6 @@ class SimpleRobotDataset(Dataset):
"dtype": str(sample[key].dtype),
}
return info
def __del__(self):
self._close_all_files()

View File

@@ -1,4 +1,5 @@
# # Action Head models
# Action Head models
from .conditional_unet1d import ConditionalUnet1D
from .transformer1d import Transformer1D
__all__ = ["ConditionalUnet1D"]
__all__ = ["ConditionalUnet1D", "Transformer1D"]

View File

@@ -124,7 +124,6 @@ class ConditionalResidualBlock1D(nn.Module):
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
@@ -149,23 +148,6 @@ class ConditionalUnet1D(nn.Module):
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList([
# down encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale),
# up encoder
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups,
cond_predict_scale=cond_predict_scale)
])
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
@@ -216,7 +198,6 @@ class ConditionalUnet1D(nn.Module):
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
@@ -225,12 +206,11 @@ class ConditionalUnet1D(nn.Module):
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None, global_cond=None,
global_cond=None,
**kwargs):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
@@ -253,22 +233,10 @@ class ConditionalUnet1D(nn.Module):
global_feature, global_cond
], axis=-1)
# encode local features
h_local = list()
if local_cond is not None:
local_cond = einops.rearrange(local_cond, 'b h t -> b t h')
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
@@ -279,12 +247,6 @@ class ConditionalUnet1D(nn.Module):
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)
x = upsample(x)

View File

@@ -0,0 +1,146 @@
import torch
import torch.nn as nn
from types import SimpleNamespace
from typing import Optional, Union
from pathlib import Path
import importlib.util
def _load_gr00t_dit():
repo_root = Path(__file__).resolve().parents[4]
dit_path = repo_root / "gr00t" / "models" / "dit.py"
spec = importlib.util.spec_from_file_location("gr00t_dit_standalone", dit_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load DiT from {dit_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.DiT
DiT = _load_gr00t_dit()
class Gr00tDiT1D(nn.Module):
"""
Adapter that wraps gr00t DiT with the same call signature used by VLA heads.
Expected forward interface:
- sample: (B, T_action, input_dim)
- timestep: (B,) or scalar diffusion timestep
- cond: (B, T_obs, cond_dim)
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
hidden_dim: int = 256,
mlp_ratio: int = 4,
dropout: float = 0.1,
add_action_pos_emb: bool = True,
add_cond_pos_emb: bool = True,
):
super().__init__()
if cond_dim <= 0:
raise ValueError("Gr00tDiT1D requires cond_dim > 0.")
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.input_proj = nn.Linear(input_dim, n_emb)
self.cond_proj = nn.Linear(cond_dim, n_emb)
self.output_proj = nn.Linear(hidden_dim, output_dim)
self.action_pos_emb = (
nn.Parameter(torch.zeros(1, horizon, n_emb))
if add_action_pos_emb
else None
)
self.cond_pos_emb = (
nn.Parameter(torch.zeros(1, n_obs_steps, n_emb))
if add_cond_pos_emb
else None
)
args = SimpleNamespace(
embed_dim=n_emb,
nheads=n_head,
mlp_ratio=mlp_ratio,
dropout=dropout,
num_layers=n_layer,
hidden_dim=hidden_dim,
)
self.dit = DiT(args, cross_attention_dim=n_emb)
self._init_weights()
def _init_weights(self):
if self.action_pos_emb is not None:
nn.init.normal_(self.action_pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def _normalize_timesteps(
self,
timestep: Union[torch.Tensor, float, int],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
if not torch.is_tensor(timestep):
timesteps = torch.tensor([timestep], device=device)
else:
timesteps = timestep.to(device)
if timesteps.ndim == 0:
timesteps = timesteps[None]
if timesteps.shape[0] != batch_size:
timesteps = timesteps.expand(batch_size)
return timesteps.long()
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if cond is None:
raise ValueError("`cond` is required for Gr00tDiT1D forward.")
bsz, t_act, _ = sample.shape
if t_act > self.horizon:
raise ValueError(
f"sample length {t_act} exceeds configured horizon {self.horizon}"
)
hidden_states = self.input_proj(sample)
if self.action_pos_emb is not None:
hidden_states = hidden_states + self.action_pos_emb[:, :t_act, :]
encoder_hidden_states = self.cond_proj(cond)
if self.cond_pos_emb is not None:
t_obs = encoder_hidden_states.shape[1]
if t_obs > self.n_obs_steps:
raise ValueError(
f"cond length {t_obs} exceeds configured n_obs_steps {self.n_obs_steps}"
)
encoder_hidden_states = (
encoder_hidden_states + self.cond_pos_emb[:, :t_obs, :]
)
timesteps = self._normalize_timesteps(
timestep, batch_size=bsz, device=sample.device
)
dit_output = self.dit(
hidden_states=hidden_states,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
)
return self.output_proj(dit_output)

View File

@@ -0,0 +1,396 @@
"""
Transformer-based Diffusion Policy Head
使用Transformer架构Encoder-Decoder替代UNet进行噪声预测。
支持通过Cross-Attention注入全局条件观测特征
"""
import math
import torch
import torch.nn as nn
from typing import Optional
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码(用于时间步嵌入)"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Transformer1D(nn.Module):
"""
Transformer-based 1D Diffusion Model
使用Encoder-Decoder架构
- Encoder: 处理条件(观测 + 时间步)
- Decoder: 通过Cross-Attention预测噪声
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon长度
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
p_drop_emb: Embedding dropout
p_drop_attn: Attention dropout
causal_attn: 是否使用因果注意力(自回归)
n_cond_layers: Encoder层数0表示使用MLP
"""
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0
):
super().__init__()
# 计算序列长度
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 1 # 时间步token数量
# 确定是否使用观测作为条件
obs_as_cond = cond_dim > 0
if obs_as_cond:
T_cond += n_obs_steps
# 保存配置
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.obs_as_cond = obs_as_cond
self.input_dim = input_dim
self.output_dim = output_dim
# ==================== 输入嵌入 ====================
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
# ==================== 条件编码 ====================
# 时间步嵌入
self.time_emb = SinusoidalPosEmb(n_emb)
# 观测条件嵌入(可选)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
# 条件位置编码
self.cond_pos_emb = None
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
# ==================== Encoder ====================
self.encoder = None
self.encoder_only = False
if T_cond > 0:
if n_cond_layers > 0:
# 使用Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True # Pre-LN更稳定
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers
)
else:
# 使用简单的MLP
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb)
)
else:
# Encoder-only模式BERT风格
self.encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer
)
# ==================== Attention Mask ====================
self.mask = None
self.memory_mask = None
if causal_attn:
# 因果mask确保只关注左侧
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
# 交叉注意力mask
S = T_cond
t, s = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij'
)
mask = t >= (s - 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
# ==================== Decoder ====================
if not self.encoder_only:
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer
)
# ==================== 输出头 ====================
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
# ==================== 初始化 ====================
self.apply(self._init_weights)
# 打印参数量
total_params = sum(p.numel() for p in self.parameters())
print(f"Transformer1D parameters: {total_params:,}")
def _init_weights(self, module):
"""初始化权重"""
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# MultiheadAttention的权重初始化
for name in ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']:
weight = getattr(module, name, None)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ['in_proj_bias', 'bias_k', 'bias_v']:
bias = getattr(module, name, None)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, Transformer1D):
# 位置编码初始化
torch.nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)
if self.cond_pos_emb is not None:
torch.nn.init.normal_(self.cond_pos_emb, mean=0.0, std=0.02)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
cond: Optional[torch.Tensor] = None,
**kwargs
):
"""
前向传播
Args:
sample: (B, T, input_dim) 输入序列(加噪动作)
timestep: (B,) 时间步
cond: (B, T', cond_dim) 条件序列(观测特征)
Returns:
(B, T, output_dim) 预测的噪声
"""
# ==================== 处理时间步 ====================
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 扩展到batch维度
timesteps = timesteps.expand(sample.shape[0])
time_emb = self.time_emb(timesteps).unsqueeze(1) # (B, 1, n_emb)
# ==================== 处理输入 ====================
input_emb = self.input_emb(sample) # (B, T, n_emb)
# ==================== Encoder-Decoder模式 ====================
if not self.encoder_only:
# --- Encoder: 处理条件 ---
cond_embeddings = time_emb
if self.obs_as_cond and cond is not None:
# 添加观测条件
cond_obs_emb = self.cond_obs_emb(cond) # (B, T_cond-1, n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# 添加位置编码
tc = cond_embeddings.shape[1]
pos_emb = self.cond_pos_emb[:, :tc, :]
x = self.drop(cond_embeddings + pos_emb)
# 通过encoder
memory = self.encoder(x) # (B, T_cond, n_emb)
# --- Decoder: 预测噪声 ---
# 添加位置编码到输入
token_embeddings = input_emb
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
# Cross-Attention: Query来自输入Key/Value来自memory
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask
)
# ==================== Encoder-Only模式 ====================
else:
# BERT风格时间步作为特殊token
token_embeddings = torch.cat([time_emb, input_emb], dim=1)
t = token_embeddings.shape[1]
pos_emb = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + pos_emb)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 1:, :] # 移除时间步token
# ==================== 输出头 ====================
x = self.ln_f(x)
x = self.head(x) # (B, T, output_dim)
return x
# ============================================================================
# 便捷函数创建Transformer1D模型
# ============================================================================
def create_transformer1d(
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int,
cond_dim: int,
n_layer: int = 8,
n_head: int = 8,
n_emb: int = 256,
**kwargs
) -> Transformer1D:
"""
创建Transformer1D模型的便捷函数
Args:
input_dim: 输入动作维度
output_dim: 输出动作维度
horizon: 预测horizon
n_obs_steps: 观测步数
cond_dim: 条件维度
n_layer: Transformer层数
n_head: 注意力头数
n_emb: 嵌入维度
**kwargs: 其他参数
Returns:
Transformer1D模型
"""
model = Transformer1D(
input_dim=input_dim,
output_dim=output_dim,
horizon=horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
**kwargs
)
return model
if __name__ == "__main__":
print("=" * 80)
print("Testing Transformer1D")
print("=" * 80)
# 配置
B = 4
T = 16
action_dim = 16
obs_horizon = 2
cond_dim = 416 # vision + state特征维度
# 创建模型
model = Transformer1D(
input_dim=action_dim,
output_dim=action_dim,
horizon=T,
n_obs_steps=obs_horizon,
cond_dim=cond_dim,
n_layer=4,
n_head=8,
n_emb=256,
causal_attn=False
)
# 测试前向传播
sample = torch.randn(B, T, action_dim)
timestep = torch.randint(0, 100, (B,))
cond = torch.randn(B, obs_horizon, cond_dim)
output = model(sample, timestep, cond)
print(f"\n输入:")
print(f" sample: {sample.shape}")
print(f" timestep: {timestep.shape}")
print(f" cond: {cond.shape}")
print(f"\n输出:")
print(f" output: {output.shape}")
print(f"\n✅ 测试通过!")