From c2000b55331708c2b7a3764c426d485f0f36bdbf Mon Sep 17 00:00:00 2001 From: Logic Date: Wed, 1 Apr 2026 23:35:31 +0800 Subject: [PATCH] feat: add IMF AttnRes policy training path --- roboimi/demos/vla_scripts/train_vla.py | 32 +- roboimi/vla/agent_imf.py | 161 +++++++ .../vla/conf/agent/resnet_imf_attnres.yaml | 40 ++ roboimi/vla/conf/config.yaml | 1 + roboimi/vla/conf/head/imf_transformer1d.yaml | 22 + .../heads/attnres_transformer_components.py | 249 ++++++++++ roboimi/vla/models/heads/imf_transformer1d.py | 379 ++++++++++++++++ ...st_imf_transformer1d_external_alignment.py | 196 ++++++++ tests/test_imf_vla_agent.py | 427 ++++++++++++++++++ tests/test_train_vla_transformer_optimizer.py | 70 ++- 10 files changed, 1566 insertions(+), 11 deletions(-) create mode 100644 roboimi/vla/agent_imf.py create mode 100644 roboimi/vla/conf/agent/resnet_imf_attnres.yaml create mode 100644 roboimi/vla/conf/head/imf_transformer1d.yaml create mode 100644 roboimi/vla/models/heads/attnres_transformer_components.py create mode 100644 roboimi/vla/models/heads/imf_transformer1d.py create mode 100644 tests/test_imf_transformer1d_external_alignment.py create mode 100644 tests/test_imf_vla_agent.py diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 8b3e787..cd5da37 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -14,8 +14,17 @@ from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from pathlib import Path -# 确保正确的导入路径 -sys.path.append(os.getcwd()) +# 确保正确的导入路径(不能依赖 cwd,因为 Hydra 会在运行时切换 cwd) +def _ensure_repo_root_on_syspath(): + repo_root = Path(__file__).resolve().parents[3] + repo_root_str = str(repo_root) + if repo_root_str in sys.path: + sys.path.remove(repo_root_str) + sys.path.insert(0, repo_root_str) + return repo_root + + +_REPO_ROOT = _ensure_repo_root_on_syspath() from hydra.utils import instantiate @@ -26,6 +35,13 @@ if not OmegaConf.has_resolver("len"): OmegaConf.register_new_resolver("len", lambda x: len(x)) +def _configure_cuda_runtime(cfg): + """Apply process-level CUDA runtime switches required by this environment.""" + if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)): + torch.backends.cudnn.enabled = False + log.warning('⚠️ 已按配置禁用 cuDNN;GPU 卷积将回退到非-cuDNN 实现') + + def recursive_to_device(data, device): """ 递归地将嵌套字典/列表中的张量移动到指定设备。 @@ -113,14 +129,11 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty def build_training_optimizer(agent, lr, weight_decay): - """为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。""" + """为训练脚本构建优化器,优先复用任意 head 自带的参数分组。""" trainable_params = [param for param in agent.parameters() if param.requires_grad] noise_pred_net = getattr(agent, 'noise_pred_net', None) get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None) - use_head_groups = ( - getattr(agent, 'head_type', None) == 'transformer' - and callable(get_optim_groups) - ) + use_head_groups = callable(get_optim_groups) if not use_head_groups: return AdamW(trainable_params, lr=lr, weight_decay=weight_decay) @@ -138,7 +151,7 @@ def build_training_optimizer(agent, lr, weight_decay): for param in params: param_id = id(param) if param_id in grouped_param_ids: - raise ValueError('Transformer optimizer groups contain duplicate parameters') + raise ValueError('Head optimizer groups contain duplicate parameters') grouped_param_ids.add(param_id) head_trainable_param_ids = { @@ -146,7 +159,7 @@ def build_training_optimizer(agent, lr, weight_decay): } missing_head_param_ids = head_trainable_param_ids - grouped_param_ids if missing_head_param_ids: - raise ValueError('Transformer optimizer groups missed trainable head parameters') + raise ValueError('Head optimizer groups missed trainable head parameters') remaining_params = [ param for param in trainable_params @@ -258,6 +271,7 @@ def _run_training(cfg: DictConfig): print("=" * 80) log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})") + _configure_cuda_runtime(cfg) swanlab_module = _init_swanlab(cfg) try: # 创建检查点目录 diff --git a/roboimi/vla/agent_imf.py b/roboimi/vla/agent_imf.py new file mode 100644 index 0000000..6dfc307 --- /dev/null +++ b/roboimi/vla/agent_imf.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from contextlib import nullcontext +from typing import Dict, Optional + +import torch +import torch.nn.functional as F + +from roboimi.vla.agent import VLAAgent + +try: + from torch.func import jvp as TORCH_FUNC_JVP +except ImportError: # pragma: no cover + TORCH_FUNC_JVP = None + + +class IMFVLAAgent(VLAAgent): + def __init__(self, *args, inference_steps: int = 1, **kwargs): + if inference_steps != 1: + raise ValueError( + 'IMFVLAAgent only supports one-step inference; ' + f'inference_steps must be 1, got {inference_steps}.' + ) + super().__init__(*args, inference_steps=inference_steps, **kwargs) + self.inference_steps = 1 + + @staticmethod + def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor: + while value.ndim < reference.ndim: + value = value.unsqueeze(-1) + return value + + @staticmethod + def _apply_conditioning( + trajectory: torch.Tensor, + condition_data: Optional[torch.Tensor] = None, + condition_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if condition_data is None or condition_mask is None: + return trajectory + conditioned = trajectory.clone() + conditioned[condition_mask] = condition_data[condition_mask] + return conditioned + + @staticmethod + def _jvp_math_sdp_context(z_t: torch.Tensor): + if z_t.is_cuda: + return torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=True, + enable_mem_efficient=False, + enable_cudnn=False, + ) + return nullcontext() + + @staticmethod + def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor): + return v.detach(), torch.zeros_like(r), torch.ones_like(t) + + def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor: + return self.noise_pred_net(z, r, t, cond=cond) + + def _compute_u_and_du_dt( + self, + z_t: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + cond, + v: torch.Tensor, + condition_data: Optional[torch.Tensor] = None, + condition_mask: Optional[torch.Tensor] = None, + ): + tangents = self._jvp_tangents(v, r, t) + + def g(z, r_value, t_value): + conditioned_z = self._apply_conditioning(z, condition_data, condition_mask) + return self.fn(conditioned_z, r_value, t_value, cond=cond) + + with self._jvp_math_sdp_context(z_t): + if TORCH_FUNC_JVP is not None: + try: + return TORCH_FUNC_JVP(g, (z_t, r, t), tangents) + except (RuntimeError, TypeError, NotImplementedError): + pass + + u = g(z_t, r, t) + _, du_dt = torch.autograd.functional.jvp( + g, + (z_t, r, t), + tangents, + create_graph=False, + strict=False, + ) + return u, du_dt + + def _compound_velocity( + self, + u: torch.Tensor, + du_dt: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + delta = self._broadcast_batch_time(t - r, u) + return u + delta * du_dt.detach() + + def _sample_one_step( + self, + z_t: torch.Tensor, + r: Optional[torch.Tensor] = None, + t: Optional[torch.Tensor] = None, + cond=None, + ) -> torch.Tensor: + batch_size = z_t.shape[0] + if t is None: + t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype) + if r is None: + r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype) + u = self.fn(z_t, r, t, cond=cond) + delta = self._broadcast_batch_time(t - r, z_t) + return z_t - delta * u + + def compute_loss(self, batch): + actions, states, images = batch['action'], batch['qpos'], batch['images'] + action_is_pad = batch.get('action_is_pad', None) + batch_size = actions.shape[0] + + states = self.normalization.normalize_qpos(states) + actions = self.normalization.normalize_action(actions) + cond = self._build_cond(images, states) + + x = actions + e = torch.randn_like(x) + t = torch.rand(batch_size, device=x.device, dtype=x.dtype) + r = torch.rand(batch_size, device=x.device, dtype=x.dtype) + t, r = torch.maximum(t, r), torch.minimum(t, r) + + t_broadcast = self._broadcast_batch_time(t, x) + z_t = (1 - t_broadcast) * x + t_broadcast * e + + v = self.fn(z_t, t, t, cond=cond) + u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v) + V = self._compound_velocity(u, du_dt, r, t) + target = e - x + + loss = F.mse_loss(V, target, reduction='none') + if action_is_pad is not None: + mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype) + valid_count = mask.sum() * loss.shape[-1] + loss = (loss * mask).sum() / valid_count.clamp_min(1.0) + else: + loss = loss.mean() + return loss + + @torch.no_grad() + def predict_action(self, images, proprioception): + batch_size = proprioception.shape[0] + proprioception = self.normalization.normalize_qpos(proprioception) + cond = self._build_cond(images, proprioception) + z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype) + action = self._sample_one_step(z_t, cond=cond) + return self.normalization.denormalize_action(action) diff --git a/roboimi/vla/conf/agent/resnet_imf_attnres.yaml b/roboimi/vla/conf/agent/resnet_imf_attnres.yaml new file mode 100644 index 0000000..e04bfb4 --- /dev/null +++ b/roboimi/vla/conf/agent/resnet_imf_attnres.yaml @@ -0,0 +1,40 @@ +# @package agent +defaults: + - /backbone@vision_backbone: resnet_diffusion + - /modules@state_encoder: identity_state_encoder + - /modules@action_encoder: identity_action_encoder + - /head: imf_transformer1d + - _self_ + +_target_: roboimi.vla.agent_imf.IMFVLAAgent + +action_dim: 16 +obs_dim: 16 +normalization_type: "min_max" +pred_horizon: 16 +obs_horizon: 2 +num_action_steps: 8 +camera_names: ${data.camera_names} +num_cams: 3 + +vision_backbone: + num_cameras: ${agent.num_cams} + camera_names: ${agent.camera_names} + +diffusion_steps: 100 +inference_steps: 1 +head_type: "transformer" + +head: + input_dim: ${agent.action_dim} + output_dim: ${agent.action_dim} + horizon: ${agent.pred_horizon} + n_obs_steps: ${agent.obs_horizon} + cond_dim: 208 + causal_attn: false + time_as_cond: true + obs_as_cond: true + n_cond_layers: 0 + backbone_type: attnres_full + n_head: 1 + n_kv_head: 1 diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 6eef43f..7f991e0 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -13,6 +13,7 @@ train: lr: 1e-4 # 学习率 max_steps: 100000 # 最大训练步数 device: "cuda" # 设备: "cuda" 或 "cpu" + disable_cudnn: false # 遇到当前机器的 cuDNN 兼容性问题时可置 true # 数据加载 num_workers: 12 # DataLoader 工作进程数(调试时设为 0) diff --git a/roboimi/vla/conf/head/imf_transformer1d.yaml b/roboimi/vla/conf/head/imf_transformer1d.yaml new file mode 100644 index 0000000..92f7054 --- /dev/null +++ b/roboimi/vla/conf/head/imf_transformer1d.yaml @@ -0,0 +1,22 @@ +_target_: roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D +_partial_: true + +input_dim: ${agent.action_dim} +output_dim: ${agent.action_dim} +horizon: ${agent.pred_horizon} +n_obs_steps: ${agent.obs_horizon} +cond_dim: 208 +n_layer: 12 +n_head: 1 +n_emb: 768 +p_drop_emb: 0.1 +p_drop_attn: 0.1 +causal_attn: false +time_as_cond: true +obs_as_cond: true +n_cond_layers: 0 +backbone_type: attnres_full +n_kv_head: 1 +attn_res_ffn_mult: 2.667 +attn_res_eps: 1.0e-6 +attn_res_rope_theta: 10000.0 diff --git a/roboimi/vla/models/heads/attnres_transformer_components.py b/roboimi/vla/models/heads/attnres_transformer_components.py new file mode 100644 index 0000000..013eb60 --- /dev/null +++ b/roboimi/vla/models/heads/attnres_transformer_components.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + return (x.float() * rms).to(x.dtype) * self.weight + + +class RMSNormNoWeight(nn.Module): + def __init__(self, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + return (x.float() * rms).to(x.dtype) + + +def precompute_rope_freqs( + dim: int, + max_seq_len: int, + theta: float = 10000.0, + device: Optional[torch.device] = None, +) -> Tensor: + if dim % 2 != 0: + raise ValueError(f'RoPE requires an even head dimension, got {dim}.') + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) + positions = torch.arange(max_seq_len, device=device).float() + angles = torch.outer(positions, freqs) + return torch.polar(torch.ones_like(angles), angles) + + +def apply_rope(x: Tensor, freqs: Tensor) -> Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs = freqs.unsqueeze(0).unsqueeze(2) + x_rotated = x_complex * freqs + return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype) + + +class GroupedQuerySelfAttention(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + n_kv_heads: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + if d_model % n_heads != 0: + raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.') + if n_heads % n_kv_heads != 0: + raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.') + + self.d_model = d_model + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.n_kv_groups = n_heads // n_kv_heads + self.d_head = d_model // n_heads + self.attn_dropout = nn.Dropout(dropout) + self.out_dropout = nn.Dropout(dropout) + + self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False) + self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) + self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) + self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False) + + def forward( + self, + x: Tensor, + rope_freqs: Tensor, + mask: Optional[Tensor] = None, + ) -> Tensor: + batch_size, seq_len, _ = x.shape + + q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head) + k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head) + v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head) + + q = apply_rope(q, rope_freqs) + k = apply_rope(k, rope_freqs) + + if self.n_kv_heads != self.n_heads: + k = k.unsqueeze(3).expand( + batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head + ) + k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head) + v = v.unsqueeze(3).expand( + batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head + ) + v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + scale = 1.0 / math.sqrt(self.d_head) + attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale + if mask is not None: + attn_weights = attn_weights + mask + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = self.attn_dropout(attn_weights) + + out = torch.matmul(attn_weights, v) + out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) + return self.out_dropout(self.w_o(out)) + + +class SwiGLUFFN(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None: + super().__init__() + raw = int(mult * d_model) + d_ff = ((raw + 7) // 8) * 8 + self.w_gate = nn.Linear(d_model, d_ff, bias=False) + self.w_up = nn.Linear(d_model, d_ff, bias=False) + self.w_down = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> Tensor: + return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) + + +class AttnResOperator(nn.Module): + def __init__(self, d_model: int, eps: float = 1e-6) -> None: + super().__init__() + self.pseudo_query = nn.Parameter(torch.zeros(d_model)) + self.key_norm = RMSNormNoWeight(eps=eps) + + def forward(self, sources: Tensor) -> Tensor: + keys = self.key_norm(sources) + logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys) + weights = F.softmax(logits, dim=0) + return torch.einsum('nbt,nbtd->btd', weights, sources) + + +class AttnResSubLayer(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + n_kv_heads: int, + dropout: float, + ffn_mult: float, + eps: float, + is_attention: bool, + ) -> None: + super().__init__() + self.norm = RMSNorm(d_model, eps=eps) + self.attn_res = AttnResOperator(d_model, eps=eps) + self.is_attention = is_attention + if self.is_attention: + self.fn = GroupedQuerySelfAttention( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ) + else: + self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult) + + def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor: + h = self.attn_res(sources) + normed = self.norm(h) + if self.is_attention: + return self.fn(normed, rope_freqs, mask) + return self.fn(normed) + + +class AttnResTransformerBackbone(nn.Module): + def __init__( + self, + d_model: int, + n_blocks: int, + n_heads: int, + n_kv_heads: int, + max_seq_len: int, + dropout: float = 0.0, + ffn_mult: float = 2.667, + eps: float = 1e-6, + rope_theta: float = 10000.0, + causal_attn: bool = False, + ) -> None: + super().__init__() + self.causal_attn = causal_attn + self.layers = nn.ModuleList() + for _ in range(n_blocks): + self.layers.append( + AttnResSubLayer( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + is_attention=True, + ) + ) + self.layers.append( + AttnResSubLayer( + d_model=d_model, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + dropout=dropout, + ffn_mult=ffn_mult, + eps=eps, + is_attention=False, + ) + ) + + rope_freqs = precompute_rope_freqs( + dim=d_model // n_heads, + max_seq_len=max_seq_len, + theta=rope_theta, + ) + self.register_buffer('rope_freqs', rope_freqs, persistent=False) + + @staticmethod + def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor: + mask = torch.full((seq_len, seq_len), float('-inf'), device=device) + mask = torch.triu(mask, diagonal=1) + return mask.unsqueeze(0).unsqueeze(0) + + def forward(self, x: Tensor) -> Tensor: + seq_len = x.shape[1] + rope_freqs = self.rope_freqs[:seq_len] + mask = None + if self.causal_attn: + mask = self._build_causal_mask(seq_len, x.device) + + layer_outputs = [x] + for layer in self.layers: + sources = torch.stack(layer_outputs, dim=0) + output = layer(sources, rope_freqs, mask) + layer_outputs.append(output) + + return torch.stack(layer_outputs, dim=0).sum(dim=0) diff --git a/roboimi/vla/models/heads/imf_transformer1d.py b/roboimi/vla/models/heads/imf_transformer1d.py new file mode 100644 index 0000000..ae605c1 --- /dev/null +++ b/roboimi/vla/models/heads/imf_transformer1d.py @@ -0,0 +1,379 @@ +"""Local IMF-AttnRes transformer head aligned with diffusion_policy@185ed659.""" + +from __future__ import annotations + +import logging +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .attnres_transformer_components import ( + AttnResOperator, + AttnResSubLayer, + AttnResTransformerBackbone, + GroupedQuerySelfAttention, + RMSNorm, + RMSNormNoWeight, + SwiGLUFFN, +) +from .transformer1d import ModuleAttrMixin, SinusoidalPosEmb + +logger = logging.getLogger(__name__) + + +class IMFTransformer1D(ModuleAttrMixin): + def __init__( + self, + input_dim: int, + output_dim: int, + horizon: int, + n_obs_steps: Optional[int] = None, + cond_dim: int = 0, + n_layer: int = 12, + n_head: int = 1, + n_emb: int = 768, + p_drop_emb: float = 0.1, + p_drop_attn: float = 0.1, + causal_attn: bool = False, + time_as_cond: bool = True, + obs_as_cond: bool = False, + n_cond_layers: int = 0, + backbone_type: str = 'attnres_full', + n_kv_head: int = 1, + attn_res_ffn_mult: float = 2.667, + attn_res_eps: float = 1e-6, + attn_res_rope_theta: float = 10000.0, + ) -> None: + super().__init__() + + if n_head != 1: + raise AssertionError('IMFTransformer1D currently supports single-head attention only.') + if n_obs_steps is None: + n_obs_steps = horizon + + self.backbone_type = backbone_type + + T = horizon + T_cond = 2 + if not time_as_cond: + T += 2 + T_cond -= 2 + obs_as_cond = cond_dim > 0 + if obs_as_cond: + assert time_as_cond + T_cond += n_obs_steps + + self.input_emb = nn.Linear(input_dim, n_emb) + self.drop = nn.Dropout(p_drop_emb) + self.time_emb = SinusoidalPosEmb(n_emb) + self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None + self.time_token_proj = None + self.cond_pos_emb = None + self.pos_emb = None + self.encoder = None + self.decoder = None + self.attnres_backbone = None + encoder_only = False + + if backbone_type == 'attnres_full': + if not time_as_cond: + raise ValueError('attnres_full backbone requires time_as_cond=True.') + if n_cond_layers != 0: + raise ValueError('attnres_full backbone does not support n_cond_layers > 0.') + + self.time_token_proj = nn.Linear(n_emb, n_emb) + self.attnres_backbone = AttnResTransformerBackbone( + d_model=n_emb, + n_blocks=n_layer, + n_heads=n_head, + n_kv_heads=n_kv_head, + max_seq_len=T + T_cond, + dropout=p_drop_attn, + ffn_mult=attn_res_ffn_mult, + eps=attn_res_eps, + rope_theta=attn_res_rope_theta, + causal_attn=causal_attn, + ) + self.ln_f = RMSNorm(n_emb, eps=attn_res_eps) + else: + self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) + if T_cond > 0: + self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) + if n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_cond_layers, + ) + else: + self.encoder = nn.Sequential( + nn.Linear(n_emb, 4 * n_emb), + nn.Mish(), + nn.Linear(4 * n_emb, n_emb), + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=n_layer, + ) + else: + encoder_only = True + encoder_layer = nn.TransformerEncoderLayer( + d_model=n_emb, + nhead=n_head, + dim_feedforward=4 * n_emb, + dropout=p_drop_attn, + activation='gelu', + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layer, + ) + + self.ln_f = nn.LayerNorm(n_emb) + + if causal_attn and backbone_type != 'attnres_full': + sz = T + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('mask', mask) + + if time_as_cond and obs_as_cond: + S = T_cond + t_idx, s_idx = torch.meshgrid( + torch.arange(T), + torch.arange(S), + indexing='ij', + ) + mask = t_idx >= (s_idx - 2) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + self.register_buffer('memory_mask', mask) + else: + self.memory_mask = None + else: + self.mask = None + self.memory_mask = None + + self.head = nn.Linear(n_emb, output_dim) + + self.T = T + self.T_cond = T_cond + self.horizon = horizon + self.time_as_cond = time_as_cond + self.obs_as_cond = obs_as_cond + self.encoder_only = encoder_only + + self.apply(self._init_weights) + logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters())) + + def _init_weights(self, module): + ignore_types = ( + nn.Dropout, + SinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + AttnResTransformerBackbone, + AttnResSubLayer, + GroupedQuerySelfAttention, + SwiGLUFFN, + RMSNormNoWeight, + ) + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'): + weight = getattr(module, name) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + + for name in ('in_proj_bias', 'bias_k', 'bias_v'): + bias = getattr(module, name) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, (nn.LayerNorm, RMSNorm)): + if getattr(module, 'bias', None) is not None: + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, AttnResOperator): + torch.nn.init.zeros_(module.pseudo_query) + elif isinstance(module, IMFTransformer1D): + if module.pos_emb is not None: + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_pos_emb is not None: + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + pass + else: + raise RuntimeError(f'Unaccounted module {module}') + + def get_optim_groups(self, weight_decay: float = 1e-3): + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm) + for mn, m in self.named_modules(): + for pn, _ in m.named_parameters(recurse=False): + fpn = f'{mn}.{pn}' if mn else pn + + if pn.endswith('bias'): + no_decay.add(fpn) + elif pn.startswith('bias'): + no_decay.add(fpn) + elif pn == 'pseudo_query': + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + if self.pos_emb is not None: + no_decay.add('pos_emb') + no_decay.add('_dummy_variable') + if self.cond_pos_emb is not None: + no_decay.add('cond_pos_emb') + + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!' + assert len(param_dict.keys() - union_params) == 0, ( + f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!' + ) + + return [ + { + 'params': [param_dict[pn] for pn in sorted(list(decay))], + 'weight_decay': weight_decay, + }, + { + 'params': [param_dict[pn] for pn in sorted(list(no_decay))], + 'weight_decay': 0.0, + }, + ] + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + + def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor: + if not torch.is_tensor(value): + value = torch.tensor([value], dtype=sample.dtype, device=sample.device) + elif value.ndim == 0: + value = value[None].to(device=sample.device, dtype=sample.dtype) + else: + value = value.to(device=sample.device, dtype=sample.dtype) + return value.expand(sample.shape[0]) + + def _forward_attnres_full( + self, + sample: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + cond: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + sample_tokens = self.input_emb(sample) + token_parts = [ + self.time_token_proj(self.time_emb(r)).unsqueeze(1), + self.time_token_proj(self.time_emb(t)).unsqueeze(1), + ] + if self.obs_as_cond: + if cond is None: + raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.') + token_parts.append(self.cond_obs_emb(cond)) + token_parts.append(sample_tokens) + x = torch.cat(token_parts, dim=1) + x = self.drop(x) + x = self.attnres_backbone(x) + x = x[:, -sample_tokens.shape[1]:, :] + return x + + def _forward_vanilla( + self, + sample: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, + cond: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r_emb = self.time_emb(r).unsqueeze(1) + t_emb = self.time_emb(t).unsqueeze(1) + input_emb = self.input_emb(sample) + + if self.encoder_only: + token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1) + token_count = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :token_count, :] + x = self.drop(token_embeddings + position_embeddings) + x = self.encoder(src=x, mask=self.mask) + x = x[:, 2:, :] + else: + cond_embeddings = torch.cat([r_emb, t_emb], dim=1) + if self.obs_as_cond: + cond_embeddings = torch.cat([cond_embeddings, self.cond_obs_emb(cond)], dim=1) + token_count = cond_embeddings.shape[1] + position_embeddings = self.cond_pos_emb[:, :token_count, :] + x = self.drop(cond_embeddings + position_embeddings) + x = self.encoder(x) + memory = x + + token_embeddings = input_emb + token_count = token_embeddings.shape[1] + position_embeddings = self.pos_emb[:, :token_count, :] + x = self.drop(token_embeddings + position_embeddings) + x = self.decoder( + tgt=x, + memory=memory, + tgt_mask=self.mask, + memory_mask=self.memory_mask, + ) + return x + + def forward( + self, + sample: torch.Tensor, + r: Union[torch.Tensor, float, int], + t: Union[torch.Tensor, float, int], + cond: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r = self._prepare_time_input(r, sample) + t = self._prepare_time_input(t, sample) + + if self.backbone_type == 'attnres_full': + x = self._forward_attnres_full(sample, r, t, cond=cond) + else: + x = self._forward_vanilla(sample, r, t, cond=cond) + + x = self.ln_f(x) + x = self.head(x) + return x diff --git a/tests/test_imf_transformer1d_external_alignment.py b/tests/test_imf_transformer1d_external_alignment.py new file mode 100644 index 0000000..c3dff34 --- /dev/null +++ b/tests/test_imf_transformer1d_external_alignment.py @@ -0,0 +1,196 @@ +import contextlib +import importlib +import inspect +import subprocess +import sys +import types +import unittest +from pathlib import Path + +import torch + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +_EXTERNAL_COMMIT = '185ed659' +_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d' +_MISSING = object() + + +def _find_external_checkout_root() -> Path | None: + for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents): + candidate = ancestor / 'diffusion_policy' + if (candidate / '.git').exists(): + return candidate + return None + + +_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root() +_EXTERNAL_MODULE_PATHS = { + 'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py', + 'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py', + 'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py', + 'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py', +} + + +@contextlib.contextmanager +def _temporary_registered_modules(): + previous_modules = {} + + def remember(name: str) -> None: + if name not in previous_modules: + previous_modules[name] = sys.modules.get(name, _MISSING) + + def ensure_package(name: str) -> None: + if not name or name in sys.modules: + return + remember(name) + package = types.ModuleType(name) + package.__path__ = [] + sys.modules[name] = package + + def load(name: str, source: str, origin: str): + package_parts = name.split('.')[:-1] + for idx in range(1, len(package_parts) + 1): + ensure_package('.'.join(package_parts[:idx])) + + remember(name) + module = types.ModuleType(name) + module.__file__ = origin + module.__package__ = name.rpartition('.')[0] + sys.modules[name] = module + exec(compile(source, origin, 'exec'), module.__dict__) + return module + + try: + yield load + finally: + for name, previous in reversed(list(previous_modules.items())): + if previous is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous + + +def _git_show(repo_root: Path, commit: str, relative_path: str) -> str: + result = subprocess.run( + ['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'], + check=True, + capture_output=True, + text=True, + ) + return result.stdout + + +@contextlib.contextmanager +def _load_external_module_or_skip(test_case: unittest.TestCase): + if _EXTERNAL_CHECKOUT_ROOT is None: + test_case.skipTest('external diffusion_policy checkout unavailable') + + try: + sources = { + name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path) + for name, relative_path in _EXTERNAL_MODULE_PATHS.items() + } + except subprocess.CalledProcessError as exc: + test_case.skipTest( + f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}' + ) + + with _temporary_registered_modules() as load_external: + for name, relative_path in _EXTERNAL_MODULE_PATHS.items(): + load_external( + name, + sources[name], + origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}', + ) + yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion'] + + +def _load_local_module(): + importlib.invalidate_caches() + sys.modules.pop(_LOCAL_MODULE_NAME, None) + return importlib.import_module(_LOCAL_MODULE_NAME) + + +class IMFTransformer1DExternalAlignmentTest(unittest.TestCase): + def _optim_group_names(self, model, groups): + names_by_param = {id(param): name for name, param in model.named_parameters()} + return [ + {names_by_param[id(param)] for param in group['params']} + for group in groups + ] + + def test_local_defaults_preserve_supported_attnres_config(self): + local_module = _load_local_module() + ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters + + self.assertEqual(ctor['backbone_type'].default, 'attnres_full') + self.assertEqual(ctor['n_head'].default, 1) + self.assertEqual(ctor['n_kv_head'].default, 1) + self.assertEqual(ctor['n_cond_layers'].default, 0) + self.assertTrue(ctor['time_as_cond'].default) + self.assertFalse(ctor['causal_attn'].default) + + def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self): + local_module = _load_local_module() + with _load_external_module_or_skip(self) as external_module: + config = dict( + input_dim=4, + output_dim=4, + horizon=6, + n_obs_steps=3, + cond_dim=5, + n_layer=2, + n_head=1, + n_emb=16, + p_drop_emb=0.0, + p_drop_attn=0.0, + causal_attn=False, + time_as_cond=True, + n_cond_layers=0, + backbone_type='attnres_full', + n_kv_head=1, + ) + + torch.manual_seed(7) + external_model = external_module.IMFTransformerForDiffusion(**config) + local_model = local_module.IMFTransformer1D(**config) + external_model.eval() + local_model.eval() + + external_state_dict = external_model.state_dict() + self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys())) + local_model.load_state_dict(external_state_dict, strict=True) + + batch_size = 2 + sample = torch.randn(batch_size, config['horizon'], config['input_dim']) + r = torch.tensor([0.1, 0.4], dtype=torch.float32) + t = torch.tensor([0.7, 0.9], dtype=torch.float32) + cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim']) + + with torch.no_grad(): + external_out = external_model(sample=sample, r=r, t=t, cond=cond) + local_out = local_model(sample=sample, r=r, t=t, cond=cond) + + self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim'])) + self.assertEqual(local_out.shape, external_out.shape) + self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5)) + + weight_decay = 0.123 + external_groups = external_model.get_optim_groups(weight_decay=weight_decay) + local_groups = local_model.get_optim_groups(weight_decay=weight_decay) + + self.assertEqual(len(local_groups), len(external_groups)) + self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0]) + self.assertEqual( + self._optim_group_names(local_model, local_groups), + self._optim_group_names(external_model, external_groups), + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_imf_vla_agent.py b/tests/test_imf_vla_agent.py new file mode 100644 index 0000000..0050c9c --- /dev/null +++ b/tests/test_imf_vla_agent.py @@ -0,0 +1,427 @@ +import contextlib +import importlib +import sys +import types +import unittest +from pathlib import Path +from unittest import mock + +import torch +from hydra import compose, initialize_config_dir +from hydra.core.global_hydra import GlobalHydra +from hydra.utils import instantiate +from omegaconf import OmegaConf +from torch import nn + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve()) +_MISSING = object() +_CAMERA_NAMES = ('r_vis', 'top', 'front') + + +class _FakeScheduler: + def __init__(self, num_train_timesteps=100, **kwargs): + self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps) + self.timesteps = [] + + def add_noise(self, sample, noise, timestep): + return sample + noise + + def set_timesteps(self, num_inference_steps): + self.timesteps = list(range(num_inference_steps - 1, -1, -1)) + + def step(self, noise_pred, timestep, sample): + return types.SimpleNamespace(prev_sample=sample) + + +class _IdentityCrop: + def __init__(self, size): + self.size = size + + def __call__(self, x): + return x + + +class _FakeResNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1) + self.relu1 = nn.ReLU() + self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2) + self.relu2 = nn.ReLU() + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(16, 16) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.relu2(self.conv2(x)) + x = self.avgpool(x) + x = torch.flatten(x, start_dim=1) + return self.fc(x) + + +class _FakeRearrange(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x): + return x + + +class _StubIMFHead(nn.Module): + def __init__( + self, + input_dim, + output_dim, + horizon, + n_obs_steps, + cond_dim, + **kwargs, + ): + super().__init__() + self.constructor_kwargs = { + 'input_dim': input_dim, + 'output_dim': output_dim, + 'horizon': horizon, + 'n_obs_steps': n_obs_steps, + 'cond_dim': cond_dim, + **kwargs, + } + self.proj = nn.Linear(input_dim, output_dim) + self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1)) + + def forward(self, sample, r, t, cond=None): + return torch.zeros_like(sample) + + def get_optim_groups(self, weight_decay): + return [ + {'params': [self.proj.weight], 'weight_decay': weight_decay}, + {'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0}, + ] + + +@contextlib.contextmanager +def _stub_optional_modules(include_imf_head=False): + previous_modules = {} + + def inject(name, module): + if name not in previous_modules: + previous_modules[name] = sys.modules.get(name, _MISSING) + sys.modules[name] = module + + diffusers_module = types.ModuleType('diffusers') + schedulers_module = types.ModuleType('diffusers.schedulers') + ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm') + ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim') + ddpm_module.DDPMScheduler = _FakeScheduler + ddim_module.DDIMScheduler = _FakeScheduler + diffusers_module.DDPMScheduler = _FakeScheduler + diffusers_module.DDIMScheduler = _FakeScheduler + diffusers_module.schedulers = schedulers_module + schedulers_module.scheduling_ddpm = ddpm_module + schedulers_module.scheduling_ddim = ddim_module + + torchvision_module = types.ModuleType('torchvision') + models_module = types.ModuleType('torchvision.models') + transforms_module = types.ModuleType('torchvision.transforms') + models_module.resnet18 = lambda weights=None: _FakeResNet() + transforms_module.CenterCrop = _IdentityCrop + transforms_module.RandomCrop = _IdentityCrop + torchvision_module.models = models_module + torchvision_module.transforms = transforms_module + + einops_module = types.ModuleType('einops') + einops_module.rearrange = lambda x, *args, **kwargs: x + einops_layers_module = types.ModuleType('einops.layers') + einops_layers_torch_module = types.ModuleType('einops.layers.torch') + einops_layers_torch_module.Rearrange = _FakeRearrange + einops_module.layers = einops_layers_module + einops_layers_module.torch = einops_layers_torch_module + + try: + inject('diffusers', diffusers_module) + inject('diffusers.schedulers', schedulers_module) + inject('diffusers.schedulers.scheduling_ddpm', ddpm_module) + inject('diffusers.schedulers.scheduling_ddim', ddim_module) + inject('torchvision', torchvision_module) + inject('torchvision.models', models_module) + inject('torchvision.transforms', transforms_module) + inject('einops', einops_module) + inject('einops.layers', einops_layers_module) + inject('einops.layers.torch', einops_layers_torch_module) + + if include_imf_head: + import roboimi.vla.models.heads as heads_package + + imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d') + imf_head_module.IMFTransformer1D = _StubIMFHead + inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module) + setattr(heads_package, 'imf_transformer1d', imf_head_module) + + yield + finally: + for name, previous in reversed(list(previous_modules.items())): + if previous is _MISSING: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous + + +def _compose_cfg(overrides=None): + if not OmegaConf.has_resolver('len'): + OmegaConf.register_new_resolver('len', lambda x: len(x)) + + GlobalHydra.instance().clear() + with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR): + return compose(config_name='config', overrides=list(overrides or [])) + + +def _load_imf_agent_class(): + with _stub_optional_modules(): + sys.modules.pop('roboimi.vla.agent_imf', None) + module = importlib.import_module('roboimi.vla.agent_imf') + return module.IMFVLAAgent, module + + +class _StubVisionBackbone(nn.Module): + output_dim = 1 + + def __init__(self, camera_names=_CAMERA_NAMES): + super().__init__() + self.camera_names = tuple(camera_names) + self.num_cameras = len(self.camera_names) + + def forward(self, images): + per_camera_features = [] + for camera_name in self.camera_names: + image_batch = images[camera_name] + per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)) + return torch.cat(per_camera_features, dim=-1) + + +class _RecordingLinearIMFHead(nn.Module): + def __init__(self): + super().__init__() + self.scale = nn.Parameter(torch.tensor(0.5)) + self.calls = [] + + @staticmethod + def _broadcast_batch_time(value, reference): + while value.ndim < reference.ndim: + value = value.unsqueeze(-1) + return value + + def forward(self, sample, r, t, cond=None): + record = { + 'sample': sample.detach().clone(), + 'r': r.detach().clone(), + 't': t.detach().clone(), + 'cond': None if cond is None else cond.detach().clone(), + } + self.calls.append(record) + cond_term = 0.0 + if cond is not None: + cond_term = cond.mean(dim=(1, 2), keepdim=True) + r_b = self._broadcast_batch_time(r, sample) + t_b = self._broadcast_batch_time(t, sample) + return self.scale * sample + r_b + 2.0 * t_b + cond_term + + +class _ForbiddenScheduler: + def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression + raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps') + + def step(self, *args, **kwargs): # pragma: no cover - only runs on regression + raise AssertionError('IMF inference should not use DDIM scheduler step') + + +def _make_images(batch_size, obs_horizon, per_camera_fill): + return { + name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32) + for name, value in per_camera_fill.items() + } + + +class IMFVLAAgentTest(unittest.TestCase): + def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2): + agent_cls, agent_module = _load_imf_agent_class() + head = _RecordingLinearIMFHead() + agent = agent_cls( + vision_backbone=_StubVisionBackbone(), + state_encoder=nn.Identity(), + action_encoder=nn.Identity(), + head=head, + action_dim=2, + obs_dim=1, + pred_horizon=pred_horizon, + obs_horizon=obs_horizon, + diffusion_steps=10, + inference_steps=1, + num_cams=len(_CAMERA_NAMES), + camera_names=_CAMERA_NAMES, + num_action_steps=num_action_steps, + head_type='transformer', + ) + return agent, head, agent_module + + def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self): + agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2) + images = _make_images( + batch_size=1, + obs_horizon=2, + per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, + ) + qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32) + actions = torch.tensor( + [[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]], + dtype=torch.float32, + ) + action_is_pad = torch.tensor([[False, False, True]]) + noise = torch.tensor( + [[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]], + dtype=torch.float32, + ) + t_sample = torch.tensor([0.8], dtype=torch.float32) + r_sample = torch.tensor([0.25], dtype=torch.float32) + + with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \ + mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]): + loss = agent.compute_loss( + { + 'images': images, + 'qpos': qpos, + 'action': actions, + 'action_is_pad': action_is_pad, + } + ) + + cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32) + cond_term = cond.mean(dim=(1, 2), keepdim=True) + t = t_sample + r = r_sample + z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise + scale = head.scale.detach() + u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term + v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term + du_dt = scale * v + 2.0 + compound_velocity = u + (t - r).view(1, 1, 1) * du_dt + target = noise - actions + elementwise_loss = (compound_velocity - target) ** 2 + mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype) + expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1]) + + self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) + self.assertEqual(len(head.calls), 2) + self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample)) + self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample)) + self.assertTrue(torch.allclose(head.calls[0]['cond'], cond)) + + def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self): + agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2) + agent.infer_scheduler = _ForbiddenScheduler() + + images = _make_images( + batch_size=2, + obs_horizon=2, + per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, + ) + qpos = torch.tensor( + [ + [[1.0], [2.0]], + [[3.0], [4.0]], + ], + dtype=torch.float32, + ) + initial_noise = torch.tensor( + [ + [[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]], + [[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]], + ], + dtype=torch.float32, + ) + + with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise): + predicted_actions = agent.predict_action(images, qpos) + + expected_cond = torch.tensor( + [ + [[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]], + [[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]], + ], + dtype=torch.float32, + ) + cond_term = expected_cond.mean(dim=(1, 2), keepdim=True) + expected_actions = 0.5 * initial_noise - 2.0 - cond_term + + self.assertEqual(predicted_actions.shape, (2, 3, 2)) + self.assertTrue(torch.allclose(predicted_actions, expected_actions)) + self.assertEqual(len(head.calls), 1) + self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2))) + self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2))) + self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) + + def test_select_action_only_regenerates_when_action_queue_is_empty(self): + agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2) + observation = { + 'qpos': torch.tensor([0.25], dtype=torch.float32), + 'images': { + 'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32), + 'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32), + 'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32), + }, + } + first_chunk = torch.tensor( + [[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]], + dtype=torch.float32, + ) + second_chunk = torch.tensor( + [[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]], + dtype=torch.float32, + ) + + with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk: + first_action = agent.select_action(observation) + second_action = agent.select_action(observation) + third_action = agent.select_action(observation) + + self.assertTrue(torch.equal(first_action, first_chunk[0, 1])) + self.assertTrue(torch.equal(second_action, first_chunk[0, 2])) + self.assertTrue(torch.equal(third_action, second_chunk[0, 1])) + self.assertEqual(mock_predict_chunk.call_count, 2) + + def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self): + cfg = _compose_cfg( + overrides=[ + 'agent=resnet_imf_attnres', + 'agent.vision_backbone.pretrained_backbone_weights=null', + 'agent.vision_backbone.input_shape=[3,16,16]', + 'agent.vision_backbone.freeze_backbone=false', + 'agent.head.n_layer=1', + 'agent.head.n_emb=16', + ] + ) + + self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') + self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D') + self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full') + self.assertEqual(cfg.agent.head.n_head, 1) + self.assertEqual(cfg.agent.head.n_kv_head, 1) + self.assertEqual(cfg.agent.head.n_cond_layers, 0) + self.assertTrue(cfg.agent.head.time_as_cond) + self.assertFalse(cfg.agent.head.causal_attn) + self.assertEqual(cfg.agent.inference_steps, 1) + self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES)) + + with _stub_optional_modules(include_imf_head=True): + agent = instantiate(cfg.agent) + + self.assertEqual(agent.head_type, 'transformer') + self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim) + self.assertIsInstance(agent.noise_pred_net, _StubIMFHead) + self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) + self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_train_vla_transformer_optimizer.py b/tests/test_train_vla_transformer_optimizer.py index 204014d..bee12bd 100644 --- a/tests/test_train_vla_transformer_optimizer.py +++ b/tests/test_train_vla_transformer_optimizer.py @@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module): ] -class FakeTransformerAgent(nn.Module): +class FakeIMFAgent(nn.Module): def __init__(self): super().__init__() - self.head_type = 'transformer' + self.head_type = 'imf_transformer' + self.noise_pred_net = RecordingTransformerHead() + self.backbone = nn.Linear(4, 3) + self.adapter = nn.Linear(3, 2, bias=False) + + +class FakeTransformerAgent(nn.Module): + def __init__(self, *, head_type='transformer'): + super().__init__() + self.head_type = head_type self.noise_pred_net = RecordingTransformerHead() self.backbone = nn.Linear(4, 3) self.adapter = nn.Linear(3, 2, bias=False) @@ -205,6 +214,47 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase): for group in optimizer.param_groups ] + def test_configure_cuda_runtime_can_disable_cudnn_for_training(self): + module = self._load_train_vla_module() + cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True)) + + original = module.torch.backends.cudnn.enabled + try: + module.torch.backends.cudnn.enabled = True + module._configure_cuda_runtime(cfg) + self.assertFalse(module.torch.backends.cudnn.enabled) + finally: + module.torch.backends.cudnn.enabled = original + + + def test_train_script_uses_file_based_repo_root_on_sys_path(self): + module = self._load_train_vla_module() + + fake_sys_path = ['/tmp/site-packages', '/another/path'] + with mock.patch.object(module.sys, 'path', fake_sys_path): + repo_root = module._ensure_repo_root_on_syspath() + + self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve()) + self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve()) + + + def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self): + module = self._load_train_vla_module() + agent = FakeIMFAgent() + + optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123) + + self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123]) + group_names = self._group_names(agent, optimizer) + self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'}) + self.assertEqual(group_names[1], { + 'noise_pred_net.proj.bias', + 'noise_pred_net.norm.weight', + 'noise_pred_net.norm.bias', + }) + self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'}) + + def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self): module = self._load_train_vla_module() agent = FakeTransformerAgent() @@ -268,6 +318,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase): self.assertNotIn('frozen.weight', optimizer_names) self.assertNotIn('frozen.bias', optimizer_names) + def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self): + module = self._load_train_vla_module() + agent = FakeTransformerAgent(head_type='imf') + + with mock.patch.object(module, 'AdamW', RecordingAdamW): + optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123) + + self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123]) + grouped_names = self._group_names(agent, optimizer) + self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'}) + self.assertEqual( + grouped_names[1], + {'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'}, + ) + self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'}) + def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self): module = self._load_train_vla_module() agent = FakeTransformerAgent()