feat: add IMF AttnRes policy training path
This commit is contained in:
@@ -14,8 +14,17 @@ from torch.optim import AdamW
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# 确保正确的导入路径
|
# 确保正确的导入路径(不能依赖 cwd,因为 Hydra 会在运行时切换 cwd)
|
||||||
sys.path.append(os.getcwd())
|
def _ensure_repo_root_on_syspath():
|
||||||
|
repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
repo_root_str = str(repo_root)
|
||||||
|
if repo_root_str in sys.path:
|
||||||
|
sys.path.remove(repo_root_str)
|
||||||
|
sys.path.insert(0, repo_root_str)
|
||||||
|
return repo_root
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = _ensure_repo_root_on_syspath()
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
@@ -26,6 +35,13 @@ if not OmegaConf.has_resolver("len"):
|
|||||||
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
OmegaConf.register_new_resolver("len", lambda x: len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_cuda_runtime(cfg):
|
||||||
|
"""Apply process-level CUDA runtime switches required by this environment."""
|
||||||
|
if str(cfg.train.device).startswith('cuda') and bool(cfg.train.get('disable_cudnn', False)):
|
||||||
|
torch.backends.cudnn.enabled = False
|
||||||
|
log.warning('⚠️ 已按配置禁用 cuDNN;GPU 卷积将回退到非-cuDNN 实现')
|
||||||
|
|
||||||
|
|
||||||
def recursive_to_device(data, device):
|
def recursive_to_device(data, device):
|
||||||
"""
|
"""
|
||||||
递归地将嵌套字典/列表中的张量移动到指定设备。
|
递归地将嵌套字典/列表中的张量移动到指定设备。
|
||||||
@@ -113,14 +129,11 @@ def get_lr_schedule_with_warmup(optimizer, warmup_steps, max_steps, scheduler_ty
|
|||||||
|
|
||||||
|
|
||||||
def build_training_optimizer(agent, lr, weight_decay):
|
def build_training_optimizer(agent, lr, weight_decay):
|
||||||
"""为训练脚本构建优化器,优先复用 transformer head 自带的参数分组。"""
|
"""为训练脚本构建优化器,优先复用任意 head 自带的参数分组。"""
|
||||||
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
trainable_params = [param for param in agent.parameters() if param.requires_grad]
|
||||||
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
noise_pred_net = getattr(agent, 'noise_pred_net', None)
|
||||||
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
get_optim_groups = getattr(noise_pred_net, 'get_optim_groups', None)
|
||||||
use_head_groups = (
|
use_head_groups = callable(get_optim_groups)
|
||||||
getattr(agent, 'head_type', None) == 'transformer'
|
|
||||||
and callable(get_optim_groups)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not use_head_groups:
|
if not use_head_groups:
|
||||||
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
return AdamW(trainable_params, lr=lr, weight_decay=weight_decay)
|
||||||
@@ -138,7 +151,7 @@ def build_training_optimizer(agent, lr, weight_decay):
|
|||||||
for param in params:
|
for param in params:
|
||||||
param_id = id(param)
|
param_id = id(param)
|
||||||
if param_id in grouped_param_ids:
|
if param_id in grouped_param_ids:
|
||||||
raise ValueError('Transformer optimizer groups contain duplicate parameters')
|
raise ValueError('Head optimizer groups contain duplicate parameters')
|
||||||
grouped_param_ids.add(param_id)
|
grouped_param_ids.add(param_id)
|
||||||
|
|
||||||
head_trainable_param_ids = {
|
head_trainable_param_ids = {
|
||||||
@@ -146,7 +159,7 @@ def build_training_optimizer(agent, lr, weight_decay):
|
|||||||
}
|
}
|
||||||
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
missing_head_param_ids = head_trainable_param_ids - grouped_param_ids
|
||||||
if missing_head_param_ids:
|
if missing_head_param_ids:
|
||||||
raise ValueError('Transformer optimizer groups missed trainable head parameters')
|
raise ValueError('Head optimizer groups missed trainable head parameters')
|
||||||
|
|
||||||
remaining_params = [
|
remaining_params = [
|
||||||
param for param in trainable_params
|
param for param in trainable_params
|
||||||
@@ -258,6 +271,7 @@ def _run_training(cfg: DictConfig):
|
|||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
log.info(f"🚀 开始 VLA 训练 (设备: {cfg.train.device})")
|
||||||
|
_configure_cuda_runtime(cfg)
|
||||||
swanlab_module = _init_swanlab(cfg)
|
swanlab_module = _init_swanlab(cfg)
|
||||||
try:
|
try:
|
||||||
# 创建检查点目录
|
# 创建检查点目录
|
||||||
|
|||||||
161
roboimi/vla/agent_imf.py
Normal file
161
roboimi/vla/agent_imf.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from roboimi.vla.agent import VLAAgent
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.func import jvp as TORCH_FUNC_JVP
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
TORCH_FUNC_JVP = None
|
||||||
|
|
||||||
|
|
||||||
|
class IMFVLAAgent(VLAAgent):
|
||||||
|
def __init__(self, *args, inference_steps: int = 1, **kwargs):
|
||||||
|
if inference_steps != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'IMFVLAAgent only supports one-step inference; '
|
||||||
|
f'inference_steps must be 1, got {inference_steps}.'
|
||||||
|
)
|
||||||
|
super().__init__(*args, inference_steps=inference_steps, **kwargs)
|
||||||
|
self.inference_steps = 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||||
|
while value.ndim < reference.ndim:
|
||||||
|
value = value.unsqueeze(-1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_conditioning(
|
||||||
|
trajectory: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if condition_data is None or condition_mask is None:
|
||||||
|
return trajectory
|
||||||
|
conditioned = trajectory.clone()
|
||||||
|
conditioned[condition_mask] = condition_data[condition_mask]
|
||||||
|
return conditioned
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||||
|
if z_t.is_cuda:
|
||||||
|
return torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=False,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=False,
|
||||||
|
enable_cudnn=False,
|
||||||
|
)
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||||
|
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||||
|
|
||||||
|
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||||
|
return self.noise_pred_net(z, r, t, cond=cond)
|
||||||
|
|
||||||
|
def _compute_u_and_du_dt(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond,
|
||||||
|
v: torch.Tensor,
|
||||||
|
condition_data: Optional[torch.Tensor] = None,
|
||||||
|
condition_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
tangents = self._jvp_tangents(v, r, t)
|
||||||
|
|
||||||
|
def g(z, r_value, t_value):
|
||||||
|
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||||
|
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||||
|
|
||||||
|
with self._jvp_math_sdp_context(z_t):
|
||||||
|
if TORCH_FUNC_JVP is not None:
|
||||||
|
try:
|
||||||
|
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||||
|
except (RuntimeError, TypeError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
u = g(z_t, r, t)
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
g,
|
||||||
|
(z_t, r, t),
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
return u, du_dt
|
||||||
|
|
||||||
|
def _compound_velocity(
|
||||||
|
self,
|
||||||
|
u: torch.Tensor,
|
||||||
|
du_dt: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
delta = self._broadcast_batch_time(t - r, u)
|
||||||
|
return u + delta * du_dt.detach()
|
||||||
|
|
||||||
|
def _sample_one_step(
|
||||||
|
self,
|
||||||
|
z_t: torch.Tensor,
|
||||||
|
r: Optional[torch.Tensor] = None,
|
||||||
|
t: Optional[torch.Tensor] = None,
|
||||||
|
cond=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = z_t.shape[0]
|
||||||
|
if t is None:
|
||||||
|
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
if r is None:
|
||||||
|
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||||
|
u = self.fn(z_t, r, t, cond=cond)
|
||||||
|
delta = self._broadcast_batch_time(t - r, z_t)
|
||||||
|
return z_t - delta * u
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
|
action_is_pad = batch.get('action_is_pad', None)
|
||||||
|
batch_size = actions.shape[0]
|
||||||
|
|
||||||
|
states = self.normalization.normalize_qpos(states)
|
||||||
|
actions = self.normalization.normalize_action(actions)
|
||||||
|
cond = self._build_cond(images, states)
|
||||||
|
|
||||||
|
x = actions
|
||||||
|
e = torch.randn_like(x)
|
||||||
|
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||||
|
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||||
|
|
||||||
|
t_broadcast = self._broadcast_batch_time(t, x)
|
||||||
|
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||||
|
|
||||||
|
v = self.fn(z_t, t, t, cond=cond)
|
||||||
|
u, du_dt = self._compute_u_and_du_dt(z_t, r, t, cond=cond, v=v)
|
||||||
|
V = self._compound_velocity(u, du_dt, r, t)
|
||||||
|
target = e - x
|
||||||
|
|
||||||
|
loss = F.mse_loss(V, target, reduction='none')
|
||||||
|
if action_is_pad is not None:
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(loss.dtype)
|
||||||
|
valid_count = mask.sum() * loss.shape[-1]
|
||||||
|
loss = (loss * mask).sum() / valid_count.clamp_min(1.0)
|
||||||
|
else:
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(self, images, proprioception):
|
||||||
|
batch_size = proprioception.shape[0]
|
||||||
|
proprioception = self.normalization.normalize_qpos(proprioception)
|
||||||
|
cond = self._build_cond(images, proprioception)
|
||||||
|
z_t = torch.randn((batch_size, self.pred_horizon, self.action_dim), device=cond.device, dtype=cond.dtype)
|
||||||
|
action = self._sample_one_step(z_t, cond=cond)
|
||||||
|
return self.normalization.denormalize_action(action)
|
||||||
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
40
roboimi/vla/conf/agent/resnet_imf_attnres.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# @package agent
|
||||||
|
defaults:
|
||||||
|
- /backbone@vision_backbone: resnet_diffusion
|
||||||
|
- /modules@state_encoder: identity_state_encoder
|
||||||
|
- /modules@action_encoder: identity_action_encoder
|
||||||
|
- /head: imf_transformer1d
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: roboimi.vla.agent_imf.IMFVLAAgent
|
||||||
|
|
||||||
|
action_dim: 16
|
||||||
|
obs_dim: 16
|
||||||
|
normalization_type: "min_max"
|
||||||
|
pred_horizon: 16
|
||||||
|
obs_horizon: 2
|
||||||
|
num_action_steps: 8
|
||||||
|
camera_names: ${data.camera_names}
|
||||||
|
num_cams: 3
|
||||||
|
|
||||||
|
vision_backbone:
|
||||||
|
num_cameras: ${agent.num_cams}
|
||||||
|
camera_names: ${agent.camera_names}
|
||||||
|
|
||||||
|
diffusion_steps: 100
|
||||||
|
inference_steps: 1
|
||||||
|
head_type: "transformer"
|
||||||
|
|
||||||
|
head:
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_head: 1
|
||||||
|
n_kv_head: 1
|
||||||
@@ -13,6 +13,7 @@ train:
|
|||||||
lr: 1e-4 # 学习率
|
lr: 1e-4 # 学习率
|
||||||
max_steps: 100000 # 最大训练步数
|
max_steps: 100000 # 最大训练步数
|
||||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
|
disable_cudnn: false # 遇到当前机器的 cuDNN 兼容性问题时可置 true
|
||||||
|
|
||||||
# 数据加载
|
# 数据加载
|
||||||
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
num_workers: 12 # DataLoader 工作进程数(调试时设为 0)
|
||||||
|
|||||||
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
22
roboimi/vla/conf/head/imf_transformer1d.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
_target_: roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D
|
||||||
|
_partial_: true
|
||||||
|
|
||||||
|
input_dim: ${agent.action_dim}
|
||||||
|
output_dim: ${agent.action_dim}
|
||||||
|
horizon: ${agent.pred_horizon}
|
||||||
|
n_obs_steps: ${agent.obs_horizon}
|
||||||
|
cond_dim: 208
|
||||||
|
n_layer: 12
|
||||||
|
n_head: 1
|
||||||
|
n_emb: 768
|
||||||
|
p_drop_emb: 0.1
|
||||||
|
p_drop_attn: 0.1
|
||||||
|
causal_attn: false
|
||||||
|
time_as_cond: true
|
||||||
|
obs_as_cond: true
|
||||||
|
n_cond_layers: 0
|
||||||
|
backbone_type: attnres_full
|
||||||
|
n_kv_head: 1
|
||||||
|
attn_res_ffn_mult: 2.667
|
||||||
|
attn_res_eps: 1.0e-6
|
||||||
|
attn_res_rope_theta: 10000.0
|
||||||
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
249
roboimi/vla/models/heads/attnres_transformer_components.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormNoWeight(nn.Module):
|
||||||
|
def __init__(self, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
return (x.float() * rms).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_rope_freqs(
|
||||||
|
dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
theta: float = 10000.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if dim % 2 != 0:
|
||||||
|
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||||
|
positions = torch.arange(max_seq_len, device=device).float()
|
||||||
|
angles = torch.outer(positions, freqs)
|
||||||
|
return torch.polar(torch.ones_like(angles), angles)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
||||||
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs
|
||||||
|
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupedQuerySelfAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if d_model % n_heads != 0:
|
||||||
|
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
||||||
|
if n_heads % n_kv_heads != 0:
|
||||||
|
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.n_kv_groups = n_heads // n_kv_heads
|
||||||
|
self.d_head = d_model // n_heads
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
self.out_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
||||||
|
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
||||||
|
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rope_freqs: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
||||||
|
|
||||||
|
q = apply_rope(q, rope_freqs)
|
||||||
|
k = apply_rope(k, rope_freqs)
|
||||||
|
|
||||||
|
if self.n_kv_heads != self.n_heads:
|
||||||
|
k = k.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
v = v.unsqueeze(3).expand(
|
||||||
|
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
||||||
|
)
|
||||||
|
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(self.d_head)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
||||||
|
if mask is not None:
|
||||||
|
attn_weights = attn_weights + mask
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
|
|
||||||
|
out = torch.matmul(attn_weights, v)
|
||||||
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
||||||
|
return self.out_dropout(self.w_o(out))
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(nn.Module):
|
||||||
|
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
||||||
|
super().__init__()
|
||||||
|
raw = int(mult * d_model)
|
||||||
|
d_ff = ((raw + 7) // 8) * 8
|
||||||
|
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
||||||
|
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResOperator(nn.Module):
|
||||||
|
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
||||||
|
self.key_norm = RMSNormNoWeight(eps=eps)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor) -> Tensor:
|
||||||
|
keys = self.key_norm(sources)
|
||||||
|
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
||||||
|
weights = F.softmax(logits, dim=0)
|
||||||
|
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResSubLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
dropout: float,
|
||||||
|
ffn_mult: float,
|
||||||
|
eps: float,
|
||||||
|
is_attention: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(d_model, eps=eps)
|
||||||
|
self.attn_res = AttnResOperator(d_model, eps=eps)
|
||||||
|
self.is_attention = is_attention
|
||||||
|
if self.is_attention:
|
||||||
|
self.fn = GroupedQuerySelfAttention(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
||||||
|
|
||||||
|
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
h = self.attn_res(sources)
|
||||||
|
normed = self.norm(h)
|
||||||
|
if self.is_attention:
|
||||||
|
return self.fn(normed, rope_freqs, mask)
|
||||||
|
return self.fn(normed)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnResTransformerBackbone(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_blocks: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
ffn_mult: float = 2.667,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.causal_attn = causal_attn
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for _ in range(n_blocks):
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers.append(
|
||||||
|
AttnResSubLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
dropout=dropout,
|
||||||
|
ffn_mult=ffn_mult,
|
||||||
|
eps=eps,
|
||||||
|
is_attention=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rope_freqs = precompute_rope_freqs(
|
||||||
|
dim=d_model // n_heads,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
theta=rope_theta,
|
||||||
|
)
|
||||||
|
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
||||||
|
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
return mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
rope_freqs = self.rope_freqs[:seq_len]
|
||||||
|
mask = None
|
||||||
|
if self.causal_attn:
|
||||||
|
mask = self._build_causal_mask(seq_len, x.device)
|
||||||
|
|
||||||
|
layer_outputs = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
sources = torch.stack(layer_outputs, dim=0)
|
||||||
|
output = layer(sources, rope_freqs, mask)
|
||||||
|
layer_outputs.append(output)
|
||||||
|
|
||||||
|
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
||||||
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
379
roboimi/vla/models/heads/imf_transformer1d.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""Local IMF-AttnRes transformer head aligned with diffusion_policy@185ed659."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .attnres_transformer_components import (
|
||||||
|
AttnResOperator,
|
||||||
|
AttnResSubLayer,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
RMSNorm,
|
||||||
|
RMSNormNoWeight,
|
||||||
|
SwiGLUFFN,
|
||||||
|
)
|
||||||
|
from .transformer1d import ModuleAttrMixin, SinusoidalPosEmb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformer1D(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: Optional[int] = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 12,
|
||||||
|
n_head: int = 1,
|
||||||
|
n_emb: int = 768,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
time_as_cond: bool = True,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0,
|
||||||
|
backbone_type: str = 'attnres_full',
|
||||||
|
n_kv_head: int = 1,
|
||||||
|
attn_res_ffn_mult: float = 2.667,
|
||||||
|
attn_res_eps: float = 1e-6,
|
||||||
|
attn_res_rope_theta: float = 10000.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if n_head != 1:
|
||||||
|
raise AssertionError('IMFTransformer1D currently supports single-head attention only.')
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
|
||||||
|
self.backbone_type = backbone_type
|
||||||
|
|
||||||
|
T = horizon
|
||||||
|
T_cond = 2
|
||||||
|
if not time_as_cond:
|
||||||
|
T += 2
|
||||||
|
T_cond -= 2
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
if obs_as_cond:
|
||||||
|
assert time_as_cond
|
||||||
|
T_cond += n_obs_steps
|
||||||
|
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||||
|
self.time_token_proj = None
|
||||||
|
self.cond_pos_emb = None
|
||||||
|
self.pos_emb = None
|
||||||
|
self.encoder = None
|
||||||
|
self.decoder = None
|
||||||
|
self.attnres_backbone = None
|
||||||
|
encoder_only = False
|
||||||
|
|
||||||
|
if backbone_type == 'attnres_full':
|
||||||
|
if not time_as_cond:
|
||||||
|
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
||||||
|
if n_cond_layers != 0:
|
||||||
|
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
||||||
|
|
||||||
|
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
||||||
|
self.attnres_backbone = AttnResTransformerBackbone(
|
||||||
|
d_model=n_emb,
|
||||||
|
n_blocks=n_layer,
|
||||||
|
n_heads=n_head,
|
||||||
|
n_kv_heads=n_kv_head,
|
||||||
|
max_seq_len=T + T_cond,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
ffn_mult=attn_res_ffn_mult,
|
||||||
|
eps=attn_res_eps,
|
||||||
|
rope_theta=attn_res_rope_theta,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
)
|
||||||
|
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
||||||
|
else:
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
if T_cond > 0:
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_only = True
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation='gelu',
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
|
||||||
|
if causal_attn and backbone_type != 'attnres_full':
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('mask', mask)
|
||||||
|
|
||||||
|
if time_as_cond and obs_as_cond:
|
||||||
|
S = T_cond
|
||||||
|
t_idx, s_idx = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(S),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
mask = t_idx >= (s_idx - 2)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer('memory_mask', mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
self.head = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.time_as_cond = time_as_cond
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.encoder_only = encoder_only
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info('number of parameters: %e', sum(p.numel() for p in self.parameters()))
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
AttnResTransformerBackbone,
|
||||||
|
AttnResSubLayer,
|
||||||
|
GroupedQuerySelfAttention,
|
||||||
|
SwiGLUFFN,
|
||||||
|
RMSNormNoWeight,
|
||||||
|
)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
for name in ('in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'):
|
||||||
|
weight = getattr(module, name)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
for name in ('in_proj_bias', 'bias_k', 'bias_v'):
|
||||||
|
bias = getattr(module, name)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
||||||
|
if getattr(module, 'bias', None) is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, AttnResOperator):
|
||||||
|
torch.nn.init.zeros_(module.pseudo_query)
|
||||||
|
elif isinstance(module, IMFTransformer1D):
|
||||||
|
if module.pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
if module.cond_pos_emb is not None:
|
||||||
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unaccounted module {module}')
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, _ in m.named_parameters(recurse=False):
|
||||||
|
fpn = f'{mn}.{pn}' if mn else pn
|
||||||
|
|
||||||
|
if pn.endswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.startswith('bias'):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn == 'pseudo_query':
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
if self.pos_emb is not None:
|
||||||
|
no_decay.add('pos_emb')
|
||||||
|
no_decay.add('_dummy_variable')
|
||||||
|
if self.cond_pos_emb is not None:
|
||||||
|
no_decay.add('cond_pos_emb')
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||||
|
elif value.ndim == 0:
|
||||||
|
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||||
|
else:
|
||||||
|
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||||
|
return value.expand(sample.shape[0])
|
||||||
|
|
||||||
|
def _forward_attnres_full(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
sample_tokens = self.input_emb(sample)
|
||||||
|
token_parts = [
|
||||||
|
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
||||||
|
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
||||||
|
]
|
||||||
|
if self.obs_as_cond:
|
||||||
|
if cond is None:
|
||||||
|
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
||||||
|
token_parts.append(self.cond_obs_emb(cond))
|
||||||
|
token_parts.append(sample_tokens)
|
||||||
|
x = torch.cat(token_parts, dim=1)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.attnres_backbone(x)
|
||||||
|
x = x[:, -sample_tokens.shape[1]:, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_vanilla(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r_emb = self.time_emb(r).unsqueeze(1)
|
||||||
|
t_emb = self.time_emb(t).unsqueeze(1)
|
||||||
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
|
if self.encoder_only:
|
||||||
|
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(src=x, mask=self.mask)
|
||||||
|
x = x[:, 2:, :]
|
||||||
|
else:
|
||||||
|
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond_embeddings = torch.cat([cond_embeddings, self.cond_obs_emb(cond)], dim=1)
|
||||||
|
token_count = cond_embeddings.shape[1]
|
||||||
|
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(cond_embeddings + position_embeddings)
|
||||||
|
x = self.encoder(x)
|
||||||
|
memory = x
|
||||||
|
|
||||||
|
token_embeddings = input_emb
|
||||||
|
token_count = token_embeddings.shape[1]
|
||||||
|
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||||
|
x = self.drop(token_embeddings + position_embeddings)
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r = self._prepare_time_input(r, sample)
|
||||||
|
t = self._prepare_time_input(t, sample)
|
||||||
|
|
||||||
|
if self.backbone_type == 'attnres_full':
|
||||||
|
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
||||||
|
else:
|
||||||
|
x = self._forward_vanilla(sample, r, t, cond=cond)
|
||||||
|
|
||||||
|
x = self.ln_f(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
196
tests/test_imf_transformer1d_external_alignment.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(_REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_REPO_ROOT))
|
||||||
|
|
||||||
|
_EXTERNAL_COMMIT = '185ed659'
|
||||||
|
_LOCAL_MODULE_NAME = 'roboimi.vla.models.heads.imf_transformer1d'
|
||||||
|
_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_external_checkout_root() -> Path | None:
|
||||||
|
for ancestor in (_REPO_ROOT, *_REPO_ROOT.parents):
|
||||||
|
candidate = ancestor / 'diffusion_policy'
|
||||||
|
if (candidate / '.git').exists():
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_EXTERNAL_CHECKOUT_ROOT = _find_external_checkout_root()
|
||||||
|
_EXTERNAL_MODULE_PATHS = {
|
||||||
|
'diffusion_policy.model.common.module_attr_mixin': 'diffusion_policy/model/common/module_attr_mixin.py',
|
||||||
|
'diffusion_policy.model.diffusion.positional_embedding': 'diffusion_policy/model/diffusion/positional_embedding.py',
|
||||||
|
'diffusion_policy.model.diffusion.attnres_transformer_components': 'diffusion_policy/model/diffusion/attnres_transformer_components.py',
|
||||||
|
'diffusion_policy.model.diffusion.imf_transformer_for_diffusion': 'diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _temporary_registered_modules():
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def remember(name: str) -> None:
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
|
||||||
|
def ensure_package(name: str) -> None:
|
||||||
|
if not name or name in sys.modules:
|
||||||
|
return
|
||||||
|
remember(name)
|
||||||
|
package = types.ModuleType(name)
|
||||||
|
package.__path__ = []
|
||||||
|
sys.modules[name] = package
|
||||||
|
|
||||||
|
def load(name: str, source: str, origin: str):
|
||||||
|
package_parts = name.split('.')[:-1]
|
||||||
|
for idx in range(1, len(package_parts) + 1):
|
||||||
|
ensure_package('.'.join(package_parts[:idx]))
|
||||||
|
|
||||||
|
remember(name)
|
||||||
|
module = types.ModuleType(name)
|
||||||
|
module.__file__ = origin
|
||||||
|
module.__package__ = name.rpartition('.')[0]
|
||||||
|
sys.modules[name] = module
|
||||||
|
exec(compile(source, origin, 'exec'), module.__dict__)
|
||||||
|
return module
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield load
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
def _git_show(repo_root: Path, commit: str, relative_path: str) -> str:
|
||||||
|
result = subprocess.run(
|
||||||
|
['git', '-C', str(repo_root), 'show', f'{commit}:{relative_path}'],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _load_external_module_or_skip(test_case: unittest.TestCase):
|
||||||
|
if _EXTERNAL_CHECKOUT_ROOT is None:
|
||||||
|
test_case.skipTest('external diffusion_policy checkout unavailable')
|
||||||
|
|
||||||
|
try:
|
||||||
|
sources = {
|
||||||
|
name: _git_show(_EXTERNAL_CHECKOUT_ROOT, _EXTERNAL_COMMIT, relative_path)
|
||||||
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items()
|
||||||
|
}
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
test_case.skipTest(
|
||||||
|
f'external diffusion_policy commit {_EXTERNAL_COMMIT} is unavailable: {exc.stderr.strip() or exc}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with _temporary_registered_modules() as load_external:
|
||||||
|
for name, relative_path in _EXTERNAL_MODULE_PATHS.items():
|
||||||
|
load_external(
|
||||||
|
name,
|
||||||
|
sources[name],
|
||||||
|
origin=f'{_EXTERNAL_CHECKOUT_ROOT}:{_EXTERNAL_COMMIT}:{relative_path}',
|
||||||
|
)
|
||||||
|
yield sys.modules['diffusion_policy.model.diffusion.imf_transformer_for_diffusion']
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_module():
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
sys.modules.pop(_LOCAL_MODULE_NAME, None)
|
||||||
|
return importlib.import_module(_LOCAL_MODULE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
class IMFTransformer1DExternalAlignmentTest(unittest.TestCase):
|
||||||
|
def _optim_group_names(self, model, groups):
|
||||||
|
names_by_param = {id(param): name for name, param in model.named_parameters()}
|
||||||
|
return [
|
||||||
|
{names_by_param[id(param)] for param in group['params']}
|
||||||
|
for group in groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_local_defaults_preserve_supported_attnres_config(self):
|
||||||
|
local_module = _load_local_module()
|
||||||
|
ctor = inspect.signature(local_module.IMFTransformer1D.__init__).parameters
|
||||||
|
|
||||||
|
self.assertEqual(ctor['backbone_type'].default, 'attnres_full')
|
||||||
|
self.assertEqual(ctor['n_head'].default, 1)
|
||||||
|
self.assertEqual(ctor['n_kv_head'].default, 1)
|
||||||
|
self.assertEqual(ctor['n_cond_layers'].default, 0)
|
||||||
|
self.assertTrue(ctor['time_as_cond'].default)
|
||||||
|
self.assertFalse(ctor['causal_attn'].default)
|
||||||
|
|
||||||
|
def test_attnres_full_state_dict_forward_and_optim_groups_match_external(self):
|
||||||
|
local_module = _load_local_module()
|
||||||
|
with _load_external_module_or_skip(self) as external_module:
|
||||||
|
config = dict(
|
||||||
|
input_dim=4,
|
||||||
|
output_dim=4,
|
||||||
|
horizon=6,
|
||||||
|
n_obs_steps=3,
|
||||||
|
cond_dim=5,
|
||||||
|
n_layer=2,
|
||||||
|
n_head=1,
|
||||||
|
n_emb=16,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=False,
|
||||||
|
time_as_cond=True,
|
||||||
|
n_cond_layers=0,
|
||||||
|
backbone_type='attnres_full',
|
||||||
|
n_kv_head=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(7)
|
||||||
|
external_model = external_module.IMFTransformerForDiffusion(**config)
|
||||||
|
local_model = local_module.IMFTransformer1D(**config)
|
||||||
|
external_model.eval()
|
||||||
|
local_model.eval()
|
||||||
|
|
||||||
|
external_state_dict = external_model.state_dict()
|
||||||
|
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
|
||||||
|
local_model.load_state_dict(external_state_dict, strict=True)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
|
||||||
|
r = torch.tensor([0.1, 0.4], dtype=torch.float32)
|
||||||
|
t = torch.tensor([0.7, 0.9], dtype=torch.float32)
|
||||||
|
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
external_out = external_model(sample=sample, r=r, t=t, cond=cond)
|
||||||
|
local_out = local_model(sample=sample, r=r, t=t, cond=cond)
|
||||||
|
|
||||||
|
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
|
||||||
|
self.assertEqual(local_out.shape, external_out.shape)
|
||||||
|
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
|
||||||
|
|
||||||
|
weight_decay = 0.123
|
||||||
|
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
|
||||||
|
self.assertEqual(len(local_groups), len(external_groups))
|
||||||
|
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
|
||||||
|
self.assertEqual(
|
||||||
|
self._optim_group_names(local_model, local_groups),
|
||||||
|
self._optim_group_names(external_model, external_groups),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
427
tests/test_imf_vla_agent.py
Normal file
427
tests/test_imf_vla_agent.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hydra import compose, initialize_config_dir
|
||||||
|
from hydra.core.global_hydra import GlobalHydra
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||||
|
_MISSING = object()
|
||||||
|
_CAMERA_NAMES = ('r_vis', 'top', 'front')
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScheduler:
|
||||||
|
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||||
|
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||||
|
self.timesteps = []
|
||||||
|
|
||||||
|
def add_noise(self, sample, noise, timestep):
|
||||||
|
return sample + noise
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps):
|
||||||
|
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||||
|
|
||||||
|
def step(self, noise_pred, timestep, sample):
|
||||||
|
return types.SimpleNamespace(prev_sample=sample)
|
||||||
|
|
||||||
|
|
||||||
|
class _IdentityCrop:
|
||||||
|
def __init__(self, size):
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||||
|
self.relu2 = nn.ReLU()
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(16, 16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu1(self.conv1(x))
|
||||||
|
x = self.relu2(self.conv2(x))
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, start_dim=1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRearrange(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _StubIMFHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
horizon,
|
||||||
|
n_obs_steps,
|
||||||
|
cond_dim,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.constructor_kwargs = {
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'output_dim': output_dim,
|
||||||
|
'horizon': horizon,
|
||||||
|
'n_obs_steps': n_obs_steps,
|
||||||
|
'cond_dim': cond_dim,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
self.proj = nn.Linear(input_dim, output_dim)
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1))
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
return torch.zeros_like(sample)
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay):
|
||||||
|
return [
|
||||||
|
{'params': [self.proj.weight], 'weight_decay': weight_decay},
|
||||||
|
{'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _stub_optional_modules(include_imf_head=False):
|
||||||
|
previous_modules = {}
|
||||||
|
|
||||||
|
def inject(name, module):
|
||||||
|
if name not in previous_modules:
|
||||||
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||||
|
sys.modules[name] = module
|
||||||
|
|
||||||
|
diffusers_module = types.ModuleType('diffusers')
|
||||||
|
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||||
|
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||||
|
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||||
|
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||||
|
ddim_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||||
|
diffusers_module.schedulers = schedulers_module
|
||||||
|
schedulers_module.scheduling_ddpm = ddpm_module
|
||||||
|
schedulers_module.scheduling_ddim = ddim_module
|
||||||
|
|
||||||
|
torchvision_module = types.ModuleType('torchvision')
|
||||||
|
models_module = types.ModuleType('torchvision.models')
|
||||||
|
transforms_module = types.ModuleType('torchvision.transforms')
|
||||||
|
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||||
|
transforms_module.CenterCrop = _IdentityCrop
|
||||||
|
transforms_module.RandomCrop = _IdentityCrop
|
||||||
|
torchvision_module.models = models_module
|
||||||
|
torchvision_module.transforms = transforms_module
|
||||||
|
|
||||||
|
einops_module = types.ModuleType('einops')
|
||||||
|
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||||
|
einops_layers_module = types.ModuleType('einops.layers')
|
||||||
|
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||||
|
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||||
|
einops_module.layers = einops_layers_module
|
||||||
|
einops_layers_module.torch = einops_layers_torch_module
|
||||||
|
|
||||||
|
try:
|
||||||
|
inject('diffusers', diffusers_module)
|
||||||
|
inject('diffusers.schedulers', schedulers_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||||
|
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||||
|
inject('torchvision', torchvision_module)
|
||||||
|
inject('torchvision.models', models_module)
|
||||||
|
inject('torchvision.transforms', transforms_module)
|
||||||
|
inject('einops', einops_module)
|
||||||
|
inject('einops.layers', einops_layers_module)
|
||||||
|
inject('einops.layers.torch', einops_layers_torch_module)
|
||||||
|
|
||||||
|
if include_imf_head:
|
||||||
|
import roboimi.vla.models.heads as heads_package
|
||||||
|
|
||||||
|
imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d')
|
||||||
|
imf_head_module.IMFTransformer1D = _StubIMFHead
|
||||||
|
inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module)
|
||||||
|
setattr(heads_package, 'imf_transformer1d', imf_head_module)
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for name, previous in reversed(list(previous_modules.items())):
|
||||||
|
if previous is _MISSING:
|
||||||
|
sys.modules.pop(name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[name] = previous
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_cfg(overrides=None):
|
||||||
|
if not OmegaConf.has_resolver('len'):
|
||||||
|
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||||
|
|
||||||
|
GlobalHydra.instance().clear()
|
||||||
|
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||||
|
return compose(config_name='config', overrides=list(overrides or []))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_imf_agent_class():
|
||||||
|
with _stub_optional_modules():
|
||||||
|
sys.modules.pop('roboimi.vla.agent_imf', None)
|
||||||
|
module = importlib.import_module('roboimi.vla.agent_imf')
|
||||||
|
return module.IMFVLAAgent, module
|
||||||
|
|
||||||
|
|
||||||
|
class _StubVisionBackbone(nn.Module):
|
||||||
|
output_dim = 1
|
||||||
|
|
||||||
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = tuple(camera_names)
|
||||||
|
self.num_cameras = len(self.camera_names)
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
per_camera_features = []
|
||||||
|
for camera_name in self.camera_names:
|
||||||
|
image_batch = images[camera_name]
|
||||||
|
per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||||
|
return torch.cat(per_camera_features, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingLinearIMFHead(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.tensor(0.5))
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _broadcast_batch_time(value, reference):
|
||||||
|
while value.ndim < reference.ndim:
|
||||||
|
value = value.unsqueeze(-1)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def forward(self, sample, r, t, cond=None):
|
||||||
|
record = {
|
||||||
|
'sample': sample.detach().clone(),
|
||||||
|
'r': r.detach().clone(),
|
||||||
|
't': t.detach().clone(),
|
||||||
|
'cond': None if cond is None else cond.detach().clone(),
|
||||||
|
}
|
||||||
|
self.calls.append(record)
|
||||||
|
cond_term = 0.0
|
||||||
|
if cond is not None:
|
||||||
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
r_b = self._broadcast_batch_time(r, sample)
|
||||||
|
t_b = self._broadcast_batch_time(t, sample)
|
||||||
|
return self.scale * sample + r_b + 2.0 * t_b + cond_term
|
||||||
|
|
||||||
|
|
||||||
|
class _ForbiddenScheduler:
|
||||||
|
def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||||
|
raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps')
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
||||||
|
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||||
|
|
||||||
|
|
||||||
|
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||||
|
return {
|
||||||
|
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||||
|
for name, value in per_camera_fill.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class IMFVLAAgentTest(unittest.TestCase):
|
||||||
|
def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2):
|
||||||
|
agent_cls, agent_module = _load_imf_agent_class()
|
||||||
|
head = _RecordingLinearIMFHead()
|
||||||
|
agent = agent_cls(
|
||||||
|
vision_backbone=_StubVisionBackbone(),
|
||||||
|
state_encoder=nn.Identity(),
|
||||||
|
action_encoder=nn.Identity(),
|
||||||
|
head=head,
|
||||||
|
action_dim=2,
|
||||||
|
obs_dim=1,
|
||||||
|
pred_horizon=pred_horizon,
|
||||||
|
obs_horizon=obs_horizon,
|
||||||
|
diffusion_steps=10,
|
||||||
|
inference_steps=1,
|
||||||
|
num_cams=len(_CAMERA_NAMES),
|
||||||
|
camera_names=_CAMERA_NAMES,
|
||||||
|
num_action_steps=num_action_steps,
|
||||||
|
head_type='transformer',
|
||||||
|
)
|
||||||
|
return agent, head, agent_module
|
||||||
|
|
||||||
|
def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self):
|
||||||
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=1,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||||
|
actions = torch.tensor(
|
||||||
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
action_is_pad = torch.tensor([[False, False, True]])
|
||||||
|
noise = torch.tensor(
|
||||||
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||||
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||||
|
loss = agent.compute_loss(
|
||||||
|
{
|
||||||
|
'images': images,
|
||||||
|
'qpos': qpos,
|
||||||
|
'action': actions,
|
||||||
|
'action_is_pad': action_is_pad,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32)
|
||||||
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
t = t_sample
|
||||||
|
r = r_sample
|
||||||
|
z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise
|
||||||
|
scale = head.scale.detach()
|
||||||
|
u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term
|
||||||
|
v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term
|
||||||
|
du_dt = scale * v + 2.0
|
||||||
|
compound_velocity = u + (t - r).view(1, 1, 1) * du_dt
|
||||||
|
target = noise - actions
|
||||||
|
elementwise_loss = (compound_velocity - target) ** 2
|
||||||
|
mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype)
|
||||||
|
expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1])
|
||||||
|
|
||||||
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
||||||
|
self.assertEqual(len(head.calls), 2)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], cond))
|
||||||
|
|
||||||
|
def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self):
|
||||||
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
||||||
|
agent.infer_scheduler = _ForbiddenScheduler()
|
||||||
|
|
||||||
|
images = _make_images(
|
||||||
|
batch_size=2,
|
||||||
|
obs_horizon=2,
|
||||||
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||||
|
)
|
||||||
|
qpos = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0], [2.0]],
|
||||||
|
[[3.0], [4.0]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
initial_noise = torch.tensor(
|
||||||
|
[
|
||||||
|
[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]],
|
||||||
|
[[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||||
|
predicted_actions = agent.predict_action(images, qpos)
|
||||||
|
|
||||||
|
expected_cond = torch.tensor(
|
||||||
|
[
|
||||||
|
[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]],
|
||||||
|
[[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
cond_term = expected_cond.mean(dim=(1, 2), keepdim=True)
|
||||||
|
expected_actions = 0.5 * initial_noise - 2.0 - cond_term
|
||||||
|
|
||||||
|
self.assertEqual(predicted_actions.shape, (2, 3, 2))
|
||||||
|
self.assertTrue(torch.allclose(predicted_actions, expected_actions))
|
||||||
|
self.assertEqual(len(head.calls), 1)
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2)))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||||
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||||
|
|
||||||
|
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||||
|
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||||
|
observation = {
|
||||||
|
'qpos': torch.tensor([0.25], dtype=torch.float32),
|
||||||
|
'images': {
|
||||||
|
'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32),
|
||||||
|
'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32),
|
||||||
|
'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
first_chunk = torch.tensor(
|
||||||
|
[[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
second_chunk = torch.tensor(
|
||||||
|
[[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk:
|
||||||
|
first_action = agent.select_action(observation)
|
||||||
|
second_action = agent.select_action(observation)
|
||||||
|
third_action = agent.select_action(observation)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(first_action, first_chunk[0, 1]))
|
||||||
|
self.assertTrue(torch.equal(second_action, first_chunk[0, 2]))
|
||||||
|
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
||||||
|
self.assertEqual(mock_predict_chunk.call_count, 2)
|
||||||
|
|
||||||
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
||||||
|
cfg = _compose_cfg(
|
||||||
|
overrides=[
|
||||||
|
'agent=resnet_imf_attnres',
|
||||||
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||||
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||||
|
'agent.vision_backbone.freeze_backbone=false',
|
||||||
|
'agent.head.n_layer=1',
|
||||||
|
'agent.head.n_emb=16',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||||
|
self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D')
|
||||||
|
self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full')
|
||||||
|
self.assertEqual(cfg.agent.head.n_head, 1)
|
||||||
|
self.assertEqual(cfg.agent.head.n_kv_head, 1)
|
||||||
|
self.assertEqual(cfg.agent.head.n_cond_layers, 0)
|
||||||
|
self.assertTrue(cfg.agent.head.time_as_cond)
|
||||||
|
self.assertFalse(cfg.agent.head.causal_attn)
|
||||||
|
self.assertEqual(cfg.agent.inference_steps, 1)
|
||||||
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||||
|
|
||||||
|
with _stub_optional_modules(include_imf_head=True):
|
||||||
|
agent = instantiate(cfg.agent)
|
||||||
|
|
||||||
|
self.assertEqual(agent.head_type, 'transformer')
|
||||||
|
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim)
|
||||||
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||||
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -101,10 +101,19 @@ class RecordingTransformerHead(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FakeTransformerAgent(nn.Module):
|
class FakeIMFAgent(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_type = 'transformer'
|
self.head_type = 'imf_transformer'
|
||||||
|
self.noise_pred_net = RecordingTransformerHead()
|
||||||
|
self.backbone = nn.Linear(4, 3)
|
||||||
|
self.adapter = nn.Linear(3, 2, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTransformerAgent(nn.Module):
|
||||||
|
def __init__(self, *, head_type='transformer'):
|
||||||
|
super().__init__()
|
||||||
|
self.head_type = head_type
|
||||||
self.noise_pred_net = RecordingTransformerHead()
|
self.noise_pred_net = RecordingTransformerHead()
|
||||||
self.backbone = nn.Linear(4, 3)
|
self.backbone = nn.Linear(4, 3)
|
||||||
self.adapter = nn.Linear(3, 2, bias=False)
|
self.adapter = nn.Linear(3, 2, bias=False)
|
||||||
@@ -205,6 +214,47 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|||||||
for group in optimizer.param_groups
|
for group in optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_configure_cuda_runtime_can_disable_cudnn_for_training(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
cfg = AttrDict(train=AttrDict(device='cuda', disable_cudnn=True))
|
||||||
|
|
||||||
|
original = module.torch.backends.cudnn.enabled
|
||||||
|
try:
|
||||||
|
module.torch.backends.cudnn.enabled = True
|
||||||
|
module._configure_cuda_runtime(cfg)
|
||||||
|
self.assertFalse(module.torch.backends.cudnn.enabled)
|
||||||
|
finally:
|
||||||
|
module.torch.backends.cudnn.enabled = original
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_script_uses_file_based_repo_root_on_sys_path(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
|
||||||
|
fake_sys_path = ['/tmp/site-packages', '/another/path']
|
||||||
|
with mock.patch.object(module.sys, 'path', fake_sys_path):
|
||||||
|
repo_root = module._ensure_repo_root_on_syspath()
|
||||||
|
|
||||||
|
self.assertEqual(Path(repo_root).resolve(), _REPO_ROOT.resolve())
|
||||||
|
self.assertEqual(Path(fake_sys_path[0]).resolve(), _REPO_ROOT.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_transformer_head_with_get_optim_groups_still_uses_custom_groups(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeIMFAgent()
|
||||||
|
|
||||||
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||||
|
group_names = self._group_names(agent, optimizer)
|
||||||
|
self.assertEqual(group_names[0], {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(group_names[1], {
|
||||||
|
'noise_pred_net.proj.bias',
|
||||||
|
'noise_pred_net.norm.weight',
|
||||||
|
'noise_pred_net.norm.bias',
|
||||||
|
})
|
||||||
|
self.assertEqual(group_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||||
|
|
||||||
|
|
||||||
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
agent = FakeTransformerAgent()
|
agent = FakeTransformerAgent()
|
||||||
@@ -268,6 +318,22 @@ class TrainVLATransformerOptimizerTest(unittest.TestCase):
|
|||||||
self.assertNotIn('frozen.weight', optimizer_names)
|
self.assertNotIn('frozen.weight', optimizer_names)
|
||||||
self.assertNotIn('frozen.bias', optimizer_names)
|
self.assertNotIn('frozen.bias', optimizer_names)
|
||||||
|
|
||||||
|
def test_any_head_with_get_optim_groups_uses_custom_groups_even_without_transformer_head_type(self):
|
||||||
|
module = self._load_train_vla_module()
|
||||||
|
agent = FakeTransformerAgent(head_type='imf')
|
||||||
|
|
||||||
|
with mock.patch.object(module, 'AdamW', RecordingAdamW):
|
||||||
|
optimizer = module.build_training_optimizer(agent, lr=1e-4, weight_decay=0.123)
|
||||||
|
|
||||||
|
self.assertEqual(agent.noise_pred_net.optim_group_calls, [0.123])
|
||||||
|
grouped_names = self._group_names(agent, optimizer)
|
||||||
|
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
|
||||||
|
self.assertEqual(
|
||||||
|
grouped_names[1],
|
||||||
|
{'noise_pred_net.proj.bias', 'noise_pred_net.norm.weight', 'noise_pred_net.norm.bias'},
|
||||||
|
)
|
||||||
|
self.assertEqual(grouped_names[2], {'backbone.weight', 'backbone.bias', 'adapter.weight'})
|
||||||
|
|
||||||
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
|
||||||
module = self._load_train_vla_module()
|
module = self._load_train_vla_module()
|
||||||
agent = FakeTransformerAgent()
|
agent = FakeTransformerAgent()
|
||||||
|
|||||||
Reference in New Issue
Block a user