feat: switch pusht transformer logging to swanlab

This commit is contained in:
Logic
2026-03-26 19:49:45 +08:00
parent 23374a4cd2
commit 5e7ae6cfa5
6 changed files with 601 additions and 216 deletions

View File

@@ -1,21 +1,37 @@
import wandb
import numpy as np import numpy as np
import torch import torch
import collections import collections
import pathlib
import tqdm import tqdm
import dill import dill
import math import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
del all_video_paths
max_rewards = collections.defaultdict(list)
log_data = dict()
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
max_reward = np.max(rewards)
max_rewards[prefix].append(max_reward)
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
aggregate_key_map = {
'train/': 'train_mean_score',
'test/': 'test_mean_score',
}
for prefix, value in max_rewards.items():
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
return log_data
class PushTImageRunner(BaseImageRunner): class PushTImageRunner(BaseImageRunner):
def __init__(self, def __init__(self,
output_dir, output_dir,
@@ -40,25 +56,12 @@ class PushTImageRunner(BaseImageRunner):
if n_envs is None: if n_envs is None:
n_envs = n_train + n_test n_envs = n_train + n_test
steps_per_render = max(10 // fps, 1)
def env_fn(): def env_fn():
return MultiStepWrapper( return MultiStepWrapper(
VideoRecordingWrapper(
PushTImageEnv( PushTImageEnv(
legacy=legacy_test, legacy=legacy_test,
render_size=render_size render_size=render_size
), ),
video_recoder=VideoRecorder.create_h264(
fps=fps,
codec='h264',
input_pix_fmt='rgb24',
crf=crf,
thread_type='FRAME',
thread_count=1
),
file_path=None,
steps_per_render=steps_per_render
),
n_obs_steps=n_obs_steps, n_obs_steps=n_obs_steps,
n_action_steps=n_action_steps, n_action_steps=n_action_steps,
max_episode_steps=max_steps max_episode_steps=max_steps
@@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
# train # train
for i in range(n_train): for i in range(n_train):
seed = train_start_seed + i seed = train_start_seed + i
enable_render = i < n_train_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
# test # test
for i in range(n_test): for i in range(n_test):
seed = test_start_seed + i seed = test_start_seed + i
enable_render = i < n_test_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed # set seed
assert isinstance(env, MultiStepWrapper) assert isinstance(env, MultiStepWrapper)
env.seed(seed) env.seed(seed)
@@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
n_chunks = math.ceil(n_inits / n_envs) n_chunks = math.ceil(n_inits / n_envs)
# allocate data # allocate data
all_video_paths = [None] * n_inits
all_rewards = [None] * n_inits all_rewards = [None] * n_inits
for chunk_idx in range(n_chunks): for chunk_idx in range(n_chunks):
@@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
pbar.update(action.shape[1]) pbar.update(action.shape[1])
pbar.close() pbar.close()
all_video_paths[this_global_slice] = env.render()[this_local_slice]
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice] all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
# clear out video buffer # reset env state between evaluation calls
_ = env.reset() _ = env.reset()
# log # results reported in the paper are generated using the commented out
max_rewards = collections.defaultdict(list) # line below, which would only report and average metrics from the
log_data = dict() # first n_envs initial conditions and seeds. We keep the full n_inits
# results reported in the paper are generated using the commented out line below # behavior here.
# which will only report and average metrics from first n_envs initial condition and seeds return summarize_rollout_metrics(
# fortunately this won't invalidate our conclusion since env_seeds=self.env_seeds[:n_inits],
# 1. This bug only affects the variance of metrics, not their mean env_prefixs=self.env_prefixs[:n_inits],
# 2. All baseline methods are evaluated using the same code all_rewards=all_rewards[:n_inits],
# to completely reproduce reported numbers, uncomment this line: )
# for i in range(len(self.env_fns)):
# and comment out this line
for i in range(n_inits):
seed = self.env_seeds[i]
prefix = self.env_prefixs[i]
max_reward = np.max(all_rewards[i])
max_rewards[prefix].append(max_reward)
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
# visualize sim
video_path = all_video_paths[i]
if video_path is not None:
sim_video = wandb.Video(video_path)
log_data[prefix+f'sim_video_{seed}'] = sim_video
# log aggregate metrics
for prefix, value in max_rewards.items():
name = prefix+'mean_score'
value = np.mean(value)
log_data[name] = value
return log_data

View File

@@ -8,6 +8,8 @@ if __name__ == "__main__":
os.chdir(ROOT_DIR) os.chdir(ROOT_DIR)
import os import os
import contextlib
import importlib
import hydra import hydra
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -15,7 +17,6 @@ import pathlib
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import copy import copy
import random import random
import wandb
import tqdm import tqdm
import numpy as np import numpy as np
import shutil import shutil
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
OmegaConf.register_new_resolver("eval", eval, replace=True) OmegaConf.register_new_resolver("eval", eval, replace=True)
class _LoggingBackend:
def log(self, payload, step=None):
raise NotImplementedError
def finish(self):
raise NotImplementedError
class _WandbLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
class _SwanLabLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
def _load_wandb():
try:
return importlib.import_module('wandb')
except ImportError as exc:
raise ImportError(
"wandb is required when cfg.logging.backend == 'wandb' or missing"
) from exc
def _load_swanlab():
try:
return importlib.import_module('swanlab')
except ImportError:
return None
def init_logging_backend(cfg: OmegaConf, output_dir):
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
if backend == 'swanlab':
swanlab = _load_swanlab()
if swanlab is None:
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
logging_cfg = cfg.logging
mode = logging_cfg.mode
if mode == 'online':
mode = 'cloud'
run = swanlab.init(
project=logging_cfg.project,
experiment_name=logging_cfg.name,
group=logging_cfg.group,
tags=logging_cfg.tags,
id=logging_cfg.id,
resume=logging_cfg.resume,
mode=mode,
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
config=OmegaConf.to_container(cfg, resolve=True),
)
return _SwanLabLoggingBackend(run)
if backend not in (None, 'wandb'):
raise ValueError(f"Unknown logging backend: {backend}")
wandb = _load_wandb()
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
logging_kwargs.pop('backend', None)
run = wandb.init(
dir=str(output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**logging_kwargs
)
wandb.config.update(
{
"output_dir": str(output_dir),
}
)
return _WandbLoggingBackend(run)
@contextlib.contextmanager
def logging_backend_session(cfg: OmegaConf, output_dir):
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
primary_error = None
try:
yield logging_backend
except BaseException as exc:
primary_error = exc
raise
finally:
try:
logging_backend.finish()
except BaseException:
if primary_error is None:
raise
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace): class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
include_keys = ['global_step', 'epoch'] include_keys = ['global_step', 'epoch']
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
output_dir=self.output_dir) output_dir=self.output_dir)
assert isinstance(env_runner, BaseImageRunner) assert isinstance(env_runner, BaseImageRunner)
# configure logging
wandb_run = wandb.init(
dir=str(self.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": self.output_dir,
}
)
# configure checkpoint # configure checkpoint
topk_manager = TopKCheckpointManager( topk_manager = TopKCheckpointManager(
save_dir=os.path.join(self.output_dir, 'checkpoints'), save_dir=os.path.join(self.output_dir, 'checkpoints'),
@@ -148,6 +242,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# training loop # training loop
log_path = os.path.join(self.output_dir, 'logs.json.txt') log_path = os.path.join(self.output_dir, 'logs.json.txt')
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
with JsonLogger(log_path) as json_logger: with JsonLogger(log_path) as json_logger:
for local_epoch_idx in range(cfg.training.num_epochs): for local_epoch_idx in range(cfg.training.num_epochs):
step_log = dict() step_log = dict()
@@ -190,7 +285,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
is_last_batch = (batch_idx == (len(train_dataloader)-1)) is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch: if not is_last_batch:
# log of last step is combined with validation and rollout # log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step) logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log) json_logger.log(step_log)
self.global_step += 1 self.global_step += 1
@@ -278,7 +373,7 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# end of epoch # end of epoch
# log of last step is combined with validation and rollout # log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step) logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log) json_logger.log(step_log)
self.global_step += 1 self.global_step += 1
self.epoch += 1 self.epoch += 1

View File

@@ -0,0 +1,28 @@
defaults:
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
- override /diffusion_policy/config/task@task: pusht_image
- _self_
exp_name: pusht_image_dit
policy:
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
logging:
backend: swanlab
mode: online
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
id: ${now:%Y%m%d%H%M%S}_${name}_${task_name}
group: ${exp_name}
dataloader:
num_workers: 0
val_dataloader:
num_workers: 0
task:
env_runner:
n_envs: 1
n_test_vis: 0
n_train_vis: 0

View File

@@ -36,3 +36,4 @@ av==14.0.1
pygame==2.5.2 pygame==2.5.2
robomimic==0.2.0 robomimic==0.2.0
opencv-python-headless==4.10.0.84 opencv-python-headless==4.10.0.84
swanlab

View File

@@ -0,0 +1,110 @@
import pathlib
import sys
import gym
from gym import spaces
import numpy as np
import pytest
import torch
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
import diffusion_policy.env_runner.pusht_image_runner as runner_module
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
class FakePushTImageEnv(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, legacy=False, render_size=96):
del legacy, render_size
self.observation_space = spaces.Dict({
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
})
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
self.seed_value = 0
self.step_count = 0
def seed(self, seed=None):
self.seed_value = 0 if seed is None else seed
def reset(self):
self.step_count = 0
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
def step(self, action):
del action
self.step_count += 1
reward = 0.1 if self.seed_value < 10000 else 0.9
done = self.step_count >= 1
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
return obs, reward, done, {}
def render(self, *args, **kwargs):
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
class FakePolicy:
device = torch.device('cpu')
dtype = torch.float32
def reset(self):
return None
def predict_action(self, obs_dict):
n_envs = next(iter(obs_dict.values())).shape[0]
return {
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
}
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
log_data = summarize_rollout_metrics(
env_seeds=[11, 12, 101],
env_prefixs=['train/', 'train/', 'test/'],
all_rewards=[
[0.2, 0.8],
[0.1, 0.4],
[0.5, 0.9],
],
all_video_paths=[
'/tmp/train-11.mp4',
'/tmp/train-12.mp4',
'/tmp/test-101.mp4',
],
)
assert log_data['train/sim_max_reward_11'] == 0.8
assert log_data['train/sim_max_reward_12'] == 0.4
assert log_data['test/sim_max_reward_101'] == 0.9
assert log_data['train_mean_score'] == pytest.approx(0.6)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any(key.startswith('train/sim_video_') for key in log_data)
assert not any(key.startswith('test/sim_video_') for key in log_data)
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
runner = runner_module.PushTImageRunner(
output_dir=tmp_path,
n_train=1,
n_train_vis=1,
n_test=1,
n_test_vis=1,
n_envs=2,
max_steps=2,
n_obs_steps=2,
n_action_steps=2,
tqdm_interval_sec=0.0,
)
log_data = runner.run(FakePolicy())
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
assert log_data['train_mean_score'] == pytest.approx(0.1)
assert log_data['test_mean_score'] == pytest.approx(0.9)
assert not any('sim_video' in key for key in log_data)

View File

@@ -0,0 +1,198 @@
import importlib
import pathlib
import sys
import pytest
from omegaconf import OmegaConf
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
def load_workspace_module(monkeypatch, *, wandb_missing=False):
sys.modules.pop(MODULE_NAME, None)
if wandb_missing:
monkeypatch.setitem(sys.modules, 'wandb', None)
return importlib.import_module(MODULE_NAME)
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeSwanLab:
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
monkeypatch.setattr(
workspace_module,
'_load_wandb',
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
)
cfg = OmegaConf.create({
'logging': {
'backend': 'swanlab',
'project': 'demo-project',
'name': 'demo-run',
'group': 'demo-group',
'tags': ['pusht', 'dit'],
'id': 'run-123',
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 1.0}, step=7)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['experiment_name'] == 'demo-run'
assert init_kwargs['group'] == 'demo-group'
assert init_kwargs['tags'] == ['pusht', 'dit']
assert init_kwargs['id'] == 'run-123'
assert init_kwargs['resume'] is True
assert init_kwargs['mode'] == 'cloud'
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
assert ('log', {'metric': 1.0}, 7) in events
assert events.count(('finish',)) == 1
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeRun:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
class FakeConfig:
def update(self, payload):
events.append(('config.update', payload))
class FakeWandb:
def __init__(self):
self.config = FakeConfig()
def init(self, **kwargs):
events.append(('init', kwargs))
return FakeRun()
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
cfg = OmegaConf.create({
'logging': {
'project': 'demo-project',
'name': 'demo-run',
'group': None,
'tags': ['shared'],
'id': None,
'resume': True,
'mode': 'online',
}
})
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
logger.log({'metric': 2.0}, step=3)
logger.finish()
assert events[0][0] == 'init'
init_kwargs = events[0][1]
assert init_kwargs['dir'] == str(tmp_path)
assert init_kwargs['project'] == 'demo-project'
assert init_kwargs['name'] == 'demo-run'
assert init_kwargs['mode'] == 'online'
assert ('config.update', {'output_dir': str(tmp_path)}) in events
assert ('log', {'metric': 2.0}, 3) in events
assert events.count(('finish',)) == 1
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
cfg = OmegaConf.create({
'logging': {
'backend': 'tensorboard',
'project': 'demo-project',
'name': 'demo-run',
'mode': 'offline',
}
})
with pytest.raises(ValueError, match='Unknown logging backend'):
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
raise RuntimeError('finish boom')
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(ValueError, match='primary boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 6.0}, step=12)
raise ValueError('primary boom')
assert ('log', {'metric': 6.0}, 12) in events
assert events.count(('finish',)) == 1
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
workspace_module = load_workspace_module(monkeypatch)
events = []
class FakeBackend:
def log(self, payload, step=None):
events.append(('log', payload, step))
def finish(self):
events.append(('finish',))
monkeypatch.setattr(
workspace_module,
'init_logging_backend',
lambda cfg, output_dir: FakeBackend(),
)
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
with pytest.raises(RuntimeError, match='boom'):
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
logger.log({'metric': 5.0}, step=11)
raise RuntimeError('boom')
assert ('log', {'metric': 5.0}, 11) in events
assert events.count(('finish',)) == 1