feat: add pusht image imf transformer

This commit is contained in:
Logic
2026-03-26 20:41:37 +08:00
parent 5e7ae6cfa5
commit 4cd5085b33
5 changed files with 960 additions and 0 deletions

View File

@@ -0,0 +1,298 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class IMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 1,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
time_as_cond: bool = True,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
) -> None:
super().__init__()
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 2
if not time_as_cond:
T += 2
T_cond -= 2
obs_as_cond = cond_dim > 0
if obs_as_cond:
assert time_as_cond
T_cond += n_obs_steps
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_pos_emb = None
self.encoder = None
self.decoder = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('mask', mask)
if time_as_cond and obs_as_cond:
S = T_cond
t_idx, s_idx = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij',
)
mask = t_idx >= (s_idx - 2)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.encoder_only = encoder_only
self.apply(self._init_weights)
logger.info(
'number of parameters: %e',
sum(p.numel() for p in self.parameters()),
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
for name in weight_names:
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
for name in bias_names:
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, IMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError(f'Unaccounted module {module}')
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn
if pn.endswith('bias'):
no_decay.add(fpn)
elif pn.startswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.add('pos_emb')
no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb')
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
assert len(param_dict.keys() - union_params) == 0, (
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
)
optim_groups = [
{
'params': [param_dict[pn] for pn in sorted(list(decay))],
'weight_decay': weight_decay,
},
{
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
'weight_decay': 0.0,
},
]
return optim_groups
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
elif value.ndim == 0:
value = value[None].to(device=sample.device, dtype=sample.dtype)
else:
value = value.to(device=sample.device, dtype=sample.dtype)
return value.expand(sample.shape[0])
def forward(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
r_emb = self.time_emb(r).unsqueeze(1)
t_emb = self.time_emb(t).unsqueeze(1)
input_emb = self.input_emb(sample)
if self.encoder_only:
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 2:, :]
else:
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
if self.obs_as_cond:
cond_obs_emb = self.cond_obs_emb(cond)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
token_count = cond_embeddings.shape[1]
position_embeddings = self.cond_pos_emb[:, :token_count, :]
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
token_embeddings = input_emb
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
x = self.ln_f(x)
x = self.head(x)
return x

View File

@@ -0,0 +1,273 @@
from contextlib import nullcontext
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from einops import reduce
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
DiffusionTransformerHybridImagePolicy,
)
try:
from torch.func import jvp as TORCH_FUNC_JVP
except ImportError: # pragma: no cover - depends on torch version
TORCH_FUNC_JVP = None
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
if num_inference_steps is None:
num_inference_steps = 1
elif num_inference_steps != 1:
raise ValueError(
'IMFTransformerHybridImagePolicy only supports one-step inference; '
f'num_inference_steps must be 1, got {num_inference_steps}.'
)
super().__init__(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
crop_shape=crop_shape,
obs_encoder_group_norm=obs_encoder_group_norm,
eval_fixed_crop=eval_fixed_crop,
n_layer=n_layer,
n_cond_layers=n_cond_layers,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
pred_action_steps_only=pred_action_steps_only,
**kwargs,
)
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
self.model = IMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=model_horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
)
self.num_inference_steps = 1
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
return self.model(z, r, t, cond=cond)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
while value.ndim < reference.ndim:
value = value.unsqueeze(-1)
return value
@staticmethod
def _apply_conditioning(
trajectory: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if condition_data is None or condition_mask is None:
return trajectory
conditioned = trajectory.clone()
conditioned[condition_mask] = condition_data[condition_mask]
return conditioned
@staticmethod
def _jvp_math_sdp_context(z_t: torch.Tensor):
if z_t.is_cuda:
return torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
enable_cudnn=False,
)
return nullcontext()
@staticmethod
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
def _compute_u_and_du_dt(
self,
z_t: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond,
v: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
):
tangents = self._jvp_tangents(v, r, t)
def g(z, r_value, t_value):
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
return self.fn(conditioned_z, r_value, t_value, cond=cond)
with self._jvp_math_sdp_context(z_t):
if TORCH_FUNC_JVP is not None:
try:
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
except (RuntimeError, TypeError, NotImplementedError):
pass
u = g(z_t, r, t)
_, du_dt = torch.autograd.functional.jvp(
g,
(z_t, r, t),
tangents,
create_graph=False,
strict=False,
)
return u, du_dt
def _compound_velocity(
self,
u: torch.Tensor,
du_dt: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
delta = self._broadcast_batch_time(t - r, u)
return u + delta * du_dt.detach()
def _sample_one_step(
self,
z_t: torch.Tensor,
r: torch.Tensor = None,
t: torch.Tensor = None,
cond=None,
) -> torch.Tensor:
batch_size = z_t.shape[0]
if t is None:
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
if r is None:
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
u = self.fn(z_t, r, t, cond=cond)
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
trajectory = self._sample_one_step(trajectory, cond=cond)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def compute_loss(self, batch):
assert 'valid_mask' not in batch
nobs = self.normalizer.normalize(batch['obs'])
nactions = self.normalizer['action'].normalize(batch['action'])
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
To = self.n_obs_steps
cond = None
trajectory = nactions
if self.obs_as_cond:
this_nobs = dict_apply(
nobs,
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
)
nobs_features = self.obs_encoder(this_nobs)
cond = nobs_features.reshape(batch_size, To, -1)
if self.pred_action_steps_only:
start = To - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
loss_mask[..., : self.action_dim] = True
loss_mask = loss_mask & ~condition_mask
x = trajectory
e = torch.randn_like(x)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t, r = torch.maximum(t, r), torch.minimum(t, r)
t_broadcast = self._broadcast_batch_time(t, x)
z_t = (1 - t_broadcast) * x + t_broadcast * e
z_t = self._apply_conditioning(z_t, x, condition_mask)
v = self.fn(z_t, t, t, cond=cond)
u, du_dt = self._compute_u_and_du_dt(
z_t,
r,
t,
cond=cond,
v=v,
condition_data=x,
condition_mask=condition_mask,
)
V = self._compound_velocity(u, du_dt, r, t)
target = e - x
loss = F.mse_loss(V, target, reduction='none')
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss.mean()
return loss