3 Commits

Author SHA1 Message Date
gameloader
42dc29a2cb feat: align pmf transformer training and config defaults 2026-03-16 15:37:32 +08:00
Logic
79f31940c4 feat(pusht): add pMF-style DiT image policy 2026-03-16 11:11:43 +08:00
Logic
2aa06c8917 fix(pusht): stabilize DiT pusht training on current stack 2026-03-15 18:54:50 +08:00
17 changed files with 1134 additions and 1776 deletions

View File

@@ -1,37 +1,21 @@
import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
del all_video_paths
max_rewards = collections.defaultdict(list)
log_data = dict()
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
max_reward = np.max(rewards)
max_rewards[prefix].append(max_reward)
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
aggregate_key_map = {
'train/': 'train_mean_score',
'test/': 'test_mean_score',
}
for prefix, value in max_rewards.items():
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
return log_data
class PushTImageRunner(BaseImageRunner):
def __init__(self,
output_dir,
@@ -56,11 +40,24 @@ class PushTImageRunner(BaseImageRunner):
if n_envs is None:
n_envs = n_train + n_test
steps_per_render = max(10 // fps, 1)
def env_fn():
return MultiStepWrapper(
PushTImageEnv(
legacy=legacy_test,
render_size=render_size
VideoRecordingWrapper(
PushTImageEnv(
legacy=legacy_test,
render_size=render_size
),
video_recoder=VideoRecorder.create_h264(
fps=fps,
codec='h264',
input_pix_fmt='rgb24',
crf=crf,
thread_type='FRAME',
thread_count=1
),
file_path=None,
steps_per_render=steps_per_render
),
n_obs_steps=n_obs_steps,
n_action_steps=n_action_steps,
@@ -74,8 +71,21 @@ class PushTImageRunner(BaseImageRunner):
# train
for i in range(n_train):
seed = train_start_seed + i
enable_render = i < n_train_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed
assert isinstance(env, MultiStepWrapper)
env.seed(seed)
@@ -87,8 +97,21 @@ class PushTImageRunner(BaseImageRunner):
# test
for i in range(n_test):
seed = test_start_seed + i
enable_render = i < n_test_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed
assert isinstance(env, MultiStepWrapper)
env.seed(seed)
@@ -131,6 +154,7 @@ class PushTImageRunner(BaseImageRunner):
n_chunks = math.ceil(n_inits / n_envs)
# allocate data
all_video_paths = [None] * n_inits
all_rewards = [None] * n_inits
for chunk_idx in range(n_chunks):
@@ -190,16 +214,39 @@ class PushTImageRunner(BaseImageRunner):
pbar.update(action.shape[1])
pbar.close()
all_video_paths[this_global_slice] = env.render()[this_local_slice]
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
# reset env state between evaluation calls
# clear out video buffer
_ = env.reset()
# results reported in the paper are generated using the commented out
# line below, which would only report and average metrics from the
# first n_envs initial conditions and seeds. We keep the full n_inits
# behavior here.
return summarize_rollout_metrics(
env_seeds=self.env_seeds[:n_inits],
env_prefixs=self.env_prefixs[:n_inits],
all_rewards=all_rewards[:n_inits],
)
# log
max_rewards = collections.defaultdict(list)
log_data = dict()
# results reported in the paper are generated using the commented out line below
# which will only report and average metrics from first n_envs initial condition and seeds
# fortunately this won't invalidate our conclusion since
# 1. This bug only affects the variance of metrics, not their mean
# 2. All baseline methods are evaluated using the same code
# to completely reproduce reported numbers, uncomment this line:
# for i in range(len(self.env_fns)):
# and comment out this line
for i in range(n_inits):
seed = self.env_seeds[i]
prefix = self.env_prefixs[i]
max_reward = np.max(all_rewards[i])
max_rewards[prefix].append(max_reward)
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
# visualize sim
video_path = all_video_paths[i]
if video_path is not None:
sim_video = wandb.Video(video_path)
log_data[prefix+f'sim_video_{seed}'] = sim_video
# log aggregate metrics
for prefix, value in max_rewards.items():
name = prefix+'mean_score'
value = np.mean(value)
log_data[name] = value
return log_data

View File

@@ -8,8 +8,7 @@ import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns)
env = SyncVectorEnv(env_fns)
# test env
# env.reset(seed=env_seeds)

View File

@@ -1,298 +0,0 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class IMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 1,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
time_as_cond: bool = True,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
) -> None:
super().__init__()
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 2
if not time_as_cond:
T += 2
T_cond -= 2
obs_as_cond = cond_dim > 0
if obs_as_cond:
assert time_as_cond
T_cond += n_obs_steps
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_pos_emb = None
self.encoder = None
self.decoder = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('mask', mask)
if time_as_cond and obs_as_cond:
S = T_cond
t_idx, s_idx = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij',
)
mask = t_idx >= (s_idx - 2)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.encoder_only = encoder_only
self.apply(self._init_weights)
logger.info(
'number of parameters: %e',
sum(p.numel() for p in self.parameters()),
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
for name in weight_names:
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
for name in bias_names:
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, IMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError(f'Unaccounted module {module}')
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn
if pn.endswith('bias'):
no_decay.add(fpn)
elif pn.startswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.add('pos_emb')
no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb')
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
assert len(param_dict.keys() - union_params) == 0, (
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
)
optim_groups = [
{
'params': [param_dict[pn] for pn in sorted(list(decay))],
'weight_decay': weight_decay,
},
{
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
'weight_decay': 0.0,
},
]
return optim_groups
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
elif value.ndim == 0:
value = value[None].to(device=sample.device, dtype=sample.dtype)
else:
value = value.to(device=sample.device, dtype=sample.dtype)
return value.expand(sample.shape[0])
def forward(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
r_emb = self.time_emb(r).unsqueeze(1)
t_emb = self.time_emb(t).unsqueeze(1)
input_emb = self.input_emb(sample)
if self.encoder_only:
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 2:, :]
else:
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
if self.obs_as_cond:
cond_obs_emb = self.cond_obs_emb(cond)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
token_count = cond_embeddings.shape[1]
position_embeddings = self.cond_pos_emb[:, :token_count, :]
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
token_embeddings = input_emb
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
x = self.ln_f(x)
x = self.head(x)
return x

View File

@@ -0,0 +1,265 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class PMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: Optional[int] = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 12,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
n_time_tokens: int = 4,
) -> None:
super().__init__()
if n_obs_steps is None:
n_obs_steps = horizon
if n_time_tokens < 1:
raise ValueError("n_time_tokens must be >= 1")
obs_as_cond = cond_dim > 0
T = horizon
n_global_cond_tokens = 2 * n_time_tokens
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.t_emb = SinusoidalPosEmb(n_emb)
self.r_emb = SinusoidalPosEmb(n_emb)
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
q_idx, c_idx = torch.meshgrid(
torch.arange(T),
torch.arange(T_cond),
indexing="ij",
)
obs_offset = n_global_cond_tokens
visible = c_idx < obs_offset
visible = visible | (q_idx >= (c_idx - obs_offset))
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
self.register_buffer("memory_mask", memory_mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head_u = nn.Linear(n_emb, output_dim)
self.head_v = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.n_global_cond_tokens = n_global_cond_tokens
self.n_time_tokens = n_time_tokens
self.apply(self._init_weights)
logger.info(
"number of parameters: %e", sum(p.numel() for p in self.parameters())
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ("in_proj_bias", "bias_k", "bias_v"):
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, PMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError("Unaccounted module {}".format(module))
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.startswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.update(
{
"pos_emb",
"cond_pos_emb",
"t_tokens",
"r_tokens",
"_dummy_variable",
}
)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert len(param_dict.keys() - union_params) == 0, (
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
)
return [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=torch.float32, device=device)
elif value.ndim == 0:
value = value[None].to(device=device, dtype=torch.float32)
else:
value = value.to(device=device, dtype=torch.float32)
return value.expand(batch_size)
def forward(
self,
sample: torch.Tensor,
t: Union[torch.Tensor, float, int],
r: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
batch_size = sample.shape[0]
device = sample.device
t = self._broadcast_time(t, batch_size, device)
r = self._broadcast_time(r, batch_size, device)
input_emb = self.input_emb(sample)
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
cond_embeddings = [t_cond, r_cond]
if self.obs_as_cond:
cond_embeddings.append(self.cond_obs_emb(cond))
cond_embeddings = torch.cat(cond_embeddings, dim=1)
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
memory = self.drop(cond_embeddings + cond_pos)
memory = self.encoder(memory)
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
x = self.drop(input_emb + token_pos)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
x = self.ln_f(x)
return self.head_u(x), self.head_v(x)

View File

@@ -1,273 +0,0 @@
from contextlib import nullcontext
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from einops import reduce
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
DiffusionTransformerHybridImagePolicy,
)
try:
from torch.func import jvp as TORCH_FUNC_JVP
except ImportError: # pragma: no cover - depends on torch version
TORCH_FUNC_JVP = None
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
if num_inference_steps is None:
num_inference_steps = 1
elif num_inference_steps != 1:
raise ValueError(
'IMFTransformerHybridImagePolicy only supports one-step inference; '
f'num_inference_steps must be 1, got {num_inference_steps}.'
)
super().__init__(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
crop_shape=crop_shape,
obs_encoder_group_norm=obs_encoder_group_norm,
eval_fixed_crop=eval_fixed_crop,
n_layer=n_layer,
n_cond_layers=n_cond_layers,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
pred_action_steps_only=pred_action_steps_only,
**kwargs,
)
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
self.model = IMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=model_horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
)
self.num_inference_steps = 1
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
return self.model(z, r, t, cond=cond)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
while value.ndim < reference.ndim:
value = value.unsqueeze(-1)
return value
@staticmethod
def _apply_conditioning(
trajectory: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if condition_data is None or condition_mask is None:
return trajectory
conditioned = trajectory.clone()
conditioned[condition_mask] = condition_data[condition_mask]
return conditioned
@staticmethod
def _jvp_math_sdp_context(z_t: torch.Tensor):
if z_t.is_cuda:
return torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
enable_cudnn=False,
)
return nullcontext()
@staticmethod
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
def _compute_u_and_du_dt(
self,
z_t: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond,
v: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
):
tangents = self._jvp_tangents(v, r, t)
def g(z, r_value, t_value):
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
return self.fn(conditioned_z, r_value, t_value, cond=cond)
with self._jvp_math_sdp_context(z_t):
if TORCH_FUNC_JVP is not None:
try:
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
except (RuntimeError, TypeError, NotImplementedError):
pass
u = g(z_t, r, t)
_, du_dt = torch.autograd.functional.jvp(
g,
(z_t, r, t),
tangents,
create_graph=False,
strict=False,
)
return u, du_dt
def _compound_velocity(
self,
u: torch.Tensor,
du_dt: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
delta = self._broadcast_batch_time(t - r, u)
return u + delta * du_dt.detach()
def _sample_one_step(
self,
z_t: torch.Tensor,
r: torch.Tensor = None,
t: torch.Tensor = None,
cond=None,
) -> torch.Tensor:
batch_size = z_t.shape[0]
if t is None:
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
if r is None:
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
u = self.fn(z_t, r, t, cond=cond)
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
trajectory = self._sample_one_step(trajectory, cond=cond)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def compute_loss(self, batch):
assert 'valid_mask' not in batch
nobs = self.normalizer.normalize(batch['obs'])
nactions = self.normalizer['action'].normalize(batch['action'])
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
To = self.n_obs_steps
cond = None
trajectory = nactions
if self.obs_as_cond:
this_nobs = dict_apply(
nobs,
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
)
nobs_features = self.obs_encoder(this_nobs)
cond = nobs_features.reshape(batch_size, To, -1)
if self.pred_action_steps_only:
start = To - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
loss_mask[..., : self.action_dim] = True
loss_mask = loss_mask & ~condition_mask
x = trajectory
e = torch.randn_like(x)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t, r = torch.maximum(t, r), torch.minimum(t, r)
t_broadcast = self._broadcast_batch_time(t, x)
z_t = (1 - t_broadcast) * x + t_broadcast * e
z_t = self._apply_conditioning(z_t, x, condition_mask)
v = self.fn(z_t, t, t, cond=cond)
u, du_dt = self._compute_u_and_du_dt(
z_t,
r,
t,
cond=cond,
v=v,
condition_data=x,
condition_mask=condition_mask,
)
V = self._compound_velocity(u, du_dt, r, t)
target = e - x
loss = F.mse_loss(V, target, reduction='none')
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss.mean()
return loss

View File

@@ -0,0 +1,453 @@
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import diffusion_policy.model.vision.crop_randomizer as dmvc
import robomimic.models.base_nets as rmbn
import robomimic.utils.obs_utils as ObsUtils
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
PMFTransformerForDiffusion,
)
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=4,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
n_time_tokens=4,
min_time=0.05,
du_dt_epsilon=1.0e-3,
pmf_u_loss_weight=1.0,
pmf_v_loss_weight=1.0,
noise_scale=1.0,
adatloss_eps=0.01,
p_mean=-0.4,
p_std=1.0,
tr_uniform=True,
tr_uniform_prob=0.1,
data_proportion=0.5,
**kwargs,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
obs_shape_meta = shape_meta["obs"]
obs_config = {
"low_dim": [],
"rgb": [],
"depth": [],
"scan": [],
}
obs_key_shapes = dict()
for key, attr in obs_shape_meta.items():
shape = attr["shape"]
obs_key_shapes[key] = list(shape)
obs_type = attr.get("type", "low_dim")
if obs_type == "rgb":
obs_config["rgb"].append(key)
elif obs_type == "low_dim":
obs_config["low_dim"].append(key)
else:
raise RuntimeError(f"Unsupported obs type: {obs_type}")
config = get_robomimic_config(
algo_name="bc_rnn",
hdf5_type="image",
task_name="square",
dataset_type="ph",
)
with config.unlocked():
config.observation.modalities.obs = obs_config
if crop_shape is None:
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality["obs_randomizer_class"] = None
else:
crop_h, crop_w = crop_shape
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality.obs_randomizer_kwargs.crop_height = crop_h
modality.obs_randomizer_kwargs.crop_width = crop_w
ObsUtils.initialize_obs_utils_with_config(config)
policy: PolicyAlgo = algo_factory(
algo_name=config.algo_name,
config=config,
obs_key_shapes=obs_key_shapes,
ac_dim=action_dim,
device="cpu",
)
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
if obs_encoder_group_norm:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features,
),
)
if eval_fixed_crop:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
func=lambda x: dmvc.CropRandomizer(
input_shape=x.input_shape,
crop_height=x.crop_height,
crop_width=x.crop_width,
num_crops=x.num_crops,
pos_enc=x.pos_enc,
),
)
obs_feature_dim = obs_encoder.output_shape()[0]
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
cond_dim = obs_feature_dim if obs_as_cond else 0
self.obs_encoder = obs_encoder
self.model = PMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=horizon if not pred_action_steps_only else n_action_steps,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
n_time_tokens=n_time_tokens,
)
self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0 if obs_as_cond else obs_feature_dim,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False,
)
self.normalizer = LinearNormalizer()
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.min_time = min_time
self.du_dt_epsilon = du_dt_epsilon
self.pmf_u_loss_weight = pmf_u_loss_weight
self.pmf_v_loss_weight = pmf_v_loss_weight
self.noise_scale = noise_scale
self.adatloss_eps = adatloss_eps
self.p_mean = p_mean
self.p_std = p_std
self.tr_uniform = tr_uniform
self.tr_uniform_prob = tr_uniform_prob
self.data_proportion = data_proportion
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(flat_nobs)
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
denom = loss.detach() + self.adatloss_eps
return loss / denom
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
normal = torch.randn(batch_size, device=device, dtype=dtype)
return torch.sigmoid(normal * self.p_std + self.p_mean)
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
t = self._sample_logit_normal(batch_size, device, dtype)
r = self._sample_logit_normal(batch_size, device, dtype)
if self.tr_uniform:
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
t = torch.where(uniform_mask, uniform_t, t)
r = torch.where(uniform_mask, uniform_r, r)
data_size = int(batch_size * self.data_proportion)
fm_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(fm_mask, t, r)
t_final = torch.maximum(t, r)
r_final = torch.minimum(t, r)
return t_final, r_final
def _trajectory_inputs(
self,
nobs: Dict[str, torch.Tensor],
nactions: torch.Tensor,
):
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
cond = None
trajectory = nactions
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
if self.pred_action_steps_only:
start = self.n_obs_steps - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
nobs_features = self._encode_obs(nobs, horizon)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
return batch_size, trajectory, cond, condition_mask
def _apply_conditioning(
self,
sample: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
) -> torch.Tensor:
if not condition_mask.any():
return sample
return torch.where(condition_mask, condition_data, sample)
def _compute_u_v(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
):
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
denom = self._time_view(t, sample)
u = (sample - x_hat_u) / denom
v = (sample - x_hat_v) / denom
return u, v
def _compute_du_dt(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
tangent_v: torch.Tensor,
) -> torch.Tensor:
tangent_sample = tangent_v.detach()
tangent_r = torch.zeros_like(r)
tangent_t = torch.ones_like(t)
def u_fn(sample_input, r_input, t_input):
conditioned_sample = self._apply_conditioning(
sample_input, condition_data, condition_mask
)
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
return u_value
primals = (sample, r, t)
tangents = (tangent_sample, tangent_r, tangent_t)
try:
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
except (AttributeError, NotImplementedError, RuntimeError):
_, du_dt = torch.autograd.functional.jvp(
u_fn,
primals,
tangents,
create_graph=False,
strict=False,
)
return du_dt
# ========= inference ============
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
del kwargs
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
) * self.noise_scale
time_steps = torch.linspace(
1.0,
0.0,
self.num_inference_steps + 1,
dtype=trajectory.dtype,
device=trajectory.device,
)
for step_idx in range(self.num_inference_steps):
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
t = time_steps[step_idx].expand(trajectory.shape[0])
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
u, _ = self._compute_u_v(trajectory, t, r, cond)
delta = self._time_view(t - r, trajectory)
trajectory = trajectory - delta * u
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert "past_action" not in obs_dict
nobs = self.normalizer.normalize(obs_dict)
value = next(iter(nobs.values()))
batch_size, to_steps = value.shape[:2]
horizon = self.horizon
action_dim = self.action_dim
device = self.device
dtype = self.dtype
cond = None
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim)
if self.pred_action_steps_only:
shape = (batch_size, self.n_action_steps, action_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
else:
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
cond_mask[:, : self.n_obs_steps, action_dim:] = True
nsample = self.conditional_sample(
cond_data,
cond_mask,
cond=cond,
**self.kwargs,
)
naction_pred = nsample[..., :action_dim]
action_pred = self.normalizer["action"].unnormalize(naction_pred)
if self.pred_action_steps_only:
action = action_pred
else:
start = to_steps - 1
end = start + self.n_action_steps
action = action_pred[:, start:end]
return {
"action": action,
"action_pred": action_pred,
}
# ========= training ============
def set_normalizer(self, normalizer: LinearNormalizer):
self.normalizer.load_state_dict(normalizer.state_dict())
def get_optimizer(
self,
transformer_weight_decay: float,
obs_encoder_weight_decay: float,
learning_rate: float,
betas: Tuple[float, float],
) -> torch.optim.Optimizer:
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
optim_groups.append(
{
"params": self.obs_encoder.parameters(),
"weight_decay": obs_encoder_weight_decay,
}
)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = self.normalizer.normalize(batch["obs"])
nactions = self.normalizer["action"].normalize(batch["action"])
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
noise = torch.randn_like(trajectory) * self.noise_scale
batch_size = trajectory.shape[0]
t, r = self._sample_tr(
batch_size, device=trajectory.device, dtype=trajectory.dtype
)
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
loss_mask = ~condition_mask
target_v = noise - trajectory
u, v = self._compute_u_v(z_t, t, r, cond)
du_dt = self._compute_du_dt(
sample=z_t,
t=t,
r=r,
cond=cond,
condition_data=trajectory,
condition_mask=condition_mask,
tangent_v=v,
)
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
loss_v = F.mse_loss(v, target_v, reduction="none")
loss_u = loss_u * loss_mask.type(loss_u.dtype)
loss_v = loss_v * loss_mask.type(loss_v.dtype)
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
loss_u = self._adatloss(loss_u)
loss_v = self._adatloss(loss_v)
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v

View File

@@ -8,8 +8,6 @@ if __name__ == "__main__":
os.chdir(ROOT_DIR)
import os
import contextlib
import importlib
import hydra
import torch
from omegaconf import OmegaConf
@@ -17,6 +15,7 @@ import pathlib
from torch.utils.data import DataLoader
import copy
import random
import wandb
import tqdm
import numpy as np
import shutil
@@ -32,111 +31,6 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
OmegaConf.register_new_resolver("eval", eval, replace=True)
class _LoggingBackend:
def log(self, payload, step=None):
raise NotImplementedError
def finish(self):
raise NotImplementedError
class _WandbLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
class _SwanLabLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
def _load_wandb():
try:
return importlib.import_module('wandb')
except ImportError as exc:
raise ImportError(
"wandb is required when cfg.logging.backend == 'wandb' or missing"
) from exc
def _load_swanlab():
try:
return importlib.import_module('swanlab')
except ImportError:
return None
def init_logging_backend(cfg: OmegaConf, output_dir):
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
if backend == 'swanlab':
swanlab = _load_swanlab()
if swanlab is None:
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
logging_cfg = cfg.logging
mode = logging_cfg.mode
if mode == 'online':
mode = 'cloud'
run = swanlab.init(
project=logging_cfg.project,
experiment_name=logging_cfg.name,
group=logging_cfg.group,
tags=logging_cfg.tags,
id=logging_cfg.id,
resume=logging_cfg.resume,
mode=mode,
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
config=OmegaConf.to_container(cfg, resolve=True),
)
return _SwanLabLoggingBackend(run)
if backend not in (None, 'wandb'):
raise ValueError(f"Unknown logging backend: {backend}")
wandb = _load_wandb()
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
logging_kwargs.pop('backend', None)
run = wandb.init(
dir=str(output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**logging_kwargs
)
wandb.config.update(
{
"output_dir": str(output_dir),
}
)
return _WandbLoggingBackend(run)
@contextlib.contextmanager
def logging_backend_session(cfg: OmegaConf, output_dir):
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
primary_error = None
try:
yield logging_backend
except BaseException as exc:
primary_error = exc
raise
finally:
try:
logging_backend.finish()
except BaseException:
if primary_error is None:
raise
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
include_keys = ['global_step', 'epoch']
@@ -215,6 +109,18 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
output_dir=self.output_dir)
assert isinstance(env_runner, BaseImageRunner)
# configure logging
wandb_run = wandb.init(
dir=str(self.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": self.output_dir,
}
)
# configure checkpoint
topk_manager = TopKCheckpointManager(
save_dir=os.path.join(self.output_dir, 'checkpoints'),
@@ -242,141 +148,140 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# training loop
log_path = os.path.join(self.output_dir, 'logs.json.txt')
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
with JsonLogger(log_path) as json_logger:
for local_epoch_idx in range(cfg.training.num_epochs):
step_log = dict()
# ========= train for this epoch ==========
train_losses = list()
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
# device transfer
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
if train_sampling_batch is None:
train_sampling_batch = batch
with JsonLogger(log_path) as json_logger:
for local_epoch_idx in range(cfg.training.num_epochs):
step_log = dict()
# ========= train for this epoch ==========
train_losses = list()
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
# device transfer
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
if train_sampling_batch is None:
train_sampling_batch = batch
# compute loss
raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward()
# compute loss
raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward()
# step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
lr_scheduler.step()
# step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
lr_scheduler.step()
# update ema
if cfg.training.use_ema:
ema.step(self.model)
# update ema
if cfg.training.use_ema:
ema.step(self.model)
# logging
raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
# logging
raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch:
# log of last step is combined with validation and rollout
logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch:
# log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# We can't copy the last checkpoint here
# since save_checkpoint uses threads.
# therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
# We can't copy the last checkpoint here
# since save_checkpoint uses threads.
# therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ==========
policy.train()
if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ==========
policy.train()
# end of epoch
# log of last step is combined with validation and rollout
logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
self.epoch += 1
# end of epoch
# log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
self.epoch += 1
@hydra.main(
version_base=None,

View File

@@ -1,168 +0,0 @@
# PushT Image DiT iMF + SwanLab Design
## Goal
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using `test_mean_score` as the primary metric.
## Context
- The implementation baseline is `main`.
- The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
- Environment management must use the repo-local `uv` workflow.
- The trusted remote machine alias `5880` refers to `droid-system-product-name` (`droid@100.73.14.65`) and can run two GPU jobs in parallel.
## Architecture Overview
The work is split into two verified phases:
1. **Logging migration phase**
- Keep the existing PushT image DiT training behavior intact.
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
- Preserve local `logs.json.txt` output.
- Ensure rollout metrics such as `test_mean_score` and per-seed rewards are still logged.
- Disable simulation video logging at both the config and runner/logging boundary.
2. **iMF migration phase**
- Keep the original diffusion-based transformer image policy available on `main`.
- Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
- Reuse the existing observation encoder and training workspace where possible.
- Replace diffusion training with the iMeanFlow training objective.
- Use one-step inference for validation/rollout in the iMF path.
The implementation planning boundary for this spec is:
- code changes through a smoke-tested, pushed iMF branch
- not the full 3x3 sweep execution/monitoring workflow, which should be planned separately after the code path is verified and pushed
## Logging Design
### Scope
Only the PushT image DiT experiment chain is changed:
- `train_diffusion_transformer_hybrid_workspace.py`
- `pusht_image_runner.py`
- the new/updated PushT image transformer configs
### Behavior
- SwanLab runs in `online` mode.
- Logged values are scalar metrics only, e.g.:
- `train_loss`
- `val_loss`
- `train_action_mse_error`
- `test_mean_score`
- aggregate rollout metrics and optional per-seed scalar rewards
- No simulation videos are uploaded or wrapped as logging objects.
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
### Operational safeguards
- Default PushT experiment configs set `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`.
- The PushT image runner will not emit video objects into `log_data`, preventing accidental uploads even if visualization counts are later changed.
- SwanLab credentials are provided through the environment at runtime, not committed into the repo.
## iMF Model Design
### Baseline reuse
The iMF path reuses:
- the existing image observation encoder
- the existing action/observation normalization path
- the existing training workspace skeleton
- the existing PushT image dataset and env runner
### New files
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
- `image_pusht_diffusion_policy_dit_imf.yaml`
### Existing files changed for the iMF path
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
- logging migration to SwanLab for this experiment chain
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
- `diffusion_policy/env_runner/pusht_image_runner.py`
- suppress video objects in returned logs
### Model structure
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
- `u`: average velocity field
The same function is reused at two evaluation points:
- canonical signature: `fn(z, r, t, cond)`
- `fn(z_t, r, t, cond)` predicts average velocity `u`
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
Inputs remain conditioned on encoded observations and action trajectory tokens.
## iMF Training Objective
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
1. sample `t, r`
2. sample Gaussian noise `e`
3. form `z_t = (1 - t) * x + t * e`
4. predict instantaneous velocity surrogate with the same network:
- `v = fn(z_t, t, t, cond)`
5. define the JVP function exactly as:
- `g(z, r, t) = fn(z, r, t, cond)`
6. compute the primal output and JVP with tangent:
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
7. form compound velocity:
- `V = u + (t - r) * stopgrad(du_dt)`
8. train against the average-velocity target:
- `target = e - x`
9. optimize only the masked iMF loss:
- `loss = metric(V - target)`
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
## iMF Inference Design
Inference uses a single step starting from noise:
- initialize `z_1 ~ N(0, I)`
- set `t = 1.0`, `r = 0.0`
- predict `u = fn(z_1, r, t, cond)`
- produce the action sample with one update:
- `x_hat = z_1 - (t - r) * u`
This matches the time direction in the reference iMeanFlow sampling logic.
## Testing Strategy
### Phase 1: logging migration smoke test
- use the repo-local `uv` environment
- run a debug/smoke PushT image DiT training job on a single GPU with:
- `training.debug=true`
- `dataloader.num_workers=0`
- `val_dataloader.num_workers=0`
- `task.env_runner.n_envs=1`
- `task.env_runner.n_test_vis=0`
- `task.env_runner.n_train_vis=0`
- verify:
- SwanLab initializes successfully
- `logs.json.txt` is populated
- rollout metrics still include `test_mean_score`
- no video logging is attempted
### Phase 2: iMF smoke test
- run an equivalent debug PushT image iMF job
- verify:
- forward/backward passes succeed
- JVP path executes on the local Torch version
- one-step inference returns correctly shaped actions
- rollout produces scalar metrics including `test_mean_score`
## Branch and Commit Strategy
1. start from a `main`-based worktree branch
2. commit the SwanLab/no-video migration after smoke verification
3. continue with the iMF implementation
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
## Post-Implementation Experiment Plan
After the iMF path is smoke-tested and pushed, a separate experiment-execution plan should launch:
- run a 3x3 grid over:
- `n_emb ∈ {128, 256, 384}`
- `n_layer ∈ {6, 12, 18}`
- keep the rest of the setup fixed
- use a fixed single-seed setting for comparability unless a later explicit experiment plan expands that scope
- run each experiment for 300 epochs
- primary comparison metric: `test_mean_score`
## Post-Implementation Resource Allocation
The separate experiment-execution plan should schedule three concurrent runs until the matrix is complete:
- local machine: 1 GPU
- `5880`: 2 GPUs
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
## Risks and Mitigations
- **Torch JVP compatibility risk**: provide a fallback JVP implementation and smoke-test immediately.
- **Logging regression risk**: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
- **Video/logging side effects**: disable visualizations in config and filter video objects out of runner logs.
- **Cross-host drift**: push the verified branch to Gitea before launching the experiment matrix on multiple machines.

View File

@@ -1,30 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit
policy:
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -1,32 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_imf
policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -0,0 +1,189 @@
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
checkpoint:
save_last_ckpt: true
save_last_snapshot: false
topk:
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
k: 5
mode: min
monitor_key: train_loss
dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: true
dataset_obs_steps: 2
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
inv_gamma: 1.0
max_value: 0.9999
min_value: 0.0
power: 0.75
update_after_step: 0
exp_name: default
horizon: 16
keypoint_visible_rate: 1.0
logging:
group: null
id: null
mode: online
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
project: diffusion_policy_debug
resume: true
tags:
- train_diffusion_transformer_hybrid_pmf
- pusht_image
- default
multi_run:
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
n_action_steps: 8
n_latency_steps: 0
n_obs_steps: 2
name: train_diffusion_transformer_hybrid_pmf
obs_as_cond: true
optimizer:
betas:
- 0.9
- 0.95
learning_rate: 0.0001
obs_encoder_weight_decay: 1.0e-06
transformer_weight_decay: 0.001
past_action_visible: false
policy:
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
crop_shape:
- 84
- 84
eval_fixed_crop: true
horizon: 16
n_action_steps: 8
n_cond_layers: 0
n_emb: 256
n_head: 4
n_layer: 12
n_obs_steps: 2
n_time_tokens: 4
noise_scale: 1.0
adatloss_eps: 0.01
p_mean: -0.4
p_std: 1.0
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
clip_sample: true
num_train_timesteps: 100
prediction_type: sample
variance_type: fixed_small
num_inference_steps: 1
obs_as_cond: true
obs_encoder_group_norm: true
p_drop_attn: 0.0
p_drop_emb: 0.0
pmf_u_loss_weight: 1.0
pmf_v_loss_weight: 1.0
tr_uniform: true
tr_uniform_prob: 0.1
data_proportion: 0.5
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task:
dataset:
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
horizon: 16
max_train_episodes: 90
pad_after: 7
pad_before: 1
seed: 42
val_ratio: 0.02
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
env_runner:
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
fps: 10
legacy_test: true
max_steps: 300
n_action_steps: 8
n_envs: null
n_obs_steps: 2
n_test: 50
n_test_vis: 4
n_train: 6
n_train_vis: 2
past_action: false
test_start_seed: 100000
train_start_seed: 0
image_shape:
- 3
- 96
- 96
name: pusht_image
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
debug: false
device: cuda:0
gradient_accumulate_every: 1
lr_scheduler: cosine
lr_warmup_steps: 500
max_train_steps: null
max_val_steps: null
num_epochs: 600
resume: true
rollout_every: 50
sample_every: 5
seed: 42
tqdm_interval_sec: 1.0
use_ema: true
val_every: 1
val_dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: false

View File

@@ -22,6 +22,7 @@ pymunk==6.2.1
wandb==0.13.3
threadpoolctl==3.1.0
shapely==1.8.5.post1
matplotlib==3.6.1
imageio==2.22.0
imageio-ffmpeg==0.4.7
termcolor==2.0.1
@@ -36,4 +37,3 @@ av==14.0.1
pygame==2.5.2
robomimic==0.2.0
opencv-python-headless==4.10.0.84
swanlab

View File

@@ -1,46 +0,0 @@
import inspect
import pathlib
import sys
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
IMFTransformerForDiffusion,
)
def test_imf_transformer_forward_signature_and_shape_single_head():
signature = inspect.signature(IMFTransformerForDiffusion.forward)
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
assert signature.parameters['cond'].default is None
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
)
model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape

View File

@@ -1,313 +0,0 @@
import pathlib
import sys
import pytest
import torch
import torch.nn as nn
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
IMFTransformerHybridImagePolicy,
)
class ConstantModel(nn.Module):
def __init__(self, value):
super().__init__()
self.value = value
def forward(self, sample, r, t, cond=None):
return torch.full_like(sample, self.value)
class AffineModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
return sample * self.weight + (r + t).view(-1, 1, 1)
class SumMixModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
return mixed * self.weight + t.view(-1, 1, 1)
class TrackingContext:
def __init__(self):
self.active = False
self.enter_count = 0
def __enter__(self):
self.active = True
self.enter_count += 1
return self
def __exit__(self, exc_type, exc, tb):
self.active = False
return False
def make_policy(model):
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
ModuleAttrMixin.__init__(policy)
policy.model = model
return policy
def fake_parent_init(
self,
shape_meta,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
ModuleAttrMixin.__init__(self)
self.action_dim = shape_meta['action']['shape'][0]
self.obs_feature_dim = 4
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.horizon = horizon
@pytest.fixture
def shape_meta():
return {
'action': {'shape': [2]},
'obs': {},
}
def test_sample_one_step_uses_imf_update_formula():
policy = make_policy(ConstantModel(0.25))
z_1 = torch.tensor([
[[1.0, -1.0], [0.5, 0.0]],
[[2.0, 3.0], [-2.0, 4.0]],
])
r = torch.zeros(z_1.shape[0])
t = torch.ones(z_1.shape[0])
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
assert torch.allclose(x_hat, expected)
def test_compound_velocity_uses_detached_du_dt_term():
policy = make_policy(ConstantModel(0.0))
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
r = torch.tensor([0.2])
t = torch.tensor([0.8])
compound = policy._compound_velocity(u, du_dt, r, t)
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
assert torch.allclose(compound, expected)
compound.sum().backward()
assert u.grad is not None
assert du_dt.grad is None
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
tracker = TrackingContext()
def fake_jvp(fn, primals, tangents):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
tracker = TrackingContext()
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
captured = {}
def fake_jvp(fn, primals, tangents):
captured['tangents'] = tangents
captured['primal_output'] = fn(*primals)
return captured['primal_output'], torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
tangent_v, tangent_r, tangent_t = captured['tangents']
assert torch.equal(tangent_v, v.detach())
assert tangent_v.requires_grad is False
assert torch.equal(tangent_r, torch.zeros_like(r))
assert torch.equal(tangent_t, torch.ones_like(t))
conditioned = z_t.clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_primal = policy.model(conditioned, r, t, cond=None)
assert torch.allclose(captured['primal_output'], expected_primal)
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
u, du_dt = policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
conditioned = z_t.detach().clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_u = policy.model(conditioned, r, t, cond=None)
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
assert u.shape == z_t.shape
assert du_dt.shape == z_t.shape
assert torch.allclose(u, expected_u)
assert torch.allclose(du_dt, expected_du_dt)
u.sum().backward()
assert policy.model.weight.grad is not None
assert torch.count_nonzero(policy.model.weight.grad) > 0
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
policy = IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=1,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=True,
)
assert policy.model.horizon == 4
assert policy.num_inference_steps == 1
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
with pytest.raises(ValueError, match='num_inference_steps'):
IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=2,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
)

