feat: switch pusht transformer logging to swanlab

This commit is contained in:
Logic
2026-03-26 19:49:45 +08:00
parent 23374a4cd2
commit 5e7ae6cfa5
6 changed files with 601 additions and 216 deletions

View File

@@ -1,21 +1,37 @@
import wandb
import numpy as np import numpy as np
import torch import torch
import collections import collections
import pathlib
import tqdm import tqdm
import dill import dill
import math import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
del all_video_paths
max_rewards = collections.defaultdict(list)
log_data = dict()
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
max_reward = np.max(rewards)
max_rewards[prefix].append(max_reward)
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
aggregate_key_map = {
'train/': 'train_mean_score',
'test/': 'test_mean_score',
}
for prefix, value in max_rewards.items():
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
return log_data
class PushTImageRunner(BaseImageRunner): class PushTImageRunner(BaseImageRunner):
def __init__(self, def __init__(self,
output_dir, output_dir,
@@ -40,24 +56,11 @@ class PushTImageRunner(BaseImageRunner):
if n_envs is None: if n_envs is None:
n_envs = n_train + n_test n_envs = n_train + n_test
steps_per_render = max(10 // fps, 1)
def env_fn(): def env_fn():
return MultiStepWrapper( return MultiStepWrapper(
VideoRecordingWrapper( PushTImageEnv(
PushTImageEnv( legacy=legacy_test,
legacy=legacy_test, render_size=render_size
render_size=render_size
),
video_recoder=VideoRecorder.create_h264(
fps=fps,
codec='h264',
input_pix_fmt='rgb24',
crf=crf,
thread_type='FRAME',
thread_count=1
),
file_path=None,
steps_per_render=steps_per_render
), ),
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
@@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
# train # train
for i in range(n_train): for i in range(n_train):
seed = train_start_seed + i seed = train_start_seed + i
enable_render = i < n_train_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
# test # test
for i in range(n_test): for i in range(n_test):
seed = test_start_seed + i seed = test_start_seed + i
enable_render = i < n_test_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
n_chunks = math.ceil(n_inits / n_envs) n_chunks = math.ceil(n_inits / n_envs)
# allocate data # allocate data
all_video_paths = [None] * n_inits
all_rewards = [None] * n_inits all_rewards = [None] * n_inits
for chunk_idx in range(n_chunks): for chunk_idx in range(n_chunks):
@@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
pbar.update(action.shape[1]) pbar.update(action.shape[1])
pbar.close() pbar.close()
all_video_paths[this_global_slice] = env.render()[this_local_slice]
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice] all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
# clear out video buffer # reset env state between evaluation calls
_ = env.reset() _ = env.reset()
# log # results reported in the paper are generated using the commented out
max_rewards = collections.defaultdict(list) # line below, which would only report and average metrics from the
log_data = dict() # first n_envs initial conditions and seeds. We keep the full n_inits
# results reported in the paper are generated using the commented out line below # behavior here.
# which will only report and average metrics from first n_envs initial condition and seeds return summarize_rollout_metrics(
# fortunately this won't invalidate our conclusion since env_seeds=self.env_seeds[:n_inits],
# 1. This bug only affects the variance of metrics, not their mean env_prefixs=self.env_prefixs[:n_inits],
# 2. All baseline methods are evaluated using the same code all_rewards=all_rewards[:n_inits],
# to completely reproduce reported numbers, uncomment this line: )
# for i in range(len(self.env_fns)):
# and comment out this line
for i in range(n_inits):
seed = self.env_seeds[i]
prefix = self.env_prefixs[i]
max_reward = np.max(all_rewards[i])
max_rewards[prefix].append(max_reward)
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
# visualize sim
video_path = all_video_paths[i]
if video_path is not None:
sim_video = wandb.Video(video_path)
log_data[prefix+f'sim_video_{seed}'] = sim_video
# log aggregate metrics
for prefix, value in max_rewards.items():
name = prefix+'mean_score'
value = np.mean(value)
log_data[name] = value
return log_data

View File

@@ -8,6 +8,8 @@ if __name__ == "__main__":
os.chdir(ROOT_DIR) os.chdir(ROOT_DIR)
import os import os
import contextlib
import importlib
import hydra import hydra
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -15,7 +17,6 @@ import pathlib
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import copy import copy
import random import random
import wandb
import tqdm import tqdm
import numpy as np import numpy as np
import shutil import shutil
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
OmegaConf.register_new_resolver("eval", eval, replace=True) OmegaConf.register_new_resolver("eval", eval, replace=True)
class _LoggingBackend:
def log(self, payload, step=None):
raise NotImplementedError
def finish(self):
raise NotImplementedError
class _WandbLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
class _SwanLabLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
def _load_wandb():
try:
return importlib.import_module('wandb')
except ImportError as exc:
raise ImportError(
"wandb is required when cfg.logging.backend == 'wandb' or missing"
) from exc
def _load_swanlab():
try:
return importlib.import_module('swanlab')
except ImportError:
return None
def init_logging_backend(cfg: OmegaConf, output_dir):
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
if backend == 'swanlab':
swanlab = _load_swanlab()
if swanlab is None:
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
logging_cfg = cfg.logging
mode = logging_cfg.mode
if mode == 'online':
mode = 'cloud'
run = swanlab.init(
project=logging_cfg.project,
experiment_name=logging_cfg.name,
group=logging_cfg.group,
tags=logging_cfg.tags,
id=logging_cfg.id,
resume=logging_cfg.resume,
mode=mode,
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
config=OmegaConf.to_container(cfg, resolve=True),
)
return _SwanLabLoggingBackend(run)
if backend not in (None, 'wandb'):
raise ValueError(f"Unknown logging backend: {backend}")
wandb = _load_wandb()
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
logging_kwargs.pop('backend', None)
run = wandb.init(
dir=str(output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**logging_kwargs
)
wandb.config.update(
{
"output_dir": str(output_dir),
}
)
return _WandbLoggingBackend(run)
@contextlib.contextmanager
def logging_backend_session(cfg: OmegaConf, output_dir):
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
primary_error = None
try:
yield logging_backend
except BaseException as exc:
primary_error = exc
raise
finally:
try:
logging_backend.finish()
except BaseException:
if primary_error is None:
raise
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
include_keys = ['global_step', 'epoch'] include_keys = ['global_step', 'epoch']
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
output_dir=self.output_dir) output_dir=self.output_dir)
assert isinstance(env_runner, BaseImageRunner) assert isinstance(env_runner, BaseImageRunner)
# configure logging
wandb_run = wandb.init(
dir=str(self.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": self.output_dir,
}
)
# configure checkpoint # configure checkpoint
topk_manager = TopKCheckpointManager( topk_manager = TopKCheckpointManager(
save_dir=os.path.join(self.output_dir, 'checkpoints'), save_dir=os.path.join(self.output_dir, 'checkpoints'),
@@ -148,140 +242,141 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# training loop # training loop
log_path = os.path.join(self.output_dir, 'logs.json.txt') log_path = os.path.join(self.output_dir, 'logs.json.txt')
with JsonLogger(log_path) as json_logger: with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
for local_epoch_idx in range(cfg.training.num_epochs): with JsonLogger(log_path) as json_logger:
step_log = dict() for local_epoch_idx in range(cfg.training.num_epochs):
# ========= train for this epoch ========== step_log = dict()
train_losses = list() # ========= train for this epoch ==========
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", train_losses = list()
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
for batch_idx, batch in enumerate(tepoch): leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
# device transfer for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) # device transfer
if train_sampling_batch is None: batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
train_sampling_batch = batch if train_sampling_batch is None:
train_sampling_batch = batch
# compute loss # compute loss
raw_loss = self.model.compute_loss(batch) raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward() loss.backward()
# step optimizer # step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0: if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
lr_scheduler.step() lr_scheduler.step()
# update ema # update ema
if cfg.training.use_ema: if cfg.training.use_ema:
ema.step(self.model) ema.step(self.model)
# logging # logging
raw_loss_cpu = raw_loss.item() raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu) train_losses.append(raw_loss_cpu)
step_log = { step_log = {
'train_loss': raw_loss_cpu, 'train_loss': raw_loss_cpu,
'global_step': self.global_step, 'global_step': self.global_step,
'epoch': self.epoch, 'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0] 'lr': lr_scheduler.get_last_lr()[0]
} }
is_last_batch = (batch_idx == (len(train_dataloader)-1)) is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch: if not is_last_batch:
# log of last step is combined with validation and rollout # log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step) logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log) json_logger.log(step_log)
self.global_step += 1 self.global_step += 1
if (cfg.training.max_train_steps is not None) \ if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1): and batch_idx >= (cfg.training.max_train_steps-1):
break break
# at the end of each epoch # at the end of each epoch
# replace train_loss with epoch average # replace train_loss with epoch average
train_loss = np.mean(train_losses) train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss step_log['train_loss'] = train_loss
# ========= eval for this epoch ========== # ========= eval for this epoch ==========
policy = self.model policy = self.model
if cfg.training.use_ema: if cfg.training.use_ema:
policy = self.ema_model policy = self.ema_model
policy.eval() policy.eval()
# run rollout # run rollout
if (self.epoch % cfg.training.rollout_every) == 0: if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy) runner_log = env_runner.run(policy)
# log all # log all
step_log.update(runner_log) step_log.update(runner_log)
# run validation # run validation
if (self.epoch % cfg.training.val_every) == 0: if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad(): with torch.no_grad():
val_losses = list() val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch): for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch) loss = self.model.compute_loss(batch)
val_losses.append(loss) val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \ if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1): and batch_idx >= (cfg.training.max_val_steps-1):
break break
if len(val_losses) > 0: if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item() val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss # log epoch average validation loss
step_log['val_loss'] = val_loss step_log['val_loss'] = val_loss
# run diffusion sampling on a training batch # run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0: if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad(): with torch.no_grad():
# sample trajectory from training set, and evaluate difference # sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs'] obs_dict = batch['obs']
gt_action = batch['action'] gt_action = batch['action']
result = policy.predict_action(obs_dict) result = policy.predict_action(obs_dict)
pred_action = result['action_pred'] pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action) mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item() step_log['train_action_mse_error'] = mse.item()
del batch del batch
del obs_dict del obs_dict
del gt_action del gt_action
del result del result
del pred_action del pred_action
del mse del mse
# checkpoint # checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0: if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing # checkpointing
if cfg.checkpoint.save_last_ckpt: if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint() self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot: if cfg.checkpoint.save_last_snapshot:
self.save_snapshot() self.save_snapshot()
# sanitize metric names # sanitize metric names
metric_dict = dict() metric_dict = dict()
for key, value in step_log.items(): for key, value in step_log.items():
new_key = key.replace('/', '_') new_key = key.replace('/', '_')
metric_dict[new_key] = value metric_dict[new_key] = value
# We can't copy the last checkpoint here # We can't copy the last checkpoint here
# since save_checkpoint uses threads. # since save_checkpoint uses threads.
# therefore at this point the file might have been empty! # therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
if topk_ckpt_path is not None: if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path) self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ========== # ========= eval end for this epoch ==========
policy.train() policy.train()
# end of epoch # end of epoch
# log of last step is combined with validation and rollout # log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step) logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log) json_logger.log(step_log)
self.global_step += 1 self.global_step += 1
self.epoch += 1 self.epoch += 1
@hydra.main( @hydra.main(
version_base=None, version_base=None,

View File

@@ -0,0 +1,28 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit
policy:
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
logging:
backend: swanlab
mode: online
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: ${now:%Y%m%d%H%M%S}_${name}_${task_name}
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -36,3 +36,4 @@ av==14.0.1
pygame==2.5.2 pygame==2.5.2
robomimic==0.2.0 robomimic==0.2.0
opencv-python-headless==4.10.0.84 opencv-python-headless==4.10.0.84
swanlab

View File

@@ -0,0 +1,110 @@
import pathlib
import sys
import gym
from gym import spaces
import numpy as np
import pytest
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.env_runner.pusht_image_runner as runner_module
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
class FakePushTImageEnv(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, legacy=False, render_size=96):
del legacy, render_size
self.observation_space = spaces.Dict({
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
})
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
self.seed_value = 0
self.step_count = 0
def seed(self, seed=None):
self.seed_value = 0 if seed is None else seed
def reset(self):
self.step_count = 0
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
def step(self, action):
del action
self.step_count += 1
reward = 0.1 if self.seed_value < 10000 else 0.9
done = self.step_count >= 1
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
return obs, reward, done, {}
def render(self, *args, **kwargs):
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
class FakePolicy:
device = torch.device('cpu')
dtype = torch.float32
def reset(self):
return None
def predict_action(self, obs_dict):
n_envs = next(iter(obs_dict.values())).shape[0]
return {
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
}
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
log_data = summarize_rollout_metrics(
env_seeds=[11, 12, 101],
env_prefixs=['train/', 'train/', 'test/'],
all_rewards=[
[0.2, 0.8],
[0.1, 0.4],
[0.5, 0.9],
],
all_video_paths=[
'/tmp/train-11.mp4',
'/tmp/train-12.mp4',
'/tmp/test-101.mp4',
],
)
assert log_data['train/sim_max_reward_11'] == 0.8
assert log_data['train/sim_max_reward_12'] == 0.4
assert log_data['test/sim_max_reward_101'] == 0.9
assert log_data['train_mean_score'] == pytest.approx(0.6)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any(key.startswith('train/sim_video_') for key in log_data)
assert not any(key.startswith('test/sim_video_') for key in log_data)
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
runner = runner_module.PushTImageRunner(
output_dir=tmp_path,
n_train=1,
n_train_vis=1,
n_test=1,
n_test_vis=1,
n_envs=2,
max_steps=2,
n_obs_steps=2,
n_action_steps=2,
tqdm_interval_sec=0.0,
)
log_data = runner.run(FakePolicy())
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
assert log_data['train_mean_score'] == pytest.approx(0.1)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any('sim_video' in key for key in log_data)

View File

@@ -0,0 +1,198 @@
import importlib
import pathlib
import sys
import pytest
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
def load_workspace_module(monkeypatch, *, wandb_missing=False):
sys.modules.pop(MODULE_NAME, None)
if wandb_missing:
monkeypatch.setitem(sys.modules, 'wandb', None)
return importlib.import_module(MODULE_NAME)
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeSwanLab:
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
monkeypatch.setattr(
workspace_module,
'_load_wandb',
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
)
cfg = OmegaConf.create({
'logging': {
'backend': 'swanlab',
'project': 'demo-project',
'name': 'demo-run',
'group': 'demo-group',
'tags': ['pusht', 'dit'],
'id': 'run-123',
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 1.0}, step=7)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['experiment_name'] == 'demo-run'
assert init_kwargs['group'] == 'demo-group'
assert init_kwargs['tags'] == ['pusht', 'dit']
assert init_kwargs['id'] == 'run-123'
assert init_kwargs['resume'] is True
assert init_kwargs['mode'] == 'cloud'
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
assert ('log', {'metric': 1.0}, 7) in events
assert events.count(('finish',)) == 1
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeConfig:
def update(self, payload):
events.append(('config.update', payload))
class FakeWandb:
def __init__(self):
self.config = FakeConfig()
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
cfg = OmegaConf.create({
'logging': {
'project': 'demo-project',
'name': 'demo-run',
'group': None,
'tags': ['shared'],
'id': None,
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 2.0}, step=3)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['dir'] == str(tmp_path)
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['name'] == 'demo-run'
assert init_kwargs['mode'] == 'online'
assert ('config.update', {'output_dir': str(tmp_path)}) in events
assert ('log', {'metric': 2.0}, 3) in events
assert events.count(('finish',)) == 1
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
cfg = OmegaConf.create({
'logging': {
'backend': 'tensorboard',
'project': 'demo-project',
'name': 'demo-run',
'mode': 'offline',
}
})
with pytest.raises(ValueError, match='Unknown logging backend'):
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
raise RuntimeError('finish boom')
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(ValueError, match='primary boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 6.0}, step=12)
raise ValueError('primary boom')
assert ('log', {'metric': 6.0}, 12) in events
assert events.count(('finish',)) == 1
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(RuntimeError, match='boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 5.0}, step=11)
raise RuntimeError('boom')
assert ('log', {'metric': 5.0}, 11) in events
assert events.count(('finish',)) == 1