diff --git a/roboimi/ddt/main.py b/roboimi/ddt/main.py new file mode 100644 index 0000000..91ad9e9 --- /dev/null +++ b/roboimi/ddt/main.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DDT 模型构建和优化器配置。 +""" +import argparse +from pathlib import Path + +import numpy as np +import torch +from .models import build_DDT_model + + +def get_args_parser(): + """获取 DDT 模型的参数解析器。""" + parser = argparse.ArgumentParser('DDT model configuration', add_help=False) + + # 学习率 + parser.add_argument('--lr', default=1e-4, type=float) + parser.add_argument('--lr_backbone', default=1e-5, type=float) + parser.add_argument('--batch_size', default=2, type=int) + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--epochs', default=300, type=int) + parser.add_argument('--lr_drop', default=200, type=int) + parser.add_argument('--clip_max_norm', default=0.1, type=float, + help='gradient clipping max norm') + parser.add_argument('--qpos_noise_std', action='store', default=0, type=float) + + # Backbone 参数 + parser.add_argument('--backbone', default='resnet18', type=str, + help="Name of the convolutional backbone to use") + parser.add_argument('--dilation', action='store_true', + help="If true, replace stride with dilation in the last conv block") + parser.add_argument('--position_embedding', default='sine', type=str, + choices=('sine', 'learned'), + help="Type of positional embedding") + parser.add_argument('--camera_names', default=[], type=list, + help="A list of camera names") + + # Transformer 参数 + parser.add_argument('--enc_layers', default=4, type=int, + help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=6, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=2048, type=int, + help="Intermediate size of the feedforward layers") + parser.add_argument('--hidden_dim', default=512, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--dropout', default=0.1, type=float, + help="Dropout applied in the transformer") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads") + parser.add_argument('--num_queries', default=100, type=int, + help="Number of query slots (action horizon)") + parser.add_argument('--pre_norm', action='store_true') + parser.add_argument('--state_dim', default=14, type=int) + parser.add_argument('--action_dim', default=14, type=int) + + # DDT 特有参数 + parser.add_argument('--num_blocks', default=12, type=int, + help="Total number of transformer blocks in DDT") + parser.add_argument('--mlp_ratio', default=4.0, type=float, + help="MLP hidden dimension ratio") + parser.add_argument('--num_inference_steps', default=10, type=int, + help="Number of diffusion inference steps") + + # Segmentation (未使用) + parser.add_argument('--masks', action='store_true', + help="Train segmentation head if provided") + + return parser + + +def build_DDT_model_and_optimizer(args_override): + """构建 DDT 模型和优化器。 + + Args: + args_override: 覆盖默认参数的字典 + + Returns: + model: DDT 模型 + optimizer: AdamW 优化器 + """ + parser = argparse.ArgumentParser('DDT training script', parents=[get_args_parser()]) + args = parser.parse_args([]) # 空列表避免命令行参数干扰 + + # 应用参数覆盖 + for k, v in args_override.items(): + setattr(args, k, v) + + # 构建模型 + model = build_DDT_model(args) + model.cuda() + + # 配置优化器(backbone 使用较小学习率) + param_dicts = [ + { + "params": [p for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad] + }, + { + "params": [p for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, + lr=args.lr, + weight_decay=args.weight_decay + ) + + return model, optimizer diff --git a/roboimi/ddt/models/__init__.py b/roboimi/ddt/models/__init__.py new file mode 100644 index 0000000..fbc513e --- /dev/null +++ b/roboimi/ddt/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .model import build as build_ddt +from .model import build_ddt + +def build_DDT_model(args): + """构建 DDT 模型的统一入口。""" + return build_ddt(args) diff --git a/roboimi/ddt/models/backbone.py b/roboimi/ddt/models/backbone.py new file mode 100644 index 0000000..759bfb5 --- /dev/null +++ b/roboimi/ddt/models/backbone.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +# class DINOv2BackBone(nn.Module): +# def __init__(self) -> None: +# super().__init__() +# self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') +# self.body.eval() +# self.num_channels = 384 + +# @torch.no_grad() +# def forward(self, tensor): +# xs = self.body.forward_features(tensor)["x_norm_patchtokens"] +# od = OrderedDict() +# od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1) +# return od + +class DINOv2BackBone(nn.Module): + def __init__(self, return_interm_layers: bool = False) -> None: + super().__init__() + self.body = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') + self.body.eval() + self.num_channels = 384 + self.return_interm_layers = return_interm_layers + + @torch.no_grad() + def forward(self, tensor): + features = self.body.forward_features(tensor) + + if self.return_interm_layers: + + layer1 = features["x_norm_patchtokens"] + layer2 = features["x_norm_patchtokens"] + layer3 = features["x_norm_patchtokens"] + layer4 = features["x_norm_patchtokens"] + + od = OrderedDict() + od["0"] = layer1.reshape(layer1.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["1"] = layer2.reshape(layer2.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["2"] = layer3.reshape(layer3.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + od["3"] = layer4.reshape(layer4.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + return od + else: + xs = features["x_norm_patchtokens"] + od = OrderedDict() + od["0"] = xs.reshape(xs.shape[0], 22, 16, 384).permute(0, 3, 2, 1) + return od + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + if args.backbone == 'dino_v2': + backbone = DINOv2BackBone() + else: + assert args.backbone in ['resnet18', 'resnet34'] + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/roboimi/ddt/models/ddt.py b/roboimi/ddt/models/ddt.py new file mode 100644 index 0000000..8efdd42 --- /dev/null +++ b/roboimi/ddt/models/ddt.py @@ -0,0 +1,631 @@ +""" +动作序列扩散 Transformer (Action Decoupled Diffusion Transformer) + +基于 DDT 架构修改,用于生成机器人动作序列。 +主要改动: + 1. 2D RoPE → 1D RoPE (适配时序数据) + 2. LabelEmbedder → ObservationEncoder (观测条件) + 3. 去除 patchify/unpatchify (动作序列已是 1D) +""" + +import math +from typing import Tuple, Optional + +import torch +import torch.nn as nn +from torch.nn.functional import scaled_dot_product_attention + + +# ============================================================================ +# 通用工具函数 +# ============================================================================ + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """AdaLN 调制函数。 + + Args: + x: 输入张量。 + shift: 偏移量。 + scale: 缩放量。 + + Returns: + 调制后的张量: x * (1 + scale) + shift + """ + return x * (1 + scale) + shift + + +# ============================================================================ +# 1D 旋转位置编码 (RoPE) +# ============================================================================ + +def precompute_freqs_cis_1d(dim: int, seq_len: int, theta: float = 10000.0) -> torch.Tensor: + """预计算 1D 旋转位置编码的复数频率。 + + 用于时序数据(如动作序列)的位置编码,相比 2D RoPE 更简单高效。 + + Args: + dim: 每个注意力头的维度 (head_dim)。 + seq_len: 序列长度。 + theta: RoPE 的基础频率,默认 10000.0。 + + Returns: + 复数频率张量,形状为 (seq_len, dim//2)。 + """ + # 计算频率: 1 / (theta^(2i/dim)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2] + # 位置索引 + t = torch.arange(seq_len).float() # [seq_len] + # 外积得到位置-频率矩阵 + freqs = torch.outer(t, freqs) # [seq_len, dim//2] + # 转换为复数形式 (极坐标) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [seq_len, dim//2] + return freqs_cis + + +def apply_rotary_emb_1d( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """应用 1D 旋转位置编码到 Query 和 Key。 + + Args: + xq: Query 张量,形状为 (B, N, H, Hc)。 + xk: Key 张量,形状为 (B, N, H, Hc)。 + freqs_cis: 预计算的复数频率,形状为 (N, Hc//2)。 + + Returns: + 应用 RoPE 后的 (xq, xk),形状不变。 + """ + # 调整 freqs_cis 形状以便广播: [1, N, 1, Hc//2] + freqs_cis = freqs_cis[None, :, None, :] + + # 将实数张量视为复数: [B, N, H, Hc] -> [B, N, H, Hc//2] (复数) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + + # 复数乘法实现旋转 + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [B, N, H, Hc] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# ============================================================================ +# 基础组件 +# ============================================================================ + +class Embed(nn.Module): + """线性嵌入层,将输入投影到隐藏空间。""" + + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[nn.Module] = None, + bias: bool = True, + ): + """初始化 Embed。 + + Args: + in_chans: 输入通道数/维度。 + embed_dim: 输出嵌入维度。 + norm_layer: 可选的归一化层。 + bias: 是否使用偏置。 + """ + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + x = self.norm(x) + return x + + +class TimestepEmbedder(nn.Module): + """扩散时间步嵌入器。 + + 使用正弦位置编码 + MLP 将标量时间步映射到高维向量。 + """ + + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + """初始化 TimestepEmbedder。 + + Args: + hidden_size: 输出嵌入维度。 + frequency_embedding_size: 正弦编码的维度。 + """ + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10.0) -> torch.Tensor: + """生成正弦时间步嵌入。 + + Args: + t: 时间步张量,形状为 (B,)。 + dim: 嵌入维度。 + max_period: 最大周期。 + + Returns: + 时间步嵌入,形状为 (B, dim)。 + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: torch.Tensor) -> torch.Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ObservationEncoder(nn.Module): + """观测状态编码器。 + + 将机器人的观测向量(如关节位置、末端位姿、图像特征等) + 编码为条件向量,用于条件扩散生成。 + + Attributes: + encoder: 两层 MLP 编码器。 + + Example: + >>> encoder = ObservationEncoder(obs_dim=128, hidden_size=512) + >>> obs = torch.randn(2, 128) + >>> cond = encoder(obs) # [2, 512] + """ + + def __init__(self, obs_dim: int, hidden_size: int): + """初始化 ObservationEncoder。 + + Args: + obs_dim: 观测向量的维度。 + hidden_size: 输出条件向量的维度。 + """ + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(obs_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + """前向传播。 + + Args: + obs: 观测向量,形状为 (B, obs_dim)。 + + Returns: + 条件向量,形状为 (B, hidden_size)。 + """ + return self.encoder(obs) + + +class FinalLayer(nn.Module): + """最终输出层,使用 AdaLN 调制后输出预测结果。""" + + def __init__(self, hidden_size: int, out_channels: int): + """初始化 FinalLayer。 + + Args: + hidden_size: 输入隐藏维度。 + out_channels: 输出通道数/维度。 + """ + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + """前向传播。 + + Args: + x: 输入张量,形状为 (B, N, hidden_size)。 + c: 条件张量,形状为 (B, N, hidden_size) 或 (B, 1, hidden_size)。 + + Returns: + 输出张量,形状为 (B, N, out_channels)。 + """ + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +# ============================================================================ +# 归一化和前馈网络 +# ============================================================================ + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization (RMS 归一化层)。 + + RMSNorm 是 LayerNorm 的简化版本,去掉了均值中心化操作,只保留缩放。 + 相比 LayerNorm 计算更快,效果相当,被广泛用于 LLaMA、Mistral 等大模型。 + + 数学公式: + RMSNorm(x) = x / sqrt(mean(x^2) + eps) * weight + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + """初始化 RMSNorm。 + + Args: + hidden_size: 输入特征的维度。 + eps: 防止除零的小常数。 + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class FeedForward(nn.Module): + """SwiGLU 前馈网络 (Feed-Forward Network)。 + + 使用 SwiGLU 门控激活函数的前馈网络,来自 LLaMA 架构。 + + 结构: + output = W2(SiLU(W1(x)) * W3(x)) + """ + + def __init__(self, dim: int, hidden_dim: int): + """初始化 FeedForward。 + + Args: + dim: 输入和输出的特征维度。 + hidden_dim: 隐藏层维度(实际使用 2/3 * hidden_dim)。 + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + + +# ============================================================================ +# 注意力机制 +# ============================================================================ + +class RAttention(nn.Module): + """带有旋转位置编码的多头自注意力 (Rotary Attention)。 + + 集成了以下技术: + - 1D RoPE: 通过复数旋转编码时序位置信息 + - QK-Norm: 对 Query 和 Key 进行归一化,稳定训练 + - Flash Attention: 使用 scaled_dot_product_attention + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + """初始化 RAttention。 + + Args: + dim: 输入特征维度,必须能被 num_heads 整除。 + num_heads: 注意力头数。 + qkv_bias: QKV 投影是否使用偏置。 + qk_norm: 是否对 Q, K 进行归一化。 + attn_drop: 注意力权重的 dropout 率。 + proj_drop: 输出投影的 dropout 率。 + norm_layer: 归一化层类型。 + """ + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, + pos: torch.Tensor, + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """前向传播。 + + Args: + x: 输入张量,形状为 (B, N, C)。 + pos: 1D RoPE 位置编码,形状为 (N, head_dim//2)。 + mask: 可选的注意力掩码。 + + Returns: + 输出张量,形状为 (B, N, C)。 + """ + B, N, C = x.shape + + # QKV 投影 + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, N, H, Hc] + + # QK-Norm + q = self.q_norm(q) + k = self.k_norm(k) + + # 应用 1D RoPE + q, k = apply_rotary_emb_1d(q, k, freqs_cis=pos) + + # 调整维度: [B, N, H, Hc] -> [B, H, N, Hc] + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + # Scaled Dot-Product Attention + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + # 输出投影 + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# ============================================================================ +# Transformer Block +# ============================================================================ + +class ActionDDTBlock(nn.Module): + """动作 DDT Transformer Block。 + + 结构: Pre-Norm + AdaLN + Attention + FFN + """ + + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0): + """初始化 ActionDDTBlock。 + + Args: + hidden_size: 隐藏层维度。 + num_heads: 注意力头数。 + mlp_ratio: FFN 隐藏层倍率。 + """ + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=num_heads, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + pos: torch.Tensor, + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """前向传播。 + + Args: + x: 输入张量,形状为 (B, N, hidden_size)。 + c: 条件张量,形状为 (B, 1, hidden_size) 或 (B, N, hidden_size)。 + pos: 位置编码。 + mask: 可选的注意力掩码。 + + Returns: + 输出张量,形状为 (B, N, hidden_size)。 + """ + # AdaLN 调制参数 + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ + self.adaLN_modulation(c).chunk(6, dim=-1) + + # Attention 分支 + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + # FFN 分支 + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + + return x + + +# ============================================================================ +# 主模型: ActionDDT +# ============================================================================ + +class ActionDDT(nn.Module): + """动作序列解耦扩散 Transformer (Action Decoupled Diffusion Transformer)。 + + 基于 DDT 架构,专为机器人动作序列生成设计。 + 将模型解耦为编码器和解码器两部分,编码器状态可缓存以加速推理。 + + 架构: + - 编码器: 前 num_encoder_blocks 个 block,生成状态 s + - 解码器: 剩余 block,使用状态 s 对动作序列 x 去噪 + - 使用 1D RoPE 进行时序位置编码 + - 使用 AdaLN 注入时间步和观测条件 + + Args: + action_dim: 动作向量维度(如 7-DoF 机械臂为 7)。 + obs_dim: 观测向量维度。 + action_horizon: 预测的动作序列长度。 + hidden_size: Transformer 隐藏层维度。 + num_blocks: Transformer block 总数。 + num_encoder_blocks: 编码器 block 数量。 + num_heads: 注意力头数。 + mlp_ratio: FFN 隐藏层倍率。 + + 输入: + x (Tensor): 带噪声的动作序列,形状为 (B, T, action_dim)。 + t (Tensor): 扩散时间步,形状为 (B,),取值范围 [0, 1]。 + obs (Tensor): 观测条件,形状为 (B, obs_dim)。 + s (Tensor, optional): 缓存的编码器状态。 + + 输出: + x (Tensor): 预测的速度场/噪声,形状为 (B, T, action_dim)。 + s (Tensor): 编码器状态,可缓存复用。 + + Example: + >>> model = ActionDDT(action_dim=7, obs_dim=128, action_horizon=16) + >>> x = torch.randn(2, 16, 7) # 带噪声的动作序列 + >>> t = torch.rand(2) # 随机时间步 + >>> obs = torch.randn(2, 128) # 观测条件 + >>> out, state = model(x, t, obs) + >>> out.shape + torch.Size([2, 16, 7]) + """ + + def __init__( + self, + action_dim: int = 7, + obs_dim: int = 128, + action_horizon: int = 16, + hidden_size: int = 512, + num_blocks: int = 12, + num_encoder_blocks: int = 4, + num_heads: int = 8, + mlp_ratio: float = 4.0, + ): + super().__init__() + + # 保存配置 + self.action_dim = action_dim + self.obs_dim = obs_dim + self.action_horizon = action_horizon + self.hidden_size = hidden_size + self.num_blocks = num_blocks + self.num_encoder_blocks = num_encoder_blocks + self.num_heads = num_heads + + # 动作嵌入层 + self.x_embedder = Embed(action_dim, hidden_size, bias=True) + self.s_embedder = Embed(action_dim, hidden_size, bias=True) + + # 条件嵌入 + self.t_embedder = TimestepEmbedder(hidden_size) + self.obs_encoder = ObservationEncoder(obs_dim, hidden_size) + + # 输出层 + self.final_layer = FinalLayer(hidden_size, action_dim) + + # Transformer blocks + self.blocks = nn.ModuleList([ + ActionDDTBlock(hidden_size, num_heads, mlp_ratio) + for _ in range(num_blocks) + ]) + + # 预计算 1D 位置编码 + pos = precompute_freqs_cis_1d(hidden_size // num_heads, action_horizon) + self.register_buffer('pos', pos) + + # 初始化权重 + self.initialize_weights() + + def initialize_weights(self): + """初始化模型权重。""" + # 嵌入层使用 Xavier 初始化 + for embedder in [self.x_embedder, self.s_embedder]: + w = embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(embedder.proj.bias, 0) + + # 时间步嵌入 MLP + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # 观测编码器 + for m in self.obs_encoder.encoder: + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # 输出层零初始化 (AdaLN-Zero) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + obs: torch.Tensor, + s: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """前向传播。 + + Args: + x: 带噪声的动作序列 [B, T, action_dim] + t: 扩散时间步 [B] 或 [B, 1],取值范围 [0, 1] + obs: 观测条件 [B, obs_dim] + s: 可选的编码器状态缓存 [B, T, hidden_size] + mask: 可选的注意力掩码 + + Returns: + x: 预测的速度场/噪声 [B, T, action_dim] + s: 编码器状态 [B, T, hidden_size],可缓存复用 + """ + B, T, _ = x.shape + + # 1. 时间步嵌入: [B] -> [B, 1, hidden_size] + t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size) + + # 2. 观测条件嵌入: [B, obs_dim] -> [B, 1, hidden_size] + obs_emb = self.obs_encoder(obs).view(B, 1, self.hidden_size) + + # 3. 融合条件: c = SiLU(t + obs) + c = nn.functional.silu(t_emb + obs_emb) + + # 4. 编码器部分: 生成状态 s + if s is None: + # 状态嵌入: [B, T, action_dim] -> [B, T, hidden_size] + s = self.s_embedder(x) + # 通过编码器 blocks + for i in range(self.num_encoder_blocks): + s = self.blocks[i](s, c, self.pos, mask) + # 融合时间信息 + s = nn.functional.silu(t_emb + s) + + # 5. 解码器部分: 去噪 + # 输入嵌入: [B, T, action_dim] -> [B, T, hidden_size] + x = self.x_embedder(x) + # 通过解码器 blocks,以 s 作为条件 + for i in range(self.num_encoder_blocks, self.num_blocks): + x = self.blocks[i](x, s, self.pos, None) + + # 6. 最终层: [B, T, hidden_size] -> [B, T, action_dim] + x = self.final_layer(x, s) + + return x, s \ No newline at end of file diff --git a/roboimi/ddt/models/model.py b/roboimi/ddt/models/model.py new file mode 100644 index 0000000..5176c38 --- /dev/null +++ b/roboimi/ddt/models/model.py @@ -0,0 +1,304 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DDT model and criterion classes. + +核心组装文件,将 Backbone、Transformer、Diffusion 组件组装为完整模型。 +""" +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np + +from .backbone import build_backbone + + +class SpatialSoftmax(nn.Module): + """Spatial Softmax 层,将特征图转换为关键点坐标。 + + 来自 Diffusion Policy,保留空间位置信息。 + 对每个通道计算软注意力加权的期望坐标。 + + Args: + num_kp: 关键点数量(等于输入通道数) + temperature: Softmax 温度参数(可学习) + learnable_temperature: 是否学习温度参数 + + 输入: [B, C, H, W] + 输出: [B, C * 2] - 每个通道输出 (x, y) 坐标 + """ + + def __init__(self, num_kp: int = None, temperature: float = 1.0, learnable_temperature: bool = True): + super().__init__() + self.num_kp = num_kp + if learnable_temperature: + self.temperature = nn.Parameter(torch.ones(1) * temperature) + else: + self.register_buffer('temperature', torch.ones(1) * temperature) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + + # 生成归一化坐标网格 [-1, 1] + pos_x = torch.linspace(-1, 1, W, device=x.device, dtype=x.dtype) + pos_y = torch.linspace(-1, 1, H, device=x.device, dtype=x.dtype) + + # 展平空间维度 + x_flat = x.view(B, C, -1) # [B, C, H*W] + + # Softmax 得到注意力权重 + attention = F.softmax(x_flat / self.temperature, dim=-1) # [B, C, H*W] + + # 计算期望坐标 + # pos_x: [W] -> [1, 1, W] -> repeat -> [1, 1, H*W] + pos_x_grid = pos_x.view(1, 1, 1, W).expand(1, 1, H, W).reshape(1, 1, -1) + pos_y_grid = pos_y.view(1, 1, H, 1).expand(1, 1, H, W).reshape(1, 1, -1) + + # 加权求和得到期望坐标 + expected_x = (attention * pos_x_grid).sum(dim=-1) # [B, C] + expected_y = (attention * pos_y_grid).sum(dim=-1) # [B, C] + + # 拼接 x, y 坐标 + keypoints = torch.cat([expected_x, expected_y], dim=-1) # [B, C * 2] + + return keypoints + +from .ddt import ActionDDT + + +def get_sinusoid_encoding_table(n_position, d_hid): + """生成正弦位置编码表。""" + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DDT(nn.Module): + """DDT (Decoupled Diffusion Transformer) 模型。 + + 将视觉 Backbone 和 ActionDDT 扩散模型组合,实现基于图像观测的动作序列生成。 + + 架构: + 1. Backbone: 提取多相机图像特征 + 2. 特征投影: 将图像特征投影到隐藏空间 (Bottleneck 降维) + 3. 状态编码: 编码机器人关节状态 + 4. ActionDDT: 扩散 Transformer 生成动作序列 + + Args: + backbones: 视觉骨干网络列表(每个相机一个) + state_dim: 机器人状态维度 + action_dim: 动作维度 + num_queries: 预测的动作序列长度 + camera_names: 相机名称列表 + hidden_dim: Transformer 隐藏维度 + num_blocks: Transformer block 数量 + num_encoder_blocks: 编码器 block 数量 + num_heads: 注意力头数 + num_kp: Spatial Softmax 的关键点数量 (默认 32) + """ + + def __init__( + self, + backbones, + state_dim: int, + action_dim: int, + num_queries: int, + camera_names: list, + hidden_dim: int = 512, + num_blocks: int = 12, + num_encoder_blocks: int = 4, + num_heads: int = 8, + mlp_ratio: float = 4.0, + num_kp: int = 32, # [修改] 新增参数,默认 32 + ): + super().__init__() + + self.num_queries = num_queries + self.camera_names = camera_names + self.hidden_dim = hidden_dim + self.state_dim = state_dim + self.action_dim = action_dim + self.num_kp = num_kp + + # Backbone 相关 + self.backbones = nn.ModuleList(backbones) + + # [修改] 投影层: ResNet Channels -> num_kp (32) + # 这是一个 Bottleneck 层,大幅减少特征通道数 + self.input_proj = nn.Conv2d( + backbones[0].num_channels, num_kp, kernel_size=1 + ) + + # 状态编码 (2层 MLP,与 Diffusion Policy 一致) + # 状态依然映射到 hidden_dim (512),保持信息量 + self.input_proj_robot_state = nn.Sequential( + nn.Linear(state_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + + # [修改] 图像特征聚合 (SpatialSoftmax) + # 输入: [B, num_kp, H, W] + # 输出: [B, num_kp * 2] (每个通道的 x, y 坐标) + self.img_feature_proj = SpatialSoftmax(num_kp=num_kp) + + # [修改] 计算观测维度: 图像特征 + 状态 + # 图像部分: 关键点数量 * 2(x,y) * 摄像头数量 + img_feature_dim = num_kp * 2 * len(camera_names) + obs_dim = img_feature_dim + hidden_dim + + # ActionDDT 扩散模型 + self.action_ddt = ActionDDT( + action_dim=action_dim, + obs_dim=obs_dim, # 使用新的、更紧凑的维度 + action_horizon=num_queries, + hidden_size=hidden_dim, + num_blocks=num_blocks, + num_encoder_blocks=num_encoder_blocks, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + ) + + def encode_observations(self, qpos, image): + """编码观测(图像 + 状态)为条件向量。 + + Args: + qpos: 机器人关节状态 [B, state_dim] + image: 多相机图像 [B, num_cam, C, H, W] + + Returns: + obs: 观测条件向量 [B, obs_dim] + """ + bs = qpos.shape[0] + + # 编码图像特征 + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # 取最后一层特征 + + # [说明] 这里的 input_proj 现在会将通道压缩到 32 + features = self.input_proj(features) # [B, num_kp, H', W'] + + # [说明] SpatialSoftmax 提取 32 个关键点坐标 + features = self.img_feature_proj(features) # [B, num_kp * 2] + + all_cam_features.append(features) + + # 拼接所有相机特征 + img_features = torch.cat(all_cam_features, dim=-1) # [B, num_kp * 2 * num_cam] + + # 编码状态 + qpos_features = self.input_proj_robot_state(qpos) # [B, hidden_dim] + + # 拼接观测 + obs = torch.cat([img_features, qpos_features], dim=-1) # [B, obs_dim] + + return obs + + def forward( + self, + qpos, + image, + env_state, + actions=None, + is_pad=None, + timesteps=None, + ): + """前向传播。 + + 训练时: + 输入带噪声的动作序列和时间步,预测噪声/速度场 + 推理时: + 通过扩散采样生成动作序列 + + Args: + qpos: 机器人关节状态 [B, state_dim] + image: 多相机图像 [B, num_cam, C, H, W] + env_state: 环境状态(未使用) + actions: 动作序列 [B, T, action_dim](训练时为带噪声动作) + is_pad: padding 标记 [B, T](未使用) + timesteps: 扩散时间步 [B](训练时提供) + + Returns: + 训练时: (noise_pred, encoder_state) + 推理时: (action_pred, encoder_state) + """ + # 1. 编码观测 + obs = self.encode_observations(qpos, image) + + # 2. 扩散模型前向 + if actions is not None and timesteps is not None: + # 训练模式: 预测噪声 + noise_pred, encoder_state = self.action_ddt( + x=actions, + t=timesteps, + obs=obs, + ) + return noise_pred, encoder_state + else: + # 推理模式: 需要在 Policy 层进行扩散采样 + # 这里返回编码的观测,供 Policy 层使用 + return obs, None + + def get_obs_dim(self): + """返回观测向量的维度。""" + # [修改] 使用 num_kp 重新计算 + return self.num_kp * 2 * len(self.camera_names) + self.hidden_dim + + +def build(args): + """构建 DDT 模型。 + + Args: + args: 包含模型配置的参数对象 + - state_dim: 状态维度 + - action_dim: 动作维度 + - camera_names: 相机名称列表 + - hidden_dim: 隐藏维度 + - num_queries: 动作序列长度 + - num_blocks: Transformer block 数量 + - enc_layers: 编码器层数 + - nheads: 注意力头数 + - num_kp: 关键点数量 (可选,默认32) + + Returns: + model: DDT 模型实例 + """ + state_dim = args.state_dim + action_dim = args.action_dim + + # 构建 Backbone(每个相机一个) + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + # 构建 DDT 模型 + model = DDT( + backbones=backbones, + state_dim=state_dim, + action_dim=action_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + hidden_dim=args.hidden_dim, + num_blocks=getattr(args, 'num_blocks', 12), + num_encoder_blocks=getattr(args, 'enc_layers', 4), + num_heads=args.nheads, + mlp_ratio=getattr(args, 'mlp_ratio', 4.0), + num_kp=getattr(args, 'num_kp', 32), # [修改] 传递 num_kp 参数 + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_ddt(args): + """build 的别名,保持接口一致性。""" + return build(args) \ No newline at end of file diff --git a/roboimi/ddt/models/position_encoding.py b/roboimi/ddt/models/position_encoding.py new file mode 100644 index 0000000..c75733e --- /dev/null +++ b/roboimi/ddt/models/position_encoding.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/roboimi/ddt/models/transformer.py b/roboimi/ddt/models/transformer.py new file mode 100644 index 0000000..2306ab2 --- /dev/null +++ b/roboimi/ddt/models/transformer.py @@ -0,0 +1,312 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/roboimi/ddt/policy.py b/roboimi/ddt/policy.py new file mode 100644 index 0000000..3b79b09 --- /dev/null +++ b/roboimi/ddt/policy.py @@ -0,0 +1,147 @@ +""" +DDT Policy - 基于扩散模型的动作生成策略。 + +支持 Flow Matching 训练和推理。 +""" +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchvision.transforms as transforms +from torchvision.transforms import v2 +import math + +from roboimi.ddt.main import build_DDT_model_and_optimizer + + +class DDTPolicy(nn.Module): + """DDT (Decoupled Diffusion Transformer) 策略。 + + 使用 Flow Matching 进行训练,支持多步扩散采样推理。 + 带数据增强,适配 DINOv2 等 ViT backbone。 + + Args: + args_override: 配置参数字典 + - num_inference_steps: 推理时的扩散步数 + - qpos_noise_std: qpos 噪声标准差(训练时数据增强) + - patch_h, patch_w: 图像 patch 数量(用于计算目标尺寸) + """ + + def __init__(self, args_override): + super().__init__() + model, optimizer = build_DDT_model_and_optimizer(args_override) + self.model = model + self.optimizer = optimizer + + self.num_inference_steps = args_override.get('num_inference_steps', 10) + self.qpos_noise_std = args_override.get('qpos_noise_std', 0.0) + + # 图像尺寸配置 (适配 DINOv2) + self.patch_h = args_override.get('patch_h', 16) + self.patch_w = args_override.get('patch_w', 22) + + print(f'DDT Policy: {self.num_inference_steps} steps, ' + f'image size ({self.patch_h*14}, {self.patch_w*14})') + + def __call__(self, qpos, image, actions=None, is_pad=None): + """前向传播。 + + 训练时: 使用 Flow Matching 损失 + 推理时: 通过扩散采样生成动作 + + Args: + qpos: 机器人关节状态 [B, state_dim] + image: 多相机图像 [B, num_cam, C, H, W] + actions: 目标动作序列 [B, T, action_dim](训练时提供) + is_pad: padding 标记 [B, T] + + Returns: + 训练时: loss_dict + 推理时: 预测的动作序列 [B, T, action_dim] + """ + env_state = None + + # 图像预处理 + if actions is not None: # 训练时:数据增强 + transform = v2.Compose([ + v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), + v2.RandomPerspective(distortion_scale=0.5), + v2.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), + v2.GaussianBlur(kernel_size=(9, 9), sigma=(0.1, 2.0)), + v2.Resize((self.patch_h * 14, self.patch_w * 14)), + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + if self.qpos_noise_std > 0: + qpos = qpos + (self.qpos_noise_std ** 0.5) * torch.randn_like(qpos) + else: # 推理时 + transform = v2.Compose([ + v2.Resize((self.patch_h * 14, self.patch_w * 14)), + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ]) + + image = transform(image) + + if actions is not None: + actions = actions[:, :self.model.num_queries] + is_pad = is_pad[:, :self.model.num_queries] + loss_dict = self._compute_loss(qpos, image, actions, is_pad) + return loss_dict + else: + a_hat = self._sample(qpos, image) + return a_hat + + def _compute_loss(self, qpos, image, actions, is_pad): + """计算 Flow Matching 损失。 + + Flow Matching 目标: 学习从噪声到数据的向量场 + 损失: ||v_theta(x_t, t) - (x_1 - x_0)||^2 + 其中 x_t = (1-t)*x_0 + t*x_1, x_0 是噪声, x_1 是目标动作 + """ + B, T, action_dim = actions.shape + device = actions.device + + t = torch.rand(B, device=device) + noise = torch.randn_like(actions) + + t_expand = t.view(B, 1, 1).expand(B, T, action_dim) + x_t = (1 - t_expand) * noise + t_expand * actions + target_velocity = actions - noise + + pred_velocity, _ = self.model( + qpos=qpos, + image=image, + env_state=None, + actions=x_t, + timesteps=t, + ) + + all_loss = F.mse_loss(pred_velocity, target_velocity, reduction='none') + loss = (all_loss * ~is_pad.unsqueeze(-1)).mean() + + return {'flow_loss': loss, 'loss': loss} + + @torch.no_grad() + def _sample(self, qpos, image): + """通过 ODE 求解进行扩散采样。 + + 使用 Euler 方法从 t=0 积分到 t=1: + x_{t+dt} = x_t + v_theta(x_t, t) * dt + """ + B = qpos.shape[0] + T = self.model.num_queries + action_dim = self.model.action_dim + device = qpos.device + + x = torch.randn(B, T, action_dim, device=device) + obs = self.model.encode_observations(qpos, image) + + dt = 1.0 / self.num_inference_steps + for i in range(self.num_inference_steps): + t = torch.full((B,), i * dt, device=device) + velocity, _ = self.model.action_ddt(x=x, t=t, obs=obs) + x = x + velocity * dt + + return x + + def configure_optimizers(self): + """返回优化器。""" + return self.optimizer diff --git a/roboimi/ddt/util/__init__.py b/roboimi/ddt/util/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/roboimi/ddt/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/roboimi/ddt/util/box_ops.py b/roboimi/ddt/util/box_ops.py new file mode 100644 index 0000000..9c088e5 --- /dev/null +++ b/roboimi/ddt/util/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/roboimi/ddt/util/misc.py b/roboimi/ddt/util/misc.py new file mode 100644 index 0000000..dfa9fb5 --- /dev/null +++ b/roboimi/ddt/util/misc.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/roboimi/ddt/util/plot_utils.py b/roboimi/ddt/util/plot_utils.py new file mode 100644 index 0000000..0f24bed --- /dev/null +++ b/roboimi/ddt/util/plot_utils.py @@ -0,0 +1,107 @@ +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if not dir.exists(): + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + # verify log_name exists + fn = Path(dir / log_name) + if not fn.exists(): + print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") + print(f"--> full path of missing log file: {fn}") + return + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame( + np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] + ).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/roboimi/demos/config.yaml b/roboimi/demos/config.yaml index 16ab129..31d6aff 100644 --- a/roboimi/demos/config.yaml +++ b/roboimi/demos/config.yaml @@ -9,6 +9,7 @@ temporal_agg: false # policy_class: "ACT" # backbone: 'resnet18' policy_class: "ACTTV" +# policy_class: "DDT" backbone: 'dino_v2' seed: 0 @@ -39,7 +40,7 @@ camera_names: [] # leave empty here by default xml_dir: # leave empty here by default # transformer settings -batch_size: 15 +batch_size: 32 state_dim: 16 action_dim: 16 lr_backbone: 0.00001 @@ -56,5 +57,10 @@ lr: 0.00001 kl_weight: 100 chunk_size: 10 hidden_dim: 512 -dim_feedforward: 3200 +dim_feedforward: 3200 + +# DDT 特有参数 +num_blocks: 12 # Transformer blocks 数量 +mlp_ratio: 4.0 # MLP 维度比例 +num_inference_steps: 10 # 扩散推理步数 diff --git a/roboimi/utils/model_interface.py b/roboimi/utils/model_interface.py index 007b0f7..d5c7821 100644 --- a/roboimi/utils/model_interface.py +++ b/roboimi/utils/model_interface.py @@ -2,6 +2,7 @@ import os import torch from roboimi.utils.utils import load_data, set_seed from roboimi.detr.policy import ACTPolicy, CNNMLPPolicy,ACTTVPolicy +from roboimi.ddt.policy import DDTPolicy class ModelInterface: def __init__(self, config): @@ -65,6 +66,24 @@ class ModelInterface: 'num_queries': 1, 'camera_names': self.config['camera_names'], } + elif self.config['policy_class'] == 'DDT': + self.config['policy_config'] = { + 'lr': self.config['lr'], + 'lr_backbone': self.config['lr_backbone'], + 'backbone': self.config.get('backbone', 'dino_v2'), + 'num_queries': self.config['chunk_size'], + 'hidden_dim': self.config['hidden_dim'], + 'nheads': self.config['nheads'], + 'enc_layers': self.config['enc_layers'], + 'state_dim': self.config.get('state_dim', 16), + 'action_dim': self.config.get('action_dim', 16), + 'camera_names': self.config['camera_names'], + 'qpos_noise_std': self.config.get('qpos_noise_std', 0), + # DDT 特有参数 + 'num_blocks': self.config.get('num_blocks', 12), + 'mlp_ratio': self.config.get('mlp_ratio', 4.0), + 'num_inference_steps': self.config.get('num_inference_steps', 10), + } else: raise NotImplementedError @@ -75,6 +94,8 @@ class ModelInterface: return ACTTVPolicy(self.config['policy_config']) elif self.config['policy_class'] == 'CNNMLP': return CNNMLPPolicy(self.config['policy_config']) + elif self.config['policy_class'] == 'DDT': + return DDTPolicy(self.config['policy_config']) else: raise NotImplementedError