Compare commits
6 Commits
08c1950c6d
...
feat/pusht
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36fbf2a6b7 | ||
|
|
4cd5085b33 | ||
|
|
5e7ae6cfa5 | ||
|
|
23374a4cd2 | ||
|
|
15a0c41cbf | ||
|
|
ba6ede9425 |
@@ -1,21 +1,37 @@
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
import collections
|
||||
import pathlib
|
||||
import tqdm
|
||||
import dill
|
||||
import math
|
||||
import wandb.sdk.data_types.video as wv
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
||||
|
||||
|
||||
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
|
||||
del all_video_paths
|
||||
|
||||
max_rewards = collections.defaultdict(list)
|
||||
log_data = dict()
|
||||
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
|
||||
max_reward = np.max(rewards)
|
||||
max_rewards[prefix].append(max_reward)
|
||||
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
|
||||
|
||||
aggregate_key_map = {
|
||||
'train/': 'train_mean_score',
|
||||
'test/': 'test_mean_score',
|
||||
}
|
||||
for prefix, value in max_rewards.items():
|
||||
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
|
||||
|
||||
return log_data
|
||||
|
||||
class PushTImageRunner(BaseImageRunner):
|
||||
def __init__(self,
|
||||
output_dir,
|
||||
@@ -40,24 +56,11 @@ class PushTImageRunner(BaseImageRunner):
|
||||
if n_envs is None:
|
||||
n_envs = n_train + n_test
|
||||
|
||||
steps_per_render = max(10 // fps, 1)
|
||||
def env_fn():
|
||||
return MultiStepWrapper(
|
||||
VideoRecordingWrapper(
|
||||
PushTImageEnv(
|
||||
legacy=legacy_test,
|
||||
render_size=render_size
|
||||
),
|
||||
video_recoder=VideoRecorder.create_h264(
|
||||
fps=fps,
|
||||
codec='h264',
|
||||
input_pix_fmt='rgb24',
|
||||
crf=crf,
|
||||
thread_type='FRAME',
|
||||
thread_count=1
|
||||
),
|
||||
file_path=None,
|
||||
steps_per_render=steps_per_render
|
||||
PushTImageEnv(
|
||||
legacy=legacy_test,
|
||||
render_size=render_size
|
||||
),
|
||||
n_obs_steps=n_obs_steps,
|
||||
n_action_steps=n_action_steps,
|
||||
@@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
|
||||
# train
|
||||
for i in range(n_train):
|
||||
seed = train_start_seed + i
|
||||
enable_render = i < n_train_vis
|
||||
|
||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||
# setup rendering
|
||||
# video_wrapper
|
||||
assert isinstance(env.env, VideoRecordingWrapper)
|
||||
env.env.video_recoder.stop()
|
||||
env.env.file_path = None
|
||||
if enable_render:
|
||||
filename = pathlib.Path(output_dir).joinpath(
|
||||
'media', wv.util.generate_id() + ".mp4")
|
||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||
filename = str(filename)
|
||||
env.env.file_path = filename
|
||||
|
||||
def init_fn(env, seed=seed):
|
||||
# set seed
|
||||
assert isinstance(env, MultiStepWrapper)
|
||||
env.seed(seed)
|
||||
@@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
|
||||
# test
|
||||
for i in range(n_test):
|
||||
seed = test_start_seed + i
|
||||
enable_render = i < n_test_vis
|
||||
|
||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||
# setup rendering
|
||||
# video_wrapper
|
||||
assert isinstance(env.env, VideoRecordingWrapper)
|
||||
env.env.video_recoder.stop()
|
||||
env.env.file_path = None
|
||||
if enable_render:
|
||||
filename = pathlib.Path(output_dir).joinpath(
|
||||
'media', wv.util.generate_id() + ".mp4")
|
||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||
filename = str(filename)
|
||||
env.env.file_path = filename
|
||||
|
||||
def init_fn(env, seed=seed):
|
||||
# set seed
|
||||
assert isinstance(env, MultiStepWrapper)
|
||||
env.seed(seed)
|
||||
@@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
|
||||
n_chunks = math.ceil(n_inits / n_envs)
|
||||
|
||||
# allocate data
|
||||
all_video_paths = [None] * n_inits
|
||||
all_rewards = [None] * n_inits
|
||||
|
||||
for chunk_idx in range(n_chunks):
|
||||
@@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
|
||||
pbar.update(action.shape[1])
|
||||
pbar.close()
|
||||
|
||||
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
||||
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
||||
# clear out video buffer
|
||||
# reset env state between evaluation calls
|
||||
_ = env.reset()
|
||||
|
||||
# log
|
||||
max_rewards = collections.defaultdict(list)
|
||||
log_data = dict()
|
||||
# results reported in the paper are generated using the commented out line below
|
||||
# which will only report and average metrics from first n_envs initial condition and seeds
|
||||
# fortunately this won't invalidate our conclusion since
|
||||
# 1. This bug only affects the variance of metrics, not their mean
|
||||
# 2. All baseline methods are evaluated using the same code
|
||||
# to completely reproduce reported numbers, uncomment this line:
|
||||
# for i in range(len(self.env_fns)):
|
||||
# and comment out this line
|
||||
for i in range(n_inits):
|
||||
seed = self.env_seeds[i]
|
||||
prefix = self.env_prefixs[i]
|
||||
max_reward = np.max(all_rewards[i])
|
||||
max_rewards[prefix].append(max_reward)
|
||||
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
||||
|
||||
# visualize sim
|
||||
video_path = all_video_paths[i]
|
||||
if video_path is not None:
|
||||
sim_video = wandb.Video(video_path)
|
||||
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
||||
|
||||
# log aggregate metrics
|
||||
for prefix, value in max_rewards.items():
|
||||
name = prefix+'mean_score'
|
||||
value = np.mean(value)
|
||||
log_data[name] = value
|
||||
|
||||
return log_data
|
||||
# results reported in the paper are generated using the commented out
|
||||
# line below, which would only report and average metrics from the
|
||||
# first n_envs initial conditions and seeds. We keep the full n_inits
|
||||
# behavior here.
|
||||
return summarize_rollout_metrics(
|
||||
env_seeds=self.env_seeds[:n_inits],
|
||||
env_prefixs=self.env_prefixs[:n_inits],
|
||||
all_rewards=all_rewards[:n_inits],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMFTransformerForDiffusion(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
horizon: int,
|
||||
n_obs_steps: int = None,
|
||||
cond_dim: int = 0,
|
||||
n_layer: int = 12,
|
||||
n_head: int = 1,
|
||||
n_emb: int = 768,
|
||||
p_drop_emb: float = 0.1,
|
||||
p_drop_attn: float = 0.1,
|
||||
causal_attn: bool = False,
|
||||
time_as_cond: bool = True,
|
||||
obs_as_cond: bool = False,
|
||||
n_cond_layers: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
||||
|
||||
if n_obs_steps is None:
|
||||
n_obs_steps = horizon
|
||||
|
||||
T = horizon
|
||||
T_cond = 2
|
||||
if not time_as_cond:
|
||||
T += 2
|
||||
T_cond -= 2
|
||||
obs_as_cond = cond_dim > 0
|
||||
if obs_as_cond:
|
||||
assert time_as_cond
|
||||
T_cond += n_obs_steps
|
||||
|
||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||
self.drop = nn.Dropout(p_drop_emb)
|
||||
|
||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
||||
self.cond_obs_emb = None
|
||||
if obs_as_cond:
|
||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
|
||||
|
||||
self.cond_pos_emb = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
encoder_only = False
|
||||
if T_cond > 0:
|
||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||
if n_cond_layers > 0:
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_cond_layers,
|
||||
)
|
||||
else:
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(n_emb, 4 * n_emb),
|
||||
nn.Mish(),
|
||||
nn.Linear(4 * n_emb, n_emb),
|
||||
)
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
else:
|
||||
encoder_only = True
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=n_emb,
|
||||
nhead=n_head,
|
||||
dim_feedforward=4 * n_emb,
|
||||
dropout=p_drop_attn,
|
||||
activation='gelu',
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=n_layer,
|
||||
)
|
||||
|
||||
if causal_attn:
|
||||
sz = T
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('mask', mask)
|
||||
|
||||
if time_as_cond and obs_as_cond:
|
||||
S = T_cond
|
||||
t_idx, s_idx = torch.meshgrid(
|
||||
torch.arange(T),
|
||||
torch.arange(S),
|
||||
indexing='ij',
|
||||
)
|
||||
mask = t_idx >= (s_idx - 2)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
self.register_buffer('memory_mask', mask)
|
||||
else:
|
||||
self.memory_mask = None
|
||||
else:
|
||||
self.mask = None
|
||||
self.memory_mask = None
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_emb)
|
||||
self.head = nn.Linear(n_emb, output_dim)
|
||||
|
||||
self.T = T
|
||||
self.T_cond = T_cond
|
||||
self.horizon = horizon
|
||||
self.time_as_cond = time_as_cond
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.encoder_only = encoder_only
|
||||
|
||||
self.apply(self._init_weights)
|
||||
logger.info(
|
||||
'number of parameters: %e',
|
||||
sum(p.numel() for p in self.parameters()),
|
||||
)
|
||||
|
||||
def _init_weights(self, module):
|
||||
ignore_types = (
|
||||
nn.Dropout,
|
||||
SinusoidalPosEmb,
|
||||
nn.TransformerEncoderLayer,
|
||||
nn.TransformerDecoderLayer,
|
||||
nn.TransformerEncoder,
|
||||
nn.TransformerDecoder,
|
||||
nn.ModuleList,
|
||||
nn.Mish,
|
||||
nn.Sequential,
|
||||
)
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
|
||||
for name in weight_names:
|
||||
weight = getattr(module, name)
|
||||
if weight is not None:
|
||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||
|
||||
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
|
||||
for name in bias_names:
|
||||
bias = getattr(module, name)
|
||||
if bias is not None:
|
||||
torch.nn.init.zeros_(bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, IMFTransformerForDiffusion):
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
if module.cond_obs_emb is not None:
|
||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||
elif isinstance(module, ignore_types):
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Unaccounted module {module}')
|
||||
|
||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, _ in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn
|
||||
|
||||
if pn.endswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn.startswith('bias'):
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
no_decay.add(fpn)
|
||||
|
||||
no_decay.add('pos_emb')
|
||||
no_decay.add('_dummy_variable')
|
||||
if self.cond_pos_emb is not None:
|
||||
no_decay.add('cond_pos_emb')
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
||||
)
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
||||
'weight_decay': weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||
'weight_decay': 0.0,
|
||||
},
|
||||
]
|
||||
return optim_groups
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.95),
|
||||
):
|
||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||
return optimizer
|
||||
|
||||
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_tensor(value):
|
||||
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
||||
elif value.ndim == 0:
|
||||
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
||||
else:
|
||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
||||
return value.expand(sample.shape[0])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
r: Union[torch.Tensor, float, int],
|
||||
t: Union[torch.Tensor, float, int],
|
||||
cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r = self._prepare_time_input(r, sample)
|
||||
t = self._prepare_time_input(t, sample)
|
||||
r_emb = self.time_emb(r).unsqueeze(1)
|
||||
t_emb = self.time_emb(t).unsqueeze(1)
|
||||
|
||||
input_emb = self.input_emb(sample)
|
||||
|
||||
if self.encoder_only:
|
||||
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.encoder(src=x, mask=self.mask)
|
||||
x = x[:, 2:, :]
|
||||
else:
|
||||
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
||||
if self.obs_as_cond:
|
||||
cond_obs_emb = self.cond_obs_emb(cond)
|
||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
||||
token_count = cond_embeddings.shape[1]
|
||||
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
||||
x = self.drop(cond_embeddings + position_embeddings)
|
||||
x = self.encoder(x)
|
||||
memory = x
|
||||
|
||||
token_embeddings = input_emb
|
||||
token_count = token_embeddings.shape[1]
|
||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.decoder(
|
||||
tgt=x,
|
||||
memory=memory,
|
||||
tgt_mask=self.mask,
|
||||
memory_mask=self.memory_mask,
|
||||
)
|
||||
|
||||
x = self.ln_f(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
273
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
273
diffusion_policy/policy/imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import reduce
|
||||
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
|
||||
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
|
||||
DiffusionTransformerHybridImagePolicy,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.func import jvp as TORCH_FUNC_JVP
|
||||
except ImportError: # pragma: no cover - depends on torch version
|
||||
TORCH_FUNC_JVP = None
|
||||
|
||||
|
||||
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
crop_shape=(76, 76),
|
||||
obs_encoder_group_norm=False,
|
||||
eval_fixed_crop=False,
|
||||
n_layer=8,
|
||||
n_cond_layers=0,
|
||||
n_head=1,
|
||||
n_emb=256,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.3,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
**kwargs,
|
||||
):
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 1
|
||||
elif num_inference_steps != 1:
|
||||
raise ValueError(
|
||||
'IMFTransformerHybridImagePolicy only supports one-step inference; '
|
||||
f'num_inference_steps must be 1, got {num_inference_steps}.'
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
crop_shape=crop_shape,
|
||||
obs_encoder_group_norm=obs_encoder_group_norm,
|
||||
eval_fixed_crop=eval_fixed_crop,
|
||||
n_layer=n_layer,
|
||||
n_cond_layers=n_cond_layers,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
p_drop_emb=p_drop_emb,
|
||||
p_drop_attn=p_drop_attn,
|
||||
causal_attn=causal_attn,
|
||||
time_as_cond=time_as_cond,
|
||||
obs_as_cond=obs_as_cond,
|
||||
pred_action_steps_only=pred_action_steps_only,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
|
||||
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
|
||||
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
|
||||
self.model = IMFTransformerForDiffusion(
|
||||
input_dim=input_dim,
|
||||
output_dim=input_dim,
|
||||
horizon=model_horizon,
|
||||
n_obs_steps=n_obs_steps,
|
||||
cond_dim=cond_dim,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_emb=n_emb,
|
||||
p_drop_emb=p_drop_emb,
|
||||
p_drop_attn=p_drop_attn,
|
||||
causal_attn=causal_attn,
|
||||
time_as_cond=time_as_cond,
|
||||
obs_as_cond=obs_as_cond,
|
||||
n_cond_layers=n_cond_layers,
|
||||
)
|
||||
self.num_inference_steps = 1
|
||||
|
||||
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
||||
return self.model(z, r, t, cond=cond)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
while value.ndim < reference.ndim:
|
||||
value = value.unsqueeze(-1)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _apply_conditioning(
|
||||
trajectory: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if condition_data is None or condition_mask is None:
|
||||
return trajectory
|
||||
conditioned = trajectory.clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
return conditioned
|
||||
|
||||
@staticmethod
|
||||
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
||||
if z_t.is_cuda:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=False,
|
||||
enable_cudnn=False,
|
||||
)
|
||||
return nullcontext()
|
||||
|
||||
@staticmethod
|
||||
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
||||
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
||||
|
||||
def _compute_u_and_du_dt(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond,
|
||||
v: torch.Tensor,
|
||||
condition_data: Optional[torch.Tensor] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
tangents = self._jvp_tangents(v, r, t)
|
||||
|
||||
def g(z, r_value, t_value):
|
||||
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
||||
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
||||
|
||||
with self._jvp_math_sdp_context(z_t):
|
||||
if TORCH_FUNC_JVP is not None:
|
||||
try:
|
||||
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
||||
except (RuntimeError, TypeError, NotImplementedError):
|
||||
pass
|
||||
|
||||
u = g(z_t, r, t)
|
||||
_, du_dt = torch.autograd.functional.jvp(
|
||||
g,
|
||||
(z_t, r, t),
|
||||
tangents,
|
||||
create_graph=False,
|
||||
strict=False,
|
||||
)
|
||||
return u, du_dt
|
||||
|
||||
def _compound_velocity(
|
||||
self,
|
||||
u: torch.Tensor,
|
||||
du_dt: torch.Tensor,
|
||||
r: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
delta = self._broadcast_batch_time(t - r, u)
|
||||
return u + delta * du_dt.detach()
|
||||
|
||||
def _sample_one_step(
|
||||
self,
|
||||
z_t: torch.Tensor,
|
||||
r: torch.Tensor = None,
|
||||
t: torch.Tensor = None,
|
||||
cond=None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = z_t.shape[0]
|
||||
if t is None:
|
||||
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
if r is None:
|
||||
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
||||
u = self.fn(z_t, r, t, cond=cond)
|
||||
delta = self._broadcast_batch_time(t - r, z_t)
|
||||
return z_t - delta * u
|
||||
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
condition_mask,
|
||||
cond=None,
|
||||
generator=None,
|
||||
**kwargs,
|
||||
):
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
)
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
trajectory = self._sample_one_step(trajectory, cond=cond)
|
||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||
return trajectory
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert 'valid_mask' not in batch
|
||||
nobs = self.normalizer.normalize(batch['obs'])
|
||||
nactions = self.normalizer['action'].normalize(batch['action'])
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
To = self.n_obs_steps
|
||||
|
||||
cond = None
|
||||
trajectory = nactions
|
||||
if self.obs_as_cond:
|
||||
this_nobs = dict_apply(
|
||||
nobs,
|
||||
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
|
||||
)
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
cond = nobs_features.reshape(batch_size, To, -1)
|
||||
if self.pred_action_steps_only:
|
||||
start = To - 1
|
||||
end = start + self.n_action_steps
|
||||
trajectory = nactions[:, start:end]
|
||||
else:
|
||||
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||
|
||||
if self.pred_action_steps_only:
|
||||
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||
else:
|
||||
condition_mask = self.mask_generator(trajectory.shape)
|
||||
|
||||
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||
loss_mask[..., : self.action_dim] = True
|
||||
loss_mask = loss_mask & ~condition_mask
|
||||
|
||||
x = trajectory
|
||||
e = torch.randn_like(x)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
||||
|
||||
t_broadcast = self._broadcast_batch_time(t, x)
|
||||
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
||||
z_t = self._apply_conditioning(z_t, x, condition_mask)
|
||||
|
||||
v = self.fn(z_t, t, t, cond=cond)
|
||||
u, du_dt = self._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=cond,
|
||||
v=v,
|
||||
condition_data=x,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
V = self._compound_velocity(u, du_dt, r, t)
|
||||
target = e - x
|
||||
|
||||
loss = F.mse_loss(V, target, reduction='none')
|
||||
loss = loss * loss_mask.type(loss.dtype)
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
@@ -8,6 +8,8 @@ if __name__ == "__main__":
|
||||
os.chdir(ROOT_DIR)
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
import importlib
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@@ -15,7 +17,6 @@ import pathlib
|
||||
from torch.utils.data import DataLoader
|
||||
import copy
|
||||
import random
|
||||
import wandb
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import shutil
|
||||
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
|
||||
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||||
|
||||
|
||||
class _LoggingBackend:
|
||||
def log(self, payload, step=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def finish(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _WandbLoggingBackend(_LoggingBackend):
|
||||
def __init__(self, run):
|
||||
self.run = run
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.run.log(payload, step=step)
|
||||
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
class _SwanLabLoggingBackend(_LoggingBackend):
|
||||
def __init__(self, run):
|
||||
self.run = run
|
||||
|
||||
def log(self, payload, step=None):
|
||||
self.run.log(payload, step=step)
|
||||
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
def _load_wandb():
|
||||
try:
|
||||
return importlib.import_module('wandb')
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"wandb is required when cfg.logging.backend == 'wandb' or missing"
|
||||
) from exc
|
||||
|
||||
|
||||
def _load_swanlab():
|
||||
try:
|
||||
return importlib.import_module('swanlab')
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def init_logging_backend(cfg: OmegaConf, output_dir):
|
||||
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
|
||||
if backend == 'swanlab':
|
||||
swanlab = _load_swanlab()
|
||||
if swanlab is None:
|
||||
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
|
||||
logging_cfg = cfg.logging
|
||||
mode = logging_cfg.mode
|
||||
if mode == 'online':
|
||||
mode = 'cloud'
|
||||
run = swanlab.init(
|
||||
project=logging_cfg.project,
|
||||
experiment_name=logging_cfg.name,
|
||||
group=logging_cfg.group,
|
||||
tags=logging_cfg.tags,
|
||||
id=logging_cfg.id,
|
||||
resume=logging_cfg.resume,
|
||||
mode=mode,
|
||||
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
)
|
||||
return _SwanLabLoggingBackend(run)
|
||||
|
||||
if backend not in (None, 'wandb'):
|
||||
raise ValueError(f"Unknown logging backend: {backend}")
|
||||
|
||||
wandb = _load_wandb()
|
||||
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
|
||||
logging_kwargs.pop('backend', None)
|
||||
run = wandb.init(
|
||||
dir=str(output_dir),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
**logging_kwargs
|
||||
)
|
||||
wandb.config.update(
|
||||
{
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
)
|
||||
return _WandbLoggingBackend(run)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def logging_backend_session(cfg: OmegaConf, output_dir):
|
||||
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
|
||||
primary_error = None
|
||||
try:
|
||||
yield logging_backend
|
||||
except BaseException as exc:
|
||||
primary_error = exc
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
logging_backend.finish()
|
||||
except BaseException:
|
||||
if primary_error is None:
|
||||
raise
|
||||
|
||||
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
include_keys = ['global_step', 'epoch']
|
||||
|
||||
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
output_dir=self.output_dir)
|
||||
assert isinstance(env_runner, BaseImageRunner)
|
||||
|
||||
# configure logging
|
||||
wandb_run = wandb.init(
|
||||
dir=str(self.output_dir),
|
||||
config=OmegaConf.to_container(cfg, resolve=True),
|
||||
**cfg.logging
|
||||
)
|
||||
wandb.config.update(
|
||||
{
|
||||
"output_dir": self.output_dir,
|
||||
}
|
||||
)
|
||||
|
||||
# configure checkpoint
|
||||
topk_manager = TopKCheckpointManager(
|
||||
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
||||
@@ -148,140 +242,141 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||
|
||||
# training loop
|
||||
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
||||
with JsonLogger(log_path) as json_logger:
|
||||
for local_epoch_idx in range(cfg.training.num_epochs):
|
||||
step_log = dict()
|
||||
# ========= train for this epoch ==========
|
||||
train_losses = list()
|
||||
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||
for batch_idx, batch in enumerate(tepoch):
|
||||
# device transfer
|
||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||
if train_sampling_batch is None:
|
||||
train_sampling_batch = batch
|
||||
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
|
||||
with JsonLogger(log_path) as json_logger:
|
||||
for local_epoch_idx in range(cfg.training.num_epochs):
|
||||
step_log = dict()
|
||||
# ========= train for this epoch ==========
|
||||
train_losses = list()
|
||||
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||
for batch_idx, batch in enumerate(tepoch):
|
||||
# device transfer
|
||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||
if train_sampling_batch is None:
|
||||
train_sampling_batch = batch
|
||||
|
||||
# compute loss
|
||||
raw_loss = self.model.compute_loss(batch)
|
||||
loss = raw_loss / cfg.training.gradient_accumulate_every
|
||||
loss.backward()
|
||||
# compute loss
|
||||
raw_loss = self.model.compute_loss(batch)
|
||||
loss = raw_loss / cfg.training.gradient_accumulate_every
|
||||
loss.backward()
|
||||
|
||||
# step optimizer
|
||||
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# update ema
|
||||
if cfg.training.use_ema:
|
||||
ema.step(self.model)
|
||||
# step optimizer
|
||||
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# update ema
|
||||
if cfg.training.use_ema:
|
||||
ema.step(self.model)
|
||||
|
||||
# logging
|
||||
raw_loss_cpu = raw_loss.item()
|
||||
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
||||
train_losses.append(raw_loss_cpu)
|
||||
step_log = {
|
||||
'train_loss': raw_loss_cpu,
|
||||
'global_step': self.global_step,
|
||||
'epoch': self.epoch,
|
||||
'lr': lr_scheduler.get_last_lr()[0]
|
||||
}
|
||||
# logging
|
||||
raw_loss_cpu = raw_loss.item()
|
||||
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
||||
train_losses.append(raw_loss_cpu)
|
||||
step_log = {
|
||||
'train_loss': raw_loss_cpu,
|
||||
'global_step': self.global_step,
|
||||
'epoch': self.epoch,
|
||||
'lr': lr_scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
||||
if not is_last_batch:
|
||||
# log of last step is combined with validation and rollout
|
||||
wandb_run.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
||||
if not is_last_batch:
|
||||
# log of last step is combined with validation and rollout
|
||||
logging_backend.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
|
||||
if (cfg.training.max_train_steps is not None) \
|
||||
and batch_idx >= (cfg.training.max_train_steps-1):
|
||||
break
|
||||
if (cfg.training.max_train_steps is not None) \
|
||||
and batch_idx >= (cfg.training.max_train_steps-1):
|
||||
break
|
||||
|
||||
# at the end of each epoch
|
||||
# replace train_loss with epoch average
|
||||
train_loss = np.mean(train_losses)
|
||||
step_log['train_loss'] = train_loss
|
||||
# at the end of each epoch
|
||||
# replace train_loss with epoch average
|
||||
train_loss = np.mean(train_losses)
|
||||
step_log['train_loss'] = train_loss
|
||||
|
||||
# ========= eval for this epoch ==========
|
||||
policy = self.model
|
||||
if cfg.training.use_ema:
|
||||
policy = self.ema_model
|
||||
policy.eval()
|
||||
# ========= eval for this epoch ==========
|
||||
policy = self.model
|
||||
if cfg.training.use_ema:
|
||||
policy = self.ema_model
|
||||
policy.eval()
|
||||
|
||||
# run rollout
|
||||
if (self.epoch % cfg.training.rollout_every) == 0:
|
||||
runner_log = env_runner.run(policy)
|
||||
# log all
|
||||
step_log.update(runner_log)
|
||||
# run rollout
|
||||
if (self.epoch % cfg.training.rollout_every) == 0:
|
||||
runner_log = env_runner.run(policy)
|
||||
# log all
|
||||
step_log.update(runner_log)
|
||||
|
||||
# run validation
|
||||
if (self.epoch % cfg.training.val_every) == 0:
|
||||
with torch.no_grad():
|
||||
val_losses = list()
|
||||
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||
for batch_idx, batch in enumerate(tepoch):
|
||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||
loss = self.model.compute_loss(batch)
|
||||
val_losses.append(loss)
|
||||
if (cfg.training.max_val_steps is not None) \
|
||||
and batch_idx >= (cfg.training.max_val_steps-1):
|
||||
break
|
||||
if len(val_losses) > 0:
|
||||
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
||||
# log epoch average validation loss
|
||||
step_log['val_loss'] = val_loss
|
||||
# run validation
|
||||
if (self.epoch % cfg.training.val_every) == 0:
|
||||
with torch.no_grad():
|
||||
val_losses = list()
|
||||
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||
for batch_idx, batch in enumerate(tepoch):
|
||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||
loss = self.model.compute_loss(batch)
|
||||
val_losses.append(loss)
|
||||
if (cfg.training.max_val_steps is not None) \
|
||||
and batch_idx >= (cfg.training.max_val_steps-1):
|
||||
break
|
||||
if len(val_losses) > 0:
|
||||
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
||||
# log epoch average validation loss
|
||||
step_log['val_loss'] = val_loss
|
||||
|
||||
# run diffusion sampling on a training batch
|
||||
if (self.epoch % cfg.training.sample_every) == 0:
|
||||
with torch.no_grad():
|
||||
# sample trajectory from training set, and evaluate difference
|
||||
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
|
||||
obs_dict = batch['obs']
|
||||
gt_action = batch['action']
|
||||
|
||||
result = policy.predict_action(obs_dict)
|
||||
pred_action = result['action_pred']
|
||||
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
||||
step_log['train_action_mse_error'] = mse.item()
|
||||
del batch
|
||||
del obs_dict
|
||||
del gt_action
|
||||
del result
|
||||
del pred_action
|
||||
del mse
|
||||
|
||||
# checkpoint
|
||||
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
||||
# checkpointing
|
||||
if cfg.checkpoint.save_last_ckpt:
|
||||
self.save_checkpoint()
|
||||
if cfg.checkpoint.save_last_snapshot:
|
||||
self.save_snapshot()
|
||||
|
||||
# sanitize metric names
|
||||
metric_dict = dict()
|
||||
for key, value in step_log.items():
|
||||
new_key = key.replace('/', '_')
|
||||
metric_dict[new_key] = value
|
||||
# run diffusion sampling on a training batch
|
||||
if (self.epoch % cfg.training.sample_every) == 0:
|
||||
with torch.no_grad():
|
||||
# sample trajectory from training set, and evaluate difference
|
||||
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
|
||||
obs_dict = batch['obs']
|
||||
gt_action = batch['action']
|
||||
|
||||
result = policy.predict_action(obs_dict)
|
||||
pred_action = result['action_pred']
|
||||
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
||||
step_log['train_action_mse_error'] = mse.item()
|
||||
del batch
|
||||
del obs_dict
|
||||
del gt_action
|
||||
del result
|
||||
del pred_action
|
||||
del mse
|
||||
|
||||
# We can't copy the last checkpoint here
|
||||
# since save_checkpoint uses threads.
|
||||
# therefore at this point the file might have been empty!
|
||||
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
||||
# checkpoint
|
||||
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
||||
# checkpointing
|
||||
if cfg.checkpoint.save_last_ckpt:
|
||||
self.save_checkpoint()
|
||||
if cfg.checkpoint.save_last_snapshot:
|
||||
self.save_snapshot()
|
||||
|
||||
if topk_ckpt_path is not None:
|
||||
self.save_checkpoint(path=topk_ckpt_path)
|
||||
# ========= eval end for this epoch ==========
|
||||
policy.train()
|
||||
# sanitize metric names
|
||||
metric_dict = dict()
|
||||
for key, value in step_log.items():
|
||||
new_key = key.replace('/', '_')
|
||||
metric_dict[new_key] = value
|
||||
|
||||
# We can't copy the last checkpoint here
|
||||
# since save_checkpoint uses threads.
|
||||
# therefore at this point the file might have been empty!
|
||||
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
||||
|
||||
# end of epoch
|
||||
# log of last step is combined with validation and rollout
|
||||
wandb_run.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
self.epoch += 1
|
||||
if topk_ckpt_path is not None:
|
||||
self.save_checkpoint(path=topk_ckpt_path)
|
||||
# ========= eval end for this epoch ==========
|
||||
policy.train()
|
||||
|
||||
# end of epoch
|
||||
# log of last step is combined with validation and rollout
|
||||
logging_backend.log(step_log, step=self.global_step)
|
||||
json_logger.log(step_log)
|
||||
self.global_step += 1
|
||||
self.epoch += 1
|
||||
|
||||
@hydra.main(
|
||||
version_base=None,
|
||||
|
||||
168
docs/superpowers/specs/2026-03-26-pusht-imf-swanlab-design.md
Normal file
168
docs/superpowers/specs/2026-03-26-pusht-imf-swanlab-design.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# PushT Image DiT iMF + SwanLab Design
|
||||
|
||||
## Goal
|
||||
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using `test_mean_score` as the primary metric.
|
||||
|
||||
## Context
|
||||
- The implementation baseline is `main`.
|
||||
- The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
|
||||
- Environment management must use the repo-local `uv` workflow.
|
||||
- The trusted remote machine alias `5880` refers to `droid-system-product-name` (`droid@100.73.14.65`) and can run two GPU jobs in parallel.
|
||||
|
||||
## Architecture Overview
|
||||
The work is split into two verified phases:
|
||||
|
||||
1. **Logging migration phase**
|
||||
- Keep the existing PushT image DiT training behavior intact.
|
||||
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
|
||||
- Preserve local `logs.json.txt` output.
|
||||
- Ensure rollout metrics such as `test_mean_score` and per-seed rewards are still logged.
|
||||
- Disable simulation video logging at both the config and runner/logging boundary.
|
||||
|
||||
2. **iMF migration phase**
|
||||
- Keep the original diffusion-based transformer image policy available on `main`.
|
||||
- Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
|
||||
- Reuse the existing observation encoder and training workspace where possible.
|
||||
- Replace diffusion training with the iMeanFlow training objective.
|
||||
- Use one-step inference for validation/rollout in the iMF path.
|
||||
|
||||
The implementation planning boundary for this spec is:
|
||||
- code changes through a smoke-tested, pushed iMF branch
|
||||
- not the full 3x3 sweep execution/monitoring workflow, which should be planned separately after the code path is verified and pushed
|
||||
|
||||
## Logging Design
|
||||
### Scope
|
||||
Only the PushT image DiT experiment chain is changed:
|
||||
- `train_diffusion_transformer_hybrid_workspace.py`
|
||||
- `pusht_image_runner.py`
|
||||
- the new/updated PushT image transformer configs
|
||||
|
||||
### Behavior
|
||||
- SwanLab runs in `online` mode.
|
||||
- Logged values are scalar metrics only, e.g.:
|
||||
- `train_loss`
|
||||
- `val_loss`
|
||||
- `train_action_mse_error`
|
||||
- `test_mean_score`
|
||||
- aggregate rollout metrics and optional per-seed scalar rewards
|
||||
- No simulation videos are uploaded or wrapped as logging objects.
|
||||
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
|
||||
|
||||
### Operational safeguards
|
||||
- Default PushT experiment configs set `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`.
|
||||
- The PushT image runner will not emit video objects into `log_data`, preventing accidental uploads even if visualization counts are later changed.
|
||||
- SwanLab credentials are provided through the environment at runtime, not committed into the repo.
|
||||
|
||||
## iMF Model Design
|
||||
### Baseline reuse
|
||||
The iMF path reuses:
|
||||
- the existing image observation encoder
|
||||
- the existing action/observation normalization path
|
||||
- the existing training workspace skeleton
|
||||
- the existing PushT image dataset and env runner
|
||||
|
||||
### New files
|
||||
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
||||
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
|
||||
- `image_pusht_diffusion_policy_dit_imf.yaml`
|
||||
|
||||
### Existing files changed for the iMF path
|
||||
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
|
||||
- logging migration to SwanLab for this experiment chain
|
||||
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
|
||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
||||
- suppress video objects in returned logs
|
||||
|
||||
### Model structure
|
||||
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
|
||||
- `u`: average velocity field
|
||||
|
||||
The same function is reused at two evaluation points:
|
||||
- canonical signature: `fn(z, r, t, cond)`
|
||||
- `fn(z_t, r, t, cond)` predicts average velocity `u`
|
||||
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
|
||||
|
||||
Inputs remain conditioned on encoded observations and action trajectory tokens.
|
||||
|
||||
## iMF Training Objective
|
||||
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
|
||||
1. sample `t, r`
|
||||
2. sample Gaussian noise `e`
|
||||
3. form `z_t = (1 - t) * x + t * e`
|
||||
4. predict instantaneous velocity surrogate with the same network:
|
||||
- `v = fn(z_t, t, t, cond)`
|
||||
5. define the JVP function exactly as:
|
||||
- `g(z, r, t) = fn(z, r, t, cond)`
|
||||
6. compute the primal output and JVP with tangent:
|
||||
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
|
||||
7. form compound velocity:
|
||||
- `V = u + (t - r) * stopgrad(du_dt)`
|
||||
8. train against the average-velocity target:
|
||||
- `target = e - x`
|
||||
9. optimize only the masked iMF loss:
|
||||
- `loss = metric(V - target)`
|
||||
|
||||
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
|
||||
|
||||
## iMF Inference Design
|
||||
Inference uses a single step starting from noise:
|
||||
- initialize `z_1 ~ N(0, I)`
|
||||
- set `t = 1.0`, `r = 0.0`
|
||||
- predict `u = fn(z_1, r, t, cond)`
|
||||
- produce the action sample with one update:
|
||||
- `x_hat = z_1 - (t - r) * u`
|
||||
|
||||
This matches the time direction in the reference iMeanFlow sampling logic.
|
||||
|
||||
## Testing Strategy
|
||||
### Phase 1: logging migration smoke test
|
||||
- use the repo-local `uv` environment
|
||||
- run a debug/smoke PushT image DiT training job on a single GPU with:
|
||||
- `training.debug=true`
|
||||
- `dataloader.num_workers=0`
|
||||
- `val_dataloader.num_workers=0`
|
||||
- `task.env_runner.n_envs=1`
|
||||
- `task.env_runner.n_test_vis=0`
|
||||
- `task.env_runner.n_train_vis=0`
|
||||
- verify:
|
||||
- SwanLab initializes successfully
|
||||
- `logs.json.txt` is populated
|
||||
- rollout metrics still include `test_mean_score`
|
||||
- no video logging is attempted
|
||||
|
||||
### Phase 2: iMF smoke test
|
||||
- run an equivalent debug PushT image iMF job
|
||||
- verify:
|
||||
- forward/backward passes succeed
|
||||
- JVP path executes on the local Torch version
|
||||
- one-step inference returns correctly shaped actions
|
||||
- rollout produces scalar metrics including `test_mean_score`
|
||||
|
||||
## Branch and Commit Strategy
|
||||
1. start from a `main`-based worktree branch
|
||||
2. commit the SwanLab/no-video migration after smoke verification
|
||||
3. continue with the iMF implementation
|
||||
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
|
||||
|
||||
## Post-Implementation Experiment Plan
|
||||
After the iMF path is smoke-tested and pushed, a separate experiment-execution plan should launch:
|
||||
- run a 3x3 grid over:
|
||||
- `n_emb ∈ {128, 256, 384}`
|
||||
- `n_layer ∈ {6, 12, 18}`
|
||||
- keep the rest of the setup fixed
|
||||
- use a fixed single-seed setting for comparability unless a later explicit experiment plan expands that scope
|
||||
- run each experiment for 300 epochs
|
||||
- primary comparison metric: `test_mean_score`
|
||||
|
||||
## Post-Implementation Resource Allocation
|
||||
The separate experiment-execution plan should schedule three concurrent runs until the matrix is complete:
|
||||
- local machine: 1 GPU
|
||||
- `5880`: 2 GPUs
|
||||
|
||||
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
|
||||
|
||||
## Risks and Mitigations
|
||||
- **Torch JVP compatibility risk**: provide a fallback JVP implementation and smoke-test immediately.
|
||||
- **Logging regression risk**: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
|
||||
- **Video/logging side effects**: disable visualizations in config and filter video objects out of runner logs.
|
||||
- **Cross-host drift**: push the verified branch to Gitea before launching the experiment matrix on multiple machines.
|
||||
30
image_pusht_diffusion_policy_dit.yaml
Normal file
30
image_pusht_diffusion_policy_dit.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
32
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
32
image_pusht_diffusion_policy_dit_imf.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
defaults:
|
||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||
- override /diffusion_policy/config/task@task: pusht_image
|
||||
- _self_
|
||||
|
||||
exp_name: pusht_image_dit_imf
|
||||
|
||||
policy:
|
||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
||||
num_inference_steps: 1
|
||||
n_head: 1
|
||||
|
||||
logging:
|
||||
backend: swanlab
|
||||
mode: online
|
||||
name: ${exp_name}
|
||||
resume: false
|
||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||
id: null
|
||||
group: ${exp_name}
|
||||
|
||||
dataloader:
|
||||
num_workers: 0
|
||||
|
||||
val_dataloader:
|
||||
num_workers: 0
|
||||
|
||||
task:
|
||||
env_runner:
|
||||
n_envs: 1
|
||||
n_test_vis: 0
|
||||
n_train_vis: 0
|
||||
@@ -36,3 +36,4 @@ av==14.0.1
|
||||
pygame==2.5.2
|
||||
robomimic==0.2.0
|
||||
opencv-python-headless==4.10.0.84
|
||||
swanlab
|
||||
|
||||
46
tests/test_imf_transformer_for_diffusion.py
Normal file
46
tests/test_imf_transformer_for_diffusion.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import inspect
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
|
||||
IMFTransformerForDiffusion,
|
||||
)
|
||||
|
||||
|
||||
def test_imf_transformer_forward_signature_and_shape_single_head():
|
||||
signature = inspect.signature(IMFTransformerForDiffusion.forward)
|
||||
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
|
||||
assert signature.parameters['cond'].default is None
|
||||
|
||||
model = IMFTransformerForDiffusion(
|
||||
input_dim=3,
|
||||
output_dim=3,
|
||||
horizon=5,
|
||||
n_obs_steps=2,
|
||||
cond_dim=4,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
n_cond_layers=0,
|
||||
)
|
||||
model.configure_optimizers()
|
||||
|
||||
sample = torch.randn(2, 5, 3)
|
||||
r = torch.rand(2)
|
||||
t = torch.rand(2)
|
||||
cond = torch.randn(2, 2, 4)
|
||||
|
||||
pred_u = model(sample, r, t, cond=cond)
|
||||
|
||||
assert pred_u.shape == sample.shape
|
||||
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
313
tests/test_imf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
|
||||
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
|
||||
IMFTransformerHybridImagePolicy,
|
||||
)
|
||||
|
||||
|
||||
class ConstantModel(nn.Module):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
return torch.full_like(sample, self.value)
|
||||
|
||||
|
||||
class AffineModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
return sample * self.weight + (r + t).view(-1, 1, 1)
|
||||
|
||||
|
||||
class SumMixModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
def forward(self, sample, r, t, cond=None):
|
||||
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
|
||||
return mixed * self.weight + t.view(-1, 1, 1)
|
||||
|
||||
|
||||
class TrackingContext:
|
||||
def __init__(self):
|
||||
self.active = False
|
||||
self.enter_count = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
self.enter_count += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.active = False
|
||||
return False
|
||||
|
||||
|
||||
def make_policy(model):
|
||||
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
|
||||
ModuleAttrMixin.__init__(policy)
|
||||
policy.model = model
|
||||
return policy
|
||||
|
||||
|
||||
def fake_parent_init(
|
||||
self,
|
||||
shape_meta,
|
||||
noise_scheduler,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
crop_shape=(76, 76),
|
||||
obs_encoder_group_norm=False,
|
||||
eval_fixed_crop=False,
|
||||
n_layer=8,
|
||||
n_cond_layers=0,
|
||||
n_head=1,
|
||||
n_emb=256,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.3,
|
||||
causal_attn=True,
|
||||
time_as_cond=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
**kwargs,
|
||||
):
|
||||
ModuleAttrMixin.__init__(self)
|
||||
self.action_dim = shape_meta['action']['shape'][0]
|
||||
self.obs_feature_dim = 4
|
||||
self.obs_as_cond = obs_as_cond
|
||||
self.pred_action_steps_only = pred_action_steps_only
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.horizon = horizon
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shape_meta():
|
||||
return {
|
||||
'action': {'shape': [2]},
|
||||
'obs': {},
|
||||
}
|
||||
|
||||
|
||||
def test_sample_one_step_uses_imf_update_formula():
|
||||
policy = make_policy(ConstantModel(0.25))
|
||||
z_1 = torch.tensor([
|
||||
[[1.0, -1.0], [0.5, 0.0]],
|
||||
[[2.0, 3.0], [-2.0, 4.0]],
|
||||
])
|
||||
r = torch.zeros(z_1.shape[0])
|
||||
t = torch.ones(z_1.shape[0])
|
||||
|
||||
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
|
||||
|
||||
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
|
||||
assert torch.allclose(x_hat, expected)
|
||||
|
||||
|
||||
def test_compound_velocity_uses_detached_du_dt_term():
|
||||
policy = make_policy(ConstantModel(0.0))
|
||||
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
|
||||
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
|
||||
r = torch.tensor([0.2])
|
||||
t = torch.tensor([0.8])
|
||||
|
||||
compound = policy._compound_velocity(u, du_dt, r, t)
|
||||
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
|
||||
|
||||
assert torch.allclose(compound, expected)
|
||||
|
||||
compound.sum().backward()
|
||||
assert u.grad is not None
|
||||
assert du_dt.grad is None
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
|
||||
tracker = TrackingContext()
|
||||
|
||||
def fake_jvp(fn, primals, tangents):
|
||||
assert tracker.active is True
|
||||
return fn(*primals), torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||
|
||||
policy = make_policy(ConstantModel(0.5))
|
||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||
z_t = torch.randn(2, 3, 4)
|
||||
r = torch.rand(2, requires_grad=True)
|
||||
t = torch.rand(2, requires_grad=True)
|
||||
v = torch.randn_like(z_t, requires_grad=True)
|
||||
|
||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||
|
||||
assert tracker.enter_count == 1
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
|
||||
tracker = TrackingContext()
|
||||
|
||||
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
|
||||
assert tracker.active is True
|
||||
return fn(*primals), torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
|
||||
|
||||
policy = make_policy(ConstantModel(0.5))
|
||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
||||
z_t = torch.randn(2, 3, 4)
|
||||
r = torch.rand(2, requires_grad=True)
|
||||
t = torch.rand(2, requires_grad=True)
|
||||
v = torch.randn_like(z_t, requires_grad=True)
|
||||
|
||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
||||
|
||||
assert tracker.enter_count == 1
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_jvp(fn, primals, tangents):
|
||||
captured['tangents'] = tangents
|
||||
captured['primal_output'] = fn(*primals)
|
||||
return captured['primal_output'], torch.zeros_like(primals[0])
|
||||
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
||||
|
||||
policy = make_policy(SumMixModel())
|
||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
|
||||
r = torch.rand(1, requires_grad=True)
|
||||
t = torch.rand(1, requires_grad=True)
|
||||
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
|
||||
condition_mask = torch.tensor([[[False, True, False]]])
|
||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||
|
||||
policy._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=None,
|
||||
v=v,
|
||||
condition_data=condition_data,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
|
||||
tangent_v, tangent_r, tangent_t = captured['tangents']
|
||||
assert torch.equal(tangent_v, v.detach())
|
||||
assert tangent_v.requires_grad is False
|
||||
assert torch.equal(tangent_r, torch.zeros_like(r))
|
||||
assert torch.equal(tangent_t, torch.ones_like(t))
|
||||
|
||||
conditioned = z_t.clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
expected_primal = policy.model(conditioned, r, t, cond=None)
|
||||
assert torch.allclose(captured['primal_output'], expected_primal)
|
||||
|
||||
|
||||
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
|
||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
||||
|
||||
policy = make_policy(SumMixModel())
|
||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
|
||||
r = torch.rand(1, requires_grad=True)
|
||||
t = torch.rand(1, requires_grad=True)
|
||||
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
|
||||
condition_mask = torch.tensor([[[False, True, False]]])
|
||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
||||
|
||||
u, du_dt = policy._compute_u_and_du_dt(
|
||||
z_t,
|
||||
r,
|
||||
t,
|
||||
cond=None,
|
||||
v=v,
|
||||
condition_data=condition_data,
|
||||
condition_mask=condition_mask,
|
||||
)
|
||||
|
||||
conditioned = z_t.detach().clone()
|
||||
conditioned[condition_mask] = condition_data[condition_mask]
|
||||
expected_u = policy.model(conditioned, r, t, cond=None)
|
||||
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
|
||||
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
|
||||
|
||||
assert u.shape == z_t.shape
|
||||
assert du_dt.shape == z_t.shape
|
||||
assert torch.allclose(u, expected_u)
|
||||
assert torch.allclose(du_dt, expected_du_dt)
|
||||
|
||||
u.sum().backward()
|
||||
assert policy.model.weight.grad is not None
|
||||
assert torch.count_nonzero(policy.model.weight.grad) > 0
|
||||
|
||||
|
||||
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
|
||||
monkeypatch.setattr(
|
||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||
'__init__',
|
||||
fake_parent_init,
|
||||
)
|
||||
|
||||
policy = IMFTransformerHybridImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=None,
|
||||
horizon=10,
|
||||
n_action_steps=4,
|
||||
n_obs_steps=2,
|
||||
num_inference_steps=1,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=True,
|
||||
)
|
||||
|
||||
assert policy.model.horizon == 4
|
||||
assert policy.num_inference_steps == 1
|
||||
|
||||
|
||||
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
|
||||
monkeypatch.setattr(
|
||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
||||
'__init__',
|
||||
fake_parent_init,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='num_inference_steps'):
|
||||
IMFTransformerHybridImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=None,
|
||||
horizon=10,
|
||||
n_action_steps=4,
|
||||
n_obs_steps=2,
|
||||
num_inference_steps=2,
|
||||
n_layer=1,
|
||||
n_head=1,
|
||||
n_emb=16,
|
||||
p_drop_emb=0.0,
|
||||
p_drop_attn=0.0,
|
||||
causal_attn=True,
|
||||
obs_as_cond=True,
|
||||
pred_action_steps_only=False,
|
||||
)
|
||||
110
tests/test_pusht_image_runner_metrics.py
Normal file
110
tests/test_pusht_image_runner_metrics.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
import diffusion_policy.env_runner.pusht_image_runner as runner_module
|
||||
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
|
||||
|
||||
|
||||
class FakePushTImageEnv(gym.Env):
|
||||
metadata = {'render.modes': ['rgb_array']}
|
||||
|
||||
def __init__(self, legacy=False, render_size=96):
|
||||
del legacy, render_size
|
||||
self.observation_space = spaces.Dict({
|
||||
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
|
||||
})
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
|
||||
self.seed_value = 0
|
||||
self.step_count = 0
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.seed_value = 0 if seed is None else seed
|
||||
|
||||
def reset(self):
|
||||
self.step_count = 0
|
||||
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
self.step_count += 1
|
||||
reward = 0.1 if self.seed_value < 10000 else 0.9
|
||||
done = self.step_count >= 1
|
||||
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
|
||||
return obs, reward, done, {}
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
|
||||
|
||||
|
||||
class FakePolicy:
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
def reset(self):
|
||||
return None
|
||||
|
||||
def predict_action(self, obs_dict):
|
||||
n_envs = next(iter(obs_dict.values())).shape[0]
|
||||
return {
|
||||
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
|
||||
}
|
||||
|
||||
|
||||
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
|
||||
log_data = summarize_rollout_metrics(
|
||||
env_seeds=[11, 12, 101],
|
||||
env_prefixs=['train/', 'train/', 'test/'],
|
||||
all_rewards=[
|
||||
[0.2, 0.8],
|
||||
[0.1, 0.4],
|
||||
[0.5, 0.9],
|
||||
],
|
||||
all_video_paths=[
|
||||
'/tmp/train-11.mp4',
|
||||
'/tmp/train-12.mp4',
|
||||
'/tmp/test-101.mp4',
|
||||
],
|
||||
)
|
||||
|
||||
assert log_data['train/sim_max_reward_11'] == 0.8
|
||||
assert log_data['train/sim_max_reward_12'] == 0.4
|
||||
assert log_data['test/sim_max_reward_101'] == 0.9
|
||||
assert log_data['train_mean_score'] == pytest.approx(0.6)
|
||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||
assert not any(key.startswith('train/sim_video_') for key in log_data)
|
||||
assert not any(key.startswith('test/sim_video_') for key in log_data)
|
||||
|
||||
|
||||
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
|
||||
|
||||
runner = runner_module.PushTImageRunner(
|
||||
output_dir=tmp_path,
|
||||
n_train=1,
|
||||
n_train_vis=1,
|
||||
n_test=1,
|
||||
n_test_vis=1,
|
||||
n_envs=2,
|
||||
max_steps=2,
|
||||
n_obs_steps=2,
|
||||
n_action_steps=2,
|
||||
tqdm_interval_sec=0.0,
|
||||
)
|
||||
|
||||
log_data = runner.run(FakePolicy())
|
||||
|
||||
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
|
||||
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
|
||||
assert log_data['train_mean_score'] == pytest.approx(0.1)
|
||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||
assert not any('sim_video' in key for key in log_data)
|
||||
32
tests/test_pusht_swanlab_config.py
Normal file
32
tests/test_pusht_swanlab_config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pathlib
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _load_cfg(name: str):
|
||||
return OmegaConf.load(ROOT_DIR / name)
|
||||
|
||||
|
||||
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
|
||||
|
||||
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
|
||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
|
||||
|
||||
assert cfg.logging.backend == 'swanlab'
|
||||
assert cfg.logging.mode == 'online'
|
||||
assert cfg.logging.name == cfg.exp_name
|
||||
assert cfg.logging.resume is False
|
||||
assert cfg.logging.id is None
|
||||
assert cfg.logging.group == cfg.exp_name
|
||||
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
|
||||
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
|
||||
|
||||
|
||||
def load_workspace_module(monkeypatch, *, wandb_missing=False):
|
||||
sys.modules.pop(MODULE_NAME, None)
|
||||
if wandb_missing:
|
||||
monkeypatch.setitem(sys.modules, 'wandb', None)
|
||||
return importlib.import_module(MODULE_NAME)
|
||||
|
||||
|
||||
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
|
||||
events = []
|
||||
|
||||
class FakeRun:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
class FakeSwanLab:
|
||||
def init(self, **kwargs):
|
||||
events.append(('init', kwargs))
|
||||
return FakeRun()
|
||||
|
||||
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'_load_wandb',
|
||||
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'backend': 'swanlab',
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'group': 'demo-group',
|
||||
'tags': ['pusht', 'dit'],
|
||||
'id': 'run-123',
|
||||
'resume': True,
|
||||
'mode': 'online',
|
||||
}
|
||||
})
|
||||
|
||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
logger.log({'metric': 1.0}, step=7)
|
||||
logger.finish()
|
||||
|
||||
assert events[0][0] == 'init'
|
||||
init_kwargs = events[0][1]
|
||||
assert init_kwargs['project'] == 'demo-project'
|
||||
assert init_kwargs['experiment_name'] == 'demo-run'
|
||||
assert init_kwargs['group'] == 'demo-group'
|
||||
assert init_kwargs['tags'] == ['pusht', 'dit']
|
||||
assert init_kwargs['id'] == 'run-123'
|
||||
assert init_kwargs['resume'] is True
|
||||
assert init_kwargs['mode'] == 'cloud'
|
||||
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
|
||||
assert ('log', {'metric': 1.0}, 7) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
|
||||
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeRun:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
class FakeConfig:
|
||||
def update(self, payload):
|
||||
events.append(('config.update', payload))
|
||||
|
||||
class FakeWandb:
|
||||
def __init__(self):
|
||||
self.config = FakeConfig()
|
||||
|
||||
def init(self, **kwargs):
|
||||
events.append(('init', kwargs))
|
||||
return FakeRun()
|
||||
|
||||
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
|
||||
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'group': None,
|
||||
'tags': ['shared'],
|
||||
'id': None,
|
||||
'resume': True,
|
||||
'mode': 'online',
|
||||
}
|
||||
})
|
||||
|
||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
logger.log({'metric': 2.0}, step=3)
|
||||
logger.finish()
|
||||
|
||||
assert events[0][0] == 'init'
|
||||
init_kwargs = events[0][1]
|
||||
assert init_kwargs['dir'] == str(tmp_path)
|
||||
assert init_kwargs['project'] == 'demo-project'
|
||||
assert init_kwargs['name'] == 'demo-run'
|
||||
assert init_kwargs['mode'] == 'online'
|
||||
assert ('config.update', {'output_dir': str(tmp_path)}) in events
|
||||
assert ('log', {'metric': 2.0}, 3) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
|
||||
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
cfg = OmegaConf.create({
|
||||
'logging': {
|
||||
'backend': 'tensorboard',
|
||||
'project': 'demo-project',
|
||||
'name': 'demo-run',
|
||||
'mode': 'offline',
|
||||
}
|
||||
})
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown logging backend'):
|
||||
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeBackend:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
raise RuntimeError('finish boom')
|
||||
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'init_logging_backend',
|
||||
lambda cfg, output_dir: FakeBackend(),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||
|
||||
with pytest.raises(ValueError, match='primary boom'):
|
||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||
logger.log({'metric': 6.0}, step=12)
|
||||
raise ValueError('primary boom')
|
||||
|
||||
assert ('log', {'metric': 6.0}, 12) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
|
||||
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
|
||||
workspace_module = load_workspace_module(monkeypatch)
|
||||
events = []
|
||||
|
||||
class FakeBackend:
|
||||
def log(self, payload, step=None):
|
||||
events.append(('log', payload, step))
|
||||
|
||||
def finish(self):
|
||||
events.append(('finish',))
|
||||
|
||||
monkeypatch.setattr(
|
||||
workspace_module,
|
||||
'init_logging_backend',
|
||||
lambda cfg, output_dir: FakeBackend(),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||
|
||||
with pytest.raises(RuntimeError, match='boom'):
|
||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||
logger.log({'metric': 5.0}, step=11)
|
||||
raise RuntimeError('boom')
|
||||
|
||||
assert ('log', {'metric': 5.0}, 11) in events
|
||||
assert events.count(('finish',)) == 1
|
||||
Reference in New Issue
Block a user