1 Commits
work ... ddt

Author SHA1 Message Date
JiajunLI
2376f494d2 feat(policy): 添加ddt policy 2026-01-28 17:14:28 +08:00
14 changed files with 2465 additions and 2 deletions

112
roboimi/ddt/main.py Normal file
View File

@@ -0,0 +1,112 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DDT 模型构建和优化器配置。
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from .models import build_DDT_model
def get_args_parser():
"""获取 DDT 模型的参数解析器。"""
parser = argparse.ArgumentParser('DDT model configuration', add_help=False)
# 学习率
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--lr_backbone', default=1e-5, type=float)
parser.add_argument('--batch_size', default=2, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float)
# Backbone 参数
parser.add_argument('--backbone', default='resnet18', type=str,
help="Name of the convolutional backbone to use")
parser.add_argument('--dilation', action='store_true',
help="If true, replace stride with dilation in the last conv block")
parser.add_argument('--position_embedding', default='sine', type=str,
choices=('sine', 'learned'),
help="Type of positional embedding")
parser.add_argument('--camera_names', default=[], type=list,
help="A list of camera names")
# Transformer 参数
parser.add_argument('--enc_layers', default=4, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers")
parser.add_argument('--hidden_dim', default=512, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots (action horizon)")
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
# DDT 特有参数
parser.add_argument('--num_blocks', default=12, type=int,
help="Total number of transformer blocks in DDT")
parser.add_argument('--mlp_ratio', default=4.0, type=float,
help="MLP hidden dimension ratio")
parser.add_argument('--num_inference_steps', default=10, type=int,
help="Number of diffusion inference steps")
# Segmentation (未使用)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if provided")
return parser
def build_DDT_model_and_optimizer(args_override):
"""构建 DDT 模型和优化器。
Args:
args_override: 覆盖默认参数的字典
Returns:
model: DDT 模型
optimizer: AdamW 优化器
"""
parser = argparse.ArgumentParser('DDT training script', parents=[get_args_parser()])
args = parser.parse_args([]) # 空列表避免命令行参数干扰
# 应用参数覆盖
for k, v in args_override.items():
setattr(args, k, v)
# 构建模型
model = build_DDT_model(args)
model.cuda()
# 配置优化器backbone 使用较小学习率)
param_dicts = [
{
"params": [p for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad]
},
{
"params": [p for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
param_dicts,
lr=args.lr,
weight_decay=args.weight_decay
)
return model, optimizer

View File

@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .model import build as build_ddt
from .model import build_ddt
def build_DDT_model(args):
"""构建 DDT 模型的统一入口。"""
return build_ddt(args)

View File

@@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List
from util.misc import NestedTensor, is_main_process
from .position_encoding import build_position_encoding
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
# class DINOv2BackBone(nn.Module):
# def __init__(self) -> None:
# super().__init__()
# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# self.body.eval()
# self.num_channels = 384
# @torch.no_grad()
# def forward(self, tensor):
# xs = self.body.forward_features(tensor)["x_norm_patchtokens"]
# od = OrderedDict()
# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
# return od
class DINOv2BackBone(nn.Module):
def __init__(self, return_interm_layers: bool = False) -> None:
super().__init__()
self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
self.body.eval()
self.num_channels = 384
self.return_interm_layers = return_interm_layers
@torch.no_grad()
def forward(self, tensor):
features = self.body.forward_features(tensor)
if self.return_interm_layers:
layer1 = features["x_norm_patchtokens"]
layer2 = features["x_norm_patchtokens"]
layer3 = features["x_norm_patchtokens"]
layer4 = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
else:
xs = features["x_norm_patchtokens"]
od = OrderedDict()
od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1)
return od
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
if args.backbone == 'dino_v2':
backbone = DINOv2BackBone()
else:
assert args.backbone in ['resnet18', 'resnet34']
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

631
roboimi/ddt/models/ddt.py Normal file
View File

@@ -0,0 +1,631 @@
"""
动作序列扩散 Transformer (Action Decoupled Diffusion Transformer)
基于 DDT 架构修改,用于生成机器人动作序列。
主要改动:
1. 2D RoPE → 1D RoPE (适配时序数据)
2. LabelEmbedder → ObservationEncoder (观测条件)
3. 去除 patchify/unpatchify (动作序列已是 1D)
"""
import math
from typing import Tuple, Optional
import torch
import torch.nn as nn
from torch.nn.functional import scaled_dot_product_attention
# ============================================================================
# 通用工具函数
# ============================================================================
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""AdaLN 调制函数。
Args:
x: 输入张量。
shift: 偏移量。
scale: 缩放量。
Returns:
调制后的张量: x * (1 + scale) + shift
"""
return x * (1 + scale) + shift
# ============================================================================
# 1D 旋转位置编码 (RoPE)
# ============================================================================
def precompute_freqs_cis_1d(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor:
"""预计算 1D 旋转位置编码的复数频率。
用于时序数据(如动作序列)的位置编码,相比 2D RoPE 更简单高效。
Args:
dim: 每个注意力头的维度 (head_dim)。
seq_len: 序列长度。
theta: RoPE 的基础频率,默认 10000.0。
Returns:
复数频率张量,形状为 (seq_len, dim//2)。
"""
# 计算频率: 1 / (theta^(2i/dim))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
# 位置索引
t = torch.arange(seq_len).float() # [seq_len]
# 外积得到位置-频率矩阵
freqs = torch.outer(t, freqs) # [seq_len, dim//2]
# 转换为复数形式 (极坐标)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [seq_len, dim//2]
return freqs_cis
def apply_rotary_emb_1d(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""应用 1D 旋转位置编码到 Query 和 Key。
Args:
xq: Query 张量,形状为 (B, N, H, Hc)。
xk: Key 张量,形状为 (B, N, H, Hc)。
freqs_cis: 预计算的复数频率,形状为 (N, Hc//2)。
Returns:
应用 RoPE 后的 (xq, xk),形状不变。
"""
# 调整 freqs_cis 形状以便广播: [1, N, 1, Hc//2]
freqs_cis = freqs_cis[None, :, None, :]
# 将实数张量视为复数: [B, N, H, Hc] -> [B, N, H, Hc//2] (复数)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 复数乘法实现旋转
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [B, N, H, Hc]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# ============================================================================
# 基础组件
# ============================================================================
class Embed(nn.Module):
"""线性嵌入层,将输入投影到隐藏空间。"""
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[nn.Module] = None,
bias: bool = True,
):
"""初始化 Embed。
Args:
in_chans: 输入通道数/维度。
embed_dim: 输出嵌入维度。
norm_layer: 可选的归一化层。
bias: 是否使用偏置。
"""
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""扩散时间步嵌入器。
使用正弦位置编码 + MLP 将标量时间步映射到高维向量。
"""
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
"""初始化 TimestepEmbedder。
Args:
hidden_size: 输出嵌入维度。
frequency_embedding_size: 正弦编码的维度。
"""
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10.0) -> torch.Tensor:
"""生成正弦时间步嵌入。
Args:
t: 时间步张量,形状为 (B,)。
dim: 嵌入维度。
max_period: 最大周期。
Returns:
时间步嵌入,形状为 (B, dim)。
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t: torch.Tensor) -> torch.Tensor:
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class ObservationEncoder(nn.Module):
"""观测状态编码器。
将机器人的观测向量(如关节位置、末端位姿、图像特征等)
编码为条件向量,用于条件扩散生成。
Attributes:
encoder: 两层 MLP 编码器。
Example:
>>> encoder = ObservationEncoder(obs_dim=128, hidden_size=512)
>>> obs = torch.randn(2, 128)
>>> cond = encoder(obs) # [2, 512]
"""
def __init__(self, obs_dim: int, hidden_size: int):
"""初始化 ObservationEncoder。
Args:
obs_dim: 观测向量的维度。
hidden_size: 输出条件向量的维度。
"""
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
"""前向传播。
Args:
obs: 观测向量,形状为 (B, obs_dim)。
Returns:
条件向量,形状为 (B, hidden_size)。
"""
return self.encoder(obs)
class FinalLayer(nn.Module):
"""最终输出层,使用 AdaLN 调制后输出预测结果。"""
def __init__(self, hidden_size: int, out_channels: int):
"""初始化 FinalLayer。
Args:
hidden_size: 输入隐藏维度。
out_channels: 输出通道数/维度。
"""
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, hidden_size)。
c: 条件张量,形状为 (B, N, hidden_size) 或 (B, 1, hidden_size)。
Returns:
输出张量,形状为 (B, N, out_channels)。
"""
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
# ============================================================================
# 归一化和前馈网络
# ============================================================================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (RMS 归一化层)。
RMSNorm 是 LayerNorm 的简化版本,去掉了均值中心化操作,只保留缩放。
相比 LayerNorm 计算更快,效果相当,被广泛用于 LLaMA、Mistral 等大模型。
数学公式:
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""初始化 RMSNorm。
Args:
hidden_size: 输入特征的维度。
eps: 防止除零的小常数。
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
"""SwiGLU 前馈网络 (Feed-Forward Network)。
使用 SwiGLU 门控激活函数的前馈网络,来自 LLaMA 架构。
结构:
output = W2(SiLU(W1(x)) * W3(x))
"""
def __init__(self, dim: int, hidden_dim: int):
"""初始化 FeedForward。
Args:
dim: 输入和输出的特征维度。
hidden_dim: 隐藏层维度(实际使用 2/3 * hidden_dim
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
# ============================================================================
# 注意力机制
# ============================================================================
class RAttention(nn.Module):
"""带有旋转位置编码的多头自注意力 (Rotary Attention)。
集成了以下技术:
- 1D RoPE: 通过复数旋转编码时序位置信息
- QK-Norm: 对 Query 和 Key 进行归一化,稳定训练
- Flash Attention: 使用 scaled_dot_product_attention
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
"""初始化 RAttention。
Args:
dim: 输入特征维度,必须能被 num_heads 整除。
num_heads: 注意力头数。
qkv_bias: QKV 投影是否使用偏置。
qk_norm: 是否对 Q, K 进行归一化。
attn_drop: 注意力权重的 dropout 率。
proj_drop: 输出投影的 dropout 率。
norm_layer: 归一化层类型。
"""
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
pos: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, C)。
pos: 1D RoPE 位置编码,形状为 (N, head_dim//2)。
mask: 可选的注意力掩码。
Returns:
输出张量,形状为 (B, N, C)。
"""
B, N, C = x.shape
# QKV 投影
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [B, N, H, Hc]
# QK-Norm
q = self.q_norm(q)
k = self.k_norm(k)
# 应用 1D RoPE
q, k = apply_rotary_emb_1d(q, k, freqs_cis=pos)
# 调整维度: [B, N, H, Hc] -> [B, H, N, Hc]
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
# Scaled Dot-Product Attention
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
# 输出投影
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# ============================================================================
# Transformer Block
# ============================================================================
class ActionDDTBlock(nn.Module):
"""动作 DDT Transformer Block。
结构: Pre-Norm + AdaLN + Attention + FFN
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
"""初始化 ActionDDTBlock。
Args:
hidden_size: 隐藏层维度。
num_heads: 注意力头数。
mlp_ratio: FFN 隐藏层倍率。
"""
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=num_heads, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
pos: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""前向传播。
Args:
x: 输入张量,形状为 (B, N, hidden_size)。
c: 条件张量,形状为 (B, 1, hidden_size) 或 (B, N, hidden_size)。
pos: 位置编码。
mask: 可选的注意力掩码。
Returns:
输出张量,形状为 (B, N, hidden_size)。
"""
# AdaLN 调制参数
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=-1)
# Attention 分支
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
# FFN 分支
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
# ============================================================================
# 主模型: ActionDDT
# ============================================================================
class ActionDDT(nn.Module):
"""动作序列解耦扩散 Transformer (Action Decoupled Diffusion Transformer)。
基于 DDT 架构,专为机器人动作序列生成设计。
将模型解耦为编码器和解码器两部分,编码器状态可缓存以加速推理。
架构:
- 编码器: 前 num_encoder_blocks 个 block生成状态 s
- 解码器: 剩余 block使用状态 s 对动作序列 x 去噪
- 使用 1D RoPE 进行时序位置编码
- 使用 AdaLN 注入时间步和观测条件
Args:
action_dim: 动作向量维度(如 7-DoF 机械臂为 7
obs_dim: 观测向量维度。
action_horizon: 预测的动作序列长度。
hidden_size: Transformer 隐藏层维度。
num_blocks: Transformer block 总数。
num_encoder_blocks: 编码器 block 数量。
num_heads: 注意力头数。
mlp_ratio: FFN 隐藏层倍率。
输入:
x (Tensor): 带噪声的动作序列,形状为 (B, T, action_dim)。
t (Tensor): 扩散时间步,形状为 (B,),取值范围 [0, 1]。
obs (Tensor): 观测条件,形状为 (B, obs_dim)。
s (Tensor, optional): 缓存的编码器状态。
输出:
x (Tensor): 预测的速度场/噪声,形状为 (B, T, action_dim)。
s (Tensor): 编码器状态,可缓存复用。
Example:
>>> model = ActionDDT(action_dim=7, obs_dim=128, action_horizon=16)
>>> x = torch.randn(2, 16, 7) # 带噪声的动作序列
>>> t = torch.rand(2) # 随机时间步
>>> obs = torch.randn(2, 128) # 观测条件
>>> out, state = model(x, t, obs)
>>> out.shape
torch.Size([2, 16, 7])
"""
def __init__(
self,
action_dim: int = 7,
obs_dim: int = 128,
action_horizon: int = 16,
hidden_size: int = 512,
num_blocks: int = 12,
num_encoder_blocks: int = 4,
num_heads: int = 8,
mlp_ratio: float = 4.0,
):
super().__init__()
# 保存配置
self.action_dim = action_dim
self.obs_dim = obs_dim
self.action_horizon = action_horizon
self.hidden_size = hidden_size
self.num_blocks = num_blocks
self.num_encoder_blocks = num_encoder_blocks
self.num_heads = num_heads
# 动作嵌入层
self.x_embedder = Embed(action_dim, hidden_size, bias=True)
self.s_embedder = Embed(action_dim, hidden_size, bias=True)
# 条件嵌入
self.t_embedder = TimestepEmbedder(hidden_size)
self.obs_encoder = ObservationEncoder(obs_dim, hidden_size)
# 输出层
self.final_layer = FinalLayer(hidden_size, action_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
ActionDDTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(num_blocks)
])
# 预计算 1D 位置编码
pos = precompute_freqs_cis_1d(hidden_size // num_heads, action_horizon)
self.register_buffer('pos', pos)
# 初始化权重
self.initialize_weights()
def initialize_weights(self):
"""初始化模型权重。"""
# 嵌入层使用 Xavier 初始化
for embedder in [self.x_embedder, self.s_embedder]:
w = embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(embedder.proj.bias, 0)
# 时间步嵌入 MLP
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# 观测编码器
for m in self.obs_encoder.encoder:
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# 输出层零初始化 (AdaLN-Zero)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
obs: torch.Tensor,
s: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""前向传播。
Args:
x: 带噪声的动作序列 [B, T, action_dim]
t: 扩散时间步 [B] 或 [B, 1],取值范围 [0, 1]
obs: 观测条件 [B, obs_dim]
s: 可选的编码器状态缓存 [B, T, hidden_size]
mask: 可选的注意力掩码
Returns:
x: 预测的速度场/噪声 [B, T, action_dim]
s: 编码器状态 [B, T, hidden_size],可缓存复用
"""
B, T, _ = x.shape
# 1. 时间步嵌入: [B] -> [B, 1, hidden_size]
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
# 2. 观测条件嵌入: [B, obs_dim] -> [B, 1, hidden_size]
obs_emb = self.obs_encoder(obs).view(B, 1, self.hidden_size)
# 3. 融合条件: c = SiLU(t + obs)
c = nn.functional.silu(t_emb + obs_emb)
# 4. 编码器部分: 生成状态 s
if s is None:
# 状态嵌入: [B, T, action_dim] -> [B, T, hidden_size]
s = self.s_embedder(x)
# 通过编码器 blocks
for i in range(self.num_encoder_blocks):
s = self.blocks[i](s, c, self.pos, mask)
# 融合时间信息
s = nn.functional.silu(t_emb + s)
# 5. 解码器部分: 去噪
# 输入嵌入: [B, T, action_dim] -> [B, T, hidden_size]
x = self.x_embedder(x)
# 通过解码器 blocks以 s 作为条件
for i in range(self.num_encoder_blocks, self.num_blocks):
x = self.blocks[i](x, s, self.pos, None)
# 6. 最终层: [B, T, hidden_size] -> [B, T, action_dim]
x = self.final_layer(x, s)
return x, s

304
roboimi/ddt/models/model.py Normal file
View File

@@ -0,0 +1,304 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DDT model and criterion classes.
核心组装文件,将 Backbone、Transformer、Diffusion 组件组装为完整模型。
"""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from .backbone import build_backbone
class SpatialSoftmax(nn.Module):
"""Spatial Softmax 层,将特征图转换为关键点坐标。
来自 Diffusion Policy保留空间位置信息。
对每个通道计算软注意力加权的期望坐标。
Args:
num_kp: 关键点数量(等于输入通道数)
temperature: Softmax 温度参数(可学习)
learnable_temperature: 是否学习温度参数
输入: [B, C, H, W]
输出: [B, C * 2] - 每个通道输出 (x, y) 坐标
"""
def __init__(self, num_kp: int = None, temperature: float = 1.0, learnable_temperature: bool = True):
super().__init__()
self.num_kp = num_kp
if learnable_temperature:
self.temperature = nn.Parameter(torch.ones(1) * temperature)
else:
self.register_buffer('temperature', torch.ones(1) * temperature)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
# 生成归一化坐标网格 [-1, 1]
pos_x = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype)
pos_y = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype)
# 展平空间维度
x_flat = x.view(B, C, -1) # [B, C, H*W]
# Softmax 得到注意力权重
attention = F.softmax(x_flat / self.temperature, dim=-1) # [B, C, H*W]
# 计算期望坐标
# pos_x: [W] -> [1, 1, W] -> repeat -> [1, 1, H*W]
pos_x_grid = pos_x.view(1, 1, 1, W).expand(1, 1, H, W).reshape(1, 1, -1)
pos_y_grid = pos_y.view(1, 1, H, 1).expand(1, 1, H, W).reshape(1, 1, -1)
# 加权求和得到期望坐标
expected_x = (attention * pos_x_grid).sum(dim=-1) # [B, C]
expected_y = (attention * pos_y_grid).sum(dim=-1) # [B, C]
# 拼接 x, y 坐标
keypoints = torch.cat([expected_x, expected_y], dim=-1) # [B, C * 2]
return keypoints
from .ddt import ActionDDT
def get_sinusoid_encoding_table(n_position, d_hid):
"""生成正弦位置编码表。"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DDT(nn.Module):
"""DDT (Decoupled Diffusion Transformer) 模型。
将视觉 Backbone 和 ActionDDT 扩散模型组合,实现基于图像观测的动作序列生成。
架构:
1. Backbone: 提取多相机图像特征
2. 特征投影: 将图像特征投影到隐藏空间 (Bottleneck 降维)
3. 状态编码: 编码机器人关节状态
4. ActionDDT: 扩散 Transformer 生成动作序列
Args:
backbones: 视觉骨干网络列表(每个相机一个)
state_dim: 机器人状态维度
action_dim: 动作维度
num_queries: 预测的动作序列长度
camera_names: 相机名称列表
hidden_dim: Transformer 隐藏维度
num_blocks: Transformer block 数量
num_encoder_blocks: 编码器 block 数量
num_heads: 注意力头数
num_kp: Spatial Softmax 的关键点数量 (默认 32)
"""
def __init__(
self,
backbones,
state_dim: int,
action_dim: int,
num_queries: int,
camera_names: list,
hidden_dim: int = 512,
num_blocks: int = 12,
num_encoder_blocks: int = 4,
num_heads: int = 8,
mlp_ratio: float = 4.0,
num_kp: int = 32, # [修改] 新增参数,默认 32
):
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.hidden_dim = hidden_dim
self.state_dim = state_dim
self.action_dim = action_dim
self.num_kp = num_kp
# Backbone 相关
self.backbones = nn.ModuleList(backbones)
# [修改] 投影层: ResNet Channels -> num_kp (32)
# 这是一个 Bottleneck 层,大幅减少特征通道数
self.input_proj = nn.Conv2d(
backbones[0].num_channels, num_kp, kernel_size=1
)
# 状态编码 (2层 MLP与 Diffusion Policy 一致)
# 状态依然映射到 hidden_dim (512),保持信息量
self.input_proj_robot_state = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# [修改] 图像特征聚合 (SpatialSoftmax)
# 输入: [B, num_kp, H, W]
# 输出: [B, num_kp * 2] (每个通道的 x, y 坐标)
self.img_feature_proj = SpatialSoftmax(num_kp=num_kp)
# [修改] 计算观测维度: 图像特征 + 状态
# 图像部分: 关键点数量 * 2(x,y) * 摄像头数量
img_feature_dim = num_kp * 2 * len(camera_names)
obs_dim = img_feature_dim + hidden_dim
# ActionDDT 扩散模型
self.action_ddt = ActionDDT(
action_dim=action_dim,
obs_dim=obs_dim, # 使用新的、更紧凑的维度
action_horizon=num_queries,
hidden_size=hidden_dim,
num_blocks=num_blocks,
num_encoder_blocks=num_encoder_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
def encode_observations(self, qpos, image):
"""编码观测(图像 + 状态)为条件向量。
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
Returns:
obs: 观测条件向量 [B, obs_dim]
"""
bs = qpos.shape[0]
# 编码图像特征
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # 取最后一层特征
# [说明] 这里的 input_proj 现在会将通道压缩到 32
features = self.input_proj(features) # [B, num_kp, H', W']
# [说明] SpatialSoftmax 提取 32 个关键点坐标
features = self.img_feature_proj(features) # [B, num_kp * 2]
all_cam_features.append(features)
# 拼接所有相机特征
img_features = torch.cat(all_cam_features, dim=-1) # [B, num_kp * 2 * num_cam]
# 编码状态
qpos_features = self.input_proj_robot_state(qpos) # [B, hidden_dim]
# 拼接观测
obs = torch.cat([img_features, qpos_features], dim=-1) # [B, obs_dim]
return obs
def forward(
self,
qpos,
image,
env_state,
actions=None,
is_pad=None,
timesteps=None,
):
"""前向传播。
训练时:
输入带噪声的动作序列和时间步,预测噪声/速度场
推理时:
通过扩散采样生成动作序列
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
env_state: 环境状态(未使用)
actions: 动作序列 [B, T, action_dim](训练时为带噪声动作)
is_pad: padding 标记 [B, T](未使用)
timesteps: 扩散时间步 [B](训练时提供)
Returns:
训练时: (noise_pred, encoder_state)
推理时: (action_pred, encoder_state)
"""
# 1. 编码观测
obs = self.encode_observations(qpos, image)
# 2. 扩散模型前向
if actions is not None and timesteps is not None:
# 训练模式: 预测噪声
noise_pred, encoder_state = self.action_ddt(
x=actions,
t=timesteps,
obs=obs,
)
return noise_pred, encoder_state
else:
# 推理模式: 需要在 Policy 层进行扩散采样
# 这里返回编码的观测,供 Policy 层使用
return obs, None
def get_obs_dim(self):
"""返回观测向量的维度。"""
# [修改] 使用 num_kp 重新计算
return self.num_kp * 2 * len(self.camera_names) + self.hidden_dim
def build(args):
"""构建 DDT 模型。
Args:
args: 包含模型配置的参数对象
- state_dim: 状态维度
- action_dim: 动作维度
- camera_names: 相机名称列表
- hidden_dim: 隐藏维度
- num_queries: 动作序列长度
- num_blocks: Transformer block 数量
- enc_layers: 编码器层数
- nheads: 注意力头数
- num_kp: 关键点数量 (可选默认32)
Returns:
model: DDT 模型实例
"""
state_dim = args.state_dim
action_dim = args.action_dim
# 构建 Backbone每个相机一个
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
# 构建 DDT 模型
model = DDT(
backbones=backbones,
state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
hidden_dim=args.hidden_dim,
num_blocks=getattr(args, 'num_blocks', 12),
num_encoder_blocks=getattr(args, 'enc_layers', 4),
num_heads=args.nheads,
mlp_ratio=getattr(args, 'mlp_ratio', 4.0),
num_kp=getattr(args, 'num_kp', 32), # [修改] 传递 num_kp 参数
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model
def build_ddt(args):
"""build 的别名,保持接口一致性。"""
return build(args)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@@ -0,0 +1,312 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional, List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

147
roboimi/ddt/policy.py Normal file
View File

@@ -0,0 +1,147 @@
"""
DDT Policy - 基于扩散模型的动作生成策略。
支持 Flow Matching 训练和推理。
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torchvision.transforms import v2
import math
from roboimi.ddt.main import build_DDT_model_and_optimizer
class DDTPolicy(nn.Module):
"""DDT (Decoupled Diffusion Transformer) 策略。
使用 Flow Matching 进行训练,支持多步扩散采样推理。
带数据增强,适配 DINOv2 等 ViT backbone。
Args:
args_override: 配置参数字典
- num_inference_steps: 推理时的扩散步数
- qpos_noise_std: qpos 噪声标准差(训练时数据增强)
- patch_h, patch_w: 图像 patch 数量(用于计算目标尺寸)
"""
def __init__(self, args_override):
super().__init__()
model, optimizer = build_DDT_model_and_optimizer(args_override)
self.model = model
self.optimizer = optimizer
self.num_inference_steps = args_override.get('num_inference_steps', 10)
self.qpos_noise_std = args_override.get('qpos_noise_std', 0.0)
# 图像尺寸配置 (适配 DINOv2)
self.patch_h = args_override.get('patch_h', 16)
self.patch_w = args_override.get('patch_w', 22)
print(f'DDT Policy: {self.num_inference_steps} steps, '
f'image size ({self.patch_h*14}, {self.patch_w*14})')
def __call__(self, qpos, image, actions=None, is_pad=None):
"""前向传播。
训练时: 使用 Flow Matching 损失
推理时: 通过扩散采样生成动作
Args:
qpos: 机器人关节状态 [B, state_dim]
image: 多相机图像 [B, num_cam, C, H, W]
actions: 目标动作序列 [B, T, action_dim](训练时提供)
is_pad: padding 标记 [B, T]
Returns:
训练时: loss_dict
推理时: 预测的动作序列 [B, T, action_dim]
"""
env_state = None
# 图像预处理
if actions is not None: # 训练时:数据增强
transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)),
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
if self.qpos_noise_std > 0:
qpos = qpos + (self.qpos_noise_std ** 0.5) * torch.randn_like(qpos)
else: # 推理时
transform = v2.Compose([
v2.Resize((self.patch_h * 14, self.patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
image = transform(image)
if actions is not None:
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
loss_dict = self._compute_loss(qpos, image, actions, is_pad)
return loss_dict
else:
a_hat = self._sample(qpos, image)
return a_hat
def _compute_loss(self, qpos, image, actions, is_pad):
"""计算 Flow Matching 损失。
Flow Matching 目标: 学习从噪声到数据的向量场
损失: ||v_theta(x_t, t) - (x_1 - x_0)||^2
其中 x_t = (1-t)*x_0 + t*x_1, x_0 是噪声, x_1 是目标动作
"""
B, T, action_dim = actions.shape
device = actions.device
t = torch.rand(B, device=device)
noise = torch.randn_like(actions)
t_expand = t.view(B, 1, 1).expand(B, T, action_dim)
x_t = (1 - t_expand) * noise + t_expand * actions
target_velocity = actions - noise
pred_velocity, _ = self.model(
qpos=qpos,
image=image,
env_state=None,
actions=x_t,
timesteps=t,
)
all_loss = F.mse_loss(pred_velocity, target_velocity, reduction='none')
loss = (all_loss * ~is_pad.unsqueeze(-1)).mean()
return {'flow_loss': loss, 'loss': loss}
@torch.no_grad()
def _sample(self, qpos, image):
"""通过 ODE 求解进行扩散采样。
使用 Euler 方法从 t=0 积分到 t=1:
x_{t+dt} = x_t + v_theta(x_t, t) * dt
"""
B = qpos.shape[0]
T = self.model.num_queries
action_dim = self.model.action_dim
device = qpos.device
x = torch.randn(B, T, action_dim, device=device)
obs = self.model.encode_observations(qpos, image)
dt = 1.0 / self.num_inference_steps
for i in range(self.num_inference_steps):
t = torch.full((B,), i * dt, device=device)
velocity, _ = self.model.action_ddt(x=x, t=t, obs=obs)
x = x + velocity * dt
return x
def configure_optimizers(self):
"""返回优化器。"""
return self.optimizer

View File

@@ -0,0 +1 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

View File

@@ -0,0 +1,88 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)

468
roboimi/ddt/util/misc.py Normal file
View File

@@ -0,0 +1,468 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from packaging import version
from typing import Optional, List
import torch
import torch.distributed as dist
from torch import Tensor
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -0,0 +1,107 @@
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
'''
Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
- fields = which results to plot from each log file - plots both training and test for each field.
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
- log_name = optional, name of log file if different than default 'log.txt'.
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results.
'''
func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
# convert single Path to list to avoid 'not iterable' error
if not isinstance(logs, list):
if isinstance(logs, PurePath):
logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}")
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for i, dir in enumerate(logs):
if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
# verify log_name exists
fn = Path(dir / log_name)
if not fn.exists():
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
print(f"--> full path of missing log file: {fn}")
return
# load log file(s) and plot
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs

View File

@@ -9,6 +9,7 @@ temporal_agg: false
# policy_class: "ACT"
# backbone: 'resnet18'
policy_class: "ACTTV"
# policy_class: "DDT"
backbone: 'dino_v2'
seed: 0
@@ -39,7 +40,7 @@ camera_names: [] # leave empty here by default
xml_dir: # leave empty here by default
# transformer settings
batch_size: 15
batch_size: 32
state_dim: 16
action_dim: 16
lr_backbone: 0.00001
@@ -58,3 +59,8 @@ chunk_size: 10
hidden_dim: 512
dim_feedforward: 3200
# DDT 特有参数
num_blocks: 12 # Transformer blocks 数量
mlp_ratio: 4.0 # MLP 维度比例
num_inference_steps: 10 # 扩散推理步数

View File

@@ -2,6 +2,7 @@ import os
import torch
from roboimi.utils.utils import load_data, set_seed
from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy
from roboimi.ddt.policy import DDTPolicy
class ModelInterface:
def __init__(self, config):
@@ -65,6 +66,24 @@ class ModelInterface:
'num_queries': 1,
'camera_names': self.config['camera_names'],
}
elif self.config['policy_class'] == 'DDT':
self.config['policy_config'] = {
'lr': self.config['lr'],
'lr_backbone': self.config['lr_backbone'],
'backbone': self.config.get('backbone', 'dino_v2'),
'num_queries': self.config['chunk_size'],
'hidden_dim': self.config['hidden_dim'],
'nheads': self.config['nheads'],
'enc_layers': self.config['enc_layers'],
'state_dim': self.config.get('state_dim', 16),
'action_dim': self.config.get('action_dim', 16),
'camera_names': self.config['camera_names'],
'qpos_noise_std': self.config.get('qpos_noise_std', 0),
# DDT 特有参数
'num_blocks': self.config.get('num_blocks', 12),
'mlp_ratio': self.config.get('mlp_ratio', 4.0),
'num_inference_steps': self.config.get('num_inference_steps', 10),
}
else:
raise NotImplementedError
@@ -75,6 +94,8 @@ class ModelInterface:
return ACTTVPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'CNNMLP':
return CNNMLPPolicy(self.config['policy_config'])
elif self.config['policy_class'] == 'DDT':
return DDTPolicy(self.config['policy_config'])
else:
raise NotImplementedError