View File

@@ -1,110 +0,0 @@
import pathlib
import sys
import gym
from gym import spaces
import numpy as np
import pytest
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.env_runner.pusht_image_runner as runner_module
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
class FakePushTImageEnv(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, legacy=False, render_size=96):
del legacy, render_size
self.observation_space = spaces.Dict({
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
})
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
self.seed_value = 0
self.step_count = 0
def seed(self, seed=None):
self.seed_value = 0 if seed is None else seed
def reset(self):
self.step_count = 0
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
def step(self, action):
del action
self.step_count += 1
reward = 0.1 if self.seed_value < 10000 else 0.9
done = self.step_count >= 1
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
return obs, reward, done, {}
def render(self, *args, **kwargs):
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
class FakePolicy:
device = torch.device('cpu')
dtype = torch.float32
def reset(self):
return None
def predict_action(self, obs_dict):
n_envs = next(iter(obs_dict.values())).shape[0]
return {
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
}
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
log_data = summarize_rollout_metrics(
env_seeds=[11, 12, 101],
env_prefixs=['train/', 'train/', 'test/'],
all_rewards=[
[0.2, 0.8],
[0.1, 0.4],
[0.5, 0.9],
],
all_video_paths=[
'/tmp/train-11.mp4',
'/tmp/train-12.mp4',
'/tmp/test-101.mp4',
],
)
assert log_data['train/sim_max_reward_11'] == 0.8
assert log_data['train/sim_max_reward_12'] == 0.4
assert log_data['test/sim_max_reward_101'] == 0.9
assert log_data['train_mean_score'] == pytest.approx(0.6)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any(key.startswith('train/sim_video_') for key in log_data)
assert not any(key.startswith('test/sim_video_') for key in log_data)
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
runner = runner_module.PushTImageRunner(
output_dir=tmp_path,
n_train=1,
n_train_vis=1,
n_test=1,
n_test_vis=1,
n_envs=2,
max_steps=2,
n_obs_steps=2,
n_action_steps=2,
tqdm_interval_sec=0.0,
)
log_data = runner.run(FakePolicy())
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
assert log_data['train_mean_score'] == pytest.approx(0.1)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any('sim_video' in key for key in log_data)

