Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdf4dd8bed | ||
|
|
fd1bd20c4f |
@@ -1,112 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
DDT 模型构建和优化器配置。
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from .models import build_DDT_model
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
|
||||||
"""获取 DDT 模型的参数解析器。"""
|
|
||||||
parser = argparse.ArgumentParser('DDT model configuration', add_help=False)
|
|
||||||
|
|
||||||
# 学习率
|
|
||||||
parser.add_argument('--lr', default=1e-4, type=float)
|
|
||||||
parser.add_argument('--lr_backbone', default=1e-5, type=float)
|
|
||||||
parser.add_argument('--batch_size', default=2, type=int)
|
|
||||||
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
|
||||||
parser.add_argument('--epochs', default=300, type=int)
|
|
||||||
parser.add_argument('--lr_drop', default=200, type=int)
|
|
||||||
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
|
||||||
help='gradient clipping max norm')
|
|
||||||
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float)
|
|
||||||
|
|
||||||
# Backbone 参数
|
|
||||||
parser.add_argument('--backbone', default='resnet18', type=str,
|
|
||||||
help="Name of the convolutional backbone to use")
|
|
||||||
parser.add_argument('--dilation', action='store_true',
|
|
||||||
help="If true, replace stride with dilation in the last conv block")
|
|
||||||
parser.add_argument('--position_embedding', default='sine', type=str,
|
|
||||||
choices=('sine', 'learned'),
|
|
||||||
help="Type of positional embedding")
|
|
||||||
parser.add_argument('--camera_names', default=[], type=list,
|
|
||||||
help="A list of camera names")
|
|
||||||
|
|
||||||
# Transformer 参数
|
|
||||||
parser.add_argument('--enc_layers', default=4, type=int,
|
|
||||||
help="Number of encoding layers in the transformer")
|
|
||||||
parser.add_argument('--dec_layers', default=6, type=int,
|
|
||||||
help="Number of decoding layers in the transformer")
|
|
||||||
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
|
||||||
help="Intermediate size of the feedforward layers")
|
|
||||||
parser.add_argument('--hidden_dim', default=512, type=int,
|
|
||||||
help="Size of the embeddings (dimension of the transformer)")
|
|
||||||
parser.add_argument('--dropout', default=0.1, type=float,
|
|
||||||
help="Dropout applied in the transformer")
|
|
||||||
parser.add_argument('--nheads', default=8, type=int,
|
|
||||||
help="Number of attention heads")
|
|
||||||
parser.add_argument('--num_queries', default=100, type=int,
|
|
||||||
help="Number of query slots (action horizon)")
|
|
||||||
parser.add_argument('--pre_norm', action='store_true')
|
|
||||||
parser.add_argument('--state_dim', default=14, type=int)
|
|
||||||
parser.add_argument('--action_dim', default=14, type=int)
|
|
||||||
|
|
||||||
# DDT 特有参数
|
|
||||||
parser.add_argument('--num_blocks', default=12, type=int,
|
|
||||||
help="Total number of transformer blocks in DDT")
|
|
||||||
parser.add_argument('--mlp_ratio', default=4.0, type=float,
|
|
||||||
help="MLP hidden dimension ratio")
|
|
||||||
parser.add_argument('--num_inference_steps', default=10, type=int,
|
|
||||||
help="Number of diffusion inference steps")
|
|
||||||
|
|
||||||
# Segmentation (未使用)
|
|
||||||
parser.add_argument('--masks', action='store_true',
|
|
||||||
help="Train segmentation head if provided")
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def build_DDT_model_and_optimizer(args_override):
|
|
||||||
"""构建 DDT 模型和优化器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args_override: 覆盖默认参数的字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: DDT 模型
|
|
||||||
optimizer: AdamW 优化器
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser('DDT training script', parents=[get_args_parser()])
|
|
||||||
args = parser.parse_args([]) # 空列表避免命令行参数干扰
|
|
||||||
|
|
||||||
# 应用参数覆盖
|
|
||||||
for k, v in args_override.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
# 构建模型
|
|
||||||
model = build_DDT_model(args)
|
|
||||||
model.cuda()
|
|
||||||
|
|
||||||
# 配置优化器(backbone 使用较小学习率)
|
|
||||||
param_dicts = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters()
|
|
||||||
if "backbone" not in n and p.requires_grad]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters()
|
|
||||||
if "backbone" in n and p.requires_grad],
|
|
||||||
"lr": args.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
param_dicts,
|
|
||||||
lr=args.lr,
|
|
||||||
weight_decay=args.weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
return model, optimizer
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
from .model import build as build_ddt
|
|
||||||
from .model import build_ddt
|
|
||||||
|
|
||||||
def build_DDT_model(args):
|
|
||||||
"""构建 DDT 模型的统一入口。"""
|
|
||||||
return build_ddt(args)
|
|
||||||
@@ -1,631 +0,0 @@
|
|||||||
"""
|
|
||||||
动作序列扩散 Transformer (Action Decoupled Diffusion Transformer)
|
|
||||||
|
|
||||||
基于 DDT 架构修改,用于生成机器人动作序列。
|
|
||||||
主要改动:
|
|
||||||
1. 2D RoPE → 1D RoPE (适配时序数据)
|
|
||||||
2. LabelEmbedder → ObservationEncoder (观测条件)
|
|
||||||
3. 去除 patchify/unpatchify (动作序列已是 1D)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 通用工具函数
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""AdaLN 调制函数。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: 输入张量。
|
|
||||||
shift: 偏移量。
|
|
||||||
scale: 缩放量。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
调制后的张量: x * (1 + scale) + shift
|
|
||||||
"""
|
|
||||||
return x * (1 + scale) + shift
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 1D 旋转位置编码 (RoPE)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
def precompute_freqs_cis_1d(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
|
|
||||||
"""预计算 1D 旋转位置编码的复数频率。
|
|
||||||
|
|
||||||
用于时序数据(如动作序列)的位置编码,相比 2D RoPE 更简单高效。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim: 每个注意力头的维度 (head_dim)。
|
|
||||||
seq_len: 序列长度。
|
|
||||||
theta: RoPE 的基础频率,默认 10000.0。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
复数频率张量,形状为 (seq_len, dim//2)。
|
|
||||||
"""
|
|
||||||
# 计算频率: 1 / (theta^(2i/dim))
|
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
|
|
||||||
# 位置索引
|
|
||||||
t = torch.arange(seq_len).float() # [seq_len]
|
|
||||||
# 外积得到位置-频率矩阵
|
|
||||||
freqs = torch.outer(t, freqs) # [seq_len, dim//2]
|
|
||||||
# 转换为复数形式 (极坐标)
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [seq_len, dim//2]
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_1d(
|
|
||||||
xq: torch.Tensor,
|
|
||||||
xk: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""应用 1D 旋转位置编码到 Query 和 Key。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
xq: Query 张量,形状为 (B, N, H, Hc)。
|
|
||||||
xk: Key 张量,形状为 (B, N, H, Hc)。
|
|
||||||
freqs_cis: 预计算的复数频率,形状为 (N, Hc//2)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
应用 RoPE 后的 (xq, xk),形状不变。
|
|
||||||
"""
|
|
||||||
# 调整 freqs_cis 形状以便广播: [1, N, 1, Hc//2]
|
|
||||||
freqs_cis = freqs_cis[None, :, None, :]
|
|
||||||
|
|
||||||
# 将实数张量视为复数: [B, N, H, Hc] -> [B, N, H, Hc//2] (复数)
|
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
||||||
|
|
||||||
# 复数乘法实现旋转
|
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [B, N, H, Hc]
|
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
||||||
|
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 基础组件
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
class Embed(nn.Module):
|
|
||||||
"""线性嵌入层,将输入投影到隐藏空间。"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_chans: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
norm_layer: Optional[nn.Module] = None,
|
|
||||||
bias: bool = True,
|
|
||||||
):
|
|
||||||
"""初始化 Embed。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_chans: 输入通道数/维度。
|
|
||||||
embed_dim: 输出嵌入维度。
|
|
||||||
norm_layer: 可选的归一化层。
|
|
||||||
bias: 是否使用偏置。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.in_chans = in_chans
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedder(nn.Module):
|
|
||||||
"""扩散时间步嵌入器。
|
|
||||||
|
|
||||||
使用正弦位置编码 + MLP 将标量时间步映射到高维向量。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
|
|
||||||
"""初始化 TimestepEmbedder。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size: 输出嵌入维度。
|
|
||||||
frequency_embedding_size: 正弦编码的维度。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
|
||||||
)
|
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10.0) -> torch.Tensor:
|
|
||||||
"""生成正弦时间步嵌入。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
t: 时间步张量,形状为 (B,)。
|
|
||||||
dim: 嵌入维度。
|
|
||||||
max_period: 最大周期。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
时间步嵌入,形状为 (B, dim)。
|
|
||||||
"""
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(
|
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
|
||||||
)
|
|
||||||
args = t[..., None].float() * freqs[None, ...]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
||||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
|
||||||
t_emb = self.mlp(t_freq)
|
|
||||||
return t_emb
|
|
||||||
|
|
||||||
|
|
||||||
class ObservationEncoder(nn.Module):
|
|
||||||
"""观测状态编码器。
|
|
||||||
|
|
||||||
将机器人的观测向量(如关节位置、末端位姿、图像特征等)
|
|
||||||
编码为条件向量,用于条件扩散生成。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
encoder: 两层 MLP 编码器。
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> encoder = ObservationEncoder(obs_dim=128, hidden_size=512)
|
|
||||||
>>> obs = torch.randn(2, 128)
|
|
||||||
>>> cond = encoder(obs) # [2, 512]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, obs_dim: int, hidden_size: int):
|
|
||||||
"""初始化 ObservationEncoder。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obs_dim: 观测向量的维度。
|
|
||||||
hidden_size: 输出条件向量的维度。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = nn.Sequential(
|
|
||||||
nn.Linear(obs_dim, hidden_size),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obs: 观测向量,形状为 (B, obs_dim)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
条件向量,形状为 (B, hidden_size)。
|
|
||||||
"""
|
|
||||||
return self.encoder(obs)
|
|
||||||
|
|
||||||
|
|
||||||
class FinalLayer(nn.Module):
|
|
||||||
"""最终输出层,使用 AdaLN 调制后输出预测结果。"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, out_channels: int):
|
|
||||||
"""初始化 FinalLayer。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size: 输入隐藏维度。
|
|
||||||
out_channels: 输出通道数/维度。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: 输入张量,形状为 (B, N, hidden_size)。
|
|
||||||
c: 条件张量,形状为 (B, N, hidden_size) 或 (B, 1, hidden_size)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
输出张量,形状为 (B, N, out_channels)。
|
|
||||||
"""
|
|
||||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
|
||||||
x = modulate(self.norm_final(x), shift, scale)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 归一化和前馈网络
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
"""Root Mean Square Layer Normalization (RMS 归一化层)。
|
|
||||||
|
|
||||||
RMSNorm 是 LayerNorm 的简化版本,去掉了均值中心化操作,只保留缩放。
|
|
||||||
相比 LayerNorm 计算更快,效果相当,被广泛用于 LLaMA、Mistral 等大模型。
|
|
||||||
|
|
||||||
数学公式:
|
|
||||||
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
||||||
"""初始化 RMSNorm。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size: 输入特征的维度。
|
|
||||||
eps: 防止除零的小常数。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
"""SwiGLU 前馈网络 (Feed-Forward Network)。
|
|
||||||
|
|
||||||
使用 SwiGLU 门控激活函数的前馈网络,来自 LLaMA 架构。
|
|
||||||
|
|
||||||
结构:
|
|
||||||
output = W2(SiLU(W1(x)) * W3(x))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim: int, hidden_dim: int):
|
|
||||||
"""初始化 FeedForward。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim: 输入和输出的特征维度。
|
|
||||||
hidden_dim: 隐藏层维度(实际使用 2/3 * hidden_dim)。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
hidden_dim = int(2 * hidden_dim / 3)
|
|
||||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
|
||||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
|
||||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 注意力机制
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
class RAttention(nn.Module):
|
|
||||||
"""带有旋转位置编码的多头自注意力 (Rotary Attention)。
|
|
||||||
|
|
||||||
集成了以下技术:
|
|
||||||
- 1D RoPE: 通过复数旋转编码时序位置信息
|
|
||||||
- QK-Norm: 对 Query 和 Key 进行归一化,稳定训练
|
|
||||||
- Flash Attention: 使用 scaled_dot_product_attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_heads: int = 8,
|
|
||||||
qkv_bias: bool = False,
|
|
||||||
qk_norm: bool = True,
|
|
||||||
attn_drop: float = 0.,
|
|
||||||
proj_drop: float = 0.,
|
|
||||||
norm_layer: nn.Module = RMSNorm,
|
|
||||||
) -> None:
|
|
||||||
"""初始化 RAttention。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim: 输入特征维度,必须能被 num_heads 整除。
|
|
||||||
num_heads: 注意力头数。
|
|
||||||
qkv_bias: QKV 投影是否使用偏置。
|
|
||||||
qk_norm: 是否对 Q, K 进行归一化。
|
|
||||||
attn_drop: 注意力权重的 dropout 率。
|
|
||||||
proj_drop: 输出投影的 dropout 率。
|
|
||||||
norm_layer: 归一化层类型。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
pos: torch.Tensor,
|
|
||||||
mask: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: 输入张量,形状为 (B, N, C)。
|
|
||||||
pos: 1D RoPE 位置编码,形状为 (N, head_dim//2)。
|
|
||||||
mask: 可选的注意力掩码。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
输出张量,形状为 (B, N, C)。
|
|
||||||
"""
|
|
||||||
B, N, C = x.shape
|
|
||||||
|
|
||||||
# QKV 投影
|
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # [B, N, H, Hc]
|
|
||||||
|
|
||||||
# QK-Norm
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# 应用 1D RoPE
|
|
||||||
q, k = apply_rotary_emb_1d(q, k, freqs_cis=pos)
|
|
||||||
|
|
||||||
# 调整维度: [B, N, H, Hc] -> [B, H, N, Hc]
|
|
||||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
|
|
||||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
|
||||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# Scaled Dot-Product Attention
|
|
||||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
|
||||||
|
|
||||||
# 输出投影
|
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Transformer Block
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
class ActionDDTBlock(nn.Module):
|
|
||||||
"""动作 DDT Transformer Block。
|
|
||||||
|
|
||||||
结构: Pre-Norm + AdaLN + Attention + FFN
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
|
|
||||||
"""初始化 ActionDDTBlock。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size: 隐藏层维度。
|
|
||||||
num_heads: 注意力头数。
|
|
||||||
mlp_ratio: FFN 隐藏层倍率。
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
|
||||||
self.attn = RAttention(hidden_size, num_heads=num_heads, qkv_bias=False)
|
|
||||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
||||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
pos: torch.Tensor,
|
|
||||||
mask: Optional[torch.Tensor] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: 输入张量,形状为 (B, N, hidden_size)。
|
|
||||||
c: 条件张量,形状为 (B, 1, hidden_size) 或 (B, N, hidden_size)。
|
|
||||||
pos: 位置编码。
|
|
||||||
mask: 可选的注意力掩码。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
输出张量,形状为 (B, N, hidden_size)。
|
|
||||||
"""
|
|
||||||
# AdaLN 调制参数
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
|
|
||||||
self.adaLN_modulation(c).chunk(6, dim=-1)
|
|
||||||
|
|
||||||
# Attention 分支
|
|
||||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
|
||||||
# FFN 分支
|
|
||||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# 主模型: ActionDDT
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
class ActionDDT(nn.Module):
|
|
||||||
"""动作序列解耦扩散 Transformer (Action Decoupled Diffusion Transformer)。
|
|
||||||
|
|
||||||
基于 DDT 架构,专为机器人动作序列生成设计。
|
|
||||||
将模型解耦为编码器和解码器两部分,编码器状态可缓存以加速推理。
|
|
||||||
|
|
||||||
架构:
|
|
||||||
- 编码器: 前 num_encoder_blocks 个 block,生成状态 s
|
|
||||||
- 解码器: 剩余 block,使用状态 s 对动作序列 x 去噪
|
|
||||||
- 使用 1D RoPE 进行时序位置编码
|
|
||||||
- 使用 AdaLN 注入时间步和观测条件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_dim: 动作向量维度(如 7-DoF 机械臂为 7)。
|
|
||||||
obs_dim: 观测向量维度。
|
|
||||||
action_horizon: 预测的动作序列长度。
|
|
||||||
hidden_size: Transformer 隐藏层维度。
|
|
||||||
num_blocks: Transformer block 总数。
|
|
||||||
num_encoder_blocks: 编码器 block 数量。
|
|
||||||
num_heads: 注意力头数。
|
|
||||||
mlp_ratio: FFN 隐藏层倍率。
|
|
||||||
|
|
||||||
输入:
|
|
||||||
x (Tensor): 带噪声的动作序列,形状为 (B, T, action_dim)。
|
|
||||||
t (Tensor): 扩散时间步,形状为 (B,),取值范围 [0, 1]。
|
|
||||||
obs (Tensor): 观测条件,形状为 (B, obs_dim)。
|
|
||||||
s (Tensor, optional): 缓存的编码器状态。
|
|
||||||
|
|
||||||
输出:
|
|
||||||
x (Tensor): 预测的速度场/噪声,形状为 (B, T, action_dim)。
|
|
||||||
s (Tensor): 编码器状态,可缓存复用。
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> model = ActionDDT(action_dim=7, obs_dim=128, action_horizon=16)
|
|
||||||
>>> x = torch.randn(2, 16, 7) # 带噪声的动作序列
|
|
||||||
>>> t = torch.rand(2) # 随机时间步
|
|
||||||
>>> obs = torch.randn(2, 128) # 观测条件
|
|
||||||
>>> out, state = model(x, t, obs)
|
|
||||||
>>> out.shape
|
|
||||||
torch.Size([2, 16, 7])
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
action_dim: int = 7,
|
|
||||||
obs_dim: int = 128,
|
|
||||||
action_horizon: int = 16,
|
|
||||||
hidden_size: int = 512,
|
|
||||||
num_blocks: int = 12,
|
|
||||||
num_encoder_blocks: int = 4,
|
|
||||||
num_heads: int = 8,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.obs_dim = obs_dim
|
|
||||||
self.action_horizon = action_horizon
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.num_encoder_blocks = num_encoder_blocks
|
|
||||||
self.num_heads = num_heads
|
|
||||||
|
|
||||||
# 动作嵌入层
|
|
||||||
self.x_embedder = Embed(action_dim, hidden_size, bias=True)
|
|
||||||
self.s_embedder = Embed(action_dim, hidden_size, bias=True)
|
|
||||||
|
|
||||||
# 条件嵌入
|
|
||||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
|
||||||
self.obs_encoder = ObservationEncoder(obs_dim, hidden_size)
|
|
||||||
|
|
||||||
# 输出层
|
|
||||||
self.final_layer = FinalLayer(hidden_size, action_dim)
|
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
ActionDDTBlock(hidden_size, num_heads, mlp_ratio)
|
|
||||||
for _ in range(num_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
# 预计算 1D 位置编码
|
|
||||||
pos = precompute_freqs_cis_1d(hidden_size // num_heads, action_horizon)
|
|
||||||
self.register_buffer('pos', pos)
|
|
||||||
|
|
||||||
# 初始化权重
|
|
||||||
self.initialize_weights()
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
"""初始化模型权重。"""
|
|
||||||
# 嵌入层使用 Xavier 初始化
|
|
||||||
for embedder in [self.x_embedder, self.s_embedder]:
|
|
||||||
w = embedder.proj.weight.data
|
|
||||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
||||||
nn.init.constant_(embedder.proj.bias, 0)
|
|
||||||
|
|
||||||
# 时间步嵌入 MLP
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
||||||
|
|
||||||
# 观测编码器
|
|
||||||
for m in self.obs_encoder.encoder:
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.normal_(m.weight, std=0.02)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
# 输出层零初始化 (AdaLN-Zero)
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
|
||||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
obs: torch.Tensor,
|
|
||||||
s: Optional[torch.Tensor] = None,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: 带噪声的动作序列 [B, T, action_dim]
|
|
||||||
t: 扩散时间步 [B] 或 [B, 1],取值范围 [0, 1]
|
|
||||||
obs: 观测条件 [B, obs_dim]
|
|
||||||
s: 可选的编码器状态缓存 [B, T, hidden_size]
|
|
||||||
mask: 可选的注意力掩码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x: 预测的速度场/噪声 [B, T, action_dim]
|
|
||||||
s: 编码器状态 [B, T, hidden_size],可缓存复用
|
|
||||||
"""
|
|
||||||
B, T, _ = x.shape
|
|
||||||
|
|
||||||
# 1. 时间步嵌入: [B] -> [B, 1, hidden_size]
|
|
||||||
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
|
|
||||||
|
|
||||||
# 2. 观测条件嵌入: [B, obs_dim] -> [B, 1, hidden_size]
|
|
||||||
obs_emb = self.obs_encoder(obs).view(B, 1, self.hidden_size)
|
|
||||||
|
|
||||||
# 3. 融合条件: c = SiLU(t + obs)
|
|
||||||
c = nn.functional.silu(t_emb + obs_emb)
|
|
||||||
|
|
||||||
# 4. 编码器部分: 生成状态 s
|
|
||||||
if s is None:
|
|
||||||
# 状态嵌入: [B, T, action_dim] -> [B, T, hidden_size]
|
|
||||||
s = self.s_embedder(x)
|
|
||||||
# 通过编码器 blocks
|
|
||||||
for i in range(self.num_encoder_blocks):
|
|
||||||
s = self.blocks[i](s, c, self.pos, mask)
|
|
||||||
# 融合时间信息
|
|
||||||
s = nn.functional.silu(t_emb + s)
|
|
||||||
|
|
||||||
# 5. 解码器部分: 去噪
|
|
||||||
# 输入嵌入: [B, T, action_dim] -> [B, T, hidden_size]
|
|
||||||
x = self.x_embedder(x)
|
|
||||||
# 通过解码器 blocks,以 s 作为条件
|
|
||||||
for i in range(self.num_encoder_blocks, self.num_blocks):
|
|
||||||
x = self.blocks[i](x, s, self.pos, None)
|
|
||||||
|
|
||||||
# 6. 最终层: [B, T, hidden_size] -> [B, T, action_dim]
|
|
||||||
x = self.final_layer(x, s)
|
|
||||||
|
|
||||||
return x, s
|
|
||||||
@@ -1,304 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
DDT model and criterion classes.
|
|
||||||
|
|
||||||
核心组装文件,将 Backbone、Transformer、Diffusion 组件组装为完整模型。
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .backbone import build_backbone
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialSoftmax(nn.Module):
|
|
||||||
"""Spatial Softmax 层,将特征图转换为关键点坐标。
|
|
||||||
|
|
||||||
来自 Diffusion Policy,保留空间位置信息。
|
|
||||||
对每个通道计算软注意力加权的期望坐标。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_kp: 关键点数量(等于输入通道数)
|
|
||||||
temperature: Softmax 温度参数(可学习)
|
|
||||||
learnable_temperature: 是否学习温度参数
|
|
||||||
|
|
||||||
输入: [B, C, H, W]
|
|
||||||
输出: [B, C * 2] - 每个通道输出 (x, y) 坐标
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_kp: int = None, temperature: float = 1.0, learnable_temperature: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.num_kp = num_kp
|
|
||||||
if learnable_temperature:
|
|
||||||
self.temperature = nn.Parameter(torch.ones(1) * temperature)
|
|
||||||
else:
|
|
||||||
self.register_buffer('temperature', torch.ones(1) * temperature)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
|
|
||||||
# 生成归一化坐标网格 [-1, 1]
|
|
||||||
pos_x = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype)
|
|
||||||
pos_y = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
# 展平空间维度
|
|
||||||
x_flat = x.view(B, C, -1) # [B, C, H*W]
|
|
||||||
|
|
||||||
# Softmax 得到注意力权重
|
|
||||||
attention = F.softmax(x_flat / self.temperature, dim=-1) # [B, C, H*W]
|
|
||||||
|
|
||||||
# 计算期望坐标
|
|
||||||
# pos_x: [W] -> [1, 1, W] -> repeat -> [1, 1, H*W]
|
|
||||||
pos_x_grid = pos_x.view(1, 1, 1, W).expand(1, 1, H, W).reshape(1, 1, -1)
|
|
||||||
pos_y_grid = pos_y.view(1, 1, H, 1).expand(1, 1, H, W).reshape(1, 1, -1)
|
|
||||||
|
|
||||||
# 加权求和得到期望坐标
|
|
||||||
expected_x = (attention * pos_x_grid).sum(dim=-1) # [B, C]
|
|
||||||
expected_y = (attention * pos_y_grid).sum(dim=-1) # [B, C]
|
|
||||||
|
|
||||||
# 拼接 x, y 坐标
|
|
||||||
keypoints = torch.cat([expected_x, expected_y], dim=-1) # [B, C * 2]
|
|
||||||
|
|
||||||
return keypoints
|
|
||||||
|
|
||||||
from .ddt import ActionDDT
|
|
||||||
|
|
||||||
|
|
||||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
|
||||||
"""生成正弦位置编码表。"""
|
|
||||||
def get_position_angle_vec(position):
|
|
||||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
||||||
|
|
||||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
|
||||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
||||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
||||||
|
|
||||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class DDT(nn.Module):
|
|
||||||
"""DDT (Decoupled Diffusion Transformer) 模型。
|
|
||||||
|
|
||||||
将视觉 Backbone 和 ActionDDT 扩散模型组合,实现基于图像观测的动作序列生成。
|
|
||||||
|
|
||||||
架构:
|
|
||||||
1. Backbone: 提取多相机图像特征
|
|
||||||
2. 特征投影: 将图像特征投影到隐藏空间 (Bottleneck 降维)
|
|
||||||
3. 状态编码: 编码机器人关节状态
|
|
||||||
4. ActionDDT: 扩散 Transformer 生成动作序列
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backbones: 视觉骨干网络列表(每个相机一个)
|
|
||||||
state_dim: 机器人状态维度
|
|
||||||
action_dim: 动作维度
|
|
||||||
num_queries: 预测的动作序列长度
|
|
||||||
camera_names: 相机名称列表
|
|
||||||
hidden_dim: Transformer 隐藏维度
|
|
||||||
num_blocks: Transformer block 数量
|
|
||||||
num_encoder_blocks: 编码器 block 数量
|
|
||||||
num_heads: 注意力头数
|
|
||||||
num_kp: Spatial Softmax 的关键点数量 (默认 32)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backbones,
|
|
||||||
state_dim: int,
|
|
||||||
action_dim: int,
|
|
||||||
num_queries: int,
|
|
||||||
camera_names: list,
|
|
||||||
hidden_dim: int = 512,
|
|
||||||
num_blocks: int = 12,
|
|
||||||
num_encoder_blocks: int = 4,
|
|
||||||
num_heads: int = 8,
|
|
||||||
mlp_ratio: float = 4.0,
|
|
||||||
num_kp: int = 32, # [修改] 新增参数,默认 32
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_queries = num_queries
|
|
||||||
self.camera_names = camera_names
|
|
||||||
self.hidden_dim = hidden_dim
|
|
||||||
self.state_dim = state_dim
|
|
||||||
self.action_dim = action_dim
|
|
||||||
self.num_kp = num_kp
|
|
||||||
|
|
||||||
# Backbone 相关
|
|
||||||
self.backbones = nn.ModuleList(backbones)
|
|
||||||
|
|
||||||
# [修改] 投影层: ResNet Channels -> num_kp (32)
|
|
||||||
# 这是一个 Bottleneck 层,大幅减少特征通道数
|
|
||||||
self.input_proj = nn.Conv2d(
|
|
||||||
backbones[0].num_channels, num_kp, kernel_size=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# 状态编码 (2层 MLP,与 Diffusion Policy 一致)
|
|
||||||
# 状态依然映射到 hidden_dim (512),保持信息量
|
|
||||||
self.input_proj_robot_state = nn.Sequential(
|
|
||||||
nn.Linear(state_dim, hidden_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# [修改] 图像特征聚合 (SpatialSoftmax)
|
|
||||||
# 输入: [B, num_kp, H, W]
|
|
||||||
# 输出: [B, num_kp * 2] (每个通道的 x, y 坐标)
|
|
||||||
self.img_feature_proj = SpatialSoftmax(num_kp=num_kp)
|
|
||||||
|
|
||||||
# [修改] 计算观测维度: 图像特征 + 状态
|
|
||||||
# 图像部分: 关键点数量 * 2(x,y) * 摄像头数量
|
|
||||||
img_feature_dim = num_kp * 2 * len(camera_names)
|
|
||||||
obs_dim = img_feature_dim + hidden_dim
|
|
||||||
|
|
||||||
# ActionDDT 扩散模型
|
|
||||||
self.action_ddt = ActionDDT(
|
|
||||||
action_dim=action_dim,
|
|
||||||
obs_dim=obs_dim, # 使用新的、更紧凑的维度
|
|
||||||
action_horizon=num_queries,
|
|
||||||
hidden_size=hidden_dim,
|
|
||||||
num_blocks=num_blocks,
|
|
||||||
num_encoder_blocks=num_encoder_blocks,
|
|
||||||
num_heads=num_heads,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode_observations(self, qpos, image):
|
|
||||||
"""编码观测(图像 + 状态)为条件向量。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qpos: 机器人关节状态 [B, state_dim]
|
|
||||||
image: 多相机图像 [B, num_cam, C, H, W]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
obs: 观测条件向量 [B, obs_dim]
|
|
||||||
"""
|
|
||||||
bs = qpos.shape[0]
|
|
||||||
|
|
||||||
# 编码图像特征
|
|
||||||
all_cam_features = []
|
|
||||||
for cam_id, cam_name in enumerate(self.camera_names):
|
|
||||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
|
||||||
features = features[0] # 取最后一层特征
|
|
||||||
|
|
||||||
# [说明] 这里的 input_proj 现在会将通道压缩到 32
|
|
||||||
features = self.input_proj(features) # [B, num_kp, H', W']
|
|
||||||
|
|
||||||
# [说明] SpatialSoftmax 提取 32 个关键点坐标
|
|
||||||
features = self.img_feature_proj(features) # [B, num_kp * 2]
|
|
||||||
|
|
||||||
all_cam_features.append(features)
|
|
||||||
|
|
||||||
# 拼接所有相机特征
|
|
||||||
img_features = torch.cat(all_cam_features, dim=-1) # [B, num_kp * 2 * num_cam]
|
|
||||||
|
|
||||||
# 编码状态
|
|
||||||
qpos_features = self.input_proj_robot_state(qpos) # [B, hidden_dim]
|
|
||||||
|
|
||||||
# 拼接观测
|
|
||||||
obs = torch.cat([img_features, qpos_features], dim=-1) # [B, obs_dim]
|
|
||||||
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
qpos,
|
|
||||||
image,
|
|
||||||
env_state,
|
|
||||||
actions=None,
|
|
||||||
is_pad=None,
|
|
||||||
timesteps=None,
|
|
||||||
):
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
训练时:
|
|
||||||
输入带噪声的动作序列和时间步,预测噪声/速度场
|
|
||||||
推理时:
|
|
||||||
通过扩散采样生成动作序列
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qpos: 机器人关节状态 [B, state_dim]
|
|
||||||
image: 多相机图像 [B, num_cam, C, H, W]
|
|
||||||
env_state: 环境状态(未使用)
|
|
||||||
actions: 动作序列 [B, T, action_dim](训练时为带噪声动作)
|
|
||||||
is_pad: padding 标记 [B, T](未使用)
|
|
||||||
timesteps: 扩散时间步 [B](训练时提供)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练时: (noise_pred, encoder_state)
|
|
||||||
推理时: (action_pred, encoder_state)
|
|
||||||
"""
|
|
||||||
# 1. 编码观测
|
|
||||||
obs = self.encode_observations(qpos, image)
|
|
||||||
|
|
||||||
# 2. 扩散模型前向
|
|
||||||
if actions is not None and timesteps is not None:
|
|
||||||
# 训练模式: 预测噪声
|
|
||||||
noise_pred, encoder_state = self.action_ddt(
|
|
||||||
x=actions,
|
|
||||||
t=timesteps,
|
|
||||||
obs=obs,
|
|
||||||
)
|
|
||||||
return noise_pred, encoder_state
|
|
||||||
else:
|
|
||||||
# 推理模式: 需要在 Policy 层进行扩散采样
|
|
||||||
# 这里返回编码的观测,供 Policy 层使用
|
|
||||||
return obs, None
|
|
||||||
|
|
||||||
def get_obs_dim(self):
|
|
||||||
"""返回观测向量的维度。"""
|
|
||||||
# [修改] 使用 num_kp 重新计算
|
|
||||||
return self.num_kp * 2 * len(self.camera_names) + self.hidden_dim
|
|
||||||
|
|
||||||
|
|
||||||
def build(args):
|
|
||||||
"""构建 DDT 模型。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: 包含模型配置的参数对象
|
|
||||||
- state_dim: 状态维度
|
|
||||||
- action_dim: 动作维度
|
|
||||||
- camera_names: 相机名称列表
|
|
||||||
- hidden_dim: 隐藏维度
|
|
||||||
- num_queries: 动作序列长度
|
|
||||||
- num_blocks: Transformer block 数量
|
|
||||||
- enc_layers: 编码器层数
|
|
||||||
- nheads: 注意力头数
|
|
||||||
- num_kp: 关键点数量 (可选,默认32)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: DDT 模型实例
|
|
||||||
"""
|
|
||||||
state_dim = args.state_dim
|
|
||||||
action_dim = args.action_dim
|
|
||||||
|
|
||||||
# 构建 Backbone(每个相机一个)
|
|
||||||
backbones = []
|
|
||||||
for _ in args.camera_names:
|
|
||||||
backbone = build_backbone(args)
|
|
||||||
backbones.append(backbone)
|
|
||||||
|
|
||||||
# 构建 DDT 模型
|
|
||||||
model = DDT(
|
|
||||||
backbones=backbones,
|
|
||||||
state_dim=state_dim,
|
|
||||||
action_dim=action_dim,
|
|
||||||
num_queries=args.num_queries,
|
|
||||||
camera_names=args.camera_names,
|
|
||||||
hidden_dim=args.hidden_dim,
|
|
||||||
num_blocks=getattr(args, 'num_blocks', 12),
|
|
||||||
num_encoder_blocks=getattr(args, 'enc_layers', 4),
|
|
||||||
num_heads=args.nheads,
|
|
||||||
mlp_ratio=getattr(args, 'mlp_ratio', 4.0),
|
|
||||||
num_kp=getattr(args, 'num_kp', 32), # [修改] 传递 num_kp 参数
|
|
||||||
)
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def build_ddt(args):
|
|
||||||
"""build 的别名,保持接口一致性。"""
|
|
||||||
return build(args)
|
|
||||||
@@ -1,312 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
DETR Transformer class.
|
|
||||||
|
|
||||||
Copy-paste from torch.nn.Transformer with modifications:
|
|
||||||
* positional encodings are passed in MHattention
|
|
||||||
* extra LN at the end of encoder is removed
|
|
||||||
* decoder returns a stack of activations from all decoding layers
|
|
||||||
"""
|
|
||||||
import copy
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn, Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
|
||||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False,
|
|
||||||
return_intermediate_dec=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
|
||||||
dropout, activation, normalize_before)
|
|
||||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
||||||
|
|
||||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
|
||||||
dropout, activation, normalize_before)
|
|
||||||
decoder_norm = nn.LayerNorm(d_model)
|
|
||||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
|
||||||
return_intermediate=return_intermediate_dec)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.nhead = nhead
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
|
||||||
# TODO flatten only when input has H and W
|
|
||||||
if len(src.shape) == 4: # has H and W
|
|
||||||
# flatten NxCxHxW to HWxNxC
|
|
||||||
bs, c, h, w = src.shape
|
|
||||||
src = src.flatten(2).permute(2, 0, 1)
|
|
||||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
|
||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
# mask = mask.flatten(1)
|
|
||||||
|
|
||||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
|
||||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
|
||||||
|
|
||||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
|
||||||
src = torch.cat([addition_input, src], axis=0)
|
|
||||||
else:
|
|
||||||
assert len(src.shape) == 3
|
|
||||||
# flatten NxHWxC to HWxNxC
|
|
||||||
bs, hw, c = src.shape
|
|
||||||
src = src.permute(1, 0, 2)
|
|
||||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
||||||
|
|
||||||
tgt = torch.zeros_like(query_embed)
|
|
||||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
|
||||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
|
||||||
pos=pos_embed, query_pos=query_embed)
|
|
||||||
hs = hs.transpose(1, 2)
|
|
||||||
return hs
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = _get_clones(encoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
def forward(self, src,
|
|
||||||
mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
output = src
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
output = layer(output, src_mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
|
||||||
|
|
||||||
if self.norm is not None:
|
|
||||||
output = self.norm(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = _get_clones(decoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
self.return_intermediate = return_intermediate
|
|
||||||
|
|
||||||
def forward(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
output = tgt
|
|
||||||
|
|
||||||
intermediate = []
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
output = layer(output, memory, tgt_mask=tgt_mask,
|
|
||||||
memory_mask=memory_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
pos=pos, query_pos=query_pos)
|
|
||||||
if self.return_intermediate:
|
|
||||||
intermediate.append(self.norm(output))
|
|
||||||
|
|
||||||
if self.norm is not None:
|
|
||||||
output = self.norm(output)
|
|
||||||
if self.return_intermediate:
|
|
||||||
intermediate.pop()
|
|
||||||
intermediate.append(output)
|
|
||||||
|
|
||||||
if self.return_intermediate:
|
|
||||||
return torch.stack(intermediate)
|
|
||||||
|
|
||||||
return output.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward_post(self,
|
|
||||||
src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
q = k = self.with_pos_embed(src, pos)
|
|
||||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src = self.norm1(src)
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
||||||
src = src + self.dropout2(src2)
|
|
||||||
src = self.norm2(src)
|
|
||||||
return src
|
|
||||||
|
|
||||||
def forward_pre(self, src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
src2 = self.norm1(src)
|
|
||||||
q = k = self.with_pos_embed(src2, pos)
|
|
||||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src2 = self.norm2(src)
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
|
||||||
src = src + self.dropout2(src2)
|
|
||||||
return src
|
|
||||||
|
|
||||||
def forward(self, src,
|
|
||||||
src_mask: Optional[Tensor] = None,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None):
|
|
||||||
if self.normalize_before:
|
|
||||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
|
||||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", normalize_before=False):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.norm3 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward_post(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
q = k = self.with_pos_embed(tgt, query_pos)
|
|
||||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
|
||||||
key=self.with_pos_embed(memory, pos),
|
|
||||||
value=memory, attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
tgt = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
||||||
tgt = tgt + self.dropout3(tgt2)
|
|
||||||
tgt = self.norm3(tgt)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
def forward_pre(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
|
||||||
key=self.with_pos_embed(memory, pos),
|
|
||||||
value=memory, attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
tgt2 = self.norm3(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout3(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
def forward(self, tgt, memory,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
memory_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
pos: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
if self.normalize_before:
|
|
||||||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
|
||||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
|
||||||
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
|
||||||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, N):
|
|
||||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
||||||
|
|
||||||
|
|
||||||
def build_transformer(args):
|
|
||||||
return Transformer(
|
|
||||||
d_model=args.hidden_dim,
|
|
||||||
dropout=args.dropout,
|
|
||||||
nhead=args.nheads,
|
|
||||||
dim_feedforward=args.dim_feedforward,
|
|
||||||
num_encoder_layers=args.enc_layers,
|
|
||||||
num_decoder_layers=args.dec_layers,
|
|
||||||
normalize_before=args.pre_norm,
|
|
||||||
return_intermediate_dec=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
"""
|
|
||||||
DDT Policy - 基于扩散模型的动作生成策略。
|
|
||||||
|
|
||||||
支持 Flow Matching 训练和推理。
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
import math
|
|
||||||
|
|
||||||
from roboimi.ddt.main import build_DDT_model_and_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class DDTPolicy(nn.Module):
|
|
||||||
"""DDT (Decoupled Diffusion Transformer) 策略。
|
|
||||||
|
|
||||||
使用 Flow Matching 进行训练,支持多步扩散采样推理。
|
|
||||||
带数据增强,适配 DINOv2 等 ViT backbone。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args_override: 配置参数字典
|
|
||||||
- num_inference_steps: 推理时的扩散步数
|
|
||||||
- qpos_noise_std: qpos 噪声标准差(训练时数据增强)
|
|
||||||
- patch_h, patch_w: 图像 patch 数量(用于计算目标尺寸)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, args_override):
|
|
||||||
super().__init__()
|
|
||||||
model, optimizer = build_DDT_model_and_optimizer(args_override)
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
self.num_inference_steps = args_override.get('num_inference_steps', 10)
|
|
||||||
self.qpos_noise_std = args_override.get('qpos_noise_std', 0.0)
|
|
||||||
|
|
||||||
# 图像尺寸配置 (适配 DINOv2)
|
|
||||||
self.patch_h = args_override.get('patch_h', 16)
|
|
||||||
self.patch_w = args_override.get('patch_w', 22)
|
|
||||||
|
|
||||||
print(f'DDT Policy: {self.num_inference_steps} steps, '
|
|
||||||
f'image size ({self.patch_h*14}, {self.patch_w*14})')
|
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
|
||||||
"""前向传播。
|
|
||||||
|
|
||||||
训练时: 使用 Flow Matching 损失
|
|
||||||
推理时: 通过扩散采样生成动作
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qpos: 机器人关节状态 [B, state_dim]
|
|
||||||
image: 多相机图像 [B, num_cam, C, H, W]
|
|
||||||
actions: 目标动作序列 [B, T, action_dim](训练时提供)
|
|
||||||
is_pad: padding 标记 [B, T]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练时: loss_dict
|
|
||||||
推理时: 预测的动作序列 [B, T, action_dim]
|
|
||||||
"""
|
|
||||||
env_state = None
|
|
||||||
|
|
||||||
# 图像预处理
|
|
||||||
if actions is not None: # 训练时:数据增强
|
|
||||||
transform = v2.Compose([
|
|
||||||
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
|
||||||
v2.RandomPerspective(distortion_scale=0.5),
|
|
||||||
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
|
||||||
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
|
||||||
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
])
|
|
||||||
if self.qpos_noise_std > 0:
|
|
||||||
qpos = qpos + (self.qpos_noise_std ** 0.5) * torch.randn_like(qpos)
|
|
||||||
else: # 推理时
|
|
||||||
transform = v2.Compose([
|
|
||||||
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
|
|
||||||
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
])
|
|
||||||
|
|
||||||
image = transform(image)
|
|
||||||
|
|
||||||
if actions is not None:
|
|
||||||
actions = actions[:, :self.model.num_queries]
|
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
|
||||||
loss_dict = self._compute_loss(qpos, image, actions, is_pad)
|
|
||||||
return loss_dict
|
|
||||||
else:
|
|
||||||
a_hat = self._sample(qpos, image)
|
|
||||||
return a_hat
|
|
||||||
|
|
||||||
def _compute_loss(self, qpos, image, actions, is_pad):
|
|
||||||
"""计算 Flow Matching 损失。
|
|
||||||
|
|
||||||
Flow Matching 目标: 学习从噪声到数据的向量场
|
|
||||||
损失: ||v_theta(x_t, t) - (x_1 - x_0)||^2
|
|
||||||
其中 x_t = (1-t)*x_0 + t*x_1, x_0 是噪声, x_1 是目标动作
|
|
||||||
"""
|
|
||||||
B, T, action_dim = actions.shape
|
|
||||||
device = actions.device
|
|
||||||
|
|
||||||
t = torch.rand(B, device=device)
|
|
||||||
noise = torch.randn_like(actions)
|
|
||||||
|
|
||||||
t_expand = t.view(B, 1, 1).expand(B, T, action_dim)
|
|
||||||
x_t = (1 - t_expand) * noise + t_expand * actions
|
|
||||||
target_velocity = actions - noise
|
|
||||||
|
|
||||||
pred_velocity, _ = self.model(
|
|
||||||
qpos=qpos,
|
|
||||||
image=image,
|
|
||||||
env_state=None,
|
|
||||||
actions=x_t,
|
|
||||||
timesteps=t,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_loss = F.mse_loss(pred_velocity, target_velocity, reduction='none')
|
|
||||||
loss = (all_loss * ~is_pad.unsqueeze(-1)).mean()
|
|
||||||
|
|
||||||
return {'flow_loss': loss, 'loss': loss}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _sample(self, qpos, image):
|
|
||||||
"""通过 ODE 求解进行扩散采样。
|
|
||||||
|
|
||||||
使用 Euler 方法从 t=0 积分到 t=1:
|
|
||||||
x_{t+dt} = x_t + v_theta(x_t, t) * dt
|
|
||||||
"""
|
|
||||||
B = qpos.shape[0]
|
|
||||||
T = self.model.num_queries
|
|
||||||
action_dim = self.model.action_dim
|
|
||||||
device = qpos.device
|
|
||||||
|
|
||||||
x = torch.randn(B, T, action_dim, device=device)
|
|
||||||
obs = self.model.encode_observations(qpos, image)
|
|
||||||
|
|
||||||
dt = 1.0 / self.num_inference_steps
|
|
||||||
for i in range(self.num_inference_steps):
|
|
||||||
t = torch.full((B,), i * dt, device=device)
|
|
||||||
velocity, _ = self.model.action_ddt(x=x, t=t, obs=obs)
|
|
||||||
x = x + velocity * dt
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
"""返回优化器。"""
|
|
||||||
return self.optimizer
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Utilities for bounding box manipulation and GIoU.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from torchvision.ops.boxes import box_area
|
|
||||||
|
|
||||||
|
|
||||||
def box_cxcywh_to_xyxy(x):
|
|
||||||
x_c, y_c, w, h = x.unbind(-1)
|
|
||||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
|
||||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
||||||
return torch.stack(b, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def box_xyxy_to_cxcywh(x):
|
|
||||||
x0, y0, x1, y1 = x.unbind(-1)
|
|
||||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
|
||||||
(x1 - x0), (y1 - y0)]
|
|
||||||
return torch.stack(b, dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# modified from torchvision to also return the union
|
|
||||||
def box_iou(boxes1, boxes2):
|
|
||||||
area1 = box_area(boxes1)
|
|
||||||
area2 = box_area(boxes2)
|
|
||||||
|
|
||||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
||||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
||||||
|
|
||||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
||||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
||||||
|
|
||||||
union = area1[:, None] + area2 - inter
|
|
||||||
|
|
||||||
iou = inter / union
|
|
||||||
return iou, union
|
|
||||||
|
|
||||||
|
|
||||||
def generalized_box_iou(boxes1, boxes2):
|
|
||||||
"""
|
|
||||||
Generalized IoU from https://giou.stanford.edu/
|
|
||||||
|
|
||||||
The boxes should be in [x0, y0, x1, y1] format
|
|
||||||
|
|
||||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
|
||||||
and M = len(boxes2)
|
|
||||||
"""
|
|
||||||
# degenerate boxes gives inf / nan results
|
|
||||||
# so do an early check
|
|
||||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
|
||||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
|
||||||
iou, union = box_iou(boxes1, boxes2)
|
|
||||||
|
|
||||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
|
||||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
||||||
|
|
||||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
||||||
area = wh[:, :, 0] * wh[:, :, 1]
|
|
||||||
|
|
||||||
return iou - (area - union) / area
|
|
||||||
|
|
||||||
|
|
||||||
def masks_to_boxes(masks):
|
|
||||||
"""Compute the bounding boxes around the provided masks
|
|
||||||
|
|
||||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
|
||||||
|
|
||||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
|
||||||
"""
|
|
||||||
if masks.numel() == 0:
|
|
||||||
return torch.zeros((0, 4), device=masks.device)
|
|
||||||
|
|
||||||
h, w = masks.shape[-2:]
|
|
||||||
|
|
||||||
y = torch.arange(0, h, dtype=torch.float)
|
|
||||||
x = torch.arange(0, w, dtype=torch.float)
|
|
||||||
y, x = torch.meshgrid(y, x)
|
|
||||||
|
|
||||||
x_mask = (masks * x.unsqueeze(0))
|
|
||||||
x_max = x_mask.flatten(1).max(-1)[0]
|
|
||||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
||||||
|
|
||||||
y_mask = (masks * y.unsqueeze(0))
|
|
||||||
y_max = y_mask.flatten(1).max(-1)[0]
|
|
||||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
||||||
|
|
||||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
|
||||||
@@ -1,468 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
||||||
"""
|
|
||||||
Misc functions, including distributed helpers.
|
|
||||||
|
|
||||||
Mostly copy-paste from torchvision references.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
import datetime
|
|
||||||
import pickle
|
|
||||||
from packaging import version
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
|
||||||
import torchvision
|
|
||||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
||||||
from torchvision.ops import _new_empty_tensor
|
|
||||||
from torchvision.ops.misc import _output_size
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothedValue(object):
|
|
||||||
"""Track a series of values and provide access to smoothed values over a
|
|
||||||
window or the global series average.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, window_size=20, fmt=None):
|
|
||||||
if fmt is None:
|
|
||||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
|
||||||
self.deque = deque(maxlen=window_size)
|
|
||||||
self.total = 0.0
|
|
||||||
self.count = 0
|
|
||||||
self.fmt = fmt
|
|
||||||
|
|
||||||
def update(self, value, n=1):
|
|
||||||
self.deque.append(value)
|
|
||||||
self.count += n
|
|
||||||
self.total += value * n
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
"""
|
|
||||||
Warning: does not synchronize the deque!
|
|
||||||
"""
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return
|
|
||||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
|
||||||
dist.barrier()
|
|
||||||
dist.all_reduce(t)
|
|
||||||
t = t.tolist()
|
|
||||||
self.count = int(t[0])
|
|
||||||
self.total = t[1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def median(self):
|
|
||||||
d = torch.tensor(list(self.deque))
|
|
||||||
return d.median().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avg(self):
|
|
||||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
|
||||||
return d.mean().item()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_avg(self):
|
|
||||||
return self.total / self.count
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max(self):
|
|
||||||
return max(self.deque)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def value(self):
|
|
||||||
return self.deque[-1]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.fmt.format(
|
|
||||||
median=self.median,
|
|
||||||
avg=self.avg,
|
|
||||||
global_avg=self.global_avg,
|
|
||||||
max=self.max,
|
|
||||||
value=self.value)
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather(data):
|
|
||||||
"""
|
|
||||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
|
||||||
Args:
|
|
||||||
data: any picklable object
|
|
||||||
Returns:
|
|
||||||
list[data]: list of data gathered from each rank
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size == 1:
|
|
||||||
return [data]
|
|
||||||
|
|
||||||
# serialized to a Tensor
|
|
||||||
buffer = pickle.dumps(data)
|
|
||||||
storage = torch.ByteStorage.from_buffer(buffer)
|
|
||||||
tensor = torch.ByteTensor(storage).to("cuda")
|
|
||||||
|
|
||||||
# obtain Tensor size of each rank
|
|
||||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
|
||||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
|
||||||
dist.all_gather(size_list, local_size)
|
|
||||||
size_list = [int(size.item()) for size in size_list]
|
|
||||||
max_size = max(size_list)
|
|
||||||
|
|
||||||
# receiving Tensor from all ranks
|
|
||||||
# we pad the tensor because torch all_gather does not support
|
|
||||||
# gathering tensors of different shapes
|
|
||||||
tensor_list = []
|
|
||||||
for _ in size_list:
|
|
||||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
|
||||||
if local_size != max_size:
|
|
||||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
|
||||||
tensor = torch.cat((tensor, padding), dim=0)
|
|
||||||
dist.all_gather(tensor_list, tensor)
|
|
||||||
|
|
||||||
data_list = []
|
|
||||||
for size, tensor in zip(size_list, tensor_list):
|
|
||||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
|
||||||
data_list.append(pickle.loads(buffer))
|
|
||||||
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_dict(input_dict, average=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_dict (dict): all the values will be reduced
|
|
||||||
average (bool): whether to do average or sum
|
|
||||||
Reduce the values in the dictionary from all processes so that all processes
|
|
||||||
have the averaged results. Returns a dict with the same fields as
|
|
||||||
input_dict, after reduction.
|
|
||||||
"""
|
|
||||||
world_size = get_world_size()
|
|
||||||
if world_size < 2:
|
|
||||||
return input_dict
|
|
||||||
with torch.no_grad():
|
|
||||||
names = []
|
|
||||||
values = []
|
|
||||||
# sort the keys so that they are consistent across processes
|
|
||||||
for k in sorted(input_dict.keys()):
|
|
||||||
names.append(k)
|
|
||||||
values.append(input_dict[k])
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
dist.all_reduce(values)
|
|
||||||
if average:
|
|
||||||
values /= world_size
|
|
||||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
|
||||||
return reduced_dict
|
|
||||||
|
|
||||||
|
|
||||||
class MetricLogger(object):
|
|
||||||
def __init__(self, delimiter="\t"):
|
|
||||||
self.meters = defaultdict(SmoothedValue)
|
|
||||||
self.delimiter = delimiter
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
v = v.item()
|
|
||||||
assert isinstance(v, (float, int))
|
|
||||||
self.meters[k].update(v)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
if attr in self.meters:
|
|
||||||
return self.meters[attr]
|
|
||||||
if attr in self.__dict__:
|
|
||||||
return self.__dict__[attr]
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
|
||||||
type(self).__name__, attr))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
loss_str = []
|
|
||||||
for name, meter in self.meters.items():
|
|
||||||
loss_str.append(
|
|
||||||
"{}: {}".format(name, str(meter))
|
|
||||||
)
|
|
||||||
return self.delimiter.join(loss_str)
|
|
||||||
|
|
||||||
def synchronize_between_processes(self):
|
|
||||||
for meter in self.meters.values():
|
|
||||||
meter.synchronize_between_processes()
|
|
||||||
|
|
||||||
def add_meter(self, name, meter):
|
|
||||||
self.meters[name] = meter
|
|
||||||
|
|
||||||
def log_every(self, iterable, print_freq, header=None):
|
|
||||||
i = 0
|
|
||||||
if not header:
|
|
||||||
header = ''
|
|
||||||
start_time = time.time()
|
|
||||||
end = time.time()
|
|
||||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
|
||||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
|
||||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_msg = self.delimiter.join([
|
|
||||||
header,
|
|
||||||
'[{0' + space_fmt + '}/{1}]',
|
|
||||||
'eta: {eta}',
|
|
||||||
'{meters}',
|
|
||||||
'time: {time}',
|
|
||||||
'data: {data}',
|
|
||||||
'max mem: {memory:.0f}'
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
log_msg = self.delimiter.join([
|
|
||||||
header,
|
|
||||||
'[{0' + space_fmt + '}/{1}]',
|
|
||||||
'eta: {eta}',
|
|
||||||
'{meters}',
|
|
||||||
'time: {time}',
|
|
||||||
'data: {data}'
|
|
||||||
])
|
|
||||||
MB = 1024.0 * 1024.0
|
|
||||||
for obj in iterable:
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
yield obj
|
|
||||||
iter_time.update(time.time() - end)
|
|
||||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
|
||||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
|
||||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print(log_msg.format(
|
|
||||||
i, len(iterable), eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time), data=str(data_time),
|
|
||||||
memory=torch.cuda.max_memory_allocated() / MB))
|
|
||||||
else:
|
|
||||||
print(log_msg.format(
|
|
||||||
i, len(iterable), eta=eta_string,
|
|
||||||
meters=str(self),
|
|
||||||
time=str(iter_time), data=str(data_time)))
|
|
||||||
i += 1
|
|
||||||
end = time.time()
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
||||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
|
||||||
header, total_time_str, total_time / len(iterable)))
|
|
||||||
|
|
||||||
|
|
||||||
def get_sha():
|
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
def _run(command):
|
|
||||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
|
||||||
sha = 'N/A'
|
|
||||||
diff = "clean"
|
|
||||||
branch = 'N/A'
|
|
||||||
try:
|
|
||||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
|
||||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
|
||||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
|
||||||
diff = "has uncommited changes" if diff else "clean"
|
|
||||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
|
||||||
batch = list(zip(*batch))
|
|
||||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
|
||||||
return tuple(batch)
|
|
||||||
|
|
||||||
|
|
||||||
def _max_by_axis(the_list):
|
|
||||||
# type: (List[List[int]]) -> List[int]
|
|
||||||
maxes = the_list[0]
|
|
||||||
for sublist in the_list[1:]:
|
|
||||||
for index, item in enumerate(sublist):
|
|
||||||
maxes[index] = max(maxes[index], item)
|
|
||||||
return maxes
|
|
||||||
|
|
||||||
|
|
||||||
class NestedTensor(object):
|
|
||||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
|
||||||
self.tensors = tensors
|
|
||||||
self.mask = mask
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
# type: (Device) -> NestedTensor # noqa
|
|
||||||
cast_tensor = self.tensors.to(device)
|
|
||||||
mask = self.mask
|
|
||||||
if mask is not None:
|
|
||||||
assert mask is not None
|
|
||||||
cast_mask = mask.to(device)
|
|
||||||
else:
|
|
||||||
cast_mask = None
|
|
||||||
return NestedTensor(cast_tensor, cast_mask)
|
|
||||||
|
|
||||||
def decompose(self):
|
|
||||||
return self.tensors, self.mask
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self.tensors)
|
|
||||||
|
|
||||||
|
|
||||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
||||||
# TODO make this more general
|
|
||||||
if tensor_list[0].ndim == 3:
|
|
||||||
if torchvision._is_tracing():
|
|
||||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
|
||||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
|
||||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
|
||||||
|
|
||||||
# TODO make it support different-sized images
|
|
||||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
||||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
|
||||||
batch_shape = [len(tensor_list)] + max_size
|
|
||||||
b, c, h, w = batch_shape
|
|
||||||
dtype = tensor_list[0].dtype
|
|
||||||
device = tensor_list[0].device
|
|
||||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
||||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
|
||||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
|
||||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
m[: img.shape[1], :img.shape[2]] = False
|
|
||||||
else:
|
|
||||||
raise ValueError('not supported')
|
|
||||||
return NestedTensor(tensor, mask)
|
|
||||||
|
|
||||||
|
|
||||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
|
||||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
|
||||||
@torch.jit.unused
|
|
||||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
|
||||||
max_size = []
|
|
||||||
for i in range(tensor_list[0].dim()):
|
|
||||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
|
||||||
max_size.append(max_size_i)
|
|
||||||
max_size = tuple(max_size)
|
|
||||||
|
|
||||||
# work around for
|
|
||||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
||||||
# m[: img.shape[1], :img.shape[2]] = False
|
|
||||||
# which is not yet supported in onnx
|
|
||||||
padded_imgs = []
|
|
||||||
padded_masks = []
|
|
||||||
for img in tensor_list:
|
|
||||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
|
||||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
|
||||||
padded_imgs.append(padded_img)
|
|
||||||
|
|
||||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
|
||||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
|
||||||
padded_masks.append(padded_mask.to(torch.bool))
|
|
||||||
|
|
||||||
tensor = torch.stack(padded_imgs)
|
|
||||||
mask = torch.stack(padded_masks)
|
|
||||||
|
|
||||||
return NestedTensor(tensor, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_for_distributed(is_master):
|
|
||||||
"""
|
|
||||||
This function disables printing when not in master process
|
|
||||||
"""
|
|
||||||
import builtins as __builtin__
|
|
||||||
builtin_print = __builtin__.print
|
|
||||||
|
|
||||||
def print(*args, **kwargs):
|
|
||||||
force = kwargs.pop('force', False)
|
|
||||||
if is_master or force:
|
|
||||||
builtin_print(*args, **kwargs)
|
|
||||||
|
|
||||||
__builtin__.print = print
|
|
||||||
|
|
||||||
|
|
||||||
def is_dist_avail_and_initialized():
|
|
||||||
if not dist.is_available():
|
|
||||||
return False
|
|
||||||
if not dist.is_initialized():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 1
|
|
||||||
return dist.get_world_size()
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank():
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return 0
|
|
||||||
return dist.get_rank()
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
|
||||||
return get_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def save_on_master(*args, **kwargs):
|
|
||||||
if is_main_process():
|
|
||||||
torch.save(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_mode(args):
|
|
||||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
||||||
args.rank = int(os.environ["RANK"])
|
|
||||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
|
||||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
|
||||||
elif 'SLURM_PROCID' in os.environ:
|
|
||||||
args.rank = int(os.environ['SLURM_PROCID'])
|
|
||||||
args.gpu = args.rank % torch.cuda.device_count()
|
|
||||||
else:
|
|
||||||
print('Not using distributed mode')
|
|
||||||
args.distributed = False
|
|
||||||
return
|
|
||||||
|
|
||||||
args.distributed = True
|
|
||||||
|
|
||||||
torch.cuda.set_device(args.gpu)
|
|
||||||
args.dist_backend = 'nccl'
|
|
||||||
print('| distributed init (rank {}): {}'.format(
|
|
||||||
args.rank, args.dist_url), flush=True)
|
|
||||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
|
||||||
world_size=args.world_size, rank=args.rank)
|
|
||||||
torch.distributed.barrier()
|
|
||||||
setup_for_distributed(args.rank == 0)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
|
||||||
"""Computes the precision@k for the specified values of k"""
|
|
||||||
if target.numel() == 0:
|
|
||||||
return [torch.zeros([], device=output.device)]
|
|
||||||
maxk = max(topk)
|
|
||||||
batch_size = target.size(0)
|
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
|
||||||
pred = pred.t()
|
|
||||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for k in topk:
|
|
||||||
correct_k = correct[:k].view(-1).float().sum(0)
|
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
|
||||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
|
||||||
"""
|
|
||||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
|
||||||
This will eventually be supported natively by PyTorch, and this
|
|
||||||
class can go away.
|
|
||||||
"""
|
|
||||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
||||||
if input.numel() > 0:
|
|
||||||
return torch.nn.functional.interpolate(
|
|
||||||
input, size, scale_factor, mode, align_corners
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shape = _output_size(2, input, size, scale_factor)
|
|
||||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
|
||||||
return _new_empty_tensor(input, output_shape)
|
|
||||||
else:
|
|
||||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""
|
|
||||||
Plotting utilities to visualize training logs.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from pathlib import Path, PurePath
|
|
||||||
|
|
||||||
|
|
||||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
|
||||||
'''
|
|
||||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
|
||||||
|
|
||||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
|
||||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
|
||||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
|
||||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
|
||||||
|
|
||||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
|
||||||
- solid lines are training results, dashed lines are test results.
|
|
||||||
|
|
||||||
'''
|
|
||||||
func_name = "plot_utils.py::plot_logs"
|
|
||||||
|
|
||||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
|
||||||
# convert single Path to list to avoid 'not iterable' error
|
|
||||||
|
|
||||||
if not isinstance(logs, list):
|
|
||||||
if isinstance(logs, PurePath):
|
|
||||||
logs = [logs]
|
|
||||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
|
||||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
|
||||||
|
|
||||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
|
||||||
for i, dir in enumerate(logs):
|
|
||||||
if not isinstance(dir, PurePath):
|
|
||||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
|
||||||
if not dir.exists():
|
|
||||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
|
||||||
# verify log_name exists
|
|
||||||
fn = Path(dir / log_name)
|
|
||||||
if not fn.exists():
|
|
||||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
|
||||||
print(f"--> full path of missing log file: {fn}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# load log file(s) and plot
|
|
||||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
|
||||||
|
|
||||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
|
||||||
for j, field in enumerate(fields):
|
|
||||||
if field == 'mAP':
|
|
||||||
coco_eval = pd.DataFrame(
|
|
||||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
|
||||||
).ewm(com=ewm_col).mean()
|
|
||||||
axs[j].plot(coco_eval, c=color)
|
|
||||||
else:
|
|
||||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
|
||||||
y=[f'train_{field}', f'test_{field}'],
|
|
||||||
ax=axs[j],
|
|
||||||
color=[color] * 2,
|
|
||||||
style=['-', '--']
|
|
||||||
)
|
|
||||||
for ax, field in zip(axs, fields):
|
|
||||||
ax.legend([Path(p).name for p in logs])
|
|
||||||
ax.set_title(field)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_precision_recall(files, naming_scheme='iter'):
|
|
||||||
if naming_scheme == 'exp_id':
|
|
||||||
# name becomes exp_id
|
|
||||||
names = [f.parts[-3] for f in files]
|
|
||||||
elif naming_scheme == 'iter':
|
|
||||||
names = [f.stem for f in files]
|
|
||||||
else:
|
|
||||||
raise ValueError(f'not supported {naming_scheme}')
|
|
||||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
|
||||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
|
||||||
data = torch.load(f)
|
|
||||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
|
||||||
precision = data['precision']
|
|
||||||
recall = data['params'].recThrs
|
|
||||||
scores = data['scores']
|
|
||||||
# take precision for all classes, all areas and 100 detections
|
|
||||||
precision = precision[0, :, :, 0, -1].mean(1)
|
|
||||||
scores = scores[0, :, :, 0, -1].mean(1)
|
|
||||||
prec = precision.mean()
|
|
||||||
rec = data['recall'][0, :, 0, -1].mean()
|
|
||||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
|
||||||
f'score={scores.mean():0.3f}, ' +
|
|
||||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
|
||||||
)
|
|
||||||
axs[0].plot(recall, precision, c=color)
|
|
||||||
axs[1].plot(recall, scores, c=color)
|
|
||||||
|
|
||||||
axs[0].set_title('Precision / Recall')
|
|
||||||
axs[0].legend(names)
|
|
||||||
axs[1].set_title('Scores / Recall')
|
|
||||||
axs[1].legend(names)
|
|
||||||
return fig, axs
|
|
||||||
@@ -8,8 +8,7 @@ temporal_agg: false
|
|||||||
|
|
||||||
# policy_class: "ACT"
|
# policy_class: "ACT"
|
||||||
# backbone: 'resnet18'
|
# backbone: 'resnet18'
|
||||||
policy_class: "ACTTV"
|
policy_class: "GR00T"
|
||||||
# policy_class: "DDT"
|
|
||||||
backbone: 'dino_v2'
|
backbone: 'dino_v2'
|
||||||
|
|
||||||
seed: 0
|
seed: 0
|
||||||
@@ -40,7 +39,7 @@ camera_names: [] # leave empty here by default
|
|||||||
xml_dir: # leave empty here by default
|
xml_dir: # leave empty here by default
|
||||||
|
|
||||||
# transformer settings
|
# transformer settings
|
||||||
batch_size: 32
|
batch_size: 15
|
||||||
state_dim: 16
|
state_dim: 16
|
||||||
action_dim: 16
|
action_dim: 16
|
||||||
lr_backbone: 0.00001
|
lr_backbone: 0.00001
|
||||||
@@ -52,6 +51,21 @@ nheads: 8
|
|||||||
qpos_noise_std: 0
|
qpos_noise_std: 0
|
||||||
DT: 0.02
|
DT: 0.02
|
||||||
|
|
||||||
|
gr00t:
|
||||||
|
action_dim: 16
|
||||||
|
state_dim: 16
|
||||||
|
embed_dim: 1536
|
||||||
|
hidden_dim: 1024
|
||||||
|
num_queries: 8
|
||||||
|
|
||||||
|
nheads: 32
|
||||||
|
mlp_ratio: 4
|
||||||
|
dropout: 0.2
|
||||||
|
|
||||||
|
num_layers: 16
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# DO NOT CHANGE IF UNNECESSARY
|
# DO NOT CHANGE IF UNNECESSARY
|
||||||
lr: 0.00001
|
lr: 0.00001
|
||||||
kl_weight: 100
|
kl_weight: 100
|
||||||
@@ -59,8 +73,3 @@ chunk_size: 10
|
|||||||
hidden_dim: 512
|
hidden_dim: 512
|
||||||
dim_feedforward: 3200
|
dim_feedforward: 3200
|
||||||
|
|
||||||
# DDT 特有参数
|
|
||||||
num_blocks: 12 # Transformer blocks 数量
|
|
||||||
mlp_ratio: 4.0 # MLP 维度比例
|
|
||||||
num_inference_steps: 10 # 扩散推理步数
|
|
||||||
|
|
||||||
|
|||||||
@@ -71,11 +71,10 @@ def run_episode(config, policy, stats, save_episode,num_rollouts):
|
|||||||
qpos = pre_process(qpos_numpy)
|
qpos = pre_process(qpos_numpy)
|
||||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||||
curr_image = get_image(env._get_image_obs(), config['camera_names'])
|
curr_image = get_image(env._get_image_obs(), config['camera_names'])
|
||||||
if config['policy_class'] == "ACT" or "ACTTV":
|
if config['policy_class'] in ["ACT", "ACTTV", "GR00T"]:
|
||||||
if t % query_frequency == 0:
|
if t % query_frequency == 0:
|
||||||
all_actions = policy(qpos, curr_image)
|
all_actions = policy(qpos, curr_image)
|
||||||
raw_action = all_actions[:, t % query_frequency]
|
raw_action = all_actions[:, t % query_frequency]
|
||||||
# raw_action = all_actions[:, t % 1]
|
|
||||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||||
elif config['policy_class'] == "CNNMLP":
|
elif config['policy_class'] == "CNNMLP":
|
||||||
raw_action = policy(qpos, curr_image)
|
raw_action = policy(qpos, curr_image)
|
||||||
|
|||||||
125
roboimi/gr00t/main.py
Normal file
125
roboimi/gr00t/main.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
GR00T (diffusion-based DiT policy) model builder.
|
||||||
|
|
||||||
|
This module provides functions to build GR00T models and optimizers
|
||||||
|
from configuration dictionaries (typically from config.yaml's 'gr00t:' section).
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from .models import build_gr00t_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_args_parser():
|
||||||
|
"""
|
||||||
|
Create argument parser for GR00T model configuration.
|
||||||
|
|
||||||
|
All parameters can be overridden via args_override dictionary in
|
||||||
|
build_gr00t_model_and_optimizer(). This allows loading from config.yaml.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser('GR00T training and evaluation script', add_help=False)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument('--lr', default=1e-5, type=float,
|
||||||
|
help='Learning rate for main parameters')
|
||||||
|
parser.add_argument('--lr_backbone', default=1e-5, type=float,
|
||||||
|
help='Learning rate for backbone parameters')
|
||||||
|
parser.add_argument('--weight_decay', default=1e-4, type=float,
|
||||||
|
help='Weight decay for optimizer')
|
||||||
|
|
||||||
|
# GR00T model architecture parameters
|
||||||
|
parser.add_argument('--embed_dim', default=1536, type=int,
|
||||||
|
help='Embedding dimension for transformer')
|
||||||
|
parser.add_argument('--hidden_dim', default=1024, type=int,
|
||||||
|
help='Hidden dimension for MLP layers')
|
||||||
|
parser.add_argument('--state_dim', default=16, type=int,
|
||||||
|
help='State (qpos) dimension')
|
||||||
|
parser.add_argument('--action_dim', default=16, type=int,
|
||||||
|
help='Action dimension')
|
||||||
|
parser.add_argument('--num_queries', default=16, type=int,
|
||||||
|
help='Number of action queries (chunk size)')
|
||||||
|
|
||||||
|
# DiT (Diffusion Transformer) parameters
|
||||||
|
parser.add_argument('--num_layers', default=16, type=int,
|
||||||
|
help='Number of transformer layers')
|
||||||
|
parser.add_argument('--nheads', default=32, type=int,
|
||||||
|
help='Number of attention heads')
|
||||||
|
parser.add_argument('--mlp_ratio', default=4, type=float,
|
||||||
|
help='MLP hidden dimension ratio')
|
||||||
|
parser.add_argument('--dropout', default=0.2, type=float,
|
||||||
|
help='Dropout rate')
|
||||||
|
|
||||||
|
# Backbone parameters
|
||||||
|
parser.add_argument('--backbone', default='dino_v2', type=str,
|
||||||
|
help='Backbone architecture (dino_v2, resnet18, resnet34)')
|
||||||
|
parser.add_argument('--position_embedding', default='sine', type=str,
|
||||||
|
choices=('sine', 'learned'),
|
||||||
|
help='Type of positional encoding')
|
||||||
|
|
||||||
|
# Camera configuration
|
||||||
|
parser.add_argument('--camera_names', default=[], nargs='+',
|
||||||
|
help='List of camera names for observations')
|
||||||
|
|
||||||
|
# Other parameters (not directly used but kept for compatibility)
|
||||||
|
parser.add_argument('--batch_size', default=15, type=int)
|
||||||
|
parser.add_argument('--epochs', default=20000, type=int)
|
||||||
|
parser.add_argument('--masks', action='store_true',
|
||||||
|
help='Use intermediate layer features')
|
||||||
|
parser.add_argument('--dilation', action='store_false',
|
||||||
|
help='Use dilated convolution in backbone')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def build_gr00t_model_and_optimizer(args_override):
|
||||||
|
"""
|
||||||
|
Build GR00T model and optimizer from config dictionary.
|
||||||
|
|
||||||
|
This function is designed to work with config.yaml loading:
|
||||||
|
1. Parse default arguments
|
||||||
|
2. Override with values from args_override (typically from config['gr00t'])
|
||||||
|
3. Build model and optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_override: Dictionary of config values, typically from config.yaml's 'gr00t:' section
|
||||||
|
Expected keys: embed_dim, hidden_dim, state_dim, action_dim,
|
||||||
|
num_queries, nheads, mlp_ratio, dropout, num_layers,
|
||||||
|
lr, lr_backbone, camera_names, backbone, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model: GR00T model on CUDA
|
||||||
|
optimizer: AdamW optimizer with separate learning rates for backbone and other params
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser('GR00T training and evaluation script',
|
||||||
|
parents=[get_args_parser()])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Override with config values
|
||||||
|
for k, v in args_override.items():
|
||||||
|
setattr(args, k, v)
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
model = build_gr00t_model(args)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
# Create parameter groups with different learning rates
|
||||||
|
param_dicts = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters()
|
||||||
|
if "backbone" not in n and p.requires_grad]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters()
|
||||||
|
if "backbone" in n and p.requires_grad],
|
||||||
|
"lr": args.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(param_dicts,
|
||||||
|
lr=args.lr,
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
return model, optimizer
|
||||||
3
roboimi/gr00t/models/__init__.py
Normal file
3
roboimi/gr00t/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .gr00t import build_gr00t_model
|
||||||
|
|
||||||
|
__all__ = ['build_gr00t_model']
|
||||||
142
roboimi/gr00t/models/dit.py
Normal file
142
roboimi/gr00t/models/dit.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
from diffusers.models.embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class TimestepEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
embedding_dim = args.embed_dim
|
||||||
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, norm_eps=1e-5, norm_elementwise_affine=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
output_dim = embedding_dim * 2
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
temb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
temb = self.linear(self.silu(temb))
|
||||||
|
scale, shift = temb.chunk(2, dim=1)
|
||||||
|
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args, crosss_attention_dim, use_self_attn=False):
|
||||||
|
super().__init__()
|
||||||
|
dim = args.embed_dim
|
||||||
|
num_heads = args.nheads
|
||||||
|
mlp_ratio = args.mlp_ratio
|
||||||
|
dropout = args.dropout
|
||||||
|
self.norm1 = AdaLayerNorm(dim)
|
||||||
|
|
||||||
|
if not use_self_attn:
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
embed_dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
kdim=crosss_attention_dim,
|
||||||
|
vdim=crosss_attention_dim,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.attn = nn.MultiheadAttention(
|
||||||
|
embed_dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * mlp_ratio),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * mlp_ratio, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, temb, context=None):
|
||||||
|
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||||
|
|
||||||
|
attn_output = self.attn(
|
||||||
|
norm_hidden_states,
|
||||||
|
context if context is not None else norm_hidden_states,
|
||||||
|
context if context is not None else norm_hidden_states,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
|
|
||||||
|
ff_output = self.mlp(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class DiT(nn.Module):
|
||||||
|
def __init__(self, args, cross_attention_dim):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = args.embed_dim
|
||||||
|
num_layers = args.num_layers
|
||||||
|
output_dim = args.hidden_dim
|
||||||
|
|
||||||
|
self.timestep_encoder = TimestepEncoder(args)
|
||||||
|
|
||||||
|
all_blocks = []
|
||||||
|
for idx in range(num_layers):
|
||||||
|
use_self_attn = idx % 2 == 1
|
||||||
|
if use_self_attn:
|
||||||
|
block = BasicTransformerBlock(args, crosss_attention_dim=None, use_self_attn=True)
|
||||||
|
else:
|
||||||
|
block = BasicTransformerBlock(args, crosss_attention_dim=cross_attention_dim, use_self_attn=False)
|
||||||
|
all_blocks.append(block)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||||
|
|
||||||
|
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||||
|
self.proj_out_2 = nn.Linear(inner_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, timestep, encoder_hidden_states):
|
||||||
|
temb = self.timestep_encoder(timestep)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.contiguous()
|
||||||
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||||
|
|
||||||
|
for idx, block in enumerate(self.transformer_blocks):
|
||||||
|
if idx % 2 == 1:
|
||||||
|
hidden_states = block(hidden_states, temb)
|
||||||
|
else:
|
||||||
|
hidden_states = block(hidden_states, temb, context=encoder_hidden_states)
|
||||||
|
|
||||||
|
conditioning = temb
|
||||||
|
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||||
|
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
return self.proj_out_2(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dit(args, cross_attention_dim):
|
||||||
|
return DiT(args, cross_attention_dim)
|
||||||
124
roboimi/gr00t/models/gr00t.py
Normal file
124
roboimi/gr00t/models/gr00t.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
|
||||||
|
from .modules import (
|
||||||
|
build_action_decoder,
|
||||||
|
build_action_encoder,
|
||||||
|
build_state_encoder,
|
||||||
|
build_time_sampler,
|
||||||
|
build_noise_scheduler,
|
||||||
|
)
|
||||||
|
from .backbone import build_backbone
|
||||||
|
from .dit import build_dit
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class gr00t(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbones,
|
||||||
|
dit,
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
action_decoder,
|
||||||
|
time_sampler,
|
||||||
|
noise_scheduler,
|
||||||
|
num_queries,
|
||||||
|
camera_names,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.dit = dit
|
||||||
|
self.state_encoder = state_encoder
|
||||||
|
self.action_encoder = action_encoder
|
||||||
|
self.action_decoder = action_decoder
|
||||||
|
self.time_sampler = time_sampler
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
|
||||||
|
if backbones is not None:
|
||||||
|
self.backbones = nn.ModuleList(backbones)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
is_training = actions is not None # train or val
|
||||||
|
bs, _ = qpos.shape
|
||||||
|
|
||||||
|
all_cam_features = []
|
||||||
|
for cam_id, cam_name in enumerate(self.camera_names):
|
||||||
|
# features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||||
|
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||||
|
features = features[0] # take the last layer feature
|
||||||
|
B, C, H, W = features.shape
|
||||||
|
features_seq = features.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
||||||
|
all_cam_features.append(features_seq)
|
||||||
|
encoder_hidden_states = torch.cat(all_cam_features, dim=1)
|
||||||
|
|
||||||
|
state_features = self.state_encoder(qpos) # [B, 1, emb_dim]
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
# training logic
|
||||||
|
|
||||||
|
timesteps = self.time_sampler(bs, actions.device, actions.dtype)
|
||||||
|
noisy_actions, target_velocity = self.noise_scheduler.add_noise(
|
||||||
|
actions, timesteps
|
||||||
|
)
|
||||||
|
t_discretized = (timesteps[:, 0, 0] * 1000).long()
|
||||||
|
action_features = self.action_encoder(noisy_actions, t_discretized)
|
||||||
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||||
|
model_output = self.dit(sa_embs, t_discretized, encoder_hidden_states)
|
||||||
|
pred = self.action_decoder(model_output)
|
||||||
|
pred_actions = pred[:, -actions.shape[1] :]
|
||||||
|
action_loss = F.mse_loss(pred_actions, target_velocity, reduction='none')
|
||||||
|
return pred_actions, action_loss
|
||||||
|
else:
|
||||||
|
actions = torch.randn(bs, self.num_queries, qpos.shape[-1], device=qpos.device, dtype=qpos.dtype)
|
||||||
|
k = 5
|
||||||
|
dt = 1.0 / k
|
||||||
|
for t in range(k):
|
||||||
|
t_cont = t / float(k)
|
||||||
|
t_discretized = int(t_cont * 1000)
|
||||||
|
timesteps = torch.full((bs,), t_discretized, device=qpos.device, dtype=qpos.dtype)
|
||||||
|
action_features = self.action_encoder(actions, timesteps)
|
||||||
|
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||||
|
# Create tensor of shape [B] for DiT (consistent with training path)
|
||||||
|
model_output = self.dit(sa_embs, timesteps, encoder_hidden_states)
|
||||||
|
pred = self.action_decoder(model_output)
|
||||||
|
pred_velocity = pred[:, -self.num_queries :]
|
||||||
|
actions = actions + pred_velocity * dt
|
||||||
|
return actions, _
|
||||||
|
def build_gr00t_model(args):
|
||||||
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
|
backbones = []
|
||||||
|
for _ in args.camera_names:
|
||||||
|
backbone = build_backbone(args)
|
||||||
|
backbones.append(backbone)
|
||||||
|
|
||||||
|
cross_attention_dim = backbones[0].num_channels
|
||||||
|
|
||||||
|
dit = build_dit(args, cross_attention_dim)
|
||||||
|
|
||||||
|
state_encoder = build_state_encoder(args)
|
||||||
|
action_encoder = build_action_encoder(args)
|
||||||
|
action_decoder = build_action_decoder(args)
|
||||||
|
time_sampler = build_time_sampler(args)
|
||||||
|
noise_scheduler = build_noise_scheduler(args)
|
||||||
|
model = gr00t(
|
||||||
|
backbones,
|
||||||
|
dit,
|
||||||
|
state_encoder,
|
||||||
|
action_encoder,
|
||||||
|
action_decoder,
|
||||||
|
time_sampler,
|
||||||
|
noise_scheduler,
|
||||||
|
args.num_queries,
|
||||||
|
args.camera_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print("number of parameters: %.2fM" % (n_parameters/1e6,))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
179
roboimi/gr00t/models/modules.py
Normal file
179
roboimi/gr00t/models/modules.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# ActionEncoder
|
||||||
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = args.embed_dim
|
||||||
|
|
||||||
|
def forward(self, timesteps):
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
B, T = timesteps.shape
|
||||||
|
device = timesteps.device
|
||||||
|
|
||||||
|
half_dim = self.embed_dim // 2
|
||||||
|
|
||||||
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||||
|
torch.log(torch.tensor(10000.0)) / half_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||||
|
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||||
|
|
||||||
|
return enc
|
||||||
|
|
||||||
|
|
||||||
|
class ActionEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
action_dim = args.action_dim
|
||||||
|
embed_dim = args.embed_dim
|
||||||
|
|
||||||
|
self.W1 = nn.Linear(action_dim, embed_dim)
|
||||||
|
self.W2 = nn.Linear(2 * embed_dim, embed_dim)
|
||||||
|
self.W3 = nn.Linear(embed_dim, embed_dim)
|
||||||
|
|
||||||
|
self.pos_encoder = SinusoidalPositionalEncoding(args)
|
||||||
|
|
||||||
|
def forward(self, actions, timesteps):
|
||||||
|
B, T, _ = actions.shape
|
||||||
|
|
||||||
|
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||||
|
# so that shape => (B, T)
|
||||||
|
# Handle different input shapes: (B,), (B, 1), (B, 1, 1)
|
||||||
|
# Reshape to (B,) then expand to (B, T)
|
||||||
|
# if timesteps.dim() == 3:
|
||||||
|
# # Shape (B, 1, 1) or (B, T, 1) -> (B,)
|
||||||
|
# timesteps = timesteps[:, 0, 0]
|
||||||
|
# elif timesteps.dim() == 2:
|
||||||
|
# # Shape (B, 1) or (B, T) -> take first element if needed
|
||||||
|
# if timesteps.shape[1] == 1:
|
||||||
|
# timesteps = timesteps[:, 0]
|
||||||
|
# # else: already (B, T), use as is
|
||||||
|
# elif timesteps.dim() != 1:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Expected `timesteps` to have shape (B,), (B, 1), or (B, 1, 1), got {timesteps.shape}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Now timesteps should be (B,), expand to (B, T)
|
||||||
|
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||||
|
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Standard action MLP step for shape => (B, T, w)
|
||||||
|
a_emb = self.W1(actions)
|
||||||
|
|
||||||
|
# 3) Get the sinusoidal encoding (B, T, w)
|
||||||
|
tau_emb = self.pos_encoder(timesteps).to(dtype=a_emb.dtype)
|
||||||
|
|
||||||
|
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||||
|
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||||
|
x = F.silu(self.W2(x))
|
||||||
|
|
||||||
|
# 5) Finally W3 => (B, T, w)
|
||||||
|
x = self.W3(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_encoder(args):
|
||||||
|
return ActionEncoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# StateEncoder
|
||||||
|
class StateEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = args.state_dim
|
||||||
|
hidden_dim = args.hidden_dim
|
||||||
|
output_dim = args.embed_dim
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, states):
|
||||||
|
state_emb = self.mlp(states) # [B, emb_dim]
|
||||||
|
state_emb = state_emb.unsqueeze(1)
|
||||||
|
return state_emb # [B, 1, emb_dim]
|
||||||
|
|
||||||
|
|
||||||
|
def build_state_encoder(args):
|
||||||
|
return StateEncoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# ActionDecoder
|
||||||
|
class ActionDecoder(nn.Module):
|
||||||
|
def __init__(self,args):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = args.hidden_dim
|
||||||
|
hidden_dim = args.hidden_dim
|
||||||
|
output_dim = args.action_dim
|
||||||
|
|
||||||
|
self.num_queries = args.num_queries
|
||||||
|
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, model_output):
|
||||||
|
pred_actions = self.mlp(model_output)
|
||||||
|
return pred_actions[:, -self.num_queries:]
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_decoder(args):
|
||||||
|
return ActionDecoder(args)
|
||||||
|
|
||||||
|
|
||||||
|
# TimeSampler
|
||||||
|
class TimeSampler(nn.Module):
|
||||||
|
def __init__(self, noise_s = 0.999, noise_beta_alpha=1.5, noise_beta_beta=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.noise_s = noise_s
|
||||||
|
self.beta_dist = torch.distributions.Beta(noise_beta_alpha, noise_beta_beta)
|
||||||
|
|
||||||
|
def forward(self, batch_size, device, dtype):
|
||||||
|
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||||
|
sample = (1 - sample) * self.noise_s
|
||||||
|
return sample[:, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
def build_time_sampler(args):
|
||||||
|
return TimeSampler()
|
||||||
|
|
||||||
|
|
||||||
|
# NoiseScheduler
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class FlowMatchingScheduler(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# --- 训练逻辑:加噪并计算目标 ---
|
||||||
|
def add_noise(self, actions, timesteps):
|
||||||
|
noise = torch.randn_like(actions)
|
||||||
|
noisy_samples = actions * timesteps + noise * (1 - timesteps)
|
||||||
|
target_velocity = actions - noise
|
||||||
|
|
||||||
|
return noisy_samples, target_velocity
|
||||||
|
|
||||||
|
# --- 推理逻辑:欧拉步 (Euler Step) ---
|
||||||
|
def step(self, model_output, sample, dt):
|
||||||
|
prev_sample = sample + model_output * dt
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def build_noise_scheduler(args):
|
||||||
|
return FlowMatchingScheduler()
|
||||||
90
roboimi/gr00t/policy.py
Normal file
90
roboimi/gr00t/policy.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
GR00T Policy wrapper for imitation learning.
|
||||||
|
|
||||||
|
This module provides the gr00tPolicy class that wraps the GR00T model
|
||||||
|
for training and evaluation in the imitation learning framework.
|
||||||
|
"""
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
import torch
|
||||||
|
from roboimi.gr00t.main import build_gr00t_model_and_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class gr00tPolicy(nn.Module):
|
||||||
|
"""
|
||||||
|
GR00T Policy for action prediction using diffusion-based DiT architecture.
|
||||||
|
|
||||||
|
This policy wraps the GR00T model and handles:
|
||||||
|
- Image resizing to match DINOv2 patch size requirements
|
||||||
|
- Image normalization (ImageNet stats)
|
||||||
|
- Training with action chunks and loss computation
|
||||||
|
- Inference with diffusion sampling
|
||||||
|
"""
|
||||||
|
def __init__(self, args_override):
|
||||||
|
super().__init__()
|
||||||
|
model, optimizer = build_gr00t_model_and_optimizer(args_override)
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# DINOv2 requires image dimensions to be multiples of patch size (14)
|
||||||
|
# Common sizes: 224x224, 336x336, etc. (14*16=224, 14*24=336)
|
||||||
|
self.patch_h = 16 # Number of patches vertically
|
||||||
|
self.patch_w = 22 # Number of patches horizontally
|
||||||
|
target_size = (self.patch_h * 14, self.patch_w * 14) # (224, 308)
|
||||||
|
|
||||||
|
# Training transform with data augmentation
|
||||||
|
self.train_transform = v2.Compose([
|
||||||
|
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||||
|
v2.RandomPerspective(distortion_scale=0.5),
|
||||||
|
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
||||||
|
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
|
||||||
|
v2.Resize(target_size),
|
||||||
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Inference transform (no augmentation)
|
||||||
|
self.inference_transform = v2.Compose([
|
||||||
|
v2.Resize(target_size),
|
||||||
|
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
"""
|
||||||
|
Forward pass for training or inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos: Joint positions [B, state_dim]
|
||||||
|
image: Camera images [B, num_cameras, C, H, W]
|
||||||
|
actions: Ground truth actions [B, chunk_size, action_dim] (training only)
|
||||||
|
is_pad: Padding mask [B, chunk_size] (training only)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training: dict with 'mse' loss
|
||||||
|
Inference: predicted actions [B, num_queries, action_dim]
|
||||||
|
"""
|
||||||
|
# Apply transforms (resize + normalization)
|
||||||
|
if actions is not None: # training time
|
||||||
|
image = self.train_transform(image)
|
||||||
|
else: # inference time
|
||||||
|
image = self.inference_transform(image)
|
||||||
|
|
||||||
|
if actions is not None: # training time
|
||||||
|
actions = actions[:, :self.model.num_queries]
|
||||||
|
is_pad = is_pad[:, :self.model.num_queries]
|
||||||
|
_, action_loss = self.model(qpos, image, actions, is_pad)
|
||||||
|
|
||||||
|
# Mask out padded positions
|
||||||
|
mse_loss = (action_loss * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
|
||||||
|
loss_dict = {
|
||||||
|
'loss': mse_loss
|
||||||
|
}
|
||||||
|
return loss_dict
|
||||||
|
else: # inference time
|
||||||
|
a_hat, _ = self.model(qpos, image)
|
||||||
|
return a_hat
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""Return the optimizer for training."""
|
||||||
|
return self.optimizer
|
||||||
@@ -20,7 +20,7 @@ SIM_TASK_CONFIGS = {
|
|||||||
'dataset_dir': DATASET_DIR + '/sim_transfer',
|
'dataset_dir': DATASET_DIR + '/sim_transfer',
|
||||||
'num_episodes': 7,
|
'num_episodes': 7,
|
||||||
'episode_len': 700,
|
'episode_len': 700,
|
||||||
'camera_names': ['angle','r_vis'],
|
'camera_names': ['top','r_vis'],
|
||||||
'xml_dir': HOME_PATH + '/assets'
|
'xml_dir': HOME_PATH + '/assets'
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from roboimi.utils.utils import load_data, set_seed
|
from roboimi.utils.utils import load_data, set_seed
|
||||||
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy
|
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy, ACTTVPolicy
|
||||||
from roboimi.ddt.policy import DDTPolicy
|
from roboimi.gr00t.policy import gr00tPolicy
|
||||||
|
|
||||||
class ModelInterface:
|
class ModelInterface:
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -66,23 +66,25 @@ class ModelInterface:
|
|||||||
'num_queries': 1,
|
'num_queries': 1,
|
||||||
'camera_names': self.config['camera_names'],
|
'camera_names': self.config['camera_names'],
|
||||||
}
|
}
|
||||||
elif self.config['policy_class'] == 'DDT':
|
elif self.config['policy_class'] == 'GR00T':
|
||||||
|
# GR00T uses its own config section from config.yaml
|
||||||
|
gr00t_config = self.config.get('gr00t', {})
|
||||||
self.config['policy_config'] = {
|
self.config['policy_config'] = {
|
||||||
'lr': self.config['lr'],
|
'lr': gr00t_config.get('lr', self.config['lr']),
|
||||||
'lr_backbone': self.config['lr_backbone'],
|
'lr_backbone': gr00t_config.get('lr_backbone', self.config['lr_backbone']),
|
||||||
'backbone': self.config.get('backbone', 'dino_v2'),
|
'weight_decay': gr00t_config.get('weight_decay', 1e-4),
|
||||||
'num_queries': self.config['chunk_size'],
|
'embed_dim': gr00t_config.get('embed_dim', 1536),
|
||||||
'hidden_dim': self.config['hidden_dim'],
|
'hidden_dim': gr00t_config.get('hidden_dim', 1024),
|
||||||
'nheads': self.config['nheads'],
|
'state_dim': gr00t_config.get('state_dim', 16),
|
||||||
'enc_layers': self.config['enc_layers'],
|
'action_dim': gr00t_config.get('action_dim', 16),
|
||||||
'state_dim': self.config.get('state_dim', 16),
|
'num_queries': gr00t_config.get('num_queries', 16),
|
||||||
'action_dim': self.config.get('action_dim', 16),
|
'num_layers': gr00t_config.get('num_layers', 16),
|
||||||
|
'nheads': gr00t_config.get('nheads', 32),
|
||||||
|
'mlp_ratio': gr00t_config.get('mlp_ratio', 4),
|
||||||
|
'dropout': gr00t_config.get('dropout', 0.2),
|
||||||
|
'backbone': gr00t_config.get('backbone', 'dino_v2'),
|
||||||
|
'position_embedding': gr00t_config.get('position_embedding', 'sine'),
|
||||||
'camera_names': self.config['camera_names'],
|
'camera_names': self.config['camera_names'],
|
||||||
'qpos_noise_std': self.config.get('qpos_noise_std', 0),
|
|
||||||
# DDT 特有参数
|
|
||||||
'num_blocks': self.config.get('num_blocks', 12),
|
|
||||||
'mlp_ratio': self.config.get('mlp_ratio', 4.0),
|
|
||||||
'num_inference_steps': self.config.get('num_inference_steps', 10),
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -94,8 +96,8 @@ class ModelInterface:
|
|||||||
return ACTTVPolicy(self.config['policy_config'])
|
return ACTTVPolicy(self.config['policy_config'])
|
||||||
elif self.config['policy_class'] == 'CNNMLP':
|
elif self.config['policy_class'] == 'CNNMLP':
|
||||||
return CNNMLPPolicy(self.config['policy_config'])
|
return CNNMLPPolicy(self.config['policy_config'])
|
||||||
elif self.config['policy_class'] == 'DDT':
|
elif self.config['policy_class'] == 'GR00T':
|
||||||
return DDTPolicy(self.config['policy_config'])
|
return gr00tPolicy(self.config['policy_config'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user