diff --git a/diffusion_policy/env_runner/pusht_image_runner.py b/diffusion_policy/env_runner/pusht_image_runner.py index 82187b6..ecf361f 100644 --- a/diffusion_policy/env_runner/pusht_image_runner.py +++ b/diffusion_policy/env_runner/pusht_image_runner.py @@ -1,21 +1,37 @@ -import wandb import numpy as np import torch import collections -import pathlib import tqdm import dill import math -import wandb.sdk.data_types.video as wv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper -from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder from diffusion_policy.policy.base_image_policy import BaseImagePolicy from diffusion_policy.common.pytorch_util import dict_apply from diffusion_policy.env_runner.base_image_runner import BaseImageRunner + +def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None): + del all_video_paths + + max_rewards = collections.defaultdict(list) + log_data = dict() + for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards): + max_reward = np.max(rewards) + max_rewards[prefix].append(max_reward) + log_data[prefix + f'sim_max_reward_{seed}'] = max_reward + + aggregate_key_map = { + 'train/': 'train_mean_score', + 'test/': 'test_mean_score', + } + for prefix, value in max_rewards.items(): + log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value) + + return log_data + class PushTImageRunner(BaseImageRunner): def __init__(self, output_dir, @@ -40,24 +56,11 @@ class PushTImageRunner(BaseImageRunner): if n_envs is None: n_envs = n_train + n_test - steps_per_render = max(10 // fps, 1) def env_fn(): return MultiStepWrapper( - VideoRecordingWrapper( - PushTImageEnv( - legacy=legacy_test, - render_size=render_size - ), - video_recoder=VideoRecorder.create_h264( - fps=fps, - codec='h264', - input_pix_fmt='rgb24', - crf=crf, - thread_type='FRAME', - thread_count=1 - ), - file_path=None, - steps_per_render=steps_per_render + PushTImageEnv( + legacy=legacy_test, + render_size=render_size ), n_obs_steps=n_obs_steps, n_action_steps=n_action_steps, @@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner): # train for i in range(n_train): seed = train_start_seed + i - enable_render = i < n_train_vis - - def init_fn(env, seed=seed, enable_render=enable_render): - # setup rendering - # video_wrapper - assert isinstance(env.env, VideoRecordingWrapper) - env.env.video_recoder.stop() - env.env.file_path = None - if enable_render: - filename = pathlib.Path(output_dir).joinpath( - 'media', wv.util.generate_id() + ".mp4") - filename.parent.mkdir(parents=False, exist_ok=True) - filename = str(filename) - env.env.file_path = filename + def init_fn(env, seed=seed): # set seed assert isinstance(env, MultiStepWrapper) env.seed(seed) @@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner): # test for i in range(n_test): seed = test_start_seed + i - enable_render = i < n_test_vis - - def init_fn(env, seed=seed, enable_render=enable_render): - # setup rendering - # video_wrapper - assert isinstance(env.env, VideoRecordingWrapper) - env.env.video_recoder.stop() - env.env.file_path = None - if enable_render: - filename = pathlib.Path(output_dir).joinpath( - 'media', wv.util.generate_id() + ".mp4") - filename.parent.mkdir(parents=False, exist_ok=True) - filename = str(filename) - env.env.file_path = filename + def init_fn(env, seed=seed): # set seed assert isinstance(env, MultiStepWrapper) env.seed(seed) @@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner): n_chunks = math.ceil(n_inits / n_envs) # allocate data - all_video_paths = [None] * n_inits all_rewards = [None] * n_inits for chunk_idx in range(n_chunks): @@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner): pbar.update(action.shape[1]) pbar.close() - all_video_paths[this_global_slice] = env.render()[this_local_slice] all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice] - # clear out video buffer + # reset env state between evaluation calls _ = env.reset() - # log - max_rewards = collections.defaultdict(list) - log_data = dict() - # results reported in the paper are generated using the commented out line below - # which will only report and average metrics from first n_envs initial condition and seeds - # fortunately this won't invalidate our conclusion since - # 1. This bug only affects the variance of metrics, not their mean - # 2. All baseline methods are evaluated using the same code - # to completely reproduce reported numbers, uncomment this line: - # for i in range(len(self.env_fns)): - # and comment out this line - for i in range(n_inits): - seed = self.env_seeds[i] - prefix = self.env_prefixs[i] - max_reward = np.max(all_rewards[i]) - max_rewards[prefix].append(max_reward) - log_data[prefix+f'sim_max_reward_{seed}'] = max_reward - - # visualize sim - video_path = all_video_paths[i] - if video_path is not None: - sim_video = wandb.Video(video_path) - log_data[prefix+f'sim_video_{seed}'] = sim_video - - # log aggregate metrics - for prefix, value in max_rewards.items(): - name = prefix+'mean_score' - value = np.mean(value) - log_data[name] = value - - return log_data + # results reported in the paper are generated using the commented out + # line below, which would only report and average metrics from the + # first n_envs initial conditions and seeds. We keep the full n_inits + # behavior here. + return summarize_rollout_metrics( + env_seeds=self.env_seeds[:n_inits], + env_prefixs=self.env_prefixs[:n_inits], + all_rewards=all_rewards[:n_inits], + ) diff --git a/diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py b/diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py index 38b3be1..1f297d0 100644 --- a/diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py +++ b/diffusion_policy/workspace/train_diffusion_transformer_hybrid_workspace.py @@ -8,6 +8,8 @@ if __name__ == "__main__": os.chdir(ROOT_DIR) import os +import contextlib +import importlib import hydra import torch from omegaconf import OmegaConf @@ -15,7 +17,6 @@ import pathlib from torch.utils.data import DataLoader import copy import random -import wandb import tqdm import numpy as np import shutil @@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler OmegaConf.register_new_resolver("eval", eval, replace=True) + +class _LoggingBackend: + def log(self, payload, step=None): + raise NotImplementedError + + def finish(self): + raise NotImplementedError + + +class _WandbLoggingBackend(_LoggingBackend): + def __init__(self, run): + self.run = run + + def log(self, payload, step=None): + self.run.log(payload, step=step) + + def finish(self): + self.run.finish() + + +class _SwanLabLoggingBackend(_LoggingBackend): + def __init__(self, run): + self.run = run + + def log(self, payload, step=None): + self.run.log(payload, step=step) + + def finish(self): + self.run.finish() + + +def _load_wandb(): + try: + return importlib.import_module('wandb') + except ImportError as exc: + raise ImportError( + "wandb is required when cfg.logging.backend == 'wandb' or missing" + ) from exc + + +def _load_swanlab(): + try: + return importlib.import_module('swanlab') + except ImportError: + return None + + +def init_logging_backend(cfg: OmegaConf, output_dir): + backend = OmegaConf.select(cfg, 'logging.backend', default='wandb') + if backend == 'swanlab': + swanlab = _load_swanlab() + if swanlab is None: + raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'") + logging_cfg = cfg.logging + mode = logging_cfg.mode + if mode == 'online': + mode = 'cloud' + run = swanlab.init( + project=logging_cfg.project, + experiment_name=logging_cfg.name, + group=logging_cfg.group, + tags=logging_cfg.tags, + id=logging_cfg.id, + resume=logging_cfg.resume, + mode=mode, + logdir=str(pathlib.Path(output_dir) / 'swanlog'), + config=OmegaConf.to_container(cfg, resolve=True), + ) + return _SwanLabLoggingBackend(run) + + if backend not in (None, 'wandb'): + raise ValueError(f"Unknown logging backend: {backend}") + + wandb = _load_wandb() + logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True) + logging_kwargs.pop('backend', None) + run = wandb.init( + dir=str(output_dir), + config=OmegaConf.to_container(cfg, resolve=True), + **logging_kwargs + ) + wandb.config.update( + { + "output_dir": str(output_dir), + } + ) + return _WandbLoggingBackend(run) + + +@contextlib.contextmanager +def logging_backend_session(cfg: OmegaConf, output_dir): + logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir) + primary_error = None + try: + yield logging_backend + except BaseException as exc: + primary_error = exc + raise + finally: + try: + logging_backend.finish() + except BaseException: + if primary_error is None: + raise + class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): include_keys = ['global_step', 'epoch'] @@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): output_dir=self.output_dir) assert isinstance(env_runner, BaseImageRunner) - # configure logging - wandb_run = wandb.init( - dir=str(self.output_dir), - config=OmegaConf.to_container(cfg, resolve=True), - **cfg.logging - ) - wandb.config.update( - { - "output_dir": self.output_dir, - } - ) - # configure checkpoint topk_manager = TopKCheckpointManager( save_dir=os.path.join(self.output_dir, 'checkpoints'), @@ -148,140 +242,141 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): # training loop log_path = os.path.join(self.output_dir, 'logs.json.txt') - with JsonLogger(log_path) as json_logger: - for local_epoch_idx in range(cfg.training.num_epochs): - step_log = dict() - # ========= train for this epoch ========== - train_losses = list() - with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", - leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: - for batch_idx, batch in enumerate(tepoch): - # device transfer - batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) - if train_sampling_batch is None: - train_sampling_batch = batch + with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend: + with JsonLogger(log_path) as json_logger: + for local_epoch_idx in range(cfg.training.num_epochs): + step_log = dict() + # ========= train for this epoch ========== + train_losses = list() + with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", + leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + for batch_idx, batch in enumerate(tepoch): + # device transfer + batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + if train_sampling_batch is None: + train_sampling_batch = batch - # compute loss - raw_loss = self.model.compute_loss(batch) - loss = raw_loss / cfg.training.gradient_accumulate_every - loss.backward() + # compute loss + raw_loss = self.model.compute_loss(batch) + loss = raw_loss / cfg.training.gradient_accumulate_every + loss.backward() - # step optimizer - if self.global_step % cfg.training.gradient_accumulate_every == 0: - self.optimizer.step() - self.optimizer.zero_grad() - lr_scheduler.step() - - # update ema - if cfg.training.use_ema: - ema.step(self.model) + # step optimizer + if self.global_step % cfg.training.gradient_accumulate_every == 0: + self.optimizer.step() + self.optimizer.zero_grad() + lr_scheduler.step() + + # update ema + if cfg.training.use_ema: + ema.step(self.model) - # logging - raw_loss_cpu = raw_loss.item() - tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) - train_losses.append(raw_loss_cpu) - step_log = { - 'train_loss': raw_loss_cpu, - 'global_step': self.global_step, - 'epoch': self.epoch, - 'lr': lr_scheduler.get_last_lr()[0] - } + # logging + raw_loss_cpu = raw_loss.item() + tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) + train_losses.append(raw_loss_cpu) + step_log = { + 'train_loss': raw_loss_cpu, + 'global_step': self.global_step, + 'epoch': self.epoch, + 'lr': lr_scheduler.get_last_lr()[0] + } - is_last_batch = (batch_idx == (len(train_dataloader)-1)) - if not is_last_batch: - # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) - json_logger.log(step_log) - self.global_step += 1 + is_last_batch = (batch_idx == (len(train_dataloader)-1)) + if not is_last_batch: + # log of last step is combined with validation and rollout + logging_backend.log(step_log, step=self.global_step) + json_logger.log(step_log) + self.global_step += 1 - if (cfg.training.max_train_steps is not None) \ - and batch_idx >= (cfg.training.max_train_steps-1): - break + if (cfg.training.max_train_steps is not None) \ + and batch_idx >= (cfg.training.max_train_steps-1): + break - # at the end of each epoch - # replace train_loss with epoch average - train_loss = np.mean(train_losses) - step_log['train_loss'] = train_loss + # at the end of each epoch + # replace train_loss with epoch average + train_loss = np.mean(train_losses) + step_log['train_loss'] = train_loss - # ========= eval for this epoch ========== - policy = self.model - if cfg.training.use_ema: - policy = self.ema_model - policy.eval() + # ========= eval for this epoch ========== + policy = self.model + if cfg.training.use_ema: + policy = self.ema_model + policy.eval() - # run rollout - if (self.epoch % cfg.training.rollout_every) == 0: - runner_log = env_runner.run(policy) - # log all - step_log.update(runner_log) + # run rollout + if (self.epoch % cfg.training.rollout_every) == 0: + runner_log = env_runner.run(policy) + # log all + step_log.update(runner_log) - # run validation - if (self.epoch % cfg.training.val_every) == 0: - with torch.no_grad(): - val_losses = list() - with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", - leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: - for batch_idx, batch in enumerate(tepoch): - batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) - loss = self.model.compute_loss(batch) - val_losses.append(loss) - if (cfg.training.max_val_steps is not None) \ - and batch_idx >= (cfg.training.max_val_steps-1): - break - if len(val_losses) > 0: - val_loss = torch.mean(torch.tensor(val_losses)).item() - # log epoch average validation loss - step_log['val_loss'] = val_loss + # run validation + if (self.epoch % cfg.training.val_every) == 0: + with torch.no_grad(): + val_losses = list() + with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", + leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: + for batch_idx, batch in enumerate(tepoch): + batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) + loss = self.model.compute_loss(batch) + val_losses.append(loss) + if (cfg.training.max_val_steps is not None) \ + and batch_idx >= (cfg.training.max_val_steps-1): + break + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + # log epoch average validation loss + step_log['val_loss'] = val_loss - # run diffusion sampling on a training batch - if (self.epoch % cfg.training.sample_every) == 0: - with torch.no_grad(): - # sample trajectory from training set, and evaluate difference - batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) - obs_dict = batch['obs'] - gt_action = batch['action'] - - result = policy.predict_action(obs_dict) - pred_action = result['action_pred'] - mse = torch.nn.functional.mse_loss(pred_action, gt_action) - step_log['train_action_mse_error'] = mse.item() - del batch - del obs_dict - del gt_action - del result - del pred_action - del mse - - # checkpoint - if (self.epoch % cfg.training.checkpoint_every) == 0: - # checkpointing - if cfg.checkpoint.save_last_ckpt: - self.save_checkpoint() - if cfg.checkpoint.save_last_snapshot: - self.save_snapshot() - - # sanitize metric names - metric_dict = dict() - for key, value in step_log.items(): - new_key = key.replace('/', '_') - metric_dict[new_key] = value + # run diffusion sampling on a training batch + if (self.epoch % cfg.training.sample_every) == 0: + with torch.no_grad(): + # sample trajectory from training set, and evaluate difference + batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) + obs_dict = batch['obs'] + gt_action = batch['action'] + + result = policy.predict_action(obs_dict) + pred_action = result['action_pred'] + mse = torch.nn.functional.mse_loss(pred_action, gt_action) + step_log['train_action_mse_error'] = mse.item() + del batch + del obs_dict + del gt_action + del result + del pred_action + del mse - # We can't copy the last checkpoint here - # since save_checkpoint uses threads. - # therefore at this point the file might have been empty! - topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) + # checkpoint + if (self.epoch % cfg.training.checkpoint_every) == 0: + # checkpointing + if cfg.checkpoint.save_last_ckpt: + self.save_checkpoint() + if cfg.checkpoint.save_last_snapshot: + self.save_snapshot() - if topk_ckpt_path is not None: - self.save_checkpoint(path=topk_ckpt_path) - # ========= eval end for this epoch ========== - policy.train() + # sanitize metric names + metric_dict = dict() + for key, value in step_log.items(): + new_key = key.replace('/', '_') + metric_dict[new_key] = value + + # We can't copy the last checkpoint here + # since save_checkpoint uses threads. + # therefore at this point the file might have been empty! + topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) - # end of epoch - # log of last step is combined with validation and rollout - wandb_run.log(step_log, step=self.global_step) - json_logger.log(step_log) - self.global_step += 1 - self.epoch += 1 + if topk_ckpt_path is not None: + self.save_checkpoint(path=topk_ckpt_path) + # ========= eval end for this epoch ========== + policy.train() + + # end of epoch + # log of last step is combined with validation and rollout + logging_backend.log(step_log, step=self.global_step) + json_logger.log(step_log) + self.global_step += 1 + self.epoch += 1 @hydra.main( version_base=None, diff --git a/image_pusht_diffusion_policy_dit.yaml b/image_pusht_diffusion_policy_dit.yaml new file mode 100644 index 0000000..1987aad --- /dev/null +++ b/image_pusht_diffusion_policy_dit.yaml @@ -0,0 +1,28 @@ +defaults: + - diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_ + - override /diffusion_policy/config/task@task: pusht_image + - _self_ + +exp_name: pusht_image_dit + +policy: + _target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy + +logging: + backend: swanlab + mode: online + tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"] + id: ${now:%Y%m%d%H%M%S}_${name}_${task_name} + group: ${exp_name} + +dataloader: + num_workers: 0 + +val_dataloader: + num_workers: 0 + +task: + env_runner: + n_envs: 1 + n_test_vis: 0 + n_train_vis: 0 diff --git a/requirements-pusht-5090.txt b/requirements-pusht-5090.txt index 9bf8f9d..b2c037c 100644 --- a/requirements-pusht-5090.txt +++ b/requirements-pusht-5090.txt @@ -36,3 +36,4 @@ av==14.0.1 pygame==2.5.2 robomimic==0.2.0 opencv-python-headless==4.10.0.84 +swanlab diff --git a/tests/test_pusht_image_runner_metrics.py b/tests/test_pusht_image_runner_metrics.py new file mode 100644 index 0000000..6cd7e79 --- /dev/null +++ b/tests/test_pusht_image_runner_metrics.py @@ -0,0 +1,110 @@ +import pathlib +import sys + +import gym +from gym import spaces +import numpy as np +import pytest +import torch + +ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) + +import diffusion_policy.env_runner.pusht_image_runner as runner_module +from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics + + +class FakePushTImageEnv(gym.Env): + metadata = {'render.modes': ['rgb_array']} + + def __init__(self, legacy=False, render_size=96): + del legacy, render_size + self.observation_space = spaces.Dict({ + 'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8), + }) + self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) + self.seed_value = 0 + self.step_count = 0 + + def seed(self, seed=None): + self.seed_value = 0 if seed is None else seed + + def reset(self): + self.step_count = 0 + return {'image': np.zeros((3, 4, 4), dtype=np.uint8)} + + def step(self, action): + del action + self.step_count += 1 + reward = 0.1 if self.seed_value < 10000 else 0.9 + done = self.step_count >= 1 + obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)} + return obs, reward, done, {} + + def render(self, *args, **kwargs): + raise AssertionError('render should not be called for scalar-only PushT image rollouts') + + +class FakePolicy: + device = torch.device('cpu') + dtype = torch.float32 + + def reset(self): + return None + + def predict_action(self, obs_dict): + n_envs = next(iter(obs_dict.values())).shape[0] + return { + 'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32), + } + + +def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos(): + log_data = summarize_rollout_metrics( + env_seeds=[11, 12, 101], + env_prefixs=['train/', 'train/', 'test/'], + all_rewards=[ + [0.2, 0.8], + [0.1, 0.4], + [0.5, 0.9], + ], + all_video_paths=[ + '/tmp/train-11.mp4', + '/tmp/train-12.mp4', + '/tmp/test-101.mp4', + ], + ) + + assert log_data['train/sim_max_reward_11'] == 0.8 + assert log_data['train/sim_max_reward_12'] == 0.4 + assert log_data['test/sim_max_reward_101'] == 0.9 + assert log_data['train_mean_score'] == pytest.approx(0.6) + assert log_data['test_mean_score'] == pytest.approx(0.9) + assert not any(key.startswith('train/sim_video_') for key in log_data) + assert not any(key.startswith('test/sim_video_') for key in log_data) + + +def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch): + monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv) + + runner = runner_module.PushTImageRunner( + output_dir=tmp_path, + n_train=1, + n_train_vis=1, + n_test=1, + n_test_vis=1, + n_envs=2, + max_steps=2, + n_obs_steps=2, + n_action_steps=2, + tqdm_interval_sec=0.0, + ) + + log_data = runner.run(FakePolicy()) + + assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1) + assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9) + assert log_data['train_mean_score'] == pytest.approx(0.1) + assert log_data['test_mean_score'] == pytest.approx(0.9) + assert not any('sim_video' in key for key in log_data) diff --git a/tests/test_train_diffusion_transformer_workspace_logging.py b/tests/test_train_diffusion_transformer_workspace_logging.py new file mode 100644 index 0000000..4368405 --- /dev/null +++ b/tests/test_train_diffusion_transformer_workspace_logging.py @@ -0,0 +1,198 @@ +import importlib +import pathlib +import sys + +import pytest +from omegaconf import OmegaConf + +ROOT_DIR = pathlib.Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) + +MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace' + + +def load_workspace_module(monkeypatch, *, wandb_missing=False): + sys.modules.pop(MODULE_NAME, None) + if wandb_missing: + monkeypatch.setitem(sys.modules, 'wandb', None) + return importlib.import_module(MODULE_NAME) + + +def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch): + workspace_module = load_workspace_module(monkeypatch, wandb_missing=True) + events = [] + + class FakeRun: + def log(self, payload, step=None): + events.append(('log', payload, step)) + + def finish(self): + events.append(('finish',)) + + class FakeSwanLab: + def init(self, **kwargs): + events.append(('init', kwargs)) + return FakeRun() + + monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab()) + monkeypatch.setattr( + workspace_module, + '_load_wandb', + lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'), + ) + + cfg = OmegaConf.create({ + 'logging': { + 'backend': 'swanlab', + 'project': 'demo-project', + 'name': 'demo-run', + 'group': 'demo-group', + 'tags': ['pusht', 'dit'], + 'id': 'run-123', + 'resume': True, + 'mode': 'online', + } + }) + + logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path) + logger.log({'metric': 1.0}, step=7) + logger.finish() + + assert events[0][0] == 'init' + init_kwargs = events[0][1] + assert init_kwargs['project'] == 'demo-project' + assert init_kwargs['experiment_name'] == 'demo-run' + assert init_kwargs['group'] == 'demo-group' + assert init_kwargs['tags'] == ['pusht', 'dit'] + assert init_kwargs['id'] == 'run-123' + assert init_kwargs['resume'] is True + assert init_kwargs['mode'] == 'cloud' + assert init_kwargs['logdir'] == str(tmp_path / 'swanlog') + assert ('log', {'metric': 1.0}, 7) in events + assert events.count(('finish',)) == 1 + + +def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch): + workspace_module = load_workspace_module(monkeypatch) + events = [] + + class FakeRun: + def log(self, payload, step=None): + events.append(('log', payload, step)) + + def finish(self): + events.append(('finish',)) + + class FakeConfig: + def update(self, payload): + events.append(('config.update', payload)) + + class FakeWandb: + def __init__(self): + self.config = FakeConfig() + + def init(self, **kwargs): + events.append(('init', kwargs)) + return FakeRun() + + monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb()) + + cfg = OmegaConf.create({ + 'logging': { + 'project': 'demo-project', + 'name': 'demo-run', + 'group': None, + 'tags': ['shared'], + 'id': None, + 'resume': True, + 'mode': 'online', + } + }) + + logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path) + logger.log({'metric': 2.0}, step=3) + logger.finish() + + assert events[0][0] == 'init' + init_kwargs = events[0][1] + assert init_kwargs['dir'] == str(tmp_path) + assert init_kwargs['project'] == 'demo-project' + assert init_kwargs['name'] == 'demo-run' + assert init_kwargs['mode'] == 'online' + assert ('config.update', {'output_dir': str(tmp_path)}) in events + assert ('log', {'metric': 2.0}, 3) in events + assert events.count(('finish',)) == 1 + + +def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch): + workspace_module = load_workspace_module(monkeypatch) + cfg = OmegaConf.create({ + 'logging': { + 'backend': 'tensorboard', + 'project': 'demo-project', + 'name': 'demo-run', + 'mode': 'offline', + } + }) + + with pytest.raises(ValueError, match='Unknown logging backend'): + workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path) + + + + +def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch): + workspace_module = load_workspace_module(monkeypatch) + events = [] + + class FakeBackend: + def log(self, payload, step=None): + events.append(('log', payload, step)) + + def finish(self): + events.append(('finish',)) + raise RuntimeError('finish boom') + + monkeypatch.setattr( + workspace_module, + 'init_logging_backend', + lambda cfg, output_dir: FakeBackend(), + ) + + cfg = OmegaConf.create({'logging': {'mode': 'offline'}}) + + with pytest.raises(ValueError, match='primary boom'): + with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger: + logger.log({'metric': 6.0}, step=12) + raise ValueError('primary boom') + + assert ('log', {'metric': 6.0}, 12) in events + assert events.count(('finish',)) == 1 + +def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch): + workspace_module = load_workspace_module(monkeypatch) + events = [] + + class FakeBackend: + def log(self, payload, step=None): + events.append(('log', payload, step)) + + def finish(self): + events.append(('finish',)) + + monkeypatch.setattr( + workspace_module, + 'init_logging_backend', + lambda cfg, output_dir: FakeBackend(), + ) + + cfg = OmegaConf.create({'logging': {'mode': 'offline'}}) + + with pytest.raises(RuntimeError, match='boom'): + with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger: + logger.log({'metric': 5.0}, step=11) + raise RuntimeError('boom') + + assert ('log', {'metric': 5.0}, 11) in events + assert events.count(('finish',)) == 1