Compare commits
3 Commits
feat/pusht
...
DiT-imageP
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42dc29a2cb | ||
|
|
79f31940c4 | ||
|
|
2aa06c8917 |
@@ -1,37 +1,21 @@
|
|||||||
|
import wandb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import collections
|
import collections
|
||||||
|
import pathlib
|
||||||
import tqdm
|
import tqdm
|
||||||
import dill
|
import dill
|
||||||
import math
|
import math
|
||||||
|
import wandb.sdk.data_types.video as wv
|
||||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||||
|
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||||
|
|
||||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||||
from diffusion_policy.common.pytorch_util import dict_apply
|
from diffusion_policy.common.pytorch_util import dict_apply
|
||||||
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
||||||
|
|
||||||
|
|
||||||
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
|
|
||||||
del all_video_paths
|
|
||||||
|
|
||||||
max_rewards = collections.defaultdict(list)
|
|
||||||
log_data = dict()
|
|
||||||
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
|
|
||||||
max_reward = np.max(rewards)
|
|
||||||
max_rewards[prefix].append(max_reward)
|
|
||||||
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
|
|
||||||
|
|
||||||
aggregate_key_map = {
|
|
||||||
'train/': 'train_mean_score',
|
|
||||||
'test/': 'test_mean_score',
|
|
||||||
}
|
|
||||||
for prefix, value in max_rewards.items():
|
|
||||||
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
|
|
||||||
|
|
||||||
return log_data
|
|
||||||
|
|
||||||
class PushTImageRunner(BaseImageRunner):
|
class PushTImageRunner(BaseImageRunner):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
output_dir,
|
output_dir,
|
||||||
@@ -56,12 +40,25 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
if n_envs is None:
|
if n_envs is None:
|
||||||
n_envs = n_train + n_test
|
n_envs = n_train + n_test
|
||||||
|
|
||||||
|
steps_per_render = max(10 // fps, 1)
|
||||||
def env_fn():
|
def env_fn():
|
||||||
return MultiStepWrapper(
|
return MultiStepWrapper(
|
||||||
|
VideoRecordingWrapper(
|
||||||
PushTImageEnv(
|
PushTImageEnv(
|
||||||
legacy=legacy_test,
|
legacy=legacy_test,
|
||||||
render_size=render_size
|
render_size=render_size
|
||||||
),
|
),
|
||||||
|
video_recoder=VideoRecorder.create_h264(
|
||||||
|
fps=fps,
|
||||||
|
codec='h264',
|
||||||
|
input_pix_fmt='rgb24',
|
||||||
|
crf=crf,
|
||||||
|
thread_type='FRAME',
|
||||||
|
thread_count=1
|
||||||
|
),
|
||||||
|
file_path=None,
|
||||||
|
steps_per_render=steps_per_render
|
||||||
|
),
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
n_action_steps=n_action_steps,
|
n_action_steps=n_action_steps,
|
||||||
max_episode_steps=max_steps
|
max_episode_steps=max_steps
|
||||||
@@ -74,8 +71,21 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
# train
|
# train
|
||||||
for i in range(n_train):
|
for i in range(n_train):
|
||||||
seed = train_start_seed + i
|
seed = train_start_seed + i
|
||||||
|
enable_render = i < n_train_vis
|
||||||
|
|
||||||
|
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||||
|
# setup rendering
|
||||||
|
# video_wrapper
|
||||||
|
assert isinstance(env.env, VideoRecordingWrapper)
|
||||||
|
env.env.video_recoder.stop()
|
||||||
|
env.env.file_path = None
|
||||||
|
if enable_render:
|
||||||
|
filename = pathlib.Path(output_dir).joinpath(
|
||||||
|
'media', wv.util.generate_id() + ".mp4")
|
||||||
|
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||||
|
filename = str(filename)
|
||||||
|
env.env.file_path = filename
|
||||||
|
|
||||||
def init_fn(env, seed=seed):
|
|
||||||
# set seed
|
# set seed
|
||||||
assert isinstance(env, MultiStepWrapper)
|
assert isinstance(env, MultiStepWrapper)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -87,8 +97,21 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
# test
|
# test
|
||||||
for i in range(n_test):
|
for i in range(n_test):
|
||||||
seed = test_start_seed + i
|
seed = test_start_seed + i
|
||||||
|
enable_render = i < n_test_vis
|
||||||
|
|
||||||
|
def init_fn(env, seed=seed, enable_render=enable_render):
|
||||||
|
# setup rendering
|
||||||
|
# video_wrapper
|
||||||
|
assert isinstance(env.env, VideoRecordingWrapper)
|
||||||
|
env.env.video_recoder.stop()
|
||||||
|
env.env.file_path = None
|
||||||
|
if enable_render:
|
||||||
|
filename = pathlib.Path(output_dir).joinpath(
|
||||||
|
'media', wv.util.generate_id() + ".mp4")
|
||||||
|
filename.parent.mkdir(parents=False, exist_ok=True)
|
||||||
|
filename = str(filename)
|
||||||
|
env.env.file_path = filename
|
||||||
|
|
||||||
def init_fn(env, seed=seed):
|
|
||||||
# set seed
|
# set seed
|
||||||
assert isinstance(env, MultiStepWrapper)
|
assert isinstance(env, MultiStepWrapper)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -131,6 +154,7 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
n_chunks = math.ceil(n_inits / n_envs)
|
n_chunks = math.ceil(n_inits / n_envs)
|
||||||
|
|
||||||
# allocate data
|
# allocate data
|
||||||
|
all_video_paths = [None] * n_inits
|
||||||
all_rewards = [None] * n_inits
|
all_rewards = [None] * n_inits
|
||||||
|
|
||||||
for chunk_idx in range(n_chunks):
|
for chunk_idx in range(n_chunks):
|
||||||
@@ -190,16 +214,39 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
pbar.update(action.shape[1])
|
pbar.update(action.shape[1])
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
||||||
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
||||||
# reset env state between evaluation calls
|
# clear out video buffer
|
||||||
_ = env.reset()
|
_ = env.reset()
|
||||||
|
|
||||||
# results reported in the paper are generated using the commented out
|
# log
|
||||||
# line below, which would only report and average metrics from the
|
max_rewards = collections.defaultdict(list)
|
||||||
# first n_envs initial conditions and seeds. We keep the full n_inits
|
log_data = dict()
|
||||||
# behavior here.
|
# results reported in the paper are generated using the commented out line below
|
||||||
return summarize_rollout_metrics(
|
# which will only report and average metrics from first n_envs initial condition and seeds
|
||||||
env_seeds=self.env_seeds[:n_inits],
|
# fortunately this won't invalidate our conclusion since
|
||||||
env_prefixs=self.env_prefixs[:n_inits],
|
# 1. This bug only affects the variance of metrics, not their mean
|
||||||
all_rewards=all_rewards[:n_inits],
|
# 2. All baseline methods are evaluated using the same code
|
||||||
)
|
# to completely reproduce reported numbers, uncomment this line:
|
||||||
|
# for i in range(len(self.env_fns)):
|
||||||
|
# and comment out this line
|
||||||
|
for i in range(n_inits):
|
||||||
|
seed = self.env_seeds[i]
|
||||||
|
prefix = self.env_prefixs[i]
|
||||||
|
max_reward = np.max(all_rewards[i])
|
||||||
|
max_rewards[prefix].append(max_reward)
|
||||||
|
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
||||||
|
|
||||||
|
# visualize sim
|
||||||
|
video_path = all_video_paths[i]
|
||||||
|
if video_path is not None:
|
||||||
|
sim_video = wandb.Video(video_path)
|
||||||
|
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
||||||
|
|
||||||
|
# log aggregate metrics
|
||||||
|
for prefix, value in max_rewards.items():
|
||||||
|
name = prefix+'mean_score'
|
||||||
|
value = np.mean(value)
|
||||||
|
log_data[name] = value
|
||||||
|
|
||||||
|
return log_data
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ import dill
|
|||||||
import math
|
import math
|
||||||
import wandb.sdk.data_types.video as wv
|
import wandb.sdk.data_types.video as wv
|
||||||
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
|
||||||
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||||
# from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
|
||||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
||||||
|
|
||||||
@@ -133,7 +132,7 @@ class PushTKeypointsRunner(BaseLowdimRunner):
|
|||||||
env_prefixs.append('test/')
|
env_prefixs.append('test/')
|
||||||
env_init_fn_dills.append(dill.dumps(init_fn))
|
env_init_fn_dills.append(dill.dumps(init_fn))
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
|
|
||||||
# test env
|
# test env
|
||||||
# env.reset(seed=env_seeds)
|
# env.reset(seed=env_seeds)
|
||||||
|
|||||||
@@ -1,247 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
return (x.float() * rms).to(x.dtype) * self.weight
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNormNoWeight(nn.Module):
|
|
||||||
def __init__(self, eps: float = 1e-6) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
rms = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
return (x.float() * rms).to(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_rope_freqs(
|
|
||||||
dim: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
theta: float = 10000.0,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if dim % 2 != 0:
|
|
||||||
raise ValueError(f'RoPE requires an even head dimension, got {dim}.')
|
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
|
||||||
positions = torch.arange(max_seq_len, device=device).float()
|
|
||||||
angles = torch.outer(positions, freqs)
|
|
||||||
return torch.polar(torch.ones_like(angles), angles)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(x: Tensor, freqs: Tensor) -> Tensor:
|
|
||||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
||||||
freqs = freqs.unsqueeze(0).unsqueeze(2)
|
|
||||||
x_rotated = x_complex * freqs
|
|
||||||
return torch.view_as_real(x_rotated).reshape_as(x).to(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupedQuerySelfAttention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_heads: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if d_model % n_heads != 0:
|
|
||||||
raise ValueError(f'd_model={d_model} must be divisible by n_heads={n_heads}.')
|
|
||||||
if n_heads % n_kv_heads != 0:
|
|
||||||
raise ValueError(f'n_heads={n_heads} must be divisible by n_kv_heads={n_kv_heads}.')
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_kv_heads = n_kv_heads
|
|
||||||
self.n_kv_groups = n_heads // n_kv_heads
|
|
||||||
self.d_head = d_model // n_heads
|
|
||||||
self.attn_dropout = nn.Dropout(dropout)
|
|
||||||
self.out_dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.w_q = nn.Linear(d_model, n_heads * self.d_head, bias=False)
|
|
||||||
self.w_k = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
|
||||||
self.w_v = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
|
||||||
self.w_o = nn.Linear(n_heads * self.d_head, d_model, bias=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rope_freqs: Tensor,
|
|
||||||
mask: Optional[Tensor] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
batch_size, seq_len, _ = x.shape
|
|
||||||
|
|
||||||
q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_head)
|
|
||||||
k = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
|
||||||
v = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
|
|
||||||
|
|
||||||
q = apply_rope(q, rope_freqs)
|
|
||||||
k = apply_rope(k, rope_freqs)
|
|
||||||
|
|
||||||
if self.n_kv_heads != self.n_heads:
|
|
||||||
k = k.unsqueeze(3).expand(
|
|
||||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
|
||||||
)
|
|
||||||
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
|
||||||
v = v.unsqueeze(3).expand(
|
|
||||||
batch_size, seq_len, self.n_kv_heads, self.n_kv_groups, self.d_head
|
|
||||||
)
|
|
||||||
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head)
|
|
||||||
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
scale = 1.0 / math.sqrt(self.d_head)
|
|
||||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
|
|
||||||
if mask is not None:
|
|
||||||
attn_weights = attn_weights + mask
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
|
||||||
|
|
||||||
out = torch.matmul(attn_weights, v)
|
|
||||||
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
|
|
||||||
return self.out_dropout(self.w_o(out))
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFFN(nn.Module):
|
|
||||||
def __init__(self, d_model: int, dropout: float = 0.0, mult: float = 2.667) -> None:
|
|
||||||
super().__init__()
|
|
||||||
raw = int(mult * d_model)
|
|
||||||
d_ff = ((raw + 7) // 8) * 8
|
|
||||||
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
|
|
||||||
self.w_up = nn.Linear(d_model, d_ff, bias=False)
|
|
||||||
self.w_down = nn.Linear(d_ff, d_model, bias=False)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class AttnResOperator(nn.Module):
|
|
||||||
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.pseudo_query = nn.Parameter(torch.zeros(d_model))
|
|
||||||
self.key_norm = RMSNormNoWeight(eps=eps)
|
|
||||||
|
|
||||||
def forward(self, sources: Tensor) -> Tensor:
|
|
||||||
keys = self.key_norm(sources)
|
|
||||||
logits = torch.einsum('d,nbtd->nbt', self.pseudo_query, keys)
|
|
||||||
weights = F.softmax(logits, dim=0)
|
|
||||||
return torch.einsum('nbt,nbtd->btd', weights, sources)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnResSubLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_heads: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
dropout: float,
|
|
||||||
ffn_mult: float,
|
|
||||||
eps: float,
|
|
||||||
is_attention: bool,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.norm = RMSNorm(d_model, eps=eps)
|
|
||||||
self.attn_res = AttnResOperator(d_model, eps=eps)
|
|
||||||
self.is_attention = is_attention
|
|
||||||
if is_attention:
|
|
||||||
self.fn = GroupedQuerySelfAttention(
|
|
||||||
d_model=d_model,
|
|
||||||
n_heads=n_heads,
|
|
||||||
n_kv_heads=n_kv_heads,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.fn = SwiGLUFFN(d_model=d_model, dropout=dropout, mult=ffn_mult)
|
|
||||||
|
|
||||||
def forward(self, sources: Tensor, rope_freqs: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
||||||
h = self.attn_res(sources)
|
|
||||||
normed = self.norm(h)
|
|
||||||
if self.is_attention:
|
|
||||||
return self.fn(normed, rope_freqs, mask)
|
|
||||||
return self.fn(normed)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnResTransformerBackbone(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_blocks: int,
|
|
||||||
n_heads: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
ffn_mult: float = 2.667,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
causal_attn: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.causal_attn = causal_attn
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for _ in range(n_blocks):
|
|
||||||
self.layers.append(
|
|
||||||
AttnResSubLayer(
|
|
||||||
d_model=d_model,
|
|
||||||
n_heads=n_heads,
|
|
||||||
n_kv_heads=n_kv_heads,
|
|
||||||
dropout=dropout,
|
|
||||||
ffn_mult=ffn_mult,
|
|
||||||
eps=eps,
|
|
||||||
is_attention=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.layers.append(
|
|
||||||
AttnResSubLayer(
|
|
||||||
d_model=d_model,
|
|
||||||
n_heads=n_heads,
|
|
||||||
n_kv_heads=n_kv_heads,
|
|
||||||
dropout=dropout,
|
|
||||||
ffn_mult=ffn_mult,
|
|
||||||
eps=eps,
|
|
||||||
is_attention=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
rope_freqs = precompute_rope_freqs(
|
|
||||||
dim=d_model // n_heads,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
theta=rope_theta,
|
|
||||||
)
|
|
||||||
self.register_buffer('rope_freqs', rope_freqs, persistent=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_causal_mask(seq_len: int, device: torch.device) -> Tensor:
|
|
||||||
mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
|
|
||||||
mask = torch.triu(mask, diagonal=1)
|
|
||||||
return mask.unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
seq_len = x.shape[1]
|
|
||||||
rope_freqs = self.rope_freqs[:seq_len]
|
|
||||||
mask = None
|
|
||||||
if self.causal_attn:
|
|
||||||
mask = self._build_causal_mask(seq_len, x.device)
|
|
||||||
|
|
||||||
layer_outputs = [x]
|
|
||||||
for layer in self.layers:
|
|
||||||
sources = torch.stack(layer_outputs, dim=0)
|
|
||||||
output = layer(sources, rope_freqs, mask)
|
|
||||||
layer_outputs.append(output)
|
|
||||||
|
|
||||||
return torch.stack(layer_outputs, dim=0).sum(dim=0)
|
|
||||||
@@ -1,383 +0,0 @@
|
|||||||
from typing import Optional, Tuple, Union
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
|
||||||
from diffusion_policy.model.diffusion.attnres_transformer_components import (
|
|
||||||
AttnResOperator,
|
|
||||||
AttnResSubLayer,
|
|
||||||
AttnResTransformerBackbone,
|
|
||||||
GroupedQuerySelfAttention,
|
|
||||||
RMSNorm,
|
|
||||||
RMSNormNoWeight,
|
|
||||||
SwiGLUFFN,
|
|
||||||
)
|
|
||||||
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class IMFTransformerForDiffusion(ModuleAttrMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_dim: int,
|
|
||||||
output_dim: int,
|
|
||||||
horizon: int,
|
|
||||||
n_obs_steps: int = None,
|
|
||||||
cond_dim: int = 0,
|
|
||||||
n_layer: int = 12,
|
|
||||||
n_head: int = 1,
|
|
||||||
n_emb: int = 768,
|
|
||||||
p_drop_emb: float = 0.1,
|
|
||||||
p_drop_attn: float = 0.1,
|
|
||||||
causal_attn: bool = False,
|
|
||||||
time_as_cond: bool = True,
|
|
||||||
obs_as_cond: bool = False,
|
|
||||||
n_cond_layers: int = 0,
|
|
||||||
backbone_type: str = 'vanilla',
|
|
||||||
n_kv_head: int = 1,
|
|
||||||
attn_res_ffn_mult: float = 2.667,
|
|
||||||
attn_res_eps: float = 1e-6,
|
|
||||||
attn_res_rope_theta: float = 10000.0,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert n_head == 1, 'IMFTransformerForDiffusion currently supports single-head attention only.'
|
|
||||||
if n_obs_steps is None:
|
|
||||||
n_obs_steps = horizon
|
|
||||||
|
|
||||||
self.backbone_type = backbone_type
|
|
||||||
|
|
||||||
T = horizon
|
|
||||||
T_cond = 2
|
|
||||||
if not time_as_cond:
|
|
||||||
T += 2
|
|
||||||
T_cond -= 2
|
|
||||||
obs_as_cond = cond_dim > 0
|
|
||||||
if obs_as_cond:
|
|
||||||
assert time_as_cond
|
|
||||||
T_cond += n_obs_steps
|
|
||||||
|
|
||||||
self.input_emb = nn.Linear(input_dim, n_emb)
|
|
||||||
self.drop = nn.Dropout(p_drop_emb)
|
|
||||||
self.time_emb = SinusoidalPosEmb(n_emb)
|
|
||||||
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
|
||||||
self.time_token_proj = None
|
|
||||||
self.cond_pos_emb = None
|
|
||||||
self.pos_emb = None
|
|
||||||
self.encoder = None
|
|
||||||
self.decoder = None
|
|
||||||
self.attnres_backbone = None
|
|
||||||
encoder_only = False
|
|
||||||
|
|
||||||
if backbone_type == 'attnres_full':
|
|
||||||
if not time_as_cond:
|
|
||||||
raise ValueError('attnres_full backbone requires time_as_cond=True.')
|
|
||||||
if n_cond_layers != 0:
|
|
||||||
raise ValueError('attnres_full backbone does not support n_cond_layers > 0.')
|
|
||||||
|
|
||||||
self.time_token_proj = nn.Linear(n_emb, n_emb)
|
|
||||||
self.attnres_backbone = AttnResTransformerBackbone(
|
|
||||||
d_model=n_emb,
|
|
||||||
n_blocks=n_layer,
|
|
||||||
n_heads=n_head,
|
|
||||||
n_kv_heads=n_kv_head,
|
|
||||||
max_seq_len=T + T_cond,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
ffn_mult=attn_res_ffn_mult,
|
|
||||||
eps=attn_res_eps,
|
|
||||||
rope_theta=attn_res_rope_theta,
|
|
||||||
causal_attn=causal_attn,
|
|
||||||
)
|
|
||||||
self.ln_f = RMSNorm(n_emb, eps=attn_res_eps)
|
|
||||||
else:
|
|
||||||
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
|
||||||
if T_cond > 0:
|
|
||||||
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
|
||||||
if n_cond_layers > 0:
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True,
|
|
||||||
)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
encoder_layer=encoder_layer,
|
|
||||||
num_layers=n_cond_layers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.encoder = nn.Sequential(
|
|
||||||
nn.Linear(n_emb, 4 * n_emb),
|
|
||||||
nn.Mish(),
|
|
||||||
nn.Linear(4 * n_emb, n_emb),
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder_layer = nn.TransformerDecoderLayer(
|
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True,
|
|
||||||
)
|
|
||||||
self.decoder = nn.TransformerDecoder(
|
|
||||||
decoder_layer=decoder_layer,
|
|
||||||
num_layers=n_layer,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_only = True
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(
|
|
||||||
d_model=n_emb,
|
|
||||||
nhead=n_head,
|
|
||||||
dim_feedforward=4 * n_emb,
|
|
||||||
dropout=p_drop_attn,
|
|
||||||
activation='gelu',
|
|
||||||
batch_first=True,
|
|
||||||
norm_first=True,
|
|
||||||
)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
encoder_layer=encoder_layer,
|
|
||||||
num_layers=n_layer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ln_f = nn.LayerNorm(n_emb)
|
|
||||||
|
|
||||||
if causal_attn and backbone_type != 'attnres_full':
|
|
||||||
sz = T
|
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
||||||
self.register_buffer('mask', mask)
|
|
||||||
|
|
||||||
if time_as_cond and obs_as_cond:
|
|
||||||
S = T_cond
|
|
||||||
t_idx, s_idx = torch.meshgrid(
|
|
||||||
torch.arange(T),
|
|
||||||
torch.arange(S),
|
|
||||||
indexing='ij',
|
|
||||||
)
|
|
||||||
mask = t_idx >= (s_idx - 2)
|
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
||||||
self.register_buffer('memory_mask', mask)
|
|
||||||
else:
|
|
||||||
self.memory_mask = None
|
|
||||||
else:
|
|
||||||
self.mask = None
|
|
||||||
self.memory_mask = None
|
|
||||||
|
|
||||||
self.head = nn.Linear(n_emb, output_dim)
|
|
||||||
|
|
||||||
self.T = T
|
|
||||||
self.T_cond = T_cond
|
|
||||||
self.horizon = horizon
|
|
||||||
self.time_as_cond = time_as_cond
|
|
||||||
self.obs_as_cond = obs_as_cond
|
|
||||||
self.encoder_only = encoder_only
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
logger.info(
|
|
||||||
'number of parameters: %e',
|
|
||||||
sum(p.numel() for p in self.parameters()),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
ignore_types = (
|
|
||||||
nn.Dropout,
|
|
||||||
SinusoidalPosEmb,
|
|
||||||
nn.TransformerEncoderLayer,
|
|
||||||
nn.TransformerDecoderLayer,
|
|
||||||
nn.TransformerEncoder,
|
|
||||||
nn.TransformerDecoder,
|
|
||||||
nn.ModuleList,
|
|
||||||
nn.Mish,
|
|
||||||
nn.Sequential,
|
|
||||||
AttnResTransformerBackbone,
|
|
||||||
AttnResSubLayer,
|
|
||||||
GroupedQuerySelfAttention,
|
|
||||||
SwiGLUFFN,
|
|
||||||
RMSNormNoWeight,
|
|
||||||
)
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
elif isinstance(module, nn.MultiheadAttention):
|
|
||||||
weight_names = ['in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
|
|
||||||
for name in weight_names:
|
|
||||||
weight = getattr(module, name)
|
|
||||||
if weight is not None:
|
|
||||||
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
|
||||||
|
|
||||||
bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
|
|
||||||
for name in bias_names:
|
|
||||||
bias = getattr(module, name)
|
|
||||||
if bias is not None:
|
|
||||||
torch.nn.init.zeros_(bias)
|
|
||||||
elif isinstance(module, (nn.LayerNorm, RMSNorm)):
|
|
||||||
if getattr(module, 'bias', None) is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
torch.nn.init.ones_(module.weight)
|
|
||||||
elif isinstance(module, AttnResOperator):
|
|
||||||
torch.nn.init.zeros_(module.pseudo_query)
|
|
||||||
elif isinstance(module, IMFTransformerForDiffusion):
|
|
||||||
if module.pos_emb is not None:
|
|
||||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
|
||||||
if module.cond_pos_emb is not None:
|
|
||||||
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
|
||||||
elif isinstance(module, ignore_types):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'Unaccounted module {module}')
|
|
||||||
|
|
||||||
def get_optim_groups(self, weight_decay: float = 1e-3):
|
|
||||||
decay = set()
|
|
||||||
no_decay = set()
|
|
||||||
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, RMSNorm)
|
|
||||||
for mn, m in self.named_modules():
|
|
||||||
for pn, _ in m.named_parameters(recurse=False):
|
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn
|
|
||||||
|
|
||||||
if pn.endswith('bias'):
|
|
||||||
no_decay.add(fpn)
|
|
||||||
elif pn.startswith('bias'):
|
|
||||||
no_decay.add(fpn)
|
|
||||||
elif pn == 'pseudo_query':
|
|
||||||
no_decay.add(fpn)
|
|
||||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
||||||
decay.add(fpn)
|
|
||||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
||||||
no_decay.add(fpn)
|
|
||||||
|
|
||||||
if self.pos_emb is not None:
|
|
||||||
no_decay.add('pos_emb')
|
|
||||||
no_decay.add('_dummy_variable')
|
|
||||||
if self.cond_pos_emb is not None:
|
|
||||||
no_decay.add('cond_pos_emb')
|
|
||||||
|
|
||||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
||||||
inter_params = decay & no_decay
|
|
||||||
union_params = decay | no_decay
|
|
||||||
assert len(inter_params) == 0, f'parameters {inter_params} made it into both decay/no_decay sets!'
|
|
||||||
assert len(param_dict.keys() - union_params) == 0, (
|
|
||||||
f'parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay sets!'
|
|
||||||
)
|
|
||||||
|
|
||||||
optim_groups = [
|
|
||||||
{
|
|
||||||
'params': [param_dict[pn] for pn in sorted(list(decay))],
|
|
||||||
'weight_decay': weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'params': [param_dict[pn] for pn in sorted(list(no_decay))],
|
|
||||||
'weight_decay': 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return optim_groups
|
|
||||||
|
|
||||||
def configure_optimizers(
|
|
||||||
self,
|
|
||||||
learning_rate: float = 1e-4,
|
|
||||||
weight_decay: float = 1e-3,
|
|
||||||
betas: Tuple[float, float] = (0.9, 0.95),
|
|
||||||
):
|
|
||||||
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def _prepare_time_input(self, value: Union[torch.Tensor, float, int], sample: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not torch.is_tensor(value):
|
|
||||||
value = torch.tensor([value], dtype=sample.dtype, device=sample.device)
|
|
||||||
elif value.ndim == 0:
|
|
||||||
value = value[None].to(device=sample.device, dtype=sample.dtype)
|
|
||||||
else:
|
|
||||||
value = value.to(device=sample.device, dtype=sample.dtype)
|
|
||||||
return value.expand(sample.shape[0])
|
|
||||||
|
|
||||||
def _forward_attnres_full(
|
|
||||||
self,
|
|
||||||
sample: torch.Tensor,
|
|
||||||
r: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
cond: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
sample_tokens = self.input_emb(sample)
|
|
||||||
token_parts = [
|
|
||||||
self.time_token_proj(self.time_emb(r)).unsqueeze(1),
|
|
||||||
self.time_token_proj(self.time_emb(t)).unsqueeze(1),
|
|
||||||
]
|
|
||||||
if self.obs_as_cond:
|
|
||||||
if cond is None:
|
|
||||||
raise ValueError('cond is required when obs_as_cond=True for attnres_full backbone.')
|
|
||||||
token_parts.append(self.cond_obs_emb(cond))
|
|
||||||
token_parts.append(sample_tokens)
|
|
||||||
x = torch.cat(token_parts, dim=1)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.attnres_backbone(x)
|
|
||||||
x = x[:, -sample_tokens.shape[1] :, :]
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _forward_vanilla(
|
|
||||||
self,
|
|
||||||
sample: torch.Tensor,
|
|
||||||
r: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
cond: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r_emb = self.time_emb(r).unsqueeze(1)
|
|
||||||
t_emb = self.time_emb(t).unsqueeze(1)
|
|
||||||
input_emb = self.input_emb(sample)
|
|
||||||
|
|
||||||
if self.encoder_only:
|
|
||||||
token_embeddings = torch.cat([r_emb, t_emb, input_emb], dim=1)
|
|
||||||
token_count = token_embeddings.shape[1]
|
|
||||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
|
||||||
x = self.drop(token_embeddings + position_embeddings)
|
|
||||||
x = self.encoder(src=x, mask=self.mask)
|
|
||||||
x = x[:, 2:, :]
|
|
||||||
else:
|
|
||||||
cond_embeddings = torch.cat([r_emb, t_emb], dim=1)
|
|
||||||
if self.obs_as_cond:
|
|
||||||
cond_obs_emb = self.cond_obs_emb(cond)
|
|
||||||
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
|
|
||||||
token_count = cond_embeddings.shape[1]
|
|
||||||
position_embeddings = self.cond_pos_emb[:, :token_count, :]
|
|
||||||
x = self.drop(cond_embeddings + position_embeddings)
|
|
||||||
x = self.encoder(x)
|
|
||||||
memory = x
|
|
||||||
|
|
||||||
token_embeddings = input_emb
|
|
||||||
token_count = token_embeddings.shape[1]
|
|
||||||
position_embeddings = self.pos_emb[:, :token_count, :]
|
|
||||||
x = self.drop(token_embeddings + position_embeddings)
|
|
||||||
x = self.decoder(
|
|
||||||
tgt=x,
|
|
||||||
memory=memory,
|
|
||||||
tgt_mask=self.mask,
|
|
||||||
memory_mask=self.memory_mask,
|
|
||||||
)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
sample: torch.Tensor,
|
|
||||||
r: Union[torch.Tensor, float, int],
|
|
||||||
t: Union[torch.Tensor, float, int],
|
|
||||||
cond: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
r = self._prepare_time_input(r, sample)
|
|
||||||
t = self._prepare_time_input(t, sample)
|
|
||||||
|
|
||||||
if self.backbone_type == 'attnres_full':
|
|
||||||
x = self._forward_attnres_full(sample, r, t, cond=cond)
|
|
||||||
else:
|
|
||||||
x = self._forward_vanilla(sample, r, t, cond=cond)
|
|
||||||
|
|
||||||
x = self.ln_f(x)
|
|
||||||
x = self.head(x)
|
|
||||||
return x
|
|
||||||
@@ -0,0 +1,265 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||||
|
from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PMFTransformerForDiffusion(ModuleAttrMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
horizon: int,
|
||||||
|
n_obs_steps: Optional[int] = None,
|
||||||
|
cond_dim: int = 0,
|
||||||
|
n_layer: int = 12,
|
||||||
|
n_head: int = 12,
|
||||||
|
n_emb: int = 768,
|
||||||
|
p_drop_emb: float = 0.1,
|
||||||
|
p_drop_attn: float = 0.1,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
obs_as_cond: bool = False,
|
||||||
|
n_cond_layers: int = 0,
|
||||||
|
n_time_tokens: int = 4,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if n_obs_steps is None:
|
||||||
|
n_obs_steps = horizon
|
||||||
|
if n_time_tokens < 1:
|
||||||
|
raise ValueError("n_time_tokens must be >= 1")
|
||||||
|
|
||||||
|
obs_as_cond = cond_dim > 0
|
||||||
|
T = horizon
|
||||||
|
n_global_cond_tokens = 2 * n_time_tokens
|
||||||
|
T_cond = n_global_cond_tokens + (n_obs_steps if obs_as_cond else 0)
|
||||||
|
|
||||||
|
self.input_emb = nn.Linear(input_dim, n_emb)
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
|
||||||
|
self.drop = nn.Dropout(p_drop_emb)
|
||||||
|
|
||||||
|
self.t_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.r_emb = SinusoidalPosEmb(n_emb)
|
||||||
|
self.t_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||||
|
self.r_tokens = nn.Parameter(torch.zeros(1, n_time_tokens, n_emb))
|
||||||
|
self.cond_obs_emb = nn.Linear(cond_dim, n_emb) if obs_as_cond else None
|
||||||
|
self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
|
||||||
|
|
||||||
|
if n_cond_layers > 0:
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=n_cond_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(n_emb, 4 * n_emb),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Linear(4 * n_emb, n_emb),
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = nn.TransformerDecoderLayer(
|
||||||
|
d_model=n_emb,
|
||||||
|
nhead=n_head,
|
||||||
|
dim_feedforward=4 * n_emb,
|
||||||
|
dropout=p_drop_attn,
|
||||||
|
activation="gelu",
|
||||||
|
batch_first=True,
|
||||||
|
norm_first=True,
|
||||||
|
)
|
||||||
|
self.decoder = nn.TransformerDecoder(
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
num_layers=n_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal_attn:
|
||||||
|
sz = T
|
||||||
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||||
|
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
|
||||||
|
self.register_buffer("mask", mask)
|
||||||
|
|
||||||
|
if obs_as_cond:
|
||||||
|
q_idx, c_idx = torch.meshgrid(
|
||||||
|
torch.arange(T),
|
||||||
|
torch.arange(T_cond),
|
||||||
|
indexing="ij",
|
||||||
|
)
|
||||||
|
obs_offset = n_global_cond_tokens
|
||||||
|
visible = c_idx < obs_offset
|
||||||
|
visible = visible | (q_idx >= (c_idx - obs_offset))
|
||||||
|
memory_mask = visible.float().masked_fill(~visible, float("-inf")).masked_fill(visible, float(0.0))
|
||||||
|
self.register_buffer("memory_mask", memory_mask)
|
||||||
|
else:
|
||||||
|
self.memory_mask = None
|
||||||
|
else:
|
||||||
|
self.mask = None
|
||||||
|
self.memory_mask = None
|
||||||
|
|
||||||
|
self.ln_f = nn.LayerNorm(n_emb)
|
||||||
|
self.head_u = nn.Linear(n_emb, output_dim)
|
||||||
|
self.head_v = nn.Linear(n_emb, output_dim)
|
||||||
|
|
||||||
|
self.T = T
|
||||||
|
self.T_cond = T_cond
|
||||||
|
self.horizon = horizon
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.n_global_cond_tokens = n_global_cond_tokens
|
||||||
|
self.n_time_tokens = n_time_tokens
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
logger.info(
|
||||||
|
"number of parameters: %e", sum(p.numel() for p in self.parameters())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
ignore_types = (
|
||||||
|
nn.Dropout,
|
||||||
|
SinusoidalPosEmb,
|
||||||
|
nn.TransformerEncoderLayer,
|
||||||
|
nn.TransformerDecoderLayer,
|
||||||
|
nn.TransformerEncoder,
|
||||||
|
nn.TransformerDecoder,
|
||||||
|
nn.ModuleList,
|
||||||
|
nn.Mish,
|
||||||
|
nn.Sequential,
|
||||||
|
)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.MultiheadAttention):
|
||||||
|
for name in ("in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"):
|
||||||
|
weight = getattr(module, name)
|
||||||
|
if weight is not None:
|
||||||
|
torch.nn.init.normal_(weight, mean=0.0, std=0.02)
|
||||||
|
for name in ("in_proj_bias", "bias_k", "bias_v"):
|
||||||
|
bias = getattr(module, name)
|
||||||
|
if bias is not None:
|
||||||
|
torch.nn.init.zeros_(bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
torch.nn.init.ones_(module.weight)
|
||||||
|
elif isinstance(module, PMFTransformerForDiffusion):
|
||||||
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.t_tokens, mean=0.0, std=0.02)
|
||||||
|
torch.nn.init.normal_(module.r_tokens, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, ignore_types):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unaccounted module {}".format(module))
|
||||||
|
|
||||||
|
def get_optim_groups(self, weight_decay: float = 1e-3):
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, _ in m.named_parameters():
|
||||||
|
fpn = "%s.%s" % (mn, pn) if mn else pn
|
||||||
|
if pn.endswith("bias"):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.startswith("bias"):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
no_decay.update(
|
||||||
|
{
|
||||||
|
"pos_emb",
|
||||||
|
"cond_pos_emb",
|
||||||
|
"t_tokens",
|
||||||
|
"r_tokens",
|
||||||
|
"_dummy_variable",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, (
|
||||||
|
"parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def configure_optimizers(
|
||||||
|
self,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.95),
|
||||||
|
):
|
||||||
|
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def _broadcast_time(self, value: Union[torch.Tensor, float, int], batch_size: int, device: torch.device):
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
value = torch.tensor([value], dtype=torch.float32, device=device)
|
||||||
|
elif value.ndim == 0:
|
||||||
|
value = value[None].to(device=device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
value = value.to(device=device, dtype=torch.float32)
|
||||||
|
return value.expand(batch_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: Union[torch.Tensor, float, int],
|
||||||
|
r: Union[torch.Tensor, float, int],
|
||||||
|
cond: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = sample.shape[0]
|
||||||
|
device = sample.device
|
||||||
|
t = self._broadcast_time(t, batch_size, device)
|
||||||
|
r = self._broadcast_time(r, batch_size, device)
|
||||||
|
|
||||||
|
input_emb = self.input_emb(sample)
|
||||||
|
|
||||||
|
t_cond = self.t_tokens + self.t_emb(t).unsqueeze(1)
|
||||||
|
r_cond = self.r_tokens + self.r_emb(r).unsqueeze(1)
|
||||||
|
cond_embeddings = [t_cond, r_cond]
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond_embeddings.append(self.cond_obs_emb(cond))
|
||||||
|
cond_embeddings = torch.cat(cond_embeddings, dim=1)
|
||||||
|
|
||||||
|
cond_pos = self.cond_pos_emb[:, : cond_embeddings.shape[1], :]
|
||||||
|
memory = self.drop(cond_embeddings + cond_pos)
|
||||||
|
memory = self.encoder(memory)
|
||||||
|
|
||||||
|
token_pos = self.pos_emb[:, : input_emb.shape[1], :]
|
||||||
|
x = self.drop(input_emb + token_pos)
|
||||||
|
x = self.decoder(
|
||||||
|
tgt=x,
|
||||||
|
memory=memory,
|
||||||
|
tgt_mask=self.mask,
|
||||||
|
memory_mask=self.memory_mask,
|
||||||
|
)
|
||||||
|
x = self.ln_f(x)
|
||||||
|
return self.head_u(x), self.head_v(x)
|
||||||
@@ -1,283 +0,0 @@
|
|||||||
from contextlib import nullcontext
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import reduce
|
|
||||||
|
|
||||||
from diffusion_policy.common.pytorch_util import dict_apply
|
|
||||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import IMFTransformerForDiffusion
|
|
||||||
from diffusion_policy.policy.diffusion_transformer_hybrid_image_policy import (
|
|
||||||
DiffusionTransformerHybridImagePolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.func import jvp as TORCH_FUNC_JVP
|
|
||||||
except ImportError: # pragma: no cover - depends on torch version
|
|
||||||
TORCH_FUNC_JVP = None
|
|
||||||
|
|
||||||
|
|
||||||
class IMFTransformerHybridImagePolicy(DiffusionTransformerHybridImagePolicy):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
shape_meta: dict,
|
|
||||||
noise_scheduler,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
crop_shape=(76, 76),
|
|
||||||
obs_encoder_group_norm=False,
|
|
||||||
eval_fixed_crop=False,
|
|
||||||
n_layer=8,
|
|
||||||
n_cond_layers=0,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=256,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.3,
|
|
||||||
causal_attn=True,
|
|
||||||
time_as_cond=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
pred_action_steps_only=False,
|
|
||||||
backbone_type='vanilla',
|
|
||||||
n_kv_head=1,
|
|
||||||
attn_res_ffn_mult=2.667,
|
|
||||||
attn_res_eps=1e-6,
|
|
||||||
attn_res_rope_theta=10000.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if num_inference_steps is None:
|
|
||||||
num_inference_steps = 1
|
|
||||||
elif num_inference_steps != 1:
|
|
||||||
raise ValueError(
|
|
||||||
'IMFTransformerHybridImagePolicy only supports one-step inference; '
|
|
||||||
f'num_inference_steps must be 1, got {num_inference_steps}.'
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
shape_meta=shape_meta,
|
|
||||||
noise_scheduler=noise_scheduler,
|
|
||||||
horizon=horizon,
|
|
||||||
n_action_steps=n_action_steps,
|
|
||||||
n_obs_steps=n_obs_steps,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
crop_shape=crop_shape,
|
|
||||||
obs_encoder_group_norm=obs_encoder_group_norm,
|
|
||||||
eval_fixed_crop=eval_fixed_crop,
|
|
||||||
n_layer=n_layer,
|
|
||||||
n_cond_layers=n_cond_layers,
|
|
||||||
n_head=n_head,
|
|
||||||
n_emb=n_emb,
|
|
||||||
p_drop_emb=p_drop_emb,
|
|
||||||
p_drop_attn=p_drop_attn,
|
|
||||||
causal_attn=causal_attn,
|
|
||||||
time_as_cond=time_as_cond,
|
|
||||||
obs_as_cond=obs_as_cond,
|
|
||||||
pred_action_steps_only=pred_action_steps_only,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_dim = self.action_dim if self.obs_as_cond else (self.obs_feature_dim + self.action_dim)
|
|
||||||
cond_dim = self.obs_feature_dim if self.obs_as_cond else 0
|
|
||||||
model_horizon = self.n_action_steps if self.pred_action_steps_only else horizon
|
|
||||||
self.model = IMFTransformerForDiffusion(
|
|
||||||
input_dim=input_dim,
|
|
||||||
output_dim=input_dim,
|
|
||||||
horizon=model_horizon,
|
|
||||||
n_obs_steps=n_obs_steps,
|
|
||||||
cond_dim=cond_dim,
|
|
||||||
n_layer=n_layer,
|
|
||||||
n_head=n_head,
|
|
||||||
n_emb=n_emb,
|
|
||||||
p_drop_emb=p_drop_emb,
|
|
||||||
p_drop_attn=p_drop_attn,
|
|
||||||
causal_attn=causal_attn,
|
|
||||||
time_as_cond=time_as_cond,
|
|
||||||
obs_as_cond=obs_as_cond,
|
|
||||||
n_cond_layers=n_cond_layers,
|
|
||||||
backbone_type=backbone_type,
|
|
||||||
n_kv_head=n_kv_head,
|
|
||||||
attn_res_ffn_mult=attn_res_ffn_mult,
|
|
||||||
attn_res_eps=attn_res_eps,
|
|
||||||
attn_res_rope_theta=attn_res_rope_theta,
|
|
||||||
)
|
|
||||||
self.num_inference_steps = 1
|
|
||||||
|
|
||||||
def fn(self, z: torch.Tensor, r: torch.Tensor, t: torch.Tensor, cond=None) -> torch.Tensor:
|
|
||||||
return self.model(z, r, t, cond=cond)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _broadcast_batch_time(value: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
|
||||||
while value.ndim < reference.ndim:
|
|
||||||
value = value.unsqueeze(-1)
|
|
||||||
return value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _apply_conditioning(
|
|
||||||
trajectory: torch.Tensor,
|
|
||||||
condition_data: Optional[torch.Tensor] = None,
|
|
||||||
condition_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if condition_data is None or condition_mask is None:
|
|
||||||
return trajectory
|
|
||||||
conditioned = trajectory.clone()
|
|
||||||
conditioned[condition_mask] = condition_data[condition_mask]
|
|
||||||
return conditioned
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jvp_math_sdp_context(z_t: torch.Tensor):
|
|
||||||
if z_t.is_cuda:
|
|
||||||
return torch.backends.cuda.sdp_kernel(
|
|
||||||
enable_flash=False,
|
|
||||||
enable_math=True,
|
|
||||||
enable_mem_efficient=False,
|
|
||||||
enable_cudnn=False,
|
|
||||||
)
|
|
||||||
return nullcontext()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jvp_tangents(v: torch.Tensor, r: torch.Tensor, t: torch.Tensor):
|
|
||||||
return v.detach(), torch.zeros_like(r), torch.ones_like(t)
|
|
||||||
|
|
||||||
def _compute_u_and_du_dt(
|
|
||||||
self,
|
|
||||||
z_t: torch.Tensor,
|
|
||||||
r: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
cond,
|
|
||||||
v: torch.Tensor,
|
|
||||||
condition_data: Optional[torch.Tensor] = None,
|
|
||||||
condition_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
tangents = self._jvp_tangents(v, r, t)
|
|
||||||
|
|
||||||
def g(z, r_value, t_value):
|
|
||||||
conditioned_z = self._apply_conditioning(z, condition_data, condition_mask)
|
|
||||||
return self.fn(conditioned_z, r_value, t_value, cond=cond)
|
|
||||||
|
|
||||||
with self._jvp_math_sdp_context(z_t):
|
|
||||||
if TORCH_FUNC_JVP is not None:
|
|
||||||
try:
|
|
||||||
return TORCH_FUNC_JVP(g, (z_t, r, t), tangents)
|
|
||||||
except (RuntimeError, TypeError, NotImplementedError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
u = g(z_t, r, t)
|
|
||||||
_, du_dt = torch.autograd.functional.jvp(
|
|
||||||
g,
|
|
||||||
(z_t, r, t),
|
|
||||||
tangents,
|
|
||||||
create_graph=False,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
return u, du_dt
|
|
||||||
|
|
||||||
def _compound_velocity(
|
|
||||||
self,
|
|
||||||
u: torch.Tensor,
|
|
||||||
du_dt: torch.Tensor,
|
|
||||||
r: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
delta = self._broadcast_batch_time(t - r, u)
|
|
||||||
return u + delta * du_dt.detach()
|
|
||||||
|
|
||||||
def _sample_one_step(
|
|
||||||
self,
|
|
||||||
z_t: torch.Tensor,
|
|
||||||
r: torch.Tensor = None,
|
|
||||||
t: torch.Tensor = None,
|
|
||||||
cond=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size = z_t.shape[0]
|
|
||||||
if t is None:
|
|
||||||
t = torch.ones(batch_size, device=z_t.device, dtype=z_t.dtype)
|
|
||||||
if r is None:
|
|
||||||
r = torch.zeros(batch_size, device=z_t.device, dtype=z_t.dtype)
|
|
||||||
u = self.fn(z_t, r, t, cond=cond)
|
|
||||||
delta = self._broadcast_batch_time(t - r, z_t)
|
|
||||||
return z_t - delta * u
|
|
||||||
|
|
||||||
def conditional_sample(
|
|
||||||
self,
|
|
||||||
condition_data,
|
|
||||||
condition_mask,
|
|
||||||
cond=None,
|
|
||||||
generator=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
trajectory = torch.randn(
|
|
||||||
size=condition_data.shape,
|
|
||||||
dtype=condition_data.dtype,
|
|
||||||
device=condition_data.device,
|
|
||||||
generator=generator,
|
|
||||||
)
|
|
||||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
|
||||||
trajectory = self._sample_one_step(trajectory, cond=cond)
|
|
||||||
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
def compute_loss(self, batch):
|
|
||||||
assert 'valid_mask' not in batch
|
|
||||||
nobs = self.normalizer.normalize(batch['obs'])
|
|
||||||
nactions = self.normalizer['action'].normalize(batch['action'])
|
|
||||||
batch_size = nactions.shape[0]
|
|
||||||
horizon = nactions.shape[1]
|
|
||||||
To = self.n_obs_steps
|
|
||||||
|
|
||||||
cond = None
|
|
||||||
trajectory = nactions
|
|
||||||
if self.obs_as_cond:
|
|
||||||
this_nobs = dict_apply(
|
|
||||||
nobs,
|
|
||||||
lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]),
|
|
||||||
)
|
|
||||||
nobs_features = self.obs_encoder(this_nobs)
|
|
||||||
cond = nobs_features.reshape(batch_size, To, -1)
|
|
||||||
if self.pred_action_steps_only:
|
|
||||||
start = To - 1
|
|
||||||
end = start + self.n_action_steps
|
|
||||||
trajectory = nactions[:, start:end]
|
|
||||||
else:
|
|
||||||
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
|
||||||
nobs_features = self.obs_encoder(this_nobs)
|
|
||||||
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
|
||||||
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
|
||||||
|
|
||||||
if self.pred_action_steps_only:
|
|
||||||
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
|
||||||
else:
|
|
||||||
condition_mask = self.mask_generator(trajectory.shape)
|
|
||||||
|
|
||||||
loss_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
|
||||||
loss_mask[..., : self.action_dim] = True
|
|
||||||
loss_mask = loss_mask & ~condition_mask
|
|
||||||
|
|
||||||
x = trajectory
|
|
||||||
e = torch.randn_like(x)
|
|
||||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
|
||||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
|
||||||
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
|
||||||
|
|
||||||
t_broadcast = self._broadcast_batch_time(t, x)
|
|
||||||
z_t = (1 - t_broadcast) * x + t_broadcast * e
|
|
||||||
z_t = self._apply_conditioning(z_t, x, condition_mask)
|
|
||||||
|
|
||||||
v = self.fn(z_t, t, t, cond=cond)
|
|
||||||
u, du_dt = self._compute_u_and_du_dt(
|
|
||||||
z_t,
|
|
||||||
r,
|
|
||||||
t,
|
|
||||||
cond=cond,
|
|
||||||
v=v,
|
|
||||||
condition_data=x,
|
|
||||||
condition_mask=condition_mask,
|
|
||||||
)
|
|
||||||
V = self._compound_velocity(u, du_dt, r, t)
|
|
||||||
target = e - x
|
|
||||||
|
|
||||||
loss = F.mse_loss(V, target, reduction='none')
|
|
||||||
loss = loss * loss_mask.type(loss.dtype)
|
|
||||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
||||||
loss = loss.mean()
|
|
||||||
return loss
|
|
||||||
453
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
453
diffusion_policy/policy/pmf_transformer_hybrid_image_policy.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
from typing import Dict, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import reduce
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
|
import diffusion_policy.model.vision.crop_randomizer as dmvc
|
||||||
|
import robomimic.models.base_nets as rmbn
|
||||||
|
import robomimic.utils.obs_utils as ObsUtils
|
||||||
|
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
|
||||||
|
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
|
||||||
|
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||||
|
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||||
|
from diffusion_policy.model.diffusion.pmf_transformer_for_diffusion import (
|
||||||
|
PMFTransformerForDiffusion,
|
||||||
|
)
|
||||||
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||||
|
from robomimic.algo import algo_factory
|
||||||
|
from robomimic.algo.algo import PolicyAlgo
|
||||||
|
|
||||||
|
|
||||||
|
class PMFTransformerHybridImagePolicy(BaseImagePolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shape_meta: dict,
|
||||||
|
noise_scheduler: DDPMScheduler,
|
||||||
|
horizon,
|
||||||
|
n_action_steps,
|
||||||
|
n_obs_steps,
|
||||||
|
num_inference_steps=None,
|
||||||
|
crop_shape=(76, 76),
|
||||||
|
obs_encoder_group_norm=False,
|
||||||
|
eval_fixed_crop=False,
|
||||||
|
n_layer=8,
|
||||||
|
n_cond_layers=0,
|
||||||
|
n_head=4,
|
||||||
|
n_emb=256,
|
||||||
|
p_drop_emb=0.0,
|
||||||
|
p_drop_attn=0.0,
|
||||||
|
causal_attn=True,
|
||||||
|
obs_as_cond=True,
|
||||||
|
pred_action_steps_only=False,
|
||||||
|
n_time_tokens=4,
|
||||||
|
min_time=0.05,
|
||||||
|
du_dt_epsilon=1.0e-3,
|
||||||
|
pmf_u_loss_weight=1.0,
|
||||||
|
pmf_v_loss_weight=1.0,
|
||||||
|
noise_scale=1.0,
|
||||||
|
adatloss_eps=0.01,
|
||||||
|
p_mean=-0.4,
|
||||||
|
p_std=1.0,
|
||||||
|
tr_uniform=True,
|
||||||
|
tr_uniform_prob=0.1,
|
||||||
|
data_proportion=0.5,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
action_shape = shape_meta["action"]["shape"]
|
||||||
|
assert len(action_shape) == 1
|
||||||
|
action_dim = action_shape[0]
|
||||||
|
obs_shape_meta = shape_meta["obs"]
|
||||||
|
obs_config = {
|
||||||
|
"low_dim": [],
|
||||||
|
"rgb": [],
|
||||||
|
"depth": [],
|
||||||
|
"scan": [],
|
||||||
|
}
|
||||||
|
obs_key_shapes = dict()
|
||||||
|
for key, attr in obs_shape_meta.items():
|
||||||
|
shape = attr["shape"]
|
||||||
|
obs_key_shapes[key] = list(shape)
|
||||||
|
|
||||||
|
obs_type = attr.get("type", "low_dim")
|
||||||
|
if obs_type == "rgb":
|
||||||
|
obs_config["rgb"].append(key)
|
||||||
|
elif obs_type == "low_dim":
|
||||||
|
obs_config["low_dim"].append(key)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported obs type: {obs_type}")
|
||||||
|
|
||||||
|
config = get_robomimic_config(
|
||||||
|
algo_name="bc_rnn",
|
||||||
|
hdf5_type="image",
|
||||||
|
task_name="square",
|
||||||
|
dataset_type="ph",
|
||||||
|
)
|
||||||
|
|
||||||
|
with config.unlocked():
|
||||||
|
config.observation.modalities.obs = obs_config
|
||||||
|
|
||||||
|
if crop_shape is None:
|
||||||
|
for _, modality in config.observation.encoder.items():
|
||||||
|
if modality.obs_randomizer_class == "CropRandomizer":
|
||||||
|
modality["obs_randomizer_class"] = None
|
||||||
|
else:
|
||||||
|
crop_h, crop_w = crop_shape
|
||||||
|
for _, modality in config.observation.encoder.items():
|
||||||
|
if modality.obs_randomizer_class == "CropRandomizer":
|
||||||
|
modality.obs_randomizer_kwargs.crop_height = crop_h
|
||||||
|
modality.obs_randomizer_kwargs.crop_width = crop_w
|
||||||
|
|
||||||
|
ObsUtils.initialize_obs_utils_with_config(config)
|
||||||
|
|
||||||
|
policy: PolicyAlgo = algo_factory(
|
||||||
|
algo_name=config.algo_name,
|
||||||
|
config=config,
|
||||||
|
obs_key_shapes=obs_key_shapes,
|
||||||
|
ac_dim=action_dim,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_encoder = policy.nets["policy"].nets["encoder"].nets["obs"]
|
||||||
|
if obs_encoder_group_norm:
|
||||||
|
replace_submodules(
|
||||||
|
root_module=obs_encoder,
|
||||||
|
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||||
|
func=lambda x: nn.GroupNorm(
|
||||||
|
num_groups=x.num_features // 16,
|
||||||
|
num_channels=x.num_features,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if eval_fixed_crop:
|
||||||
|
replace_submodules(
|
||||||
|
root_module=obs_encoder,
|
||||||
|
predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
|
||||||
|
func=lambda x: dmvc.CropRandomizer(
|
||||||
|
input_shape=x.input_shape,
|
||||||
|
crop_height=x.crop_height,
|
||||||
|
crop_width=x.crop_width,
|
||||||
|
num_crops=x.num_crops,
|
||||||
|
pos_enc=x.pos_enc,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||||
|
input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
|
||||||
|
cond_dim = obs_feature_dim if obs_as_cond else 0
|
||||||
|
|
||||||
|
self.obs_encoder = obs_encoder
|
||||||
|
self.model = PMFTransformerForDiffusion(
|
||||||
|
input_dim=input_dim,
|
||||||
|
output_dim=input_dim,
|
||||||
|
horizon=horizon if not pred_action_steps_only else n_action_steps,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
n_layer=n_layer,
|
||||||
|
n_head=n_head,
|
||||||
|
n_emb=n_emb,
|
||||||
|
p_drop_emb=p_drop_emb,
|
||||||
|
p_drop_attn=p_drop_attn,
|
||||||
|
causal_attn=causal_attn,
|
||||||
|
obs_as_cond=obs_as_cond,
|
||||||
|
n_cond_layers=n_cond_layers,
|
||||||
|
n_time_tokens=n_time_tokens,
|
||||||
|
)
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
self.mask_generator = LowdimMaskGenerator(
|
||||||
|
action_dim=action_dim,
|
||||||
|
obs_dim=0 if obs_as_cond else obs_feature_dim,
|
||||||
|
max_n_obs_steps=n_obs_steps,
|
||||||
|
fix_obs_steps=True,
|
||||||
|
action_visible=False,
|
||||||
|
)
|
||||||
|
self.normalizer = LinearNormalizer()
|
||||||
|
self.horizon = horizon
|
||||||
|
self.obs_feature_dim = obs_feature_dim
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.obs_as_cond = obs_as_cond
|
||||||
|
self.pred_action_steps_only = pred_action_steps_only
|
||||||
|
self.min_time = min_time
|
||||||
|
self.du_dt_epsilon = du_dt_epsilon
|
||||||
|
self.pmf_u_loss_weight = pmf_u_loss_weight
|
||||||
|
self.pmf_v_loss_weight = pmf_v_loss_weight
|
||||||
|
self.noise_scale = noise_scale
|
||||||
|
self.adatloss_eps = adatloss_eps
|
||||||
|
self.p_mean = p_mean
|
||||||
|
self.p_std = p_std
|
||||||
|
self.tr_uniform = tr_uniform
|
||||||
|
self.tr_uniform_prob = tr_uniform_prob
|
||||||
|
self.data_proportion = data_proportion
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
if num_inference_steps is None:
|
||||||
|
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
def _encode_obs(self, nobs: Dict[str, torch.Tensor], n_steps: int) -> torch.Tensor:
|
||||||
|
flat_nobs = dict_apply(nobs, lambda x: x[:, :n_steps, ...].reshape(-1, *x.shape[2:]))
|
||||||
|
nobs_features = self.obs_encoder(flat_nobs)
|
||||||
|
return nobs_features.reshape(next(iter(nobs.values())).shape[0], n_steps, -1)
|
||||||
|
|
||||||
|
def _time_view(self, value: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||||
|
return value.reshape(value.shape[0], *([1] * (ref.ndim - 1)))
|
||||||
|
|
||||||
|
def _adatloss(self, loss: torch.Tensor) -> torch.Tensor:
|
||||||
|
denom = loss.detach() + self.adatloss_eps
|
||||||
|
return loss / denom
|
||||||
|
|
||||||
|
def _sample_logit_normal(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
normal = torch.randn(batch_size, device=device, dtype=dtype)
|
||||||
|
return torch.sigmoid(normal * self.p_std + self.p_mean)
|
||||||
|
|
||||||
|
def _sample_tr(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||||
|
t = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
r = self._sample_logit_normal(batch_size, device, dtype)
|
||||||
|
|
||||||
|
if self.tr_uniform:
|
||||||
|
uniform_mask = torch.rand(batch_size, device=device) < self.tr_uniform_prob
|
||||||
|
uniform_t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
uniform_r = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
t = torch.where(uniform_mask, uniform_t, t)
|
||||||
|
r = torch.where(uniform_mask, uniform_r, r)
|
||||||
|
|
||||||
|
data_size = int(batch_size * self.data_proportion)
|
||||||
|
fm_mask = torch.arange(batch_size, device=device) < data_size
|
||||||
|
r = torch.where(fm_mask, t, r)
|
||||||
|
|
||||||
|
t_final = torch.maximum(t, r)
|
||||||
|
r_final = torch.minimum(t, r)
|
||||||
|
return t_final, r_final
|
||||||
|
|
||||||
|
def _trajectory_inputs(
|
||||||
|
self,
|
||||||
|
nobs: Dict[str, torch.Tensor],
|
||||||
|
nactions: torch.Tensor,
|
||||||
|
):
|
||||||
|
batch_size = nactions.shape[0]
|
||||||
|
horizon = nactions.shape[1]
|
||||||
|
cond = None
|
||||||
|
trajectory = nactions
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
start = self.n_obs_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
trajectory = nactions[:, start:end]
|
||||||
|
else:
|
||||||
|
nobs_features = self._encode_obs(nobs, horizon)
|
||||||
|
trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()
|
||||||
|
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
condition_mask = self.mask_generator(trajectory.shape)
|
||||||
|
|
||||||
|
return batch_size, trajectory, cond, condition_mask
|
||||||
|
|
||||||
|
def _apply_conditioning(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
condition_data: torch.Tensor,
|
||||||
|
condition_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not condition_mask.any():
|
||||||
|
return sample
|
||||||
|
return torch.where(condition_mask, condition_data, sample)
|
||||||
|
|
||||||
|
def _compute_u_v(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
):
|
||||||
|
x_hat_u, x_hat_v = self.model(sample, t, r, cond)
|
||||||
|
denom = self._time_view(t, sample)
|
||||||
|
u = (sample - x_hat_u) / denom
|
||||||
|
v = (sample - x_hat_v) / denom
|
||||||
|
return u, v
|
||||||
|
|
||||||
|
def _compute_du_dt(
|
||||||
|
self,
|
||||||
|
sample: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
r: torch.Tensor,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
condition_data: torch.Tensor,
|
||||||
|
condition_mask: torch.Tensor,
|
||||||
|
tangent_v: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tangent_sample = tangent_v.detach()
|
||||||
|
tangent_r = torch.zeros_like(r)
|
||||||
|
tangent_t = torch.ones_like(t)
|
||||||
|
|
||||||
|
def u_fn(sample_input, r_input, t_input):
|
||||||
|
conditioned_sample = self._apply_conditioning(
|
||||||
|
sample_input, condition_data, condition_mask
|
||||||
|
)
|
||||||
|
u_value, _ = self._compute_u_v(conditioned_sample, t_input, r_input, cond)
|
||||||
|
return u_value
|
||||||
|
|
||||||
|
primals = (sample, r, t)
|
||||||
|
tangents = (tangent_sample, tangent_r, tangent_t)
|
||||||
|
try:
|
||||||
|
_, du_dt = torch.func.jvp(u_fn, primals, tangents)
|
||||||
|
except (AttributeError, NotImplementedError, RuntimeError):
|
||||||
|
_, du_dt = torch.autograd.functional.jvp(
|
||||||
|
u_fn,
|
||||||
|
primals,
|
||||||
|
tangents,
|
||||||
|
create_graph=False,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return du_dt
|
||||||
|
|
||||||
|
# ========= inference ============
|
||||||
|
def conditional_sample(
|
||||||
|
self,
|
||||||
|
condition_data,
|
||||||
|
condition_mask,
|
||||||
|
cond=None,
|
||||||
|
generator=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
del kwargs
|
||||||
|
|
||||||
|
trajectory = torch.randn(
|
||||||
|
size=condition_data.shape,
|
||||||
|
dtype=condition_data.dtype,
|
||||||
|
device=condition_data.device,
|
||||||
|
generator=generator,
|
||||||
|
) * self.noise_scale
|
||||||
|
|
||||||
|
time_steps = torch.linspace(
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
self.num_inference_steps + 1,
|
||||||
|
dtype=trajectory.dtype,
|
||||||
|
device=trajectory.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for step_idx in range(self.num_inference_steps):
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
t = time_steps[step_idx].expand(trajectory.shape[0])
|
||||||
|
r = time_steps[step_idx + 1].expand(trajectory.shape[0])
|
||||||
|
u, _ = self._compute_u_v(trajectory, t, r, cond)
|
||||||
|
delta = self._time_view(t - r, trajectory)
|
||||||
|
trajectory = trajectory - delta * u
|
||||||
|
|
||||||
|
trajectory = self._apply_conditioning(trajectory, condition_data, condition_mask)
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
assert "past_action" not in obs_dict
|
||||||
|
nobs = self.normalizer.normalize(obs_dict)
|
||||||
|
value = next(iter(nobs.values()))
|
||||||
|
batch_size, to_steps = value.shape[:2]
|
||||||
|
horizon = self.horizon
|
||||||
|
action_dim = self.action_dim
|
||||||
|
|
||||||
|
device = self.device
|
||||||
|
dtype = self.dtype
|
||||||
|
cond = None
|
||||||
|
if self.obs_as_cond:
|
||||||
|
cond = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
shape = (batch_size, horizon, action_dim)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
shape = (batch_size, self.n_action_steps, action_dim)
|
||||||
|
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
nobs_features = self._encode_obs(nobs, self.n_obs_steps)
|
||||||
|
shape = (batch_size, horizon, action_dim + self.obs_feature_dim)
|
||||||
|
cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
|
||||||
|
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||||
|
cond_data[:, : self.n_obs_steps, action_dim:] = nobs_features
|
||||||
|
cond_mask[:, : self.n_obs_steps, action_dim:] = True
|
||||||
|
|
||||||
|
nsample = self.conditional_sample(
|
||||||
|
cond_data,
|
||||||
|
cond_mask,
|
||||||
|
cond=cond,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
naction_pred = nsample[..., :action_dim]
|
||||||
|
action_pred = self.normalizer["action"].unnormalize(naction_pred)
|
||||||
|
if self.pred_action_steps_only:
|
||||||
|
action = action_pred
|
||||||
|
else:
|
||||||
|
start = to_steps - 1
|
||||||
|
end = start + self.n_action_steps
|
||||||
|
action = action_pred[:, start:end]
|
||||||
|
return {
|
||||||
|
"action": action,
|
||||||
|
"action_pred": action_pred,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ========= training ============
|
||||||
|
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||||
|
self.normalizer.load_state_dict(normalizer.state_dict())
|
||||||
|
|
||||||
|
def get_optimizer(
|
||||||
|
self,
|
||||||
|
transformer_weight_decay: float,
|
||||||
|
obs_encoder_weight_decay: float,
|
||||||
|
learning_rate: float,
|
||||||
|
betas: Tuple[float, float],
|
||||||
|
) -> torch.optim.Optimizer:
|
||||||
|
optim_groups = self.model.get_optim_groups(weight_decay=transformer_weight_decay)
|
||||||
|
optim_groups.append(
|
||||||
|
{
|
||||||
|
"params": self.obs_encoder.parameters(),
|
||||||
|
"weight_decay": obs_encoder_weight_decay,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
assert "valid_mask" not in batch
|
||||||
|
nobs = self.normalizer.normalize(batch["obs"])
|
||||||
|
nactions = self.normalizer["action"].normalize(batch["action"])
|
||||||
|
|
||||||
|
_, trajectory, cond, condition_mask = self._trajectory_inputs(nobs, nactions)
|
||||||
|
noise = torch.randn_like(trajectory) * self.noise_scale
|
||||||
|
batch_size = trajectory.shape[0]
|
||||||
|
|
||||||
|
t, r = self._sample_tr(
|
||||||
|
batch_size, device=trajectory.device, dtype=trajectory.dtype
|
||||||
|
)
|
||||||
|
z_t = (1 - self._time_view(t, trajectory)) * trajectory + self._time_view(t, trajectory) * noise
|
||||||
|
z_t = self._apply_conditioning(z_t, trajectory, condition_mask)
|
||||||
|
|
||||||
|
loss_mask = ~condition_mask
|
||||||
|
target_v = noise - trajectory
|
||||||
|
|
||||||
|
u, v = self._compute_u_v(z_t, t, r, cond)
|
||||||
|
du_dt = self._compute_du_dt(
|
||||||
|
sample=z_t,
|
||||||
|
t=t,
|
||||||
|
r=r,
|
||||||
|
cond=cond,
|
||||||
|
condition_data=trajectory,
|
||||||
|
condition_mask=condition_mask,
|
||||||
|
tangent_v=v,
|
||||||
|
)
|
||||||
|
pmf_velocity = u + self._time_view(t - r, trajectory) * du_dt.detach()
|
||||||
|
|
||||||
|
loss_u = F.mse_loss(pmf_velocity, target_v, reduction="none")
|
||||||
|
loss_v = F.mse_loss(v, target_v, reduction="none")
|
||||||
|
loss_u = loss_u * loss_mask.type(loss_u.dtype)
|
||||||
|
loss_v = loss_v * loss_mask.type(loss_v.dtype)
|
||||||
|
loss_u = reduce(loss_u, "b ... -> b (...)", "mean").mean()
|
||||||
|
loss_v = reduce(loss_v, "b ... -> b (...)", "mean").mean()
|
||||||
|
loss_u = self._adatloss(loss_u)
|
||||||
|
loss_v = self._adatloss(loss_v)
|
||||||
|
return self.pmf_u_loss_weight * loss_u + self.pmf_v_loss_weight * loss_v
|
||||||
@@ -8,8 +8,6 @@ if __name__ == "__main__":
|
|||||||
os.chdir(ROOT_DIR)
|
os.chdir(ROOT_DIR)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import contextlib
|
|
||||||
import importlib
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -17,6 +15,7 @@ import pathlib
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
|
import wandb
|
||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import shutil
|
import shutil
|
||||||
@@ -32,111 +31,6 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
|||||||
|
|
||||||
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||||||
|
|
||||||
|
|
||||||
class _LoggingBackend:
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class _WandbLoggingBackend(_LoggingBackend):
|
|
||||||
def __init__(self, run):
|
|
||||||
self.run = run
|
|
||||||
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
self.run.log(payload, step=step)
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
self.run.finish()
|
|
||||||
|
|
||||||
|
|
||||||
class _SwanLabLoggingBackend(_LoggingBackend):
|
|
||||||
def __init__(self, run):
|
|
||||||
self.run = run
|
|
||||||
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
self.run.log(payload, step=step)
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
self.run.finish()
|
|
||||||
|
|
||||||
|
|
||||||
def _load_wandb():
|
|
||||||
try:
|
|
||||||
return importlib.import_module('wandb')
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"wandb is required when cfg.logging.backend == 'wandb' or missing"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _load_swanlab():
|
|
||||||
try:
|
|
||||||
return importlib.import_module('swanlab')
|
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def init_logging_backend(cfg: OmegaConf, output_dir):
|
|
||||||
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
|
|
||||||
if backend == 'swanlab':
|
|
||||||
swanlab = _load_swanlab()
|
|
||||||
if swanlab is None:
|
|
||||||
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
|
|
||||||
logging_cfg = cfg.logging
|
|
||||||
mode = logging_cfg.mode
|
|
||||||
if mode == 'online':
|
|
||||||
mode = 'cloud'
|
|
||||||
run = swanlab.init(
|
|
||||||
project=logging_cfg.project,
|
|
||||||
experiment_name=logging_cfg.name,
|
|
||||||
group=logging_cfg.group,
|
|
||||||
tags=logging_cfg.tags,
|
|
||||||
id=logging_cfg.id,
|
|
||||||
resume=logging_cfg.resume,
|
|
||||||
mode=mode,
|
|
||||||
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
|
|
||||||
config=OmegaConf.to_container(cfg, resolve=True),
|
|
||||||
)
|
|
||||||
return _SwanLabLoggingBackend(run)
|
|
||||||
|
|
||||||
if backend not in (None, 'wandb'):
|
|
||||||
raise ValueError(f"Unknown logging backend: {backend}")
|
|
||||||
|
|
||||||
wandb = _load_wandb()
|
|
||||||
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
|
|
||||||
logging_kwargs.pop('backend', None)
|
|
||||||
run = wandb.init(
|
|
||||||
dir=str(output_dir),
|
|
||||||
config=OmegaConf.to_container(cfg, resolve=True),
|
|
||||||
**logging_kwargs
|
|
||||||
)
|
|
||||||
wandb.config.update(
|
|
||||||
{
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return _WandbLoggingBackend(run)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def logging_backend_session(cfg: OmegaConf, output_dir):
|
|
||||||
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
|
|
||||||
primary_error = None
|
|
||||||
try:
|
|
||||||
yield logging_backend
|
|
||||||
except BaseException as exc:
|
|
||||||
primary_error = exc
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
logging_backend.finish()
|
|
||||||
except BaseException:
|
|
||||||
if primary_error is None:
|
|
||||||
raise
|
|
||||||
|
|
||||||
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||||
include_keys = ['global_step', 'epoch']
|
include_keys = ['global_step', 'epoch']
|
||||||
|
|
||||||
@@ -215,6 +109,18 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
output_dir=self.output_dir)
|
output_dir=self.output_dir)
|
||||||
assert isinstance(env_runner, BaseImageRunner)
|
assert isinstance(env_runner, BaseImageRunner)
|
||||||
|
|
||||||
|
# configure logging
|
||||||
|
wandb_run = wandb.init(
|
||||||
|
dir=str(self.output_dir),
|
||||||
|
config=OmegaConf.to_container(cfg, resolve=True),
|
||||||
|
**cfg.logging
|
||||||
|
)
|
||||||
|
wandb.config.update(
|
||||||
|
{
|
||||||
|
"output_dir": self.output_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# configure checkpoint
|
# configure checkpoint
|
||||||
topk_manager = TopKCheckpointManager(
|
topk_manager = TopKCheckpointManager(
|
||||||
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
||||||
@@ -242,7 +148,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
||||||
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
|
|
||||||
with JsonLogger(log_path) as json_logger:
|
with JsonLogger(log_path) as json_logger:
|
||||||
for local_epoch_idx in range(cfg.training.num_epochs):
|
for local_epoch_idx in range(cfg.training.num_epochs):
|
||||||
step_log = dict()
|
step_log = dict()
|
||||||
@@ -285,7 +190,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
||||||
if not is_last_batch:
|
if not is_last_batch:
|
||||||
# log of last step is combined with validation and rollout
|
# log of last step is combined with validation and rollout
|
||||||
logging_backend.log(step_log, step=self.global_step)
|
wandb_run.log(step_log, step=self.global_step)
|
||||||
json_logger.log(step_log)
|
json_logger.log(step_log)
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
@@ -373,7 +278,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
# log of last step is combined with validation and rollout
|
# log of last step is combined with validation and rollout
|
||||||
logging_backend.log(step_log, step=self.global_step)
|
wandb_run.log(step_log, step=self.global_step)
|
||||||
json_logger.log(step_log)
|
json_logger.log(step_log)
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
|
|||||||
@@ -1,60 +0,0 @@
|
|||||||
# PushT iMF Full-Attention Implementation Plan
|
|
||||||
|
|
||||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
|
||||||
|
|
||||||
**Goal:** Add a separate full-attention PushT image iMF config, commit/push it on a new branch, and launch the 9-run 350-epoch architecture sweep across 3 GPUs.
|
|
||||||
|
|
||||||
**Architecture:** Keep the existing causal iMF path untouched and add a standalone full-attention config that only flips `policy.causal_attn=false` while retaining one-step iMF inference and SwanLab-safe naming. Reuse the previous 9-run architecture matrix and balanced 3-queue scheduling across local 5090 plus 5880 GPU0/GPU1.
|
|
||||||
|
|
||||||
**Tech Stack:** Hydra, Diffusion Policy iMF image workspace, SwanLab, uv env, local shell + trusted remote 5880 over SSH.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Add full-attention iMF config with TDD
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
|
||||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
|
||||||
|
|
||||||
- [ ] Write a failing config regression test asserting the new config uses SwanLab-safe naming and `policy.causal_attn == False`.
|
|
||||||
- [ ] Run the targeted pytest command and verify it fails because the config does not exist yet.
|
|
||||||
- [ ] Add the minimal full-attention config by composing from the existing PushT image iMF config and overriding only `exp_name` and `policy.causal_attn=false`.
|
|
||||||
- [ ] Re-run the targeted pytest and verify it passes.
|
|
||||||
|
|
||||||
### Task 2: Verify the new config
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Read: `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
|
||||||
|
|
||||||
- [ ] Run `train.py --help` for the new config.
|
|
||||||
- [ ] Run a real `training.debug=true` smoke test locally to confirm the training path is valid.
|
|
||||||
|
|
||||||
### Task 3: Commit and push the new branch
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Commit only the new config/test/plan files needed for the full-attention experiment chain.
|
|
||||||
|
|
||||||
- [ ] Run verification commands again before commit.
|
|
||||||
- [ ] Commit with a focused message.
|
|
||||||
- [ ] Push `feat/pusht-imf-fullattn` to origin.
|
|
||||||
|
|
||||||
### Task 4: Launch the 9-run sweep
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Write queue scripts and logs under `data/run_logs/` locally and on 5880.
|
|
||||||
- Write outputs under `data/outputs/` locally and on 5880.
|
|
||||||
|
|
||||||
- [ ] Use the same matrix as the prior iMF sweep: `n_emb ∈ {128,256,384}`, `n_layer ∈ {6,12,18}`, `seed=42`.
|
|
||||||
- [ ] Set `training.num_epochs=350` for all 9 runs.
|
|
||||||
- [ ] Encode `fullattn` in every `exp_name`, `logging.name`, and run directory to avoid collisions.
|
|
||||||
- [ ] Balance the 9 runs across local 5090, 5880 GPU0, and 5880 GPU1 as three serial queues.
|
|
||||||
- [ ] Sync the new config to the remote smoke repo before launching remote queues.
|
|
||||||
|
|
||||||
### Task 5: Monitor and auto-summarize
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Read local and remote pid files, logs, outputs, checkpoints.
|
|
||||||
|
|
||||||
- [ ] Start an xhigh monitoring agent that polls all three queues.
|
|
||||||
- [ ] On completion, parse all 9 `logs.json.txt` files and rank by max `test_mean_score`.
|
|
||||||
- [ ] Report embedding/layer trends and the best configuration.
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# PushT Image iMF AttnRes Implementation Plan
|
|
||||||
|
|
||||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
|
||||||
|
|
||||||
**Goal:** Add an AttnRes-backed full-attention iMF backbone for the PushT image experiment path, verify it with tests/smoke runs, then launch the 9-run 350-epoch architecture sweep across the local 5090 and remote 5880 GPUs.
|
|
||||||
|
|
||||||
**Architecture:** Extend `IMFTransformerForDiffusion` with a selectable `attnres_full` backbone that keeps the current iMF training/inference API unchanged while replacing the transformer internals with RMSNorm + RoPE self-attention + SwiGLU + Full AttnRes depth-wise residual routing. Add one standalone Hydra config for the PushT image sweep and reuse queue-style launch scripts with unique SwanLab names.
|
|
||||||
|
|
||||||
**Tech Stack:** Python 3.9 via uv, PyTorch 2.8 CUDA, Hydra, SwanLab online logging, local shell + SSH to trusted 5880 host.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Add regression tests for the new AttnRes path
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `tests/test_imf_transformer_for_diffusion.py`
|
|
||||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
|
||||||
|
|
||||||
- [ ] Add a failing model test that instantiates `IMFTransformerForDiffusion(backbone_type='attnres_full', causal_attn=False, ...)`, runs a forward pass with conditional observations, and asserts output shape plus optimizer construction.
|
|
||||||
- [ ] Run the targeted pytest selection and confirm the new test fails for the expected missing-backbone reason.
|
|
||||||
- [ ] Add a failing config regression test for `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml` asserting SwanLab naming fields and `policy.causal_attn == False`.
|
|
||||||
- [ ] Re-run the targeted pytest selection and confirm the config test fails before implementation.
|
|
||||||
|
|
||||||
### Task 2: Implement the AttnRes-backed iMF backbone
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `diffusion_policy/model/diffusion/attnres_transformer_components.py`
|
|
||||||
- Modify: `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
|
||||||
|
|
||||||
- [ ] Add focused reusable modules for `RMSNorm`, RoPE helpers, grouped-query self-attention, SwiGLU FFN, and the Full AttnRes operator.
|
|
||||||
- [ ] Extend `IMFTransformerForDiffusion` with a `backbone_type` switch that preserves the existing vanilla path and adds an `attnres_full` path using concatenated `[r, t, obs, sample]` tokens.
|
|
||||||
- [ ] Ensure the AttnRes path slices condition tokens away before the output head so the returned tensor still matches the sample/action horizon.
|
|
||||||
- [ ] Update optimizer parameter grouping to treat RMSNorm weights like LayerNorm weights (no decay) and include any new positional/conditioning parameters.
|
|
||||||
- [ ] Run the targeted tests and get them green.
|
|
||||||
|
|
||||||
### Task 3: Add the new PushT config and smoke-test path
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
|
||||||
- Modify: `tests/test_pusht_swanlab_config.py`
|
|
||||||
|
|
||||||
- [ ] Add a standalone PushT image config for the AttnRes iMF variant with SwanLab online logging, `policy.backbone_type=attnres_full`, and `policy.causal_attn=false`.
|
|
||||||
- [ ] Verify `uv run python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml --help` succeeds.
|
|
||||||
- [ ] Run a real smoke training command with `training.debug=true`, `training.device=cuda:0`, safety overrides (`dataloader.num_workers=0`, `task.env_runner.n_envs=1`, no vis), and confirm it reaches the training loop and writes a run directory.
|
|
||||||
|
|
||||||
### Task 4: Prepare launch scripts and start the 9-run sweep
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create or modify: `data/run_logs/imf_attnres_local_queue.sh`
|
|
||||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu0_queue.sh`
|
|
||||||
- Create or modify locally before copy: `data/run_logs/imf_attnres_remote_gpu1_queue.sh`
|
|
||||||
|
|
||||||
- [ ] Write queue command templates for the 9 runs using config `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`, `training.num_epochs=350`, unique `exp_name/logging.name`, and shared `logging.group=imf_pusht_attnres_arch_sweep`.
|
|
||||||
- [ ] Sync the necessary config/model files plus remote queue scripts to `droid@100.73.14.65:~/project/diffusion_policy-smoke`.
|
|
||||||
- [ ] Start the local queue under `nohup`, record PID, and verify the first run log is advancing.
|
|
||||||
- [ ] Start the two remote queues under `nohup`, record PIDs, and verify both first-run logs are advancing.
|
|
||||||
- [ ] Confirm all three GPUs have officially entered training for the new sweep.
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# PushT Image DiT iMF + SwanLab Design
|
|
||||||
|
|
||||||
## Goal
|
|
||||||
Migrate the PushT image DiT experiment path from W&B to SwanLab online logging, suppress simulation video logging, then add an iMeanFlow-based one-step transformer policy for PushT image experiments and run a controlled architecture sweep over embedding width and depth using `test_mean_score` as the primary metric.
|
|
||||||
|
|
||||||
## Context
|
|
||||||
- The implementation baseline is `main`.
|
|
||||||
- The experiment path is limited to the PushT image transformer workflow; unrelated workspaces and runners should remain unchanged.
|
|
||||||
- Environment management must use the repo-local `uv` workflow.
|
|
||||||
- The trusted remote machine alias `5880` refers to `droid-system-product-name` (`droid@100.73.14.65`) and can run two GPU jobs in parallel.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
The work is split into two verified phases:
|
|
||||||
|
|
||||||
1. **Logging migration phase**
|
|
||||||
- Keep the existing PushT image DiT training behavior intact.
|
|
||||||
- Replace W&B usage with SwanLab in the transformer hybrid workspace used by PushT image DiT experiments.
|
|
||||||
- Preserve local `logs.json.txt` output.
|
|
||||||
- Ensure rollout metrics such as `test_mean_score` and per-seed rewards are still logged.
|
|
||||||
- Disable simulation video logging at both the config and runner/logging boundary.
|
|
||||||
|
|
||||||
2. **iMF migration phase**
|
|
||||||
- Keep the original diffusion-based transformer image policy available on `main`.
|
|
||||||
- Add a parallel iMF-specific model/policy/config path rather than overwriting the baseline diffusion policy.
|
|
||||||
- Reuse the existing observation encoder and training workspace where possible.
|
|
||||||
- Replace diffusion training with the iMeanFlow training objective.
|
|
||||||
- Use one-step inference for validation/rollout in the iMF path.
|
|
||||||
|
|
||||||
The implementation planning boundary for this spec is:
|
|
||||||
- code changes through a smoke-tested, pushed iMF branch
|
|
||||||
- not the full 3x3 sweep execution/monitoring workflow, which should be planned separately after the code path is verified and pushed
|
|
||||||
|
|
||||||
## Logging Design
|
|
||||||
### Scope
|
|
||||||
Only the PushT image DiT experiment chain is changed:
|
|
||||||
- `train_diffusion_transformer_hybrid_workspace.py`
|
|
||||||
- `pusht_image_runner.py`
|
|
||||||
- the new/updated PushT image transformer configs
|
|
||||||
|
|
||||||
### Behavior
|
|
||||||
- SwanLab runs in `online` mode.
|
|
||||||
- Logged values are scalar metrics only, e.g.:
|
|
||||||
- `train_loss`
|
|
||||||
- `val_loss`
|
|
||||||
- `train_action_mse_error`
|
|
||||||
- `test_mean_score`
|
|
||||||
- aggregate rollout metrics and optional per-seed scalar rewards
|
|
||||||
- No simulation videos are uploaded or wrapped as logging objects.
|
|
||||||
- Local JSON logging remains enabled for auditability and remote-job fallback debugging.
|
|
||||||
|
|
||||||
### Operational safeguards
|
|
||||||
- Default PushT experiment configs set `task.env_runner.n_test_vis=0` and `task.env_runner.n_train_vis=0`.
|
|
||||||
- The PushT image runner will not emit video objects into `log_data`, preventing accidental uploads even if visualization counts are later changed.
|
|
||||||
- SwanLab credentials are provided through the environment at runtime, not committed into the repo.
|
|
||||||
|
|
||||||
## iMF Model Design
|
|
||||||
### Baseline reuse
|
|
||||||
The iMF path reuses:
|
|
||||||
- the existing image observation encoder
|
|
||||||
- the existing action/observation normalization path
|
|
||||||
- the existing training workspace skeleton
|
|
||||||
- the existing PushT image dataset and env runner
|
|
||||||
|
|
||||||
### New files
|
|
||||||
- `diffusion_policy/model/diffusion/imf_transformer_for_diffusion.py`
|
|
||||||
- `diffusion_policy/policy/imf_transformer_hybrid_image_policy.py`
|
|
||||||
- `image_pusht_diffusion_policy_dit_imf.yaml`
|
|
||||||
|
|
||||||
### Existing files changed for the iMF path
|
|
||||||
- `diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py`
|
|
||||||
- logging migration to SwanLab for this experiment chain
|
|
||||||
- no structural training-loop fork beyond instantiating the configured policy and logging scalar metrics
|
|
||||||
- `diffusion_policy/env_runner/pusht_image_runner.py`
|
|
||||||
- suppress video objects in returned logs
|
|
||||||
|
|
||||||
### Model structure
|
|
||||||
The iMF transformer mirrors the current transformer policy structure closely enough to reuse known-good conditioning patterns, but it remains a **single-head model** that predicts only:
|
|
||||||
- `u`: average velocity field
|
|
||||||
|
|
||||||
The same function is reused at two evaluation points:
|
|
||||||
- canonical signature: `fn(z, r, t, cond)`
|
|
||||||
- `fn(z_t, r, t, cond)` predicts average velocity `u`
|
|
||||||
- `fn(z_t, t, t, cond)` predicts the instantaneous velocity surrogate `v`
|
|
||||||
|
|
||||||
Inputs remain conditioned on encoded observations and action trajectory tokens.
|
|
||||||
|
|
||||||
## iMF Training Objective
|
|
||||||
For a normalized action trajectory `x`, the initial implementation follows the user-provided Algorithm 1 exactly:
|
|
||||||
1. sample `t, r`
|
|
||||||
2. sample Gaussian noise `e`
|
|
||||||
3. form `z_t = (1 - t) * x + t * e`
|
|
||||||
4. predict instantaneous velocity surrogate with the same network:
|
|
||||||
- `v = fn(z_t, t, t, cond)`
|
|
||||||
5. define the JVP function exactly as:
|
|
||||||
- `g(z, r, t) = fn(z, r, t, cond)`
|
|
||||||
6. compute the primal output and JVP with tangent:
|
|
||||||
- `u, du_dt = jvp(g, (z_t, r, t), (v.detach(), 0, 1))`
|
|
||||||
7. form compound velocity:
|
|
||||||
- `V = u + (t - r) * stopgrad(du_dt)`
|
|
||||||
8. train against the average-velocity target:
|
|
||||||
- `target = e - x`
|
|
||||||
9. optimize only the masked iMF loss:
|
|
||||||
- `loss = metric(V - target)`
|
|
||||||
|
|
||||||
There is **no auxiliary `v` loss** in the initial implementation. The implementation should prefer `torch.func.jvp` and keep a safe fallback path if the local Torch stack needs it.
|
|
||||||
|
|
||||||
## iMF Inference Design
|
|
||||||
Inference uses a single step starting from noise:
|
|
||||||
- initialize `z_1 ~ N(0, I)`
|
|
||||||
- set `t = 1.0`, `r = 0.0`
|
|
||||||
- predict `u = fn(z_1, r, t, cond)`
|
|
||||||
- produce the action sample with one update:
|
|
||||||
- `x_hat = z_1 - (t - r) * u`
|
|
||||||
|
|
||||||
This matches the time direction in the reference iMeanFlow sampling logic.
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
### Phase 1: logging migration smoke test
|
|
||||||
- use the repo-local `uv` environment
|
|
||||||
- run a debug/smoke PushT image DiT training job on a single GPU with:
|
|
||||||
- `training.debug=true`
|
|
||||||
- `dataloader.num_workers=0`
|
|
||||||
- `val_dataloader.num_workers=0`
|
|
||||||
- `task.env_runner.n_envs=1`
|
|
||||||
- `task.env_runner.n_test_vis=0`
|
|
||||||
- `task.env_runner.n_train_vis=0`
|
|
||||||
- verify:
|
|
||||||
- SwanLab initializes successfully
|
|
||||||
- `logs.json.txt` is populated
|
|
||||||
- rollout metrics still include `test_mean_score`
|
|
||||||
- no video logging is attempted
|
|
||||||
|
|
||||||
### Phase 2: iMF smoke test
|
|
||||||
- run an equivalent debug PushT image iMF job
|
|
||||||
- verify:
|
|
||||||
- forward/backward passes succeed
|
|
||||||
- JVP path executes on the local Torch version
|
|
||||||
- one-step inference returns correctly shaped actions
|
|
||||||
- rollout produces scalar metrics including `test_mean_score`
|
|
||||||
|
|
||||||
## Branch and Commit Strategy
|
|
||||||
1. start from a `main`-based worktree branch
|
|
||||||
2. commit the SwanLab/no-video migration after smoke verification
|
|
||||||
3. continue with the iMF implementation
|
|
||||||
4. once iMF smoke tests pass, create/preserve a dedicated feature branch for the experiment code and push it to Gitea
|
|
||||||
|
|
||||||
## Post-Implementation Experiment Plan
|
|
||||||
After the iMF path is smoke-tested and pushed, a separate experiment-execution plan should launch:
|
|
||||||
- run a 3x3 grid over:
|
|
||||||
- `n_emb ∈ {128, 256, 384}`
|
|
||||||
- `n_layer ∈ {6, 12, 18}`
|
|
||||||
- keep the rest of the setup fixed
|
|
||||||
- use a fixed single-seed setting for comparability unless a later explicit experiment plan expands that scope
|
|
||||||
- run each experiment for 300 epochs
|
|
||||||
- primary comparison metric: `test_mean_score`
|
|
||||||
|
|
||||||
## Post-Implementation Resource Allocation
|
|
||||||
The separate experiment-execution plan should schedule three concurrent runs until the matrix is complete:
|
|
||||||
- local machine: 1 GPU
|
|
||||||
- `5880`: 2 GPUs
|
|
||||||
|
|
||||||
Each run uses the same uv-managed environment and the same pushed branch so the code path is consistent across hosts.
|
|
||||||
|
|
||||||
## Risks and Mitigations
|
|
||||||
- **Torch JVP compatibility risk**: provide a fallback JVP implementation and smoke-test immediately.
|
|
||||||
- **Logging regression risk**: keep local JSON logging and verify scalar rollout metrics before moving to iMF.
|
|
||||||
- **Video/logging side effects**: disable visualizations in config and filter video objects out of runner logs.
|
|
||||||
- **Cross-host drift**: push the verified branch to Gitea before launching the experiment matrix on multiple machines.
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
# PushT Image iMF Full-Attention Sweep Design
|
|
||||||
|
|
||||||
## Goal
|
|
||||||
在一个独立新分支上,为 PushT 图像 iMF 路线新增 **full-attention** 变体(关闭因果注意力),并按与之前相同的架构扫描网格运行 **9 组实验**,每组训练 **350 epochs**。所有实验完成后,提取每组 **`max(test_mean_score)`** 并输出完整排名和趋势总结。
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
本次工作仅覆盖:
|
|
||||||
1. 在不影响现有因果版 iMF 路线的前提下,新增 full-attention 实验链路;
|
|
||||||
2. 对 `n_emb ∈ {128, 256, 384}` 与 `n_layer ∈ {6, 12, 18}` 的 9 组组合做 350-epoch 扫描;
|
|
||||||
3. 在本机 5090 与 5880 双卡上做三路并行调度;
|
|
||||||
4. 在全部实验完成后自动汇总结果并直接向用户汇报。
|
|
||||||
|
|
||||||
不在本次范围内:
|
|
||||||
- 不替换或删除现有因果版 iMF 配置;
|
|
||||||
- 不改动已有 DiT baseline 实现;
|
|
||||||
- 不做多 seed 扩展;
|
|
||||||
- 不额外增加视频记录。
|
|
||||||
|
|
||||||
## Design Choice
|
|
||||||
采用“**新增独立配置 + 新分支**”的方式,而不是覆盖现有 iMF 默认配置。
|
|
||||||
|
|
||||||
原因:
|
|
||||||
- 现有因果版 iMF 已完成实验与结果记录,保持不动更利于对照;
|
|
||||||
- full-attention 作为新的实验链路,使用独立配置更易复现;
|
|
||||||
- 运行时只需要通过配置切换 `policy.causal_attn=false`,不需要重新设计 iMF 算法本身。
|
|
||||||
|
|
||||||
## Configuration Design
|
|
||||||
新增一个独立配置文件,例如:
|
|
||||||
- `image_pusht_diffusion_policy_dit_imf_fullattn.yaml`
|
|
||||||
|
|
||||||
其职责:
|
|
||||||
- 继承当前 PushT image iMF 配置链路;
|
|
||||||
- 保持 iMF 单步推理、SwanLab 标量记录、无视频记录;
|
|
||||||
- 显式设置:
|
|
||||||
- `policy.causal_attn=false`
|
|
||||||
- `policy.n_head=1`
|
|
||||||
- 保持其余 iMF 训练语义不变。
|
|
||||||
|
|
||||||
SwanLab 命名延续当前修复后的策略:
|
|
||||||
- `logging.name=${exp_name}`
|
|
||||||
- `logging.resume=false`
|
|
||||||
- `logging.id=null`
|
|
||||||
- `logging.group=${exp_name}` 或统一 sweep group override
|
|
||||||
|
|
||||||
## Code Change Strategy
|
|
||||||
优先最小改动:
|
|
||||||
- 若当前 `IMFTransformerForDiffusion` 已支持 `causal_attn=False` 分支,则不改核心算法,仅通过新配置关闭因果 mask;
|
|
||||||
- 如需补充回归验证,则新增针对 full-attention 配置/掩码行为的最小测试;
|
|
||||||
- 不改变已有因果版实验配置和已有测试语义。
|
|
||||||
|
|
||||||
## Experiment Matrix
|
|
||||||
实验网格固定为:
|
|
||||||
|
|
||||||
- `n_emb=128, n_layer=6`
|
|
||||||
- `n_emb=128, n_layer=12`
|
|
||||||
- `n_emb=128, n_layer=18`
|
|
||||||
- `n_emb=256, n_layer=6`
|
|
||||||
- `n_emb=256, n_layer=12`
|
|
||||||
- `n_emb=256, n_layer=18`
|
|
||||||
- `n_emb=384, n_layer=6`
|
|
||||||
- `n_emb=384, n_layer=12`
|
|
||||||
- `n_emb=384, n_layer=18`
|
|
||||||
|
|
||||||
统一设置:
|
|
||||||
- `training.num_epochs=350`
|
|
||||||
- `training.resume=false`
|
|
||||||
- `seed=42`
|
|
||||||
- PushT image 数据路径不变
|
|
||||||
- 指标以 **`logs.json.txt` 中 `test_mean_score` 的最大值** 为准
|
|
||||||
|
|
||||||
## Scheduling Design
|
|
||||||
使用三路串行队列并行执行 9 个实验:
|
|
||||||
|
|
||||||
- 本机 5090:1 个顺序队列
|
|
||||||
- 5880 GPU0:1 个顺序队列
|
|
||||||
- 5880 GPU1:1 个顺序队列
|
|
||||||
|
|
||||||
分配原则:
|
|
||||||
- 延续按 `n_emb × n_layer` 近似平衡工作量;
|
|
||||||
- 每张卡同一时刻只跑 1 个实验;
|
|
||||||
- 队列脚本负责“前一个结束后自动启动下一个”。
|
|
||||||
|
|
||||||
## Monitoring Design
|
|
||||||
继续采用“**训练队列脚本 + 监控 agent**”双层机制:
|
|
||||||
|
|
||||||
1. **实际调度**由本地/远端队列脚本负责;
|
|
||||||
2. **监控**由一个 xhigh 子 agent 轮询:
|
|
||||||
- 读取 pid 状态
|
|
||||||
- 检查 master log
|
|
||||||
- 检查每个 run 的 `logs.json.txt`
|
|
||||||
- 判断是否卡死/失败/全部完成
|
|
||||||
3. 一旦全部完成,监控 agent 直接返回:
|
|
||||||
- 9 组实验的最终 epoch
|
|
||||||
- 每组 `max(test_mean_score)`
|
|
||||||
- 排名表
|
|
||||||
- embedding / layer 趋势总结
|
|
||||||
|
|
||||||
本次要求下,agent 在收到全部完成信号后应直接向主会话回报结果,不等待用户再次提醒。
|
|
||||||
|
|
||||||
## Success Criteria
|
|
||||||
满足以下条件即视为完成:
|
|
||||||
1. full-attention iMF 配置在新分支上可运行;
|
|
||||||
2. 9 组 350-epoch 实验全部完成;
|
|
||||||
3. 不记录仿真视频,只记录标量;
|
|
||||||
4. SwanLab 运行命名不冲突;
|
|
||||||
5. 输出 9 组实验 `max(test_mean_score)` 的完整汇总与结论;
|
|
||||||
6. 全部实验结束后主会话可直接给用户最终总结。
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
# PushT Image iMF AttnRes Design
|
|
||||||
|
|
||||||
## Goal
|
|
||||||
在现有 PushT 图像 iMF full-attention 路线之上,引入 `attn_res` 仓库中的 **Full AttnRes** 残差聚合形式,并同步使用与其匹配的 **RMSNorm + 自注意力 + SwiGLU FFN** 模块,保持 iMF 训练目标与一步推理语义不变,仅作用于本次实验链路。实现完成并验证后,启动与此前相同的 9 组 `n_emb × n_layer` 扫描(350 epochs, seed=42, SwanLab online, 无视频记录)。
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
本次工作仅覆盖:
|
|
||||||
1. 为 `IMFTransformerForDiffusion` 增加一个 AttnRes-backed backbone 变体;
|
|
||||||
2. 保持 `forward(sample, r, t, cond=None)`、iMF loss、一步推理策略接口不变;
|
|
||||||
3. 新增独立 PushT 图像配置用于该变体;
|
|
||||||
4. 复用本地 5090 + 远端 5880 双卡三路并行调度 9 组实验。
|
|
||||||
|
|
||||||
不在范围内:
|
|
||||||
- 不替换已有 vanilla iMF/full-attn 配置;
|
|
||||||
- 不修改 DiT baseline;
|
|
||||||
- 不增加视频日志;
|
|
||||||
- 不扩大到多 seed。
|
|
||||||
|
|
||||||
## Recommended Approach
|
|
||||||
采用“**在当前 iMF 模型内增加可选 AttnRes backbone**”的方式,而不是新建独立 policy 链路。
|
|
||||||
|
|
||||||
理由:
|
|
||||||
- policy / workspace / loss / sampling 路径已经被验证,保留这些路径可最大程度缩小变动面;
|
|
||||||
- 仅在模型内部切换 backbone,可以让新实验与既有 iMF 结果保持可比;
|
|
||||||
- 配置上只需显式打开 `backbone_type=attnres_full`、`causal_attn=false` 等开关,复现实验更直接。
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
### 1. Backbone split
|
|
||||||
`IMFTransformerForDiffusion` 保留现有 vanilla encoder/decoder 实现为默认路径,并新增 `attnres_full` 路径:
|
|
||||||
- **vanilla**:保持当前实现不变;
|
|
||||||
- **attnres_full**:使用单栈式全注意力 Transformer,输入 token 序列为
|
|
||||||
`[r token, t token, obs cond tokens..., action/sample tokens...]`。
|
|
||||||
|
|
||||||
模型只对末尾的 action/sample token 位置输出 `u` 预测,前置条件 token 仅参与上下文建模。
|
|
||||||
|
|
||||||
### 2. AttnRes stack
|
|
||||||
新 backbone 使用以下模块:
|
|
||||||
- `RMSNorm`
|
|
||||||
- `Rotary Position Embedding`(用于自注意力 q/k)
|
|
||||||
- `GroupedQueryAttention`(本实验默认 `n_kv_head=1`,与单头配置兼容)
|
|
||||||
- `SwiGLU` FFN
|
|
||||||
- `AttnResOperator`(每个子层一个 pseudo-query,执行 full depth-wise residual aggregation)
|
|
||||||
|
|
||||||
每个 transformer block 由两个子层组成:
|
|
||||||
1. self-attention 子层
|
|
||||||
2. FFN 子层
|
|
||||||
|
|
||||||
每个子层的输入不再是简单 `x + f(x)`,而是从 embedding 与全部历史子层输出中通过 Full AttnRes 聚合得到 `h_l`,再执行 `RMSNorm(h_l) -> sublayer_fn(...)`。
|
|
||||||
|
|
||||||
### 3. Conditioning and token flow
|
|
||||||
- `sample` 先经 `input_emb` 映射为 action tokens;
|
|
||||||
- `r` 和 `t` 各自经 `SinusoidalPosEmb + linear` 映射为两个条件 token;
|
|
||||||
- 图像观测编码后的 `cond` 通过 `cond_obs_emb` 映射为 obs tokens;
|
|
||||||
- 拼接后的完整 token 序列进入 AttnRes stack;
|
|
||||||
- 输出时切掉前置条件 token,仅保留 action/sample token 段,随后经 `RMSNorm + head` 得到最终 `u`。
|
|
||||||
|
|
||||||
### 4. Attention mode
|
|
||||||
本次实验链路固定为 **non-causal full attention**:
|
|
||||||
- `causal_attn=false`
|
|
||||||
- 不构造 causal mask
|
|
||||||
- 所有 token 可彼此双向可见
|
|
||||||
|
|
||||||
这与用户指定的“训练过程仍然使用全注意力(不加因果注意)”一致。
|
|
||||||
|
|
||||||
## Config and Logging
|
|
||||||
新增独立配置文件,例如:
|
|
||||||
- `image_pusht_diffusion_policy_dit_imf_attnres_full.yaml`
|
|
||||||
|
|
||||||
该配置需要:
|
|
||||||
- 指向现有 `IMFTransformerHybridImagePolicy`
|
|
||||||
- 显式开启 AttnRes backbone 相关参数
|
|
||||||
- 设置 `policy.causal_attn=false`
|
|
||||||
- 保持 `logging.backend=swanlab`、`logging.mode=online`
|
|
||||||
- 运行时通过覆盖保证:
|
|
||||||
- `logging.name=<unique_run_name>`
|
|
||||||
- `logging.group=imf_pusht_attnres_arch_sweep`
|
|
||||||
- `exp_name=<unique_run_name>`
|
|
||||||
- 保持 `task.env_runner.n_test_vis=0` 与 `n_train_vis=0`,仅记录标量
|
|
||||||
|
|
||||||
## Experiment Matrix
|
|
||||||
固定 9 组:
|
|
||||||
- `n_emb ∈ {128, 256, 384}`
|
|
||||||
- `n_layer ∈ {6, 12, 18}`
|
|
||||||
- `seed=42`
|
|
||||||
- `training.num_epochs=350`
|
|
||||||
|
|
||||||
## Scheduling
|
|
||||||
沿用之前验证过的三队列分配:
|
|
||||||
- 本机 5090:`384x18`, `256x6`, `128x6`
|
|
||||||
- 5880 GPU0:`384x12`, `256x12`, `128x12`
|
|
||||||
- 5880 GPU1:`384x6`, `256x18`, `128x18`
|
|
||||||
|
|
||||||
每个 run name 编码 backbone 与结构,例如:
|
|
||||||
`imf_attnres_emb256_layer12_seed42_5880gpu0`
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
实现阶段至少验证:
|
|
||||||
1. 新配置的 SwanLab 命名与 `causal_attn=false` 正确;
|
|
||||||
2. 新 backbone 的 forward shape 与 `configure_optimizers()` 可用;
|
|
||||||
3. 旧 vanilla 路径测试不回归;
|
|
||||||
4. `training.debug=true` smoke run 可以完整通过。
|
|
||||||
|
|
||||||
## Success Criteria
|
|
||||||
1. 新 AttnRes iMF 变体在本分支可训练、可一步推理;
|
|
||||||
2. 不影响已有 vanilla iMF/full-attn 链路;
|
|
||||||
3. 9 组实验成功在三张卡上正式启动;
|
|
||||||
4. SwanLab run 名称唯一,无冲突;
|
|
||||||
5. 不记录视频,仅记录标量。
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
|
||||||
- override /diffusion_policy/config/task@task: pusht_image
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
exp_name: pusht_image_dit
|
|
||||||
|
|
||||||
policy:
|
|
||||||
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
|
||||||
|
|
||||||
logging:
|
|
||||||
backend: swanlab
|
|
||||||
mode: online
|
|
||||||
name: ${exp_name}
|
|
||||||
resume: false
|
|
||||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
|
||||||
id: null
|
|
||||||
group: ${exp_name}
|
|
||||||
|
|
||||||
dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
val_dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
task:
|
|
||||||
env_runner:
|
|
||||||
n_envs: 1
|
|
||||||
n_test_vis: 0
|
|
||||||
n_train_vis: 0
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
|
||||||
- override /diffusion_policy/config/task@task: pusht_image
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
exp_name: pusht_image_dit_imf
|
|
||||||
|
|
||||||
policy:
|
|
||||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
|
||||||
num_inference_steps: 1
|
|
||||||
n_head: 1
|
|
||||||
|
|
||||||
logging:
|
|
||||||
backend: swanlab
|
|
||||||
mode: online
|
|
||||||
name: ${exp_name}
|
|
||||||
resume: false
|
|
||||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
|
||||||
id: null
|
|
||||||
group: ${exp_name}
|
|
||||||
|
|
||||||
dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
val_dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
task:
|
|
||||||
env_runner:
|
|
||||||
n_envs: 1
|
|
||||||
n_test_vis: 0
|
|
||||||
n_train_vis: 0
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
|
||||||
- override /diffusion_policy/config/task@task: pusht_image
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
exp_name: pusht_image_dit_imf_attnres_full
|
|
||||||
|
|
||||||
policy:
|
|
||||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
|
||||||
num_inference_steps: 1
|
|
||||||
n_head: 1
|
|
||||||
n_kv_head: 1
|
|
||||||
causal_attn: false
|
|
||||||
backbone_type: attnres_full
|
|
||||||
|
|
||||||
logging:
|
|
||||||
backend: swanlab
|
|
||||||
mode: online
|
|
||||||
name: ${exp_name}
|
|
||||||
resume: false
|
|
||||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
|
||||||
id: null
|
|
||||||
group: ${exp_name}
|
|
||||||
|
|
||||||
dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
val_dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
task:
|
|
||||||
env_runner:
|
|
||||||
n_envs: 1
|
|
||||||
n_test_vis: 0
|
|
||||||
n_train_vis: 0
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
|
||||||
- override /diffusion_policy/config/task@task: pusht_image
|
|
||||||
- _self_
|
|
||||||
|
|
||||||
exp_name: pusht_image_dit_imf_fullattn
|
|
||||||
|
|
||||||
policy:
|
|
||||||
_target_: diffusion_policy.policy.imf_transformer_hybrid_image_policy.IMFTransformerHybridImagePolicy
|
|
||||||
num_inference_steps: 1
|
|
||||||
n_head: 1
|
|
||||||
causal_attn: false
|
|
||||||
|
|
||||||
logging:
|
|
||||||
backend: swanlab
|
|
||||||
mode: online
|
|
||||||
name: ${exp_name}
|
|
||||||
resume: false
|
|
||||||
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
|
||||||
id: null
|
|
||||||
group: ${exp_name}
|
|
||||||
|
|
||||||
dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
val_dataloader:
|
|
||||||
num_workers: 0
|
|
||||||
|
|
||||||
task:
|
|
||||||
env_runner:
|
|
||||||
n_envs: 1
|
|
||||||
n_test_vis: 0
|
|
||||||
n_train_vis: 0
|
|
||||||
189
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
189
image_pusht_diffusion_policy_dit_pmf.yaml
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
_target_: diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
|
||||||
|
checkpoint:
|
||||||
|
save_last_ckpt: true
|
||||||
|
save_last_snapshot: false
|
||||||
|
topk:
|
||||||
|
format_str: epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt
|
||||||
|
k: 5
|
||||||
|
mode: min
|
||||||
|
monitor_key: train_loss
|
||||||
|
dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: true
|
||||||
|
dataset_obs_steps: 2
|
||||||
|
ema:
|
||||||
|
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||||
|
inv_gamma: 1.0
|
||||||
|
max_value: 0.9999
|
||||||
|
min_value: 0.0
|
||||||
|
power: 0.75
|
||||||
|
update_after_step: 0
|
||||||
|
exp_name: default
|
||||||
|
horizon: 16
|
||||||
|
keypoint_visible_rate: 1.0
|
||||||
|
logging:
|
||||||
|
group: null
|
||||||
|
id: null
|
||||||
|
mode: online
|
||||||
|
name: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
project: diffusion_policy_debug
|
||||||
|
resume: true
|
||||||
|
tags:
|
||||||
|
- train_diffusion_transformer_hybrid_pmf
|
||||||
|
- pusht_image
|
||||||
|
- default
|
||||||
|
multi_run:
|
||||||
|
run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_train_diffusion_transformer_hybrid_pmf_pusht_image
|
||||||
|
n_action_steps: 8
|
||||||
|
n_latency_steps: 0
|
||||||
|
n_obs_steps: 2
|
||||||
|
name: train_diffusion_transformer_hybrid_pmf
|
||||||
|
obs_as_cond: true
|
||||||
|
optimizer:
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.95
|
||||||
|
learning_rate: 0.0001
|
||||||
|
obs_encoder_weight_decay: 1.0e-06
|
||||||
|
transformer_weight_decay: 0.001
|
||||||
|
past_action_visible: false
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.pmf_transformer_hybrid_image_policy.PMFTransformerHybridImagePolicy
|
||||||
|
crop_shape:
|
||||||
|
- 84
|
||||||
|
- 84
|
||||||
|
eval_fixed_crop: true
|
||||||
|
horizon: 16
|
||||||
|
n_action_steps: 8
|
||||||
|
n_cond_layers: 0
|
||||||
|
n_emb: 256
|
||||||
|
n_head: 4
|
||||||
|
n_layer: 12
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_time_tokens: 4
|
||||||
|
noise_scale: 1.0
|
||||||
|
adatloss_eps: 0.01
|
||||||
|
p_mean: -0.4
|
||||||
|
p_std: 1.0
|
||||||
|
noise_scheduler:
|
||||||
|
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
||||||
|
beta_end: 0.02
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
beta_start: 0.0001
|
||||||
|
clip_sample: true
|
||||||
|
num_train_timesteps: 100
|
||||||
|
prediction_type: sample
|
||||||
|
variance_type: fixed_small
|
||||||
|
num_inference_steps: 1
|
||||||
|
obs_as_cond: true
|
||||||
|
obs_encoder_group_norm: true
|
||||||
|
p_drop_attn: 0.0
|
||||||
|
p_drop_emb: 0.0
|
||||||
|
pmf_u_loss_weight: 1.0
|
||||||
|
pmf_v_loss_weight: 1.0
|
||||||
|
tr_uniform: true
|
||||||
|
tr_uniform_prob: 0.1
|
||||||
|
data_proportion: 0.5
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task:
|
||||||
|
dataset:
|
||||||
|
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
||||||
|
horizon: 16
|
||||||
|
max_train_episodes: 90
|
||||||
|
pad_after: 7
|
||||||
|
pad_before: 1
|
||||||
|
seed: 42
|
||||||
|
val_ratio: 0.02
|
||||||
|
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
||||||
|
env_runner:
|
||||||
|
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
||||||
|
fps: 10
|
||||||
|
legacy_test: true
|
||||||
|
max_steps: 300
|
||||||
|
n_action_steps: 8
|
||||||
|
n_envs: null
|
||||||
|
n_obs_steps: 2
|
||||||
|
n_test: 50
|
||||||
|
n_test_vis: 4
|
||||||
|
n_train: 6
|
||||||
|
n_train_vis: 2
|
||||||
|
past_action: false
|
||||||
|
test_start_seed: 100000
|
||||||
|
train_start_seed: 0
|
||||||
|
image_shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
name: pusht_image
|
||||||
|
shape_meta:
|
||||||
|
action:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
obs:
|
||||||
|
agent_pos:
|
||||||
|
shape:
|
||||||
|
- 2
|
||||||
|
type: low_dim
|
||||||
|
image:
|
||||||
|
shape:
|
||||||
|
- 3
|
||||||
|
- 96
|
||||||
|
- 96
|
||||||
|
type: rgb
|
||||||
|
task_name: pusht_image
|
||||||
|
training:
|
||||||
|
checkpoint_every: 50
|
||||||
|
debug: false
|
||||||
|
device: cuda:0
|
||||||
|
gradient_accumulate_every: 1
|
||||||
|
lr_scheduler: cosine
|
||||||
|
lr_warmup_steps: 500
|
||||||
|
max_train_steps: null
|
||||||
|
max_val_steps: null
|
||||||
|
num_epochs: 600
|
||||||
|
resume: true
|
||||||
|
rollout_every: 50
|
||||||
|
sample_every: 5
|
||||||
|
seed: 42
|
||||||
|
tqdm_interval_sec: 1.0
|
||||||
|
use_ema: true
|
||||||
|
val_every: 1
|
||||||
|
val_dataloader:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 8
|
||||||
|
persistent_workers: false
|
||||||
|
pin_memory: true
|
||||||
|
shuffle: false
|
||||||
@@ -22,6 +22,7 @@ pymunk==6.2.1
|
|||||||
wandb==0.13.3
|
wandb==0.13.3
|
||||||
threadpoolctl==3.1.0
|
threadpoolctl==3.1.0
|
||||||
shapely==1.8.5.post1
|
shapely==1.8.5.post1
|
||||||
|
matplotlib==3.6.1
|
||||||
imageio==2.22.0
|
imageio==2.22.0
|
||||||
imageio-ffmpeg==0.4.7
|
imageio-ffmpeg==0.4.7
|
||||||
termcolor==2.0.1
|
termcolor==2.0.1
|
||||||
@@ -36,4 +37,3 @@ av==14.0.1
|
|||||||
pygame==2.5.2
|
pygame==2.5.2
|
||||||
robomimic==0.2.0
|
robomimic==0.2.0
|
||||||
opencv-python-headless==4.10.0.84
|
opencv-python-headless==4.10.0.84
|
||||||
swanlab
|
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
cd /home/droid/project/diffusion_policy/.worktrees/feat-pusht-imf-attnres
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
|
||||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
|
||||||
run_exp() {
|
|
||||||
local name="$1" emb="$2" layer="$3"
|
|
||||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
|
||||||
.venv/bin/python train.py \
|
|
||||||
--config-dir=. \
|
|
||||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
|
||||||
training.device=cuda:0 \
|
|
||||||
training.num_epochs=350 \
|
|
||||||
training.resume=false \
|
|
||||||
exp_name="$name" \
|
|
||||||
logging.group=imf_pusht_attnres_arch_sweep \
|
|
||||||
logging.name="$name" \
|
|
||||||
logging.resume=false \
|
|
||||||
logging.id=null \
|
|
||||||
hydra.run.dir="data/outputs/$name" \
|
|
||||||
policy.n_emb="$emb" \
|
|
||||||
policy.n_layer="$layer" \
|
|
||||||
> "data/run_logs/${name}.log" 2>&1
|
|
||||||
echo "[$(date '+%F %T')] END $name"
|
|
||||||
}
|
|
||||||
run_exp imf_attnres_emb384_layer18_seed42_local 384 18
|
|
||||||
run_exp imf_attnres_emb256_layer6_seed42_local 256 6
|
|
||||||
run_exp imf_attnres_emb128_layer6_seed42_local 128 6
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
cd /home/droid/project/diffusion_policy-smoke
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
|
||||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
|
||||||
run_exp() {
|
|
||||||
local name="$1" emb="$2" layer="$3"
|
|
||||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
|
||||||
.venv/bin/python train.py \
|
|
||||||
--config-dir=. \
|
|
||||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
|
||||||
training.device=cuda:0 \
|
|
||||||
training.num_epochs=350 \
|
|
||||||
training.resume=false \
|
|
||||||
exp_name="$name" \
|
|
||||||
logging.group=imf_pusht_attnres_arch_sweep \
|
|
||||||
logging.name="$name" \
|
|
||||||
logging.resume=false \
|
|
||||||
logging.id=null \
|
|
||||||
hydra.run.dir="data/outputs/$name" \
|
|
||||||
policy.n_emb="$emb" \
|
|
||||||
policy.n_layer="$layer" \
|
|
||||||
> "data/run_logs/${name}.log" 2>&1
|
|
||||||
echo "[$(date '+%F %T')] END $name"
|
|
||||||
}
|
|
||||||
run_exp imf_attnres_emb384_layer12_seed42_5880gpu0 384 12
|
|
||||||
run_exp imf_attnres_emb256_layer12_seed42_5880gpu0 256 12
|
|
||||||
run_exp imf_attnres_emb128_layer12_seed42_5880gpu0 128 12
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
cd /home/droid/project/diffusion_policy-smoke
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
export SWANLAB_API_KEY='PSZrBMLx1XAjDjvmhUcNz'
|
|
||||||
export LD_LIBRARY_PATH="$(printf '%s:' .venv/lib/python3.9/site-packages/nvidia/*/lib | sed 's/:$//')"
|
|
||||||
run_exp() {
|
|
||||||
local name="$1" emb="$2" layer="$3"
|
|
||||||
echo "[$(date '+%F %T')] START $name emb=$emb layer=$layer"
|
|
||||||
.venv/bin/python train.py \
|
|
||||||
--config-dir=. \
|
|
||||||
--config-name=image_pusht_diffusion_policy_dit_imf_attnres_full.yaml \
|
|
||||||
training.device=cuda:1 \
|
|
||||||
training.num_epochs=350 \
|
|
||||||
training.resume=false \
|
|
||||||
exp_name="$name" \
|
|
||||||
logging.group=imf_pusht_attnres_arch_sweep \
|
|
||||||
logging.name="$name" \
|
|
||||||
logging.resume=false \
|
|
||||||
logging.id=null \
|
|
||||||
hydra.run.dir="data/outputs/$name" \
|
|
||||||
policy.n_emb="$emb" \
|
|
||||||
policy.n_layer="$layer" \
|
|
||||||
> "data/run_logs/${name}.log" 2>&1
|
|
||||||
echo "[$(date '+%F %T')] END $name"
|
|
||||||
}
|
|
||||||
run_exp imf_attnres_emb384_layer6_seed42_5880gpu1 384 6
|
|
||||||
run_exp imf_attnres_emb256_layer18_seed42_5880gpu1 256 18
|
|
||||||
run_exp imf_attnres_emb128_layer18_seed42_5880gpu1 128 18
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
||||||
if str(ROOT_DIR) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT_DIR))
|
|
||||||
|
|
||||||
from diffusion_policy.model.diffusion.imf_transformer_for_diffusion import ( # noqa: E402
|
|
||||||
IMFTransformerForDiffusion,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_imf_transformer_forward_signature_and_shape_single_head():
|
|
||||||
signature = inspect.signature(IMFTransformerForDiffusion.forward)
|
|
||||||
assert list(signature.parameters)[:5] == ['self', 'sample', 'r', 't', 'cond']
|
|
||||||
assert signature.parameters['cond'].default is None
|
|
||||||
|
|
||||||
model = IMFTransformerForDiffusion(
|
|
||||||
input_dim=3,
|
|
||||||
output_dim=3,
|
|
||||||
horizon=5,
|
|
||||||
n_obs_steps=2,
|
|
||||||
cond_dim=4,
|
|
||||||
n_layer=1,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=16,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.0,
|
|
||||||
causal_attn=True,
|
|
||||||
time_as_cond=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
n_cond_layers=0,
|
|
||||||
)
|
|
||||||
model.configure_optimizers()
|
|
||||||
|
|
||||||
sample = torch.randn(2, 5, 3)
|
|
||||||
r = torch.rand(2)
|
|
||||||
t = torch.rand(2)
|
|
||||||
cond = torch.randn(2, 2, 4)
|
|
||||||
|
|
||||||
pred_u = model(sample, r, t, cond=cond)
|
|
||||||
|
|
||||||
assert pred_u.shape == sample.shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_imf_transformer_attnres_full_backbone_forward_shape_and_optimizer():
|
|
||||||
model = IMFTransformerForDiffusion(
|
|
||||||
input_dim=3,
|
|
||||||
output_dim=3,
|
|
||||||
horizon=5,
|
|
||||||
n_obs_steps=2,
|
|
||||||
cond_dim=4,
|
|
||||||
n_layer=2,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=16,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.0,
|
|
||||||
causal_attn=False,
|
|
||||||
time_as_cond=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
n_cond_layers=0,
|
|
||||||
backbone_type='attnres_full',
|
|
||||||
)
|
|
||||||
optimizer = model.configure_optimizers()
|
|
||||||
|
|
||||||
sample = torch.randn(2, 5, 3)
|
|
||||||
r = torch.rand(2)
|
|
||||||
t = torch.rand(2)
|
|
||||||
cond = torch.randn(2, 2, 4)
|
|
||||||
|
|
||||||
pred_u = model(sample, r, t, cond=cond)
|
|
||||||
|
|
||||||
assert pred_u.shape == sample.shape
|
|
||||||
assert optimizer is not None
|
|
||||||
@@ -1,313 +0,0 @@
|
|||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
||||||
if str(ROOT_DIR) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT_DIR))
|
|
||||||
|
|
||||||
import diffusion_policy.policy.imf_transformer_hybrid_image_policy as policy_module # noqa: E402
|
|
||||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin # noqa: E402
|
|
||||||
from diffusion_policy.policy.imf_transformer_hybrid_image_policy import ( # noqa: E402
|
|
||||||
IMFTransformerHybridImagePolicy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantModel(nn.Module):
|
|
||||||
def __init__(self, value):
|
|
||||||
super().__init__()
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def forward(self, sample, r, t, cond=None):
|
|
||||||
return torch.full_like(sample, self.value)
|
|
||||||
|
|
||||||
|
|
||||||
class AffineModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
|
||||||
|
|
||||||
def forward(self, sample, r, t, cond=None):
|
|
||||||
return sample * self.weight + (r + t).view(-1, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class SumMixModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.tensor(2.0))
|
|
||||||
|
|
||||||
def forward(self, sample, r, t, cond=None):
|
|
||||||
mixed = sample.sum(dim=-1, keepdim=True).expand_as(sample)
|
|
||||||
return mixed * self.weight + t.view(-1, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class TrackingContext:
|
|
||||||
def __init__(self):
|
|
||||||
self.active = False
|
|
||||||
self.enter_count = 0
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.active = True
|
|
||||||
self.enter_count += 1
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
self.active = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def make_policy(model):
|
|
||||||
policy = IMFTransformerHybridImagePolicy.__new__(IMFTransformerHybridImagePolicy)
|
|
||||||
ModuleAttrMixin.__init__(policy)
|
|
||||||
policy.model = model
|
|
||||||
return policy
|
|
||||||
|
|
||||||
|
|
||||||
def fake_parent_init(
|
|
||||||
self,
|
|
||||||
shape_meta,
|
|
||||||
noise_scheduler,
|
|
||||||
horizon,
|
|
||||||
n_action_steps,
|
|
||||||
n_obs_steps,
|
|
||||||
num_inference_steps=None,
|
|
||||||
crop_shape=(76, 76),
|
|
||||||
obs_encoder_group_norm=False,
|
|
||||||
eval_fixed_crop=False,
|
|
||||||
n_layer=8,
|
|
||||||
n_cond_layers=0,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=256,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.3,
|
|
||||||
causal_attn=True,
|
|
||||||
time_as_cond=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
pred_action_steps_only=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
ModuleAttrMixin.__init__(self)
|
|
||||||
self.action_dim = shape_meta['action']['shape'][0]
|
|
||||||
self.obs_feature_dim = 4
|
|
||||||
self.obs_as_cond = obs_as_cond
|
|
||||||
self.pred_action_steps_only = pred_action_steps_only
|
|
||||||
self.n_action_steps = n_action_steps
|
|
||||||
self.n_obs_steps = n_obs_steps
|
|
||||||
self.horizon = horizon
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def shape_meta():
|
|
||||||
return {
|
|
||||||
'action': {'shape': [2]},
|
|
||||||
'obs': {},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_sample_one_step_uses_imf_update_formula():
|
|
||||||
policy = make_policy(ConstantModel(0.25))
|
|
||||||
z_1 = torch.tensor([
|
|
||||||
[[1.0, -1.0], [0.5, 0.0]],
|
|
||||||
[[2.0, 3.0], [-2.0, 4.0]],
|
|
||||||
])
|
|
||||||
r = torch.zeros(z_1.shape[0])
|
|
||||||
t = torch.ones(z_1.shape[0])
|
|
||||||
|
|
||||||
x_hat = policy._sample_one_step(z_1, r=r, t=t, cond=None)
|
|
||||||
|
|
||||||
expected = z_1 - (t - r).view(-1, 1, 1) * 0.25
|
|
||||||
assert torch.allclose(x_hat, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compound_velocity_uses_detached_du_dt_term():
|
|
||||||
policy = make_policy(ConstantModel(0.0))
|
|
||||||
u = torch.tensor([[[1.0], [2.0]]], requires_grad=True)
|
|
||||||
du_dt = torch.tensor([[[3.0], [4.0]]], requires_grad=True)
|
|
||||||
r = torch.tensor([0.2])
|
|
||||||
t = torch.tensor([0.8])
|
|
||||||
|
|
||||||
compound = policy._compound_velocity(u, du_dt, r, t)
|
|
||||||
expected = u + (t - r).view(-1, 1, 1) * du_dt.detach()
|
|
||||||
|
|
||||||
assert torch.allclose(compound, expected)
|
|
||||||
|
|
||||||
compound.sum().backward()
|
|
||||||
assert u.grad is not None
|
|
||||||
assert du_dt.grad is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_torch_func_jvp(monkeypatch):
|
|
||||||
tracker = TrackingContext()
|
|
||||||
|
|
||||||
def fake_jvp(fn, primals, tangents):
|
|
||||||
assert tracker.active is True
|
|
||||||
return fn(*primals), torch.zeros_like(primals[0])
|
|
||||||
|
|
||||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
|
||||||
|
|
||||||
policy = make_policy(ConstantModel(0.5))
|
|
||||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
|
||||||
z_t = torch.randn(2, 3, 4)
|
|
||||||
r = torch.rand(2, requires_grad=True)
|
|
||||||
t = torch.rand(2, requires_grad=True)
|
|
||||||
v = torch.randn_like(z_t, requires_grad=True)
|
|
||||||
|
|
||||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
|
||||||
|
|
||||||
assert tracker.enter_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_u_and_du_dt_uses_math_sdpa_context_for_autograd_fallback(monkeypatch):
|
|
||||||
tracker = TrackingContext()
|
|
||||||
|
|
||||||
def fake_autograd_jvp(fn, primals, tangents, create_graph=False, strict=False):
|
|
||||||
assert tracker.active is True
|
|
||||||
return fn(*primals), torch.zeros_like(primals[0])
|
|
||||||
|
|
||||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
|
||||||
monkeypatch.setattr(policy_module.torch.autograd.functional, 'jvp', fake_autograd_jvp)
|
|
||||||
|
|
||||||
policy = make_policy(ConstantModel(0.5))
|
|
||||||
policy._jvp_math_sdp_context = lambda tensor: tracker
|
|
||||||
z_t = torch.randn(2, 3, 4)
|
|
||||||
r = torch.rand(2, requires_grad=True)
|
|
||||||
t = torch.rand(2, requires_grad=True)
|
|
||||||
v = torch.randn_like(z_t, requires_grad=True)
|
|
||||||
|
|
||||||
policy._compute_u_and_du_dt(z_t, r, t, cond=None, v=v)
|
|
||||||
|
|
||||||
assert tracker.enter_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_u_and_du_dt_uses_detached_v_zero_r_unit_t_and_reapplies_conditioning(monkeypatch):
|
|
||||||
captured = {}
|
|
||||||
|
|
||||||
def fake_jvp(fn, primals, tangents):
|
|
||||||
captured['tangents'] = tangents
|
|
||||||
captured['primal_output'] = fn(*primals)
|
|
||||||
return captured['primal_output'], torch.zeros_like(primals[0])
|
|
||||||
|
|
||||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', fake_jvp)
|
|
||||||
|
|
||||||
policy = make_policy(SumMixModel())
|
|
||||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]])
|
|
||||||
r = torch.rand(1, requires_grad=True)
|
|
||||||
t = torch.rand(1, requires_grad=True)
|
|
||||||
v = torch.tensor([[[10.0, 20.0, 30.0]]], requires_grad=True)
|
|
||||||
condition_mask = torch.tensor([[[False, True, False]]])
|
|
||||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
|
||||||
|
|
||||||
policy._compute_u_and_du_dt(
|
|
||||||
z_t,
|
|
||||||
r,
|
|
||||||
t,
|
|
||||||
cond=None,
|
|
||||||
v=v,
|
|
||||||
condition_data=condition_data,
|
|
||||||
condition_mask=condition_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
tangent_v, tangent_r, tangent_t = captured['tangents']
|
|
||||||
assert torch.equal(tangent_v, v.detach())
|
|
||||||
assert tangent_v.requires_grad is False
|
|
||||||
assert torch.equal(tangent_r, torch.zeros_like(r))
|
|
||||||
assert torch.equal(tangent_t, torch.ones_like(t))
|
|
||||||
|
|
||||||
conditioned = z_t.clone()
|
|
||||||
conditioned[condition_mask] = condition_data[condition_mask]
|
|
||||||
expected_primal = policy.model(conditioned, r, t, cond=None)
|
|
||||||
assert torch.allclose(captured['primal_output'], expected_primal)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_u_and_du_dt_fallback_blocks_conditioned_tangent_leakage_and_keeps_primal_gradients(monkeypatch):
|
|
||||||
monkeypatch.setattr(policy_module, 'TORCH_FUNC_JVP', None)
|
|
||||||
|
|
||||||
policy = make_policy(SumMixModel())
|
|
||||||
z_t = torch.tensor([[[1.0, 2.0, 3.0]]], requires_grad=True)
|
|
||||||
r = torch.rand(1, requires_grad=True)
|
|
||||||
t = torch.rand(1, requires_grad=True)
|
|
||||||
v = torch.tensor([[[1.0, 10.0, 100.0]]], requires_grad=True)
|
|
||||||
condition_mask = torch.tensor([[[False, True, False]]])
|
|
||||||
condition_data = torch.tensor([[[0.0, 7.0, 0.0]]])
|
|
||||||
|
|
||||||
u, du_dt = policy._compute_u_and_du_dt(
|
|
||||||
z_t,
|
|
||||||
r,
|
|
||||||
t,
|
|
||||||
cond=None,
|
|
||||||
v=v,
|
|
||||||
condition_data=condition_data,
|
|
||||||
condition_mask=condition_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioned = z_t.detach().clone()
|
|
||||||
conditioned[condition_mask] = condition_data[condition_mask]
|
|
||||||
expected_u = policy.model(conditioned, r, t, cond=None)
|
|
||||||
expected_du_dt_scalar = policy.model.weight.detach() * torch.tensor(101.0) + 1.0
|
|
||||||
expected_du_dt = torch.full_like(z_t, expected_du_dt_scalar)
|
|
||||||
|
|
||||||
assert u.shape == z_t.shape
|
|
||||||
assert du_dt.shape == z_t.shape
|
|
||||||
assert torch.allclose(u, expected_u)
|
|
||||||
assert torch.allclose(du_dt, expected_du_dt)
|
|
||||||
|
|
||||||
u.sum().backward()
|
|
||||||
assert policy.model.weight.grad is not None
|
|
||||||
assert torch.count_nonzero(policy.model.weight.grad) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_uses_action_step_horizon_when_pred_action_steps_only(monkeypatch, shape_meta):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
|
||||||
'__init__',
|
|
||||||
fake_parent_init,
|
|
||||||
)
|
|
||||||
|
|
||||||
policy = IMFTransformerHybridImagePolicy(
|
|
||||||
shape_meta=shape_meta,
|
|
||||||
noise_scheduler=None,
|
|
||||||
horizon=10,
|
|
||||||
n_action_steps=4,
|
|
||||||
n_obs_steps=2,
|
|
||||||
num_inference_steps=1,
|
|
||||||
n_layer=1,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=16,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.0,
|
|
||||||
causal_attn=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
pred_action_steps_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert policy.model.horizon == 4
|
|
||||||
assert policy.num_inference_steps == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_rejects_non_one_step_inference(monkeypatch, shape_meta):
|
|
||||||
monkeypatch.setattr(
|
|
||||||
policy_module.DiffusionTransformerHybridImagePolicy,
|
|
||||||
'__init__',
|
|
||||||
fake_parent_init,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='num_inference_steps'):
|
|
||||||
IMFTransformerHybridImagePolicy(
|
|
||||||
shape_meta=shape_meta,
|
|
||||||
noise_scheduler=None,
|
|
||||||
horizon=10,
|
|
||||||
n_action_steps=4,
|
|
||||||
n_obs_steps=2,
|
|
||||||
num_inference_steps=2,
|
|
||||||
n_layer=1,
|
|
||||||
n_head=1,
|
|
||||||
n_emb=16,
|
|
||||||
p_drop_emb=0.0,
|
|
||||||
p_drop_attn=0.0,
|
|
||||||
causal_attn=True,
|
|
||||||
obs_as_cond=True,
|
|
||||||
pred_action_steps_only=False,
|
|
||||||
)
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import gym
|
|
||||||
from gym import spaces
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
||||||
if str(ROOT_DIR) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT_DIR))
|
|
||||||
|
|
||||||
import diffusion_policy.env_runner.pusht_image_runner as runner_module
|
|
||||||
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
|
|
||||||
|
|
||||||
|
|
||||||
class FakePushTImageEnv(gym.Env):
|
|
||||||
metadata = {'render.modes': ['rgb_array']}
|
|
||||||
|
|
||||||
def __init__(self, legacy=False, render_size=96):
|
|
||||||
del legacy, render_size
|
|
||||||
self.observation_space = spaces.Dict({
|
|
||||||
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
|
|
||||||
})
|
|
||||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
|
|
||||||
self.seed_value = 0
|
|
||||||
self.step_count = 0
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.seed_value = 0 if seed is None else seed
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.step_count = 0
|
|
||||||
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
del action
|
|
||||||
self.step_count += 1
|
|
||||||
reward = 0.1 if self.seed_value < 10000 else 0.9
|
|
||||||
done = self.step_count >= 1
|
|
||||||
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
|
|
||||||
return obs, reward, done, {}
|
|
||||||
|
|
||||||
def render(self, *args, **kwargs):
|
|
||||||
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
|
|
||||||
|
|
||||||
|
|
||||||
class FakePolicy:
|
|
||||||
device = torch.device('cpu')
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def predict_action(self, obs_dict):
|
|
||||||
n_envs = next(iter(obs_dict.values())).shape[0]
|
|
||||||
return {
|
|
||||||
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
|
|
||||||
log_data = summarize_rollout_metrics(
|
|
||||||
env_seeds=[11, 12, 101],
|
|
||||||
env_prefixs=['train/', 'train/', 'test/'],
|
|
||||||
all_rewards=[
|
|
||||||
[0.2, 0.8],
|
|
||||||
[0.1, 0.4],
|
|
||||||
[0.5, 0.9],
|
|
||||||
],
|
|
||||||
all_video_paths=[
|
|
||||||
'/tmp/train-11.mp4',
|
|
||||||
'/tmp/train-12.mp4',
|
|
||||||
'/tmp/test-101.mp4',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert log_data['train/sim_max_reward_11'] == 0.8
|
|
||||||
assert log_data['train/sim_max_reward_12'] == 0.4
|
|
||||||
assert log_data['test/sim_max_reward_101'] == 0.9
|
|
||||||
assert log_data['train_mean_score'] == pytest.approx(0.6)
|
|
||||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
|
||||||
assert not any(key.startswith('train/sim_video_') for key in log_data)
|
|
||||||
assert not any(key.startswith('test/sim_video_') for key in log_data)
|
|
||||||
|
|
||||||
|
|
||||||
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
|
|
||||||
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
|
|
||||||
|
|
||||||
runner = runner_module.PushTImageRunner(
|
|
||||||
output_dir=tmp_path,
|
|
||||||
n_train=1,
|
|
||||||
n_train_vis=1,
|
|
||||||
n_test=1,
|
|
||||||
n_test_vis=1,
|
|
||||||
n_envs=2,
|
|
||||||
max_steps=2,
|
|
||||||
n_obs_steps=2,
|
|
||||||
n_action_steps=2,
|
|
||||||
tqdm_interval_sec=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_data = runner.run(FakePolicy())
|
|
||||||
|
|
||||||
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
|
|
||||||
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
|
|
||||||
assert log_data['train_mean_score'] == pytest.approx(0.1)
|
|
||||||
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
|
||||||
assert not any('sim_video' in key for key in log_data)
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
import pathlib
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cfg(name: str):
|
|
||||||
return OmegaConf.load(ROOT_DIR / name)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_pusht_dit_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
||||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit.yaml')
|
|
||||||
|
|
||||||
assert cfg.logging.backend == 'swanlab'
|
|
||||||
assert cfg.logging.mode == 'online'
|
|
||||||
assert cfg.logging.name == cfg.exp_name
|
|
||||||
assert cfg.logging.resume is False
|
|
||||||
assert cfg.logging.id is None
|
|
||||||
assert cfg.logging.group == cfg.exp_name
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_pusht_dit_imf_swanlab_config_uses_exp_name_and_no_resume_collision():
|
|
||||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf.yaml')
|
|
||||||
|
|
||||||
assert cfg.logging.backend == 'swanlab'
|
|
||||||
assert cfg.logging.mode == 'online'
|
|
||||||
assert cfg.logging.name == cfg.exp_name
|
|
||||||
assert cfg.logging.resume is False
|
|
||||||
assert cfg.logging.id is None
|
|
||||||
assert cfg.logging.group == cfg.exp_name
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_pusht_dit_imf_fullattn_config_uses_exp_name_and_disables_causal_attention():
|
|
||||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_fullattn.yaml')
|
|
||||||
|
|
||||||
assert cfg.logging.backend == 'swanlab'
|
|
||||||
assert cfg.logging.mode == 'online'
|
|
||||||
assert cfg.logging.name == cfg.exp_name
|
|
||||||
assert cfg.logging.resume is False
|
|
||||||
assert cfg.logging.id is None
|
|
||||||
assert cfg.logging.group == cfg.exp_name
|
|
||||||
assert cfg.policy.causal_attn is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_pusht_dit_imf_attnres_full_config_uses_exp_name_and_disables_causal_attention():
|
|
||||||
cfg = _load_cfg('image_pusht_diffusion_policy_dit_imf_attnres_full.yaml')
|
|
||||||
|
|
||||||
assert cfg.logging.backend == 'swanlab'
|
|
||||||
assert cfg.logging.mode == 'online'
|
|
||||||
assert cfg.logging.name == cfg.exp_name
|
|
||||||
assert cfg.logging.resume is False
|
|
||||||
assert cfg.logging.id is None
|
|
||||||
assert cfg.logging.group == cfg.exp_name
|
|
||||||
assert cfg.policy.causal_attn is False
|
|
||||||
assert cfg.policy.backbone_type == 'attnres_full'
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
|
||||||
if str(ROOT_DIR) not in sys.path:
|
|
||||||
sys.path.append(str(ROOT_DIR))
|
|
||||||
|
|
||||||
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
|
|
||||||
|
|
||||||
|
|
||||||
def load_workspace_module(monkeypatch, *, wandb_missing=False):
|
|
||||||
sys.modules.pop(MODULE_NAME, None)
|
|
||||||
if wandb_missing:
|
|
||||||
monkeypatch.setitem(sys.modules, 'wandb', None)
|
|
||||||
return importlib.import_module(MODULE_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
|
|
||||||
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
|
|
||||||
events = []
|
|
||||||
|
|
||||||
class FakeRun:
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
events.append(('log', payload, step))
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
events.append(('finish',))
|
|
||||||
|
|
||||||
class FakeSwanLab:
|
|
||||||
def init(self, **kwargs):
|
|
||||||
events.append(('init', kwargs))
|
|
||||||
return FakeRun()
|
|
||||||
|
|
||||||
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
|
|
||||||
monkeypatch.setattr(
|
|
||||||
workspace_module,
|
|
||||||
'_load_wandb',
|
|
||||||
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = OmegaConf.create({
|
|
||||||
'logging': {
|
|
||||||
'backend': 'swanlab',
|
|
||||||
'project': 'demo-project',
|
|
||||||
'name': 'demo-run',
|
|
||||||
'group': 'demo-group',
|
|
||||||
'tags': ['pusht', 'dit'],
|
|
||||||
'id': 'run-123',
|
|
||||||
'resume': True,
|
|
||||||
'mode': 'online',
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
|
||||||
logger.log({'metric': 1.0}, step=7)
|
|
||||||
logger.finish()
|
|
||||||
|
|
||||||
assert events[0][0] == 'init'
|
|
||||||
init_kwargs = events[0][1]
|
|
||||||
assert init_kwargs['project'] == 'demo-project'
|
|
||||||
assert init_kwargs['experiment_name'] == 'demo-run'
|
|
||||||
assert init_kwargs['group'] == 'demo-group'
|
|
||||||
assert init_kwargs['tags'] == ['pusht', 'dit']
|
|
||||||
assert init_kwargs['id'] == 'run-123'
|
|
||||||
assert init_kwargs['resume'] is True
|
|
||||||
assert init_kwargs['mode'] == 'cloud'
|
|
||||||
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
|
|
||||||
assert ('log', {'metric': 1.0}, 7) in events
|
|
||||||
assert events.count(('finish',)) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
|
|
||||||
workspace_module = load_workspace_module(monkeypatch)
|
|
||||||
events = []
|
|
||||||
|
|
||||||
class FakeRun:
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
events.append(('log', payload, step))
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
events.append(('finish',))
|
|
||||||
|
|
||||||
class FakeConfig:
|
|
||||||
def update(self, payload):
|
|
||||||
events.append(('config.update', payload))
|
|
||||||
|
|
||||||
class FakeWandb:
|
|
||||||
def __init__(self):
|
|
||||||
self.config = FakeConfig()
|
|
||||||
|
|
||||||
def init(self, **kwargs):
|
|
||||||
events.append(('init', kwargs))
|
|
||||||
return FakeRun()
|
|
||||||
|
|
||||||
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
|
|
||||||
|
|
||||||
cfg = OmegaConf.create({
|
|
||||||
'logging': {
|
|
||||||
'project': 'demo-project',
|
|
||||||
'name': 'demo-run',
|
|
||||||
'group': None,
|
|
||||||
'tags': ['shared'],
|
|
||||||
'id': None,
|
|
||||||
'resume': True,
|
|
||||||
'mode': 'online',
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
|
||||||
logger.log({'metric': 2.0}, step=3)
|
|
||||||
logger.finish()
|
|
||||||
|
|
||||||
assert events[0][0] == 'init'
|
|
||||||
init_kwargs = events[0][1]
|
|
||||||
assert init_kwargs['dir'] == str(tmp_path)
|
|
||||||
assert init_kwargs['project'] == 'demo-project'
|
|
||||||
assert init_kwargs['name'] == 'demo-run'
|
|
||||||
assert init_kwargs['mode'] == 'online'
|
|
||||||
assert ('config.update', {'output_dir': str(tmp_path)}) in events
|
|
||||||
assert ('log', {'metric': 2.0}, 3) in events
|
|
||||||
assert events.count(('finish',)) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
|
|
||||||
workspace_module = load_workspace_module(monkeypatch)
|
|
||||||
cfg = OmegaConf.create({
|
|
||||||
'logging': {
|
|
||||||
'backend': 'tensorboard',
|
|
||||||
'project': 'demo-project',
|
|
||||||
'name': 'demo-run',
|
|
||||||
'mode': 'offline',
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='Unknown logging backend'):
|
|
||||||
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
|
|
||||||
workspace_module = load_workspace_module(monkeypatch)
|
|
||||||
events = []
|
|
||||||
|
|
||||||
class FakeBackend:
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
events.append(('log', payload, step))
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
events.append(('finish',))
|
|
||||||
raise RuntimeError('finish boom')
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
workspace_module,
|
|
||||||
'init_logging_backend',
|
|
||||||
lambda cfg, output_dir: FakeBackend(),
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='primary boom'):
|
|
||||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
|
||||||
logger.log({'metric': 6.0}, step=12)
|
|
||||||
raise ValueError('primary boom')
|
|
||||||
|
|
||||||
assert ('log', {'metric': 6.0}, 12) in events
|
|
||||||
assert events.count(('finish',)) == 1
|
|
||||||
|
|
||||||
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
|
|
||||||
workspace_module = load_workspace_module(monkeypatch)
|
|
||||||
events = []
|
|
||||||
|
|
||||||
class FakeBackend:
|
|
||||||
def log(self, payload, step=None):
|
|
||||||
events.append(('log', payload, step))
|
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
events.append(('finish',))
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
workspace_module,
|
|
||||||
'init_logging_backend',
|
|
||||||
lambda cfg, output_dir: FakeBackend(),
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match='boom'):
|
|
||||||
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
|
||||||
logger.log({'metric': 5.0}, step=11)
|
|
||||||
raise RuntimeError('boom')
|
|
||||||
|
|
||||||
assert ('log', {'metric': 5.0}, 11) in events
|
|
||||||
assert events.count(('finish',)) == 1
|
|
||||||
Reference in New Issue
Block a user