4 Commits

Author SHA1 Message Date
Logic
9169e4d7e0 feat(pusht): add dual-head uv transformer 2026-03-17 17:05:02 +08:00
gameloader
42dc29a2cb feat: align pmf transformer training and config defaults 2026-03-16 15:37:32 +08:00
Logic
79f31940c4 feat(pusht): add pMF-style DiT image policy 2026-03-16 11:11:43 +08:00
Logic
2aa06c8917 fix(pusht): stabilize DiT pusht training on current stack 2026-03-15 18:54:50 +08:00
20 changed files with 1174 additions and 1988 deletions

View File

@@ -1,37 +1,21 @@
import wandb
import numpy as np import numpy as np
import torch import torch
import collections import collections
import pathlib
import tqdm import tqdm
import dill import dill
import math import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
del all_video_paths
max_rewards = collections.defaultdict(list)
log_data = dict()
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
max_reward = np.max(rewards)
max_rewards[prefix].append(max_reward)
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
aggregate_key_map = {
'train/': 'train_mean_score',
'test/': 'test_mean_score',
}
for prefix, value in max_rewards.items():
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
return log_data
class PushTImageRunner(BaseImageRunner): class PushTImageRunner(BaseImageRunner):
def __init__(self, def __init__(self,
output_dir, output_dir,
@@ -56,11 +40,24 @@ class PushTImageRunner(BaseImageRunner):
if n_envs is None: if n_envs is None:
n_envs = n_train + n_test n_envs = n_train + n_test
steps_per_render = max(10 // fps, 1)
def env_fn(): def env_fn():
return MultiStepWrapper( return MultiStepWrapper(
PushTImageEnv( VideoRecordingWrapper(
legacy=legacy_test, PushTImageEnv(
render_size=render_size legacy=legacy_test,
render_size=render_size
),
video_recoder=VideoRecorder.create_h264(
fps=fps,
codec='h264',
input_pix_fmt='rgb24',
crf=crf,
thread_type='FRAME',
thread_count=1
),
file_path=None,
steps_per_render=steps_per_render
), ),
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
@@ -74,8 +71,21 @@ class PushTImageRunner(BaseImageRunner):
# train # train
for i in range(n_train): for i in range(n_train):
seed = train_start_seed + i seed = train_start_seed + i
enable_render = i < n_train_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -87,8 +97,21 @@ class PushTImageRunner(BaseImageRunner):
# test # test
for i in range(n_test): for i in range(n_test):
seed = test_start_seed + i seed = test_start_seed + i
enable_render = i < n_test_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -131,6 +154,7 @@ class PushTImageRunner(BaseImageRunner):
n_chunks = math.ceil(n_inits / n_envs) n_chunks = math.ceil(n_inits / n_envs)
# allocate data # allocate data
all_video_paths = [None] * n_inits
all_rewards = [None] * n_inits all_rewards = [None] * n_inits
for chunk_idx in range(n_chunks): for chunk_idx in range(n_chunks):
@@ -190,16 +214,39 @@ class PushTImageRunner(BaseImageRunner):
pbar.update(action.shape[1]) pbar.update(action.shape[1])
pbar.close() pbar.close()
all_video_paths[this_global_slice] = env.render()[this_local_slice]
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice] all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
# reset env state between evaluation calls # clear out video buffer
_ = env.reset() _ = env.reset()
# results reported in the paper are generated using the commented out # log
# line below, which would only report and average metrics from the max_rewards = collections.defaultdict(list)
# first n_envs initial conditions and seeds. We keep the full n_inits log_data = dict()
# behavior here. # results reported in the paper are generated using the commented out line below
return summarize_rollout_metrics( # which will only report and average metrics from first n_envs initial condition and seeds
env_seeds=self.env_seeds[:n_inits], # fortunately this won't invalidate our conclusion since
env_prefixs=self.env_prefixs[:n_inits], # 1. This bug only affects the variance of metrics, not their mean
all_rewards=all_rewards[:n_inits], # 2. All baseline methods are evaluated using the same code
) # to completely reproduce reported numbers, uncomment this line:
# for i in range(len(self.env_fns)):
# and comment out this line
for i in range(n_inits):
seed = self.env_seeds[i]
prefix = self.env_prefixs[i]
max_reward = np.max(all_rewards[i])
max_rewards[prefix].append(max_reward)
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
# visualize sim
video_path = all_video_paths[i]
if video_path is not None:
sim_video = wandb.Video(video_path)
log_data[prefix+f'sim_video_{seed}'] = sim_video
# log aggregate metrics
for prefix, value in max_rewards.items():
name = prefix+'mean_score'
value = np.mean(value)
log_data[name] = value
return log_data

View File

@@ -8,8 +8,7 @@ import dill
import math import math
import wandb.sdk.data_types.video as wv import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
env_prefixs.append('test/') env_prefixs.append('test/')
env_init_fn_dills.append(dill.dumps(init_fn)) env_init_fn_dills.append(dill.dumps(init_fn))
env = AsyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
# test env # test env
# env.reset(seed=env_seeds) # env.reset(seed=env_seeds)

View File

@@ -1,298 +0,0 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class IMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: int = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 1,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
time_as_cond: bool = True,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
) -> None:
super().__init__()
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
if n_obs_steps is None:
n_obs_steps = horizon
T = horizon
T_cond = 2
if not time_as_cond:
T += 2
T_cond -= 2
obs_as_cond = cond_dim > 0
if obs_as_cond:
assert time_as_cond
T_cond += n_obs_steps
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.time_emb = SinusoidalPosEmb(n_emb)
self.cond_obs_emb = None
if obs_as_cond:
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
self.cond_pos_emb = None
self.encoder = None
self.decoder = None
encoder_only = False
if T_cond > 0:
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_layer,
)
else:
encoder_only = True
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_layer,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('mask', mask)
if time_as_cond and obs_as_cond:
S = T_cond
t_idx, s_idx = torch.meshgrid(
torch.arange(T),
torch.arange(S),
indexing='ij',
)
mask = t_idx >= (s_idx - 2)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
self.register_buffer('memory_mask', mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_f = nn.LayerNorm(n_emb)
self.head = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.time_as_cond = time_as_cond
self.obs_as_cond = obs_as_cond
self.encoder_only = encoder_only
self.apply(self._init_weights)
logger.info(
'number of parameters: %e',
sum(p.numel() for p in self.parameters()),
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
for name in weight_names:
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
for name in bias_names:
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, IMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError(f'Unaccounted module {module}')
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn
if pn.endswith('bias'):
no_decay.add(fpn)
elif pn.startswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.add('pos_emb')
no_decay.add('_dummy_variable')
if self.cond_pos_emb is not None:
no_decay.add('cond_pos_emb')
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
assert len(param_dict.keys() - union_params) == 0, (
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
)
optim_groups = [
{
'params': [param_dict[pn] for pn in sorted(list(decay))],
'weight_decay': weight_decay,
},
{
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
'weight_decay': 0.0,
},
]
return optim_groups
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
elif value.ndim == 0:
value = value[None].to(device=sample.device, dtype=sample.dtype)
else:
value = value.to(device=sample.device, dtype=sample.dtype)
return value.expand(sample.shape[0])
def forward(
self,
sample: torch.Tensor,
r: Union[torch.Tensor, float, int],
t: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
r = self._prepare_time_input(r, sample)
t = self._prepare_time_input(t, sample)
r_emb = self.time_emb(r).unsqueeze(1)
t_emb = self.time_emb(t).unsqueeze(1)
input_emb = self.input_emb(sample)
if self.encoder_only:
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.encoder(src=x, mask=self.mask)
x = x[:, 2:, :]
else:
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
if self.obs_as_cond:
cond_obs_emb = self.cond_obs_emb(cond)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
token_count = cond_embeddings.shape[1]
position_embeddings = self.cond_pos_emb[:, :token_count, :]
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
token_embeddings = input_emb
token_count = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :token_count, :]
x = self.drop(token_embeddings + position_embeddings)
x = self.decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
x = self.ln_f(x)
x = self.head(x)
return x

View File

@@ -0,0 +1,302 @@
from typing import Optional, Tuple, Union
import logging
import torch
import torch.nn as nn
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class PMFTransformerForDiffusion(ModuleAttrMixin):
def __init__(
self,
input_dim: int,
output_dim: int,
horizon: int,
n_obs_steps: Optional[int] = None,
cond_dim: int = 0,
n_layer: int = 12,
n_head: int = 12,
n_emb: int = 768,
p_drop_emb: float = 0.1,
p_drop_attn: float = 0.1,
causal_attn: bool = False,
obs_as_cond: bool = False,
n_cond_layers: int = 0,
n_time_tokens: int = 4,
n_head_layers: int = 4,
) -> None:
super().__init__()
if n_obs_steps is None:
n_obs_steps = horizon
if n_time_tokens < 1:
raise ValueError("n_time_tokens must be >= 1")
if n_head_layers < 0:
raise ValueError("n_head_layers must be >= 0")
if n_head_layers >= n_layer:
raise ValueError(
"n_head_layers must be smaller than n_layer so shared trunk depth stays positive"
)
obs_as_cond = cond_dim > 0
T = horizon
n_global_cond_tokens = 2 * n_time_tokens
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
n_shared_layers = n_layer - n_head_layers
self.input_emb = nn.Linear(input_dim, n_emb)
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
self.drop = nn.Dropout(p_drop_emb)
self.t_emb = SinusoidalPosEmb(n_emb)
self.r_emb = SinusoidalPosEmb(n_emb)
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
if n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=n_cond_layers,
)
else:
self.encoder = nn.Sequential(
nn.Linear(n_emb, 4 * n_emb),
nn.Mish(),
nn.Linear(4 * n_emb, n_emb),
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=n_emb,
nhead=n_head,
dim_feedforward=4 * n_emb,
dropout=p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.shared_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_shared_layers,
)
self.u_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
)
self.v_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=n_head_layers,
)
if causal_attn:
sz = T
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)
if obs_as_cond:
q_idx, c_idx = torch.meshgrid(
torch.arange(T),
torch.arange(T_cond),
indexing="ij",
)
obs_offset = n_global_cond_tokens
visible = c_idx < obs_offset
visible = visible | (q_idx >= (c_idx - obs_offset))
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
self.register_buffer("memory_mask", memory_mask)
else:
self.memory_mask = None
else:
self.mask = None
self.memory_mask = None
self.ln_u = nn.LayerNorm(n_emb)
self.ln_v = nn.LayerNorm(n_emb)
self.head_u = nn.Linear(n_emb, output_dim)
self.head_v = nn.Linear(n_emb, output_dim)
self.T = T
self.T_cond = T_cond
self.horizon = horizon
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.n_global_cond_tokens = n_global_cond_tokens
self.n_time_tokens = n_time_tokens
self.n_layer = n_layer
self.n_head_layers = n_head_layers
self.n_shared_layers = n_shared_layers
self.apply(self._init_weights)
logger.info(
"number of parameters: %e", sum(p.numel() for p in self.parameters())
)
logger.info(
"PMFTransformerForDiffusion layers: shared=%d u_head=%d v_head=%d",
self.n_shared_layers,
self.n_head_layers,
self.n_head_layers,
)
def _init_weights(self, module):
ignore_types = (
nn.Dropout,
SinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
for name in ("in_proj_bias", "bias_k", "bias_v"):
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, PMFTransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
pass
else:
raise RuntimeError("Unaccounted module {}".format(module))
def get_optim_groups(self, weight_decay: float = 1e-3):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn
if pn.endswith("bias"):
no_decay.add(fpn)
elif pn.startswith("bias"):
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
no_decay.update(
{
"pos_emb",
"cond_pos_emb",
"t_tokens",
"r_tokens",
"_dummy_variable",
}
)
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert len(param_dict.keys() - union_params) == 0, (
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
)
return [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0,
},
]
def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
if not torch.is_tensor(value):
value = torch.tensor([value], dtype=torch.float32, device=device)
elif value.ndim == 0:
value = value[None].to(device=device, dtype=torch.float32)
else:
value = value.to(device=device, dtype=torch.float32)
return value.expand(batch_size)
def forward(
self,
sample: torch.Tensor,
t: Union[torch.Tensor, float, int],
r: Union[torch.Tensor, float, int],
cond: Optional[torch.Tensor] = None,
):
batch_size = sample.shape[0]
device = sample.device
t = self._broadcast_time(t, batch_size, device)
r = self._broadcast_time(r, batch_size, device)
input_emb = self.input_emb(sample)
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
cond_embeddings = [t_cond, r_cond]
if self.obs_as_cond:
cond_embeddings.append(self.cond_obs_emb(cond))
cond_embeddings = torch.cat(cond_embeddings, dim=1)
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
memory = self.drop(cond_embeddings + cond_pos)
memory = self.encoder(memory)
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
x = self.drop(input_emb + token_pos)
shared_x = self.shared_decoder(
tgt=x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
u_x = self.u_decoder(
tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
v_x = self.v_decoder(
tgt=shared_x,
memory=memory,
tgt_mask=self.mask,
memory_mask=self.memory_mask,
)
return self.head_u(self.ln_u(u_x)), self.head_v(self.ln_v(v_x))

View File

@@ -1,273 +0,0 @@
from contextlib import nullcontext
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from einops import reduce
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
DiffusionTransformerHybridImagePolicy,
)
try:
from torch.func import jvp as TORCH_FUNC_JVP
except ImportError: # pragma: no cover - depends on torch version
TORCH_FUNC_JVP = None
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
if num_inference_steps is None:
num_inference_steps = 1
elif num_inference_steps != 1:
raise ValueError(
'IMFTransformerHybridImagePolicy only supports one-step inference; '
f'num_inference_steps must be 1, got {num_inference_steps}.'
)
super().__init__(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
crop_shape=crop_shape,
obs_encoder_group_norm=obs_encoder_group_norm,
eval_fixed_crop=eval_fixed_crop,
n_layer=n_layer,
n_cond_layers=n_cond_layers,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
pred_action_steps_only=pred_action_steps_only,
**kwargs,
)
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
self.model = IMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=model_horizon,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
time_as_cond=time_as_cond,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
)
self.num_inference_steps = 1
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
return self.model(z, r, t, cond=cond)
@staticmethod
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
while value.ndim < reference.ndim:
value = value.unsqueeze(-1)
return value
@staticmethod
def _apply_conditioning(
trajectory: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if condition_data is None or condition_mask is None:
return trajectory
conditioned = trajectory.clone()
conditioned[condition_mask] = condition_data[condition_mask]
return conditioned
@staticmethod
def _jvp_math_sdp_context(z_t: torch.Tensor):
if z_t.is_cuda:
return torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
enable_cudnn=False,
)
return nullcontext()
@staticmethod
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
def _compute_u_and_du_dt(
self,
z_t: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
cond,
v: torch.Tensor,
condition_data: Optional[torch.Tensor] = None,
condition_mask: Optional[torch.Tensor] = None,
):
tangents = self._jvp_tangents(v, r, t)
def g(z, r_value, t_value):
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
return self.fn(conditioned_z, r_value, t_value, cond=cond)
with self._jvp_math_sdp_context(z_t):
if TORCH_FUNC_JVP is not None:
try:
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
except (RuntimeError, TypeError, NotImplementedError):
pass
u = g(z_t, r, t)
_, du_dt = torch.autograd.functional.jvp(
g,
(z_t, r, t),
tangents,
create_graph=False,
strict=False,
)
return u, du_dt
def _compound_velocity(
self,
u: torch.Tensor,
du_dt: torch.Tensor,
r: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
delta = self._broadcast_batch_time(t - r, u)
return u + delta * du_dt.detach()
def _sample_one_step(
self,
z_t: torch.Tensor,
r: torch.Tensor = None,
t: torch.Tensor = None,
cond=None,
) -> torch.Tensor:
batch_size = z_t.shape[0]
if t is None:
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
if r is None:
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
u = self.fn(z_t, r, t, cond=cond)
delta = self._broadcast_batch_time(t - r, z_t)
return z_t - delta * u
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
trajectory = self._sample_one_step(trajectory, cond=cond)
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def compute_loss(self, batch):
assert 'valid_mask' not in batch
nobs = self.normalizer.normalize(batch['obs'])
nactions = self.normalizer['action'].normalize(batch['action'])
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
To = self.n_obs_steps
cond = None
trajectory = nactions
if self.obs_as_cond:
this_nobs = dict_apply(
nobs,
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
)
nobs_features = self.obs_encoder(this_nobs)
cond = nobs_features.reshape(batch_size, To, -1)
if self.pred_action_steps_only:
start = To - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
loss_mask[..., : self.action_dim] = True
loss_mask = loss_mask & ~condition_mask
x = trajectory
e = torch.randn_like(x)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t, r = torch.maximum(t, r), torch.minimum(t, r)
t_broadcast = self._broadcast_batch_time(t, x)
z_t = (1 - t_broadcast) * x + t_broadcast * e
z_t = self._apply_conditioning(z_t, x, condition_mask)
v = self.fn(z_t, t, t, cond=cond)
u, du_dt = self._compute_u_and_du_dt(
z_t,
r,
t,
cond=cond,
v=v,
condition_data=x,
condition_mask=condition_mask,
)
V = self._compound_velocity(u, du_dt, r, t)
target = e - x
loss = F.mse_loss(V, target, reduction='none')
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss.mean()
return loss

View File

@@ -0,0 +1,455 @@
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import diffusion_policy.model.vision.crop_randomizer as dmvc
import robomimic.models.base_nets as rmbn
import robomimic.utils.obs_utils as ObsUtils
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
PMFTransformerForDiffusion,
)
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=4,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
n_time_tokens=4,
n_head_layers=4,
min_time=0.05,
du_dt_epsilon=1.0e-3,
pmf_u_loss_weight=1.0,
pmf_v_loss_weight=1.0,
noise_scale=1.0,
adatloss_eps=0.01,
p_mean=-0.4,
p_std=1.0,
tr_uniform=True,
tr_uniform_prob=0.1,
data_proportion=0.5,
**kwargs,
):
super().__init__()
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
obs_shape_meta = shape_meta["obs"]
obs_config = {
"low_dim": [],
"rgb": [],
"depth": [],
"scan": [],
}
obs_key_shapes = dict()
for key, attr in obs_shape_meta.items():
shape = attr["shape"]
obs_key_shapes[key] = list(shape)
obs_type = attr.get("type", "low_dim")
if obs_type == "rgb":
obs_config["rgb"].append(key)
elif obs_type == "low_dim":
obs_config["low_dim"].append(key)
else:
raise RuntimeError(f"Unsupported obs type: {obs_type}")
config = get_robomimic_config(
algo_name="bc_rnn",
hdf5_type="image",
task_name="square",
dataset_type="ph",
)
with config.unlocked():
config.observation.modalities.obs = obs_config
if crop_shape is None:
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality["obs_randomizer_class"] = None
else:
crop_h, crop_w = crop_shape
for _, modality in config.observation.encoder.items():
if modality.obs_randomizer_class == "CropRandomizer":
modality.obs_randomizer_kwargs.crop_height = crop_h
modality.obs_randomizer_kwargs.crop_width = crop_w
ObsUtils.initialize_obs_utils_with_config(config)
policy: PolicyAlgo = algo_factory(
algo_name=config.algo_name,
config=config,
obs_key_shapes=obs_key_shapes,
ac_dim=action_dim,
device="cpu",
)
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
if obs_encoder_group_norm:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features,
),
)
if eval_fixed_crop:
replace_submodules(
root_module=obs_encoder,
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
func=lambda x: dmvc.CropRandomizer(
input_shape=x.input_shape,
crop_height=x.crop_height,
crop_width=x.crop_width,
num_crops=x.num_crops,
pos_enc=x.pos_enc,
),
)
obs_feature_dim = obs_encoder.output_shape()[0]
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
cond_dim = obs_feature_dim if obs_as_cond else 0
self.obs_encoder = obs_encoder
self.model = PMFTransformerForDiffusion(
input_dim=input_dim,
output_dim=input_dim,
horizon=horizon if not pred_action_steps_only else n_action_steps,
n_obs_steps=n_obs_steps,
cond_dim=cond_dim,
n_layer=n_layer,
n_head=n_head,
n_emb=n_emb,
p_drop_emb=p_drop_emb,
p_drop_attn=p_drop_attn,
causal_attn=causal_attn,
obs_as_cond=obs_as_cond,
n_cond_layers=n_cond_layers,
n_time_tokens=n_time_tokens,
n_head_layers=n_head_layers,
)
self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0 if obs_as_cond else obs_feature_dim,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False,
)
self.normalizer = LinearNormalizer()
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.min_time = min_time
self.du_dt_epsilon = du_dt_epsilon
self.pmf_u_loss_weight = pmf_u_loss_weight
self.pmf_v_loss_weight = pmf_v_loss_weight
self.noise_scale = noise_scale
self.adatloss_eps = adatloss_eps
self.p_mean = p_mean
self.p_std = p_std
self.tr_uniform = tr_uniform
self.tr_uniform_prob = tr_uniform_prob
self.data_proportion = data_proportion
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(flat_nobs)
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
denom = loss.detach() + self.adatloss_eps
return loss / denom
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
normal = torch.randn(batch_size, device=device, dtype=dtype)
return torch.sigmoid(normal * self.p_std + self.p_mean)
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
t = self._sample_logit_normal(batch_size, device, dtype)
r = self._sample_logit_normal(batch_size, device, dtype)
if self.tr_uniform:
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
t = torch.where(uniform_mask, uniform_t, t)
r = torch.where(uniform_mask, uniform_r, r)
data_size = int(batch_size * self.data_proportion)
fm_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(fm_mask, t, r)
t_final = torch.maximum(t, r)
r_final = torch.minimum(t, r)
return t_final, r_final
def _trajectory_inputs(
self,
nobs: Dict[str, torch.Tensor],
nactions: torch.Tensor,
):
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
cond = None
trajectory = nactions
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
if self.pred_action_steps_only:
start = self.n_obs_steps - 1
end = start + self.n_action_steps
trajectory = nactions[:, start:end]
else:
nobs_features = self._encode_obs(nobs, horizon)
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
if self.pred_action_steps_only:
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
else:
condition_mask = self.mask_generator(trajectory.shape)
return batch_size, trajectory, cond, condition_mask
def _apply_conditioning(
self,
sample: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
) -> torch.Tensor:
if not condition_mask.any():
return sample
return torch.where(condition_mask, condition_data, sample)
def _compute_u_v(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
):
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
denom = self._time_view(t, sample)
u = (sample - x_hat_u) / denom
v = (sample - x_hat_v) / denom
return u, v
def _compute_du_dt(
self,
sample: torch.Tensor,
t: torch.Tensor,
r: torch.Tensor,
cond: torch.Tensor,
condition_data: torch.Tensor,
condition_mask: torch.Tensor,
tangent_v: torch.Tensor,
) -> torch.Tensor:
tangent_sample = tangent_v.detach()
tangent_r = torch.zeros_like(r)
tangent_t = torch.ones_like(t)
def u_fn(sample_input, r_input, t_input):
conditioned_sample = self._apply_conditioning(
sample_input, condition_data, condition_mask
)
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
return u_value
primals = (sample, r, t)
tangents = (tangent_sample, tangent_r, tangent_t)
try:
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
except (AttributeError, NotImplementedError, RuntimeError):
_, du_dt = torch.autograd.functional.jvp(
u_fn,
primals,
tangents,
create_graph=False,
strict=False,
)
return du_dt
# ========= inference ============
def conditional_sample(
self,
condition_data,
condition_mask,
cond=None,
generator=None,
**kwargs,
):
del kwargs
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
) * self.noise_scale
time_steps = torch.linspace(
1.0,
0.0,
self.num_inference_steps + 1,
dtype=trajectory.dtype,
device=trajectory.device,
)
for step_idx in range(self.num_inference_steps):
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
t = time_steps[step_idx].expand(trajectory.shape[0])
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
u, _ = self._compute_u_v(trajectory, t, r, cond)
delta = self._time_view(t - r, trajectory)
trajectory = trajectory - delta * u
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
return trajectory
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert "past_action" not in obs_dict
nobs = self.normalizer.normalize(obs_dict)
value = next(iter(nobs.values()))
batch_size, to_steps = value.shape[:2]
horizon = self.horizon
action_dim = self.action_dim
device = self.device
dtype = self.dtype
cond = None
if self.obs_as_cond:
cond = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim)
if self.pred_action_steps_only:
shape = (batch_size, self.n_action_steps, action_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
else:
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
cond_mask[:, : self.n_obs_steps, action_dim:] = True
nsample = self.conditional_sample(
cond_data,
cond_mask,
cond=cond,
**self.kwargs,
)
naction_pred = nsample[..., :action_dim]
action_pred = self.normalizer["action"].unnormalize(naction_pred)
if self.pred_action_steps_only:
action = action_pred
else:
start = to_steps - 1
end = start + self.n_action_steps
action = action_pred[:, start:end]
return {
"action": action,
"action_pred": action_pred,
}
# ========= training ============
def set_normalizer(self, normalizer: LinearNormalizer):
self.normalizer.load_state_dict(normalizer.state_dict())
def get_optimizer(
self,
transformer_weight_decay: float,
obs_encoder_weight_decay: float,
learning_rate: float,
betas: Tuple[float, float],
) -> torch.optim.Optimizer:
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
optim_groups.append(
{
"params": self.obs_encoder.parameters(),
"weight_decay": obs_encoder_weight_decay,
}
)
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = self.normalizer.normalize(batch["obs"])
nactions = self.normalizer["action"].normalize(batch["action"])
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
noise = torch.randn_like(trajectory) * self.noise_scale
batch_size = trajectory.shape[0]
t, r = self._sample_tr(
batch_size, device=trajectory.device, dtype=trajectory.dtype
)
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
loss_mask = ~condition_mask
target_v = noise - trajectory
u, v = self._compute_u_v(z_t, t, r, cond)
du_dt = self._compute_du_dt(
sample=z_t,
t=t,
r=r,
cond=cond,
condition_data=trajectory,
condition_mask=condition_mask,
tangent_v=v,
)
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
loss_v = F.mse_loss(v, target_v, reduction="none")
loss_u = loss_u * loss_mask.type(loss_u.dtype)
loss_v = loss_v * loss_mask.type(loss_v.dtype)
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
loss_u = self._adatloss(loss_u)
loss_v = self._adatloss(loss_v)
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v

View File

@@ -8,8 +8,6 @@ if __name__ == "__main__":
os.chdir(ROOT_DIR) os.chdir(ROOT_DIR)
import os import os
import contextlib
import importlib
import hydra import hydra
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -17,6 +15,7 @@ import pathlib
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import copy import copy
import random import random
import wandb
import tqdm import tqdm
import numpy as np import numpy as np
import shutil import shutil
@@ -32,111 +31,6 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
OmegaConf.register_new_resolver("eval", eval, replace=True) OmegaConf.register_new_resolver("eval", eval, replace=True)
class _LoggingBackend:
def log(self, payload, step=None):
raise NotImplementedError
def finish(self):
raise NotImplementedError
class _WandbLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
class _SwanLabLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
def _load_wandb():
try:
return importlib.import_module('wandb')
except ImportError as exc:
raise ImportError(
"wandb is required when cfg.logging.backend == 'wandb' or missing"
) from exc
def _load_swanlab():
try:
return importlib.import_module('swanlab')
except ImportError:
return None
def init_logging_backend(cfg: OmegaConf, output_dir):
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
if backend == 'swanlab':
swanlab = _load_swanlab()
if swanlab is None:
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
logging_cfg = cfg.logging
mode = logging_cfg.mode
if mode == 'online':
mode = 'cloud'
run = swanlab.init(
project=logging_cfg.project,
experiment_name=logging_cfg.name,
group=logging_cfg.group,
tags=logging_cfg.tags,
id=logging_cfg.id,
resume=logging_cfg.resume,
mode=mode,
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
config=OmegaConf.to_container(cfg, resolve=True),
)
return _SwanLabLoggingBackend(run)
if backend not in (None, 'wandb'):
raise ValueError(f"Unknown logging backend: {backend}")
wandb = _load_wandb()
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
logging_kwargs.pop('backend', None)
run = wandb.init(
dir=str(output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**logging_kwargs
)
wandb.config.update(
{
"output_dir": str(output_dir),
}
)
return _WandbLoggingBackend(run)
@contextlib.contextmanager
def logging_backend_session(cfg: OmegaConf, output_dir):
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
primary_error = None
try:
yield logging_backend
except BaseException as exc:
primary_error = exc
raise
finally:
try:
logging_backend.finish()
except BaseException:
if primary_error is None:
raise
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
include_keys = ['global_step', 'epoch'] include_keys = ['global_step', 'epoch']
@@ -215,6 +109,18 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
output_dir=self.output_dir) output_dir=self.output_dir)
assert isinstance(env_runner, BaseImageRunner) assert isinstance(env_runner, BaseImageRunner)
# configure logging
wandb_run = wandb.init(
dir=str(self.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": self.output_dir,
}
)
# configure checkpoint # configure checkpoint
topk_manager = TopKCheckpointManager( topk_manager = TopKCheckpointManager(
save_dir=os.path.join(self.output_dir, 'checkpoints'), save_dir=os.path.join(self.output_dir, 'checkpoints'),
@@ -242,141 +148,140 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# training loop # training loop
log_path = os.path.join(self.output_dir, 'logs.json.txt') log_path = os.path.join(self.output_dir, 'logs.json.txt')
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend: with JsonLogger(log_path) as json_logger:
with JsonLogger(log_path) as json_logger: for local_epoch_idx in range(cfg.training.num_epochs):
for local_epoch_idx in range(cfg.training.num_epochs): step_log = dict()
step_log = dict() # ========= train for this epoch ==========
# ========= train for this epoch ========== train_losses = list()
train_losses = list() with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: for batch_idx, batch in enumerate(tepoch):
for batch_idx, batch in enumerate(tepoch): # device transfer
# device transfer batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) if train_sampling_batch is None:
if train_sampling_batch is None: train_sampling_batch = batch
train_sampling_batch = batch
# compute loss # compute loss
raw_loss = self.model.compute_loss(batch) raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward() loss.backward()
# step optimizer # step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0: if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
lr_scheduler.step() lr_scheduler.step()
# update ema
if cfg.training.use_ema:
ema.step(self.model)
# logging
raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch:
# log of last step is combined with validation and rollout
logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# We can't copy the last checkpoint here # update ema
# since save_checkpoint uses threads. if cfg.training.use_ema:
# therefore at this point the file might have been empty! ema.step(self.model)
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
if topk_ckpt_path is not None: # logging
self.save_checkpoint(path=topk_ckpt_path) raw_loss_cpu = raw_loss.item()
# ========= eval end for this epoch ========== tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
policy.train() train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
# end of epoch is_last_batch = (batch_idx == (len(train_dataloader)-1))
# log of last step is combined with validation and rollout if not is_last_batch:
logging_backend.log(step_log, step=self.global_step) # log of last step is combined with validation and rollout
json_logger.log(step_log) wandb_run.log(step_log, step=self.global_step)
self.global_step += 1 json_logger.log(step_log)
self.epoch += 1 self.global_step += 1
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# We can't copy the last checkpoint here
# since save_checkpoint uses threads.
# therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ==========
policy.train()
# end of epoch
# log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
self.epoch += 1
@hydra.main( @hydra.main(
version_base=None, version_base=None,

View File

@@ -1,60 +0,0 @@
# PushT iMF Full-Attention Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Add a separate full-attention PushT image iMF config, commit/push it on a new branch, and launch the 9-run 350-epoch architecture sweep across 3 GPUs.
**Architecture:** Keep the existing causal iMF path untouched and add a standalone full-attention config that only flips `policy.causal_attn=false` while retaining one-step iMF inference and SwanLab-safe naming. Reuse the previous 9-run architecture matrix and balanced 3-queue scheduling across local 5090 plus 5880 GPU0/GPU1.
**Tech Stack:** Hydra, Diffusion Policy iMF image workspace, SwanLab, uv env, local shell + trusted remote 5880 over SSH.
---
### Task 1: Add full-attention iMF config with TDD
**Files:**
- Create: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
- Modify: `tests/test_pusht_swanlab_config.py`
- [ ] Write a failing config regression test asserting the new config uses SwanLab-safe naming and `policy.causal_attn == False`.
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
- [ ] Add the minimal full-attention config by composing from the existing PushT image iMF config and overriding only `exp_name` and `policy.causal_attn=false`.
- [ ] Re-run the targeted pytest and verify it passes.
### Task 2: Verify the new config
**Files:**
- Read: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
- [ ] Run `train.py --help` for the new config.
- [ ] Run a real `training.debug=true` smoke test locally to confirm the training path is valid.
### Task 3: Commit and push the new branch
**Files:**
- Commit only the new config/test/plan files needed for the full-attention experiment chain.
- [ ] Run verification commands again before commit.
- [ ] Commit with a focused message.
- [ ] Push `feat/pusht-imf-fullattn` to origin.
### Task 4: Launch the 9-run sweep
**Files:**
- Write queue scripts and logs under `data/run_logs/` locally and on 5880.
- Write outputs under `data/outputs/` locally and on 5880.
- [ ] Use the same matrix as the prior iMF sweep: `n_emb ∈ {128,256,384}`, `n_layer ∈ {6,12,18}`, `seed=42`.
- [ ] Set `training.num_epochs=350` for all 9 runs.
- [ ] Encode `fullattn` in every `exp_name`, `logging.name`, and run directory to avoid collisions.
- [ ] Balance the 9 runs across local 5090, 5880 GPU0, and 5880 GPU1 as three serial queues.
- [ ] Sync the new config to the remote smoke repo before launching remote queues.
### Task 5: Monitor and auto-summarize
**Files:**
- Read local and remote pid files, logs, outputs, checkpoints.
- [ ] Start an xhigh monitoring agent that polls all three queues.
- [ ] On completion, parse all 9 `logs.json.txt` files and rank by max `test_mean_score`.
- [ ] Report embedding/layer trends and the best configuration.

View File

@@ -1,168 +0,0 @@
# PushT Image DiT iMF + SwanLab Design
## Goal
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using `test_mean_score` as the primary metric.
## Context
- The implementation baseline is `main`.
- The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
- Environment management must use the repo-local `uv` workflow.
- The trusted remote machine alias `5880` refers to `droid-system-product-name` (`droid@100.73.14.65`) and can run two GPU jobs in parallel.
## Architecture Overview
The work is split into two verified phases:
1. **Logging migration phase**
- Keep the existing PushT image DiT training behavior intact.
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
- Preserve local `logs.json.txt` output.
- Ensure rollout metrics such as `test_mean_score` and per-seed rewards are still logged.
- Disable simulation video logging at both the config and runner/logging boundary.
2. **iMF migration phase**
- Keep the original diffusion-based transformer image policy available on `main`.
- Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
- Reuse the existing observation encoder and training workspace where possible.
- Replace diffusion training with the iMeanFlow training objective.
- Use one-step inference for validation/rollout in the iMF path.
The implementation planning boundary for this spec is:
- code changes through a smoke-tested, pushed iMF branch
- not the full 3x3 sweep execution/monitoring workflow, which should be planned separately after the code path is verified and pushed
## Logging Design
### Scope
Only the PushT image DiT experiment chain is changed:
- `train_diffusion_transformer_hybrid_workspace.py`
- `pusht_image_runner.py`
- the new/updated PushT image transformer configs
### Behavior
- SwanLab runs in `online` mode.
- Logged values are scalar metrics only, e.g.:
- `train_loss`
- `val_loss`
- `train_action_mse_error`
- `test_mean_score`
- aggregate rollout metrics and optional per-seed scalar rewards
- No simulation videos are uploaded or wrapped as logging objects.
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
### Operational safeguards
- Default PushT experiment configs set `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`.
- The PushT image runner will not emit video objects into `log_data`, preventing accidental uploads even if visualization counts are later changed.
- SwanLab credentials are provided through the environment at runtime, not committed into the repo.
## iMF Model Design
### Baseline reuse
The iMF path reuses:
- the existing image observation encoder
- the existing action/observation normalization path
- the existing training workspace skeleton
- the existing PushT image dataset and env runner
### New files
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
- `image_pusht_diffusion_policy_dit_imf.yaml`
### Existing files changed for the iMF path
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
- logging migration to SwanLab for this experiment chain
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
- `diffusion_policy/env_runner/pusht_image_runner.py`
- suppress video objects in returned logs
### Model structure
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
- `u`: average velocity field
The same function is reused at two evaluation points:
- canonical signature: `fn(z, r, t, cond)`
- `fn(z_t, r, t, cond)` predicts average velocity `u`
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
Inputs remain conditioned on encoded observations and action trajectory tokens.
## iMF Training Objective
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
1. sample `t, r`
2. sample Gaussian noise `e`
3. form `z_t = (1 - t) * x + t * e`
4. predict instantaneous velocity surrogate with the same network:
- `v = fn(z_t, t, t, cond)`
5. define the JVP function exactly as:
- `g(z, r, t) = fn(z, r, t, cond)`
6. compute the primal output and JVP with tangent:
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
7. form compound velocity:
- `V = u + (t - r) * stopgrad(du_dt)`
8. train against the average-velocity target:
- `target = e - x`
9. optimize only the masked iMF loss:
- `loss = metric(V - target)`
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
## iMF Inference Design
Inference uses a single step starting from noise:
- initialize `z_1 ~ N(0, I)`
- set `t = 1.0`, `r = 0.0`
- predict `u = fn(z_1, r, t, cond)`
- produce the action sample with one update:
- `x_hat = z_1 - (t - r) * u`
This matches the time direction in the reference iMeanFlow sampling logic.
## Testing Strategy
### Phase 1: logging migration smoke test
- use the repo-local `uv` environment
- run a debug/smoke PushT image DiT training job on a single GPU with:
- `training.debug=true`
- `dataloader.num_workers=0`
- `val_dataloader.num_workers=0`
- `task.env_runner.n_envs=1`
- `task.env_runner.n_test_vis=0`
- `task.env_runner.n_train_vis=0`
- verify:
- SwanLab initializes successfully
- `logs.json.txt` is populated
- rollout metrics still include `test_mean_score`
- no video logging is attempted
### Phase 2: iMF smoke test
- run an equivalent debug PushT image iMF job
- verify:
- forward/backward passes succeed
- JVP path executes on the local Torch version
- one-step inference returns correctly shaped actions
- rollout produces scalar metrics including `test_mean_score`
## Branch and Commit Strategy
1. start from a `main`-based worktree branch
2. commit the SwanLab/no-video migration after smoke verification
3. continue with the iMF implementation
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
## Post-Implementation Experiment Plan
After the iMF path is smoke-tested and pushed, a separate experiment-execution plan should launch:
- run a 3x3 grid over:
- `n_emb ∈ {128, 256, 384}`
- `n_layer ∈ {6, 12, 18}`
- keep the rest of the setup fixed
- use a fixed single-seed setting for comparability unless a later explicit experiment plan expands that scope
- run each experiment for 300 epochs
- primary comparison metric: `test_mean_score`
## Post-Implementation Resource Allocation
The separate experiment-execution plan should schedule three concurrent runs until the matrix is complete:
- local machine: 1 GPU
- `5880`: 2 GPUs
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
## Risks and Mitigations
- **Torch JVP compatibility risk**: provide a fallback JVP implementation and smoke-test immediately.
- **Logging regression risk**: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
- **Video/logging side effects**: disable visualizations in config and filter video objects out of runner logs.
- **Cross-host drift**: push the verified branch to Gitea before launching the experiment matrix on multiple machines.

View File

@@ -1,107 +0,0 @@
# PushT Image iMF Full-Attention Sweep Design
## Goal
在一个独立新分支上,为 PushT 图像 iMF 路线新增 **full-attention** 变体(关闭因果注意力),并按与之前相同的架构扫描网格运行 **9 组实验**,每组训练 **350 epochs**。所有实验完成后,提取每组 **`max(test_mean_score)`** 并输出完整排名和趋势总结。
## Scope
本次工作仅覆盖:
1. 在不影响现有因果版 iMF 路线的前提下,新增 full-attention 实验链路;
2.`n_emb ∈ {128, 256, 384}``n_layer ∈ {6, 12, 18}` 的 9 组组合做 350-epoch 扫描;
3. 在本机 5090 与 5880 双卡上做三路并行调度;
4. 在全部实验完成后自动汇总结果并直接向用户汇报。
不在本次范围内:
- 不替换或删除现有因果版 iMF 配置;
- 不改动已有 DiT baseline 实现;
- 不做多 seed 扩展;
- 不额外增加视频记录。
## Design Choice
采用“**新增独立配置 + 新分支**”的方式,而不是覆盖现有 iMF 默认配置。
原因:
- 现有因果版 iMF 已完成实验与结果记录,保持不动更利于对照;
- full-attention 作为新的实验链路,使用独立配置更易复现;
- 运行时只需要通过配置切换 `policy.causal_attn=false`,不需要重新设计 iMF 算法本身。
## Configuration Design
新增一个独立配置文件,例如:
- `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
其职责:
- 继承当前 PushT image iMF 配置链路;
- 保持 iMF 单步推理、SwanLab 标量记录、无视频记录;
- 显式设置:
- `policy.causal_attn=false`
- `policy.n_head=1`
- 保持其余 iMF 训练语义不变。
SwanLab 命名延续当前修复后的策略:
- `logging.name=${exp_name}`
- `logging.resume=false`
- `logging.id=null`
- `logging.group=${exp_name}` 或统一 sweep group override
## Code Change Strategy
优先最小改动:
- 若当前 `IMFTransformerForDiffusion` 已支持 `causal_attn=False` 分支,则不改核心算法,仅通过新配置关闭因果 mask
- 如需补充回归验证,则新增针对 full-attention 配置/掩码行为的最小测试;
- 不改变已有因果版实验配置和已有测试语义。
## Experiment Matrix
实验网格固定为:
- `n_emb=128, n_layer=6`
- `n_emb=128, n_layer=12`
- `n_emb=128, n_layer=18`
- `n_emb=256, n_layer=6`
- `n_emb=256, n_layer=12`
- `n_emb=256, n_layer=18`
- `n_emb=384, n_layer=6`
- `n_emb=384, n_layer=12`
- `n_emb=384, n_layer=18`
统一设置:
- `training.num_epochs=350`
- `training.resume=false`
- `seed=42`
- PushT image 数据路径不变
- 指标以 **`logs.json.txt``test_mean_score` 的最大值** 为准
## Scheduling Design
使用三路串行队列并行执行 9 个实验:
- 本机 50901 个顺序队列
- 5880 GPU01 个顺序队列
- 5880 GPU11 个顺序队列
分配原则:
- 延续按 `n_emb × n_layer` 近似平衡工作量;
- 每张卡同一时刻只跑 1 个实验;
- 队列脚本负责“前一个结束后自动启动下一个”。
## Monitoring Design
继续采用“**训练队列脚本 + 监控 agent**”双层机制:
1. **实际调度**由本地/远端队列脚本负责;
2. **监控**由一个 xhigh 子 agent 轮询:
- 读取 pid 状态
- 检查 master log
- 检查每个 run 的 `logs.json.txt`
- 判断是否卡死/失败/全部完成
3. 一旦全部完成,监控 agent 直接返回:
- 9 组实验的最终 epoch
- 每组 `max(test_mean_score)`
- 排名表
- embedding / layer 趋势总结
本次要求下agent 在收到全部完成信号后应直接向主会话回报结果,不等待用户再次提醒。
## Success Criteria
满足以下条件即视为完成:
1. full-attention iMF 配置在新分支上可运行;
2. 9 组 350-epoch 实验全部完成;
3. 不记录仿真视频,只记录标量;
4. SwanLab 运行命名不冲突;
5. 输出 9 组实验 `max(test_mean_score)` 的完整汇总与结论;
6. 全部实验结束后主会话可直接给用户最终总结。

View File

@@ -1,30 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit
policy:
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -1,32 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_imf
policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -1,33 +0,0 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit_imf_fullattn
policy:
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
num_inference_steps: 1
n_head: 1
causal_attn: false
logging:
backend: swanlab
mode: online
name: ${exp_name}
resume: false
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: null
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -0,0 +1,190 @@
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
checkpoint:
save_last_ckpt: true
save_last_snapshot: false
topk:
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
k: 5
mode: min
monitor_key: train_loss
dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: true
dataset_obs_steps: 2
ema:
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
inv_gamma: 1.0
max_value: 0.9999
min_value: 0.0
power: 0.75
update_after_step: 0
exp_name: default
horizon: 16
keypoint_visible_rate: 1.0
logging:
group: null
id: null
mode: online
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
project: diffusion_policy_debug
resume: true
tags:
- train_diffusion_transformer_hybrid_pmf
- pusht_image
- default
multi_run:
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
n_action_steps: 8
n_latency_steps: 0
n_obs_steps: 2
name: train_diffusion_transformer_hybrid_pmf
obs_as_cond: true
optimizer:
betas:
- 0.9
- 0.95
learning_rate: 0.0001
obs_encoder_weight_decay: 1.0e-06
transformer_weight_decay: 0.001
past_action_visible: false
policy:
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
crop_shape:
- 84
- 84
eval_fixed_crop: true
horizon: 16
n_action_steps: 8
n_cond_layers: 0
n_emb: 256
n_head: 4
n_layer: 12
n_head_layers: 4
n_obs_steps: 2
n_time_tokens: 4
noise_scale: 1.0
adatloss_eps: 0.01
p_mean: -0.4
p_std: 1.0
noise_scheduler:
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001
clip_sample: true
num_train_timesteps: 100
prediction_type: sample
variance_type: fixed_small
num_inference_steps: 1
obs_as_cond: true
obs_encoder_group_norm: true
p_drop_attn: 0.0
p_drop_emb: 0.0
pmf_u_loss_weight: 1.0
pmf_v_loss_weight: 1.0
tr_uniform: true
tr_uniform_prob: 0.1
data_proportion: 0.5
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task:
dataset:
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
horizon: 16
max_train_episodes: 90
pad_after: 7
pad_before: 1
seed: 42
val_ratio: 0.02
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
env_runner:
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
fps: 10
legacy_test: true
max_steps: 300
n_action_steps: 8
n_envs: null
n_obs_steps: 2
n_test: 50
n_test_vis: 4
n_train: 6
n_train_vis: 2
past_action: false
test_start_seed: 100000
train_start_seed: 0
image_shape:
- 3
- 96
- 96
name: pusht_image
shape_meta:
action:
shape:
- 2
obs:
agent_pos:
shape:
- 2
type: low_dim
image:
shape:
- 3
- 96
- 96
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
debug: false
device: cuda:0
gradient_accumulate_every: 1
lr_scheduler: cosine
lr_warmup_steps: 500
max_train_steps: null
max_val_steps: null
num_epochs: 600
resume: true
rollout_every: 50
sample_every: 5
seed: 42
tqdm_interval_sec: 1.0
use_ema: true
val_every: 1
val_dataloader:
batch_size: 64
num_workers: 8
persistent_workers: false
pin_memory: true
shuffle: false

View File

@@ -22,6 +22,7 @@ pymunk==6.2.1
wandb==0.13.3 wandb==0.13.3
threadpoolctl==3.1.0 threadpoolctl==3.1.0
shapely==1.8.5.post1 shapely==1.8.5.post1
matplotlib==3.6.1
imageio==2.22.0 imageio==2.22.0
imageio-ffmpeg==0.4.7 imageio-ffmpeg==0.4.7
termcolor==2.0.1 termcolor==2.0.1
@@ -36,4 +37,3 @@ av==14.0.1
pygame==2.5.2 pygame==2.5.2
robomimic==0.2.0 robomimic==0.2.0
opencv-python-headless==4.10.0.84 opencv-python-headless==4.10.0.84
swanlab

View File

@@ -1,46 +0,0 @@
import inspect
import pathlib
import sys
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
IMFTransformerForDiffusion,
)
def test_imf_transformer_forward_signature_and_shape_single_head():
signature = inspect.signature(IMFTransformerForDiffusion.forward)
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
assert signature.parameters['cond'].default is None
model = IMFTransformerForDiffusion(
input_dim=3,
output_dim=3,
horizon=5,
n_obs_steps=2,
cond_dim=4,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
n_cond_layers=0,
)
model.configure_optimizers()
sample = torch.randn(2, 5, 3)
r = torch.rand(2)
t = torch.rand(2)
cond = torch.randn(2, 2, 4)
pred_u = model(sample, r, t, cond=cond)
assert pred_u.shape == sample.shape

View File

@@ -1,313 +0,0 @@
import pathlib
import sys
import pytest
import torch
import torch.nn as nn
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
IMFTransformerHybridImagePolicy,
)
class ConstantModel(nn.Module):
def __init__(self, value):
super().__init__()
self.value = value
def forward(self, sample, r, t, cond=None):
return torch.full_like(sample, self.value)
class AffineModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
return sample * self.weight + (r + t).view(-1, 1, 1)
class SumMixModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(2.0))
def forward(self, sample, r, t, cond=None):
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
return mixed * self.weight + t.view(-1, 1, 1)
class TrackingContext:
def __init__(self):
self.active = False
self.enter_count = 0
def __enter__(self):
self.active = True
self.enter_count += 1
return self
def __exit__(self, exc_type, exc, tb):
self.active = False
return False
def make_policy(model):
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
ModuleAttrMixin.__init__(policy)
policy.model = model
return policy
def fake_parent_init(
self,
shape_meta,
noise_scheduler,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
crop_shape=(76, 76),
obs_encoder_group_norm=False,
eval_fixed_crop=False,
n_layer=8,
n_cond_layers=0,
n_head=1,
n_emb=256,
p_drop_emb=0.0,
p_drop_attn=0.3,
causal_attn=True,
time_as_cond=True,
obs_as_cond=True,
pred_action_steps_only=False,
**kwargs,
):
ModuleAttrMixin.__init__(self)
self.action_dim = shape_meta['action']['shape'][0]
self.obs_feature_dim = 4
self.obs_as_cond = obs_as_cond
self.pred_action_steps_only = pred_action_steps_only
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.horizon = horizon
@pytest.fixture
def shape_meta():
return {
'action': {'shape': [2]},
'obs': {},
}
def test_sample_one_step_uses_imf_update_formula():
policy = make_policy(ConstantModel(0.25))
z_1 = torch.tensor([
[[1.0, -1.0], [0.5, 0.0]],
[[2.0, 3.0], [-2.0, 4.0]],
])
r = torch.zeros(z_1.shape[0])
t = torch.ones(z_1.shape[0])
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
assert torch.allclose(x_hat, expected)
def test_compound_velocity_uses_detached_du_dt_term():
policy = make_policy(ConstantModel(0.0))
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
r = torch.tensor([0.2])
t = torch.tensor([0.8])
compound = policy._compound_velocity(u, du_dt, r, t)
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
assert torch.allclose(compound, expected)
compound.sum().backward()
assert u.grad is not None
assert du_dt.grad is None
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
tracker = TrackingContext()
def fake_jvp(fn, primals, tangents):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
tracker = TrackingContext()
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
assert tracker.active is True
return fn(*primals), torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
policy = make_policy(ConstantModel(0.5))
policy._jvp_math_sdp_context = lambda tensor: tracker
z_t = torch.randn(2, 3, 4)
r = torch.rand(2, requires_grad=True)
t = torch.rand(2, requires_grad=True)
v = torch.randn_like(z_t, requires_grad=True)
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
assert tracker.enter_count == 1
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
captured = {}
def fake_jvp(fn, primals, tangents):
captured['tangents'] = tangents
captured['primal_output'] = fn(*primals)
return captured['primal_output'], torch.zeros_like(primals[0])
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
tangent_v, tangent_r, tangent_t = captured['tangents']
assert torch.equal(tangent_v, v.detach())
assert tangent_v.requires_grad is False
assert torch.equal(tangent_r, torch.zeros_like(r))
assert torch.equal(tangent_t, torch.ones_like(t))
conditioned = z_t.clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_primal = policy.model(conditioned, r, t, cond=None)
assert torch.allclose(captured['primal_output'], expected_primal)
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
policy = make_policy(SumMixModel())
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
r = torch.rand(1, requires_grad=True)
t = torch.rand(1, requires_grad=True)
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
condition_mask = torch.tensor([[[False, True, False]]])
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
u, du_dt = policy._compute_u_and_du_dt(
z_t,
r,
t,
cond=None,
v=v,
condition_data=condition_data,
condition_mask=condition_mask,
)
conditioned = z_t.detach().clone()
conditioned[condition_mask] = condition_data[condition_mask]
expected_u = policy.model(conditioned, r, t, cond=None)
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
assert u.shape == z_t.shape
assert du_dt.shape == z_t.shape
assert torch.allclose(u, expected_u)
assert torch.allclose(du_dt, expected_du_dt)
u.sum().backward()
assert policy.model.weight.grad is not None
assert torch.count_nonzero(policy.model.weight.grad) > 0
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
policy = IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=1,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=True,
)
assert policy.model.horizon == 4
assert policy.num_inference_steps == 1
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
monkeypatch.setattr(
policy_module.DiffusionTransformerHybridImagePolicy,
'__init__',
fake_parent_init,
)
with pytest.raises(ValueError, match='num_inference_steps'):
IMFTransformerHybridImagePolicy(
shape_meta=shape_meta,
noise_scheduler=None,
horizon=10,
n_action_steps=4,
n_obs_steps=2,
num_inference_steps=2,
n_layer=1,
n_head=1,
n_emb=16,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=True,
obs_as_cond=True,
pred_action_steps_only=False,
)

View File

@@ -1,110 +0,0 @@
import pathlib
import sys
import gym
from gym import spaces
import numpy as np
import pytest
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.env_runner.pusht_image_runner as runner_module
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
class FakePushTImageEnv(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, legacy=False, render_size=96):
del legacy, render_size
self.observation_space = spaces.Dict({
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
})
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
self.seed_value = 0
self.step_count = 0
def seed(self, seed=None):
self.seed_value = 0 if seed is None else seed
def reset(self):
self.step_count = 0
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
def step(self, action):
del action
self.step_count += 1
reward = 0.1 if self.seed_value < 10000 else 0.9
done = self.step_count >= 1
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
return obs, reward, done, {}
def render(self, *args, **kwargs):
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
class FakePolicy:
device = torch.device('cpu')
dtype = torch.float32
def reset(self):
return None
def predict_action(self, obs_dict):
n_envs = next(iter(obs_dict.values())).shape[0]
return {
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
}
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
log_data = summarize_rollout_metrics(
env_seeds=[11, 12, 101],
env_prefixs=['train/', 'train/', 'test/'],
all_rewards=[
[0.2, 0.8],
[0.1, 0.4],
[0.5, 0.9],
],
all_video_paths=[
'/tmp/train-11.mp4',
'/tmp/train-12.mp4',
'/tmp/test-101.mp4',
],
)
assert log_data['train/sim_max_reward_11'] == 0.8
assert log_data['train/sim_max_reward_12'] == 0.4
assert log_data['test/sim_max_reward_101'] == 0.9
assert log_data['train_mean_score'] == pytest.approx(0.6)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any(key.startswith('train/sim_video_') for key in log_data)
assert not any(key.startswith('test/sim_video_') for key in log_data)
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
runner = runner_module.PushTImageRunner(
output_dir=tmp_path,
n_train=1,
n_train_vis=1,
n_test=1,
n_test_vis=1,
n_envs=2,
max_steps=2,
n_obs_steps=2,
n_action_steps=2,
tqdm_interval_sec=0.0,
)
log_data = runner.run(FakePolicy())
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
assert log_data['train_mean_score'] == pytest.approx(0.1)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any('sim_video' in key for key in log_data)

View File

@@ -1,44 +0,0 @@
import pathlib
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
def _load_cfg(name: str):
return OmegaConf.load(ROOT_DIR / name)
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention():
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml')
assert cfg.logging.backend == 'swanlab'
assert cfg.logging.mode == 'online'
assert cfg.logging.name == cfg.exp_name
assert cfg.logging.resume is False
assert cfg.logging.id is None
assert cfg.logging.group == cfg.exp_name
assert cfg.policy.causal_attn is False

View File

@@ -1,198 +0,0 @@
import importlib
import pathlib
import sys
import pytest
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
def load_workspace_module(monkeypatch, *, wandb_missing=False):
sys.modules.pop(MODULE_NAME, None)
if wandb_missing:
monkeypatch.setitem(sys.modules, 'wandb', None)
return importlib.import_module(MODULE_NAME)
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeSwanLab:
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
monkeypatch.setattr(
workspace_module,
'_load_wandb',
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
)
cfg = OmegaConf.create({
'logging': {
'backend': 'swanlab',
'project': 'demo-project',
'name': 'demo-run',
'group': 'demo-group',
'tags': ['pusht', 'dit'],
'id': 'run-123',
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 1.0}, step=7)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['experiment_name'] == 'demo-run'
assert init_kwargs['group'] == 'demo-group'
assert init_kwargs['tags'] == ['pusht', 'dit']
assert init_kwargs['id'] == 'run-123'
assert init_kwargs['resume'] is True
assert init_kwargs['mode'] == 'cloud'
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
assert ('log', {'metric': 1.0}, 7) in events
assert events.count(('finish',)) == 1
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeConfig:
def update(self, payload):
events.append(('config.update', payload))
class FakeWandb:
def __init__(self):
self.config = FakeConfig()
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
cfg = OmegaConf.create({
'logging': {
'project': 'demo-project',
'name': 'demo-run',
'group': None,
'tags': ['shared'],
'id': None,
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 2.0}, step=3)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['dir'] == str(tmp_path)
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['name'] == 'demo-run'
assert init_kwargs['mode'] == 'online'
assert ('config.update', {'output_dir': str(tmp_path)}) in events
assert ('log', {'metric': 2.0}, 3) in events
assert events.count(('finish',)) == 1
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
cfg = OmegaConf.create({
'logging': {
'backend': 'tensorboard',
'project': 'demo-project',
'name': 'demo-run',
'mode': 'offline',
}
})
with pytest.raises(ValueError, match='Unknown logging backend'):
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
raise RuntimeError('finish boom')
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(ValueError, match='primary boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 6.0}, step=12)
raise ValueError('primary boom')
assert ('log', {'metric': 6.0}, 12) in events
assert events.count(('finish',)) == 1
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(RuntimeError, match='boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 5.0}, step=11)
raise RuntimeError('boom')
assert ('log', {'metric': 5.0}, 11) in events
assert events.count(('finish',)) == 1