View File

@@ -1,32 +0,0 @@
import pathlib
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
def _load_cfg(name: str):
return OmegaConf.load(ROOT_DIR / name)
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name

View File

@@ -1,198 +0,0 @@
import importlib
import pathlib
import sys
import pytest
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
def load_workspace_module(monkeypatch, *, wandb_missing=False):
sys.modules.pop(MODULE_NAME, None)
if wandb_missing:
monkeypatch.setitem(sys.modules, 'wandb', None)
return importlib.import_module(MODULE_NAME)
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeSwanLab:
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
monkeypatch.setattr(
workspace_module,
'_load_wandb',
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
)
cfg = OmegaConf.create({
'logging': {
'backend': 'swanlab',
'project': 'demo-project',
'name': 'demo-run',
'group': 'demo-group',
'tags': ['pusht', 'dit'],
'id': 'run-123',
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 1.0}, step=7)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['experiment_name'] == 'demo-run'
assert init_kwargs['group'] == 'demo-group'
assert init_kwargs['tags'] == ['pusht', 'dit']
assert init_kwargs['id'] == 'run-123'
assert init_kwargs['resume'] is True
assert init_kwargs['mode'] == 'cloud'
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
assert ('log', {'metric': 1.0}, 7) in events
assert events.count(('finish',)) == 1
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeConfig:
def update(self, payload):
events.append(('config.update', payload))
class FakeWandb:
def __init__(self):
self.config = FakeConfig()
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
cfg = OmegaConf.create({
'logging': {
'project': 'demo-project',
'name': 'demo-run',
'group': None,
'tags': ['shared'],
'id': None,
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 2.0}, step=3)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['dir'] == str(tmp_path)
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['name'] == 'demo-run'
assert init_kwargs['mode'] == 'online'
assert ('config.update', {'output_dir': str(tmp_path)}) in events
assert ('log', {'metric': 2.0}, 3) in events
assert events.count(('finish',)) == 1
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
cfg = OmegaConf.create({
'logging': {
'backend': 'tensorboard',
'project': 'demo-project',
'name': 'demo-run',
'mode': 'offline',
}
})
with pytest.raises(ValueError, match='Unknown logging backend'):
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
raise RuntimeError('finish boom')
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(ValueError, match='primary boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 6.0}, step=12)
raise ValueError('primary boom')
assert ('log', {'metric': 6.0}, 12) in events
assert events.count(('finish',)) == 1
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(RuntimeError, match='boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 5.0}, step=11)
raise RuntimeError('boom')
assert ('log', {'metric': 5.0}, 11) in events
assert events.count(('finish',)) == 1