feat: switch pusht transformer logging to swanlab

This commit is contained in:
Logic
2026-03-26 19:49:45 +08:00
parent 23374a4cd2
commit 5e7ae6cfa5
6 changed files with 601 additions and 216 deletions

View File

@@ -1,21 +1,37 @@
import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import dill
import math
import wandb.sdk.data_types.video as wv
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
del all_video_paths
max_rewards = collections.defaultdict(list)
log_data = dict()
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
max_reward = np.max(rewards)
max_rewards[prefix].append(max_reward)
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
aggregate_key_map = {
'train/': 'train_mean_score',
'test/': 'test_mean_score',
}
for prefix, value in max_rewards.items():
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
return log_data
class PushTImageRunner(BaseImageRunner):
def __init__(self,
output_dir,
@@ -40,24 +56,11 @@ class PushTImageRunner(BaseImageRunner):
if n_envs is None:
n_envs = n_train + n_test
steps_per_render = max(10 // fps, 1)
def env_fn():
return MultiStepWrapper(
VideoRecordingWrapper(
PushTImageEnv(
legacy=legacy_test,
render_size=render_size
),
video_recoder=VideoRecorder.create_h264(
fps=fps,
codec='h264',
input_pix_fmt='rgb24',
crf=crf,
thread_type='FRAME',
thread_count=1
),
file_path=None,
steps_per_render=steps_per_render
PushTImageEnv(
legacy=legacy_test,
render_size=render_size
),
n_obs_steps=n_obs_steps,
n_action_steps=n_action_steps,
@@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
# train
for i in range(n_train):
seed = train_start_seed + i
enable_render = i < n_train_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed
assert isinstance(env, MultiStepWrapper)
env.seed(seed)
@@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
# test
for i in range(n_test):
seed = test_start_seed + i
enable_render = i < n_test_vis
def init_fn(env, seed=seed, enable_render=enable_render):
# setup rendering
# video_wrapper
assert isinstance(env.env, VideoRecordingWrapper)
env.env.video_recoder.stop()
env.env.file_path = None
if enable_render:
filename = pathlib.Path(output_dir).joinpath(
'media', wv.util.generate_id() + ".mp4")
filename.parent.mkdir(parents=False, exist_ok=True)
filename = str(filename)
env.env.file_path = filename
def init_fn(env, seed=seed):
# set seed
assert isinstance(env, MultiStepWrapper)
env.seed(seed)
@@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
n_chunks = math.ceil(n_inits / n_envs)
# allocate data
all_video_paths = [None] * n_inits
all_rewards = [None] * n_inits
for chunk_idx in range(n_chunks):
@@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
pbar.update(action.shape[1])
pbar.close()
all_video_paths[this_global_slice] = env.render()[this_local_slice]
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
# clear out video buffer
# reset env state between evaluation calls
_ = env.reset()
# log
max_rewards = collections.defaultdict(list)
log_data = dict()
# results reported in the paper are generated using the commented out line below
# which will only report and average metrics from first n_envs initial condition and seeds
# fortunately this won't invalidate our conclusion since
# 1. This bug only affects the variance of metrics, not their mean
# 2. All baseline methods are evaluated using the same code
# to completely reproduce reported numbers, uncomment this line:
# for i in range(len(self.env_fns)):
# and comment out this line
for i in range(n_inits):
seed = self.env_seeds[i]
prefix = self.env_prefixs[i]
max_reward = np.max(all_rewards[i])
max_rewards[prefix].append(max_reward)
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
# visualize sim
video_path = all_video_paths[i]
if video_path is not None:
sim_video = wandb.Video(video_path)
log_data[prefix+f'sim_video_{seed}'] = sim_video
# log aggregate metrics
for prefix, value in max_rewards.items():
name = prefix+'mean_score'
value = np.mean(value)
log_data[name] = value
return log_data
# results reported in the paper are generated using the commented out
# line below, which would only report and average metrics from the
# first n_envs initial conditions and seeds. We keep the full n_inits
# behavior here.
return summarize_rollout_metrics(
env_seeds=self.env_seeds[:n_inits],
env_prefixs=self.env_prefixs[:n_inits],
all_rewards=all_rewards[:n_inits],
)

View File

@@ -8,6 +8,8 @@ if __name__ == "__main__":
os.chdir(ROOT_DIR)
import os
import contextlib
import importlib
import hydra
import torch
from omegaconf import OmegaConf
@@ -15,7 +17,6 @@ import pathlib
from torch.utils.data import DataLoader
import copy
import random
import wandb
import tqdm
import numpy as np
import shutil
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
OmegaConf.register_new_resolver("eval", eval, replace=True)
class _LoggingBackend:
def log(self, payload, step=None):
raise NotImplementedError
def finish(self):
raise NotImplementedError
class _WandbLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
class _SwanLabLoggingBackend(_LoggingBackend):
def __init__(self, run):
self.run = run
def log(self, payload, step=None):
self.run.log(payload, step=step)
def finish(self):
self.run.finish()
def _load_wandb():
try:
return importlib.import_module('wandb')
except ImportError as exc:
raise ImportError(
"wandb is required when cfg.logging.backend == 'wandb' or missing"
) from exc
def _load_swanlab():
try:
return importlib.import_module('swanlab')
except ImportError:
return None
def init_logging_backend(cfg: OmegaConf, output_dir):
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
if backend == 'swanlab':
swanlab = _load_swanlab()
if swanlab is None:
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
logging_cfg = cfg.logging
mode = logging_cfg.mode
if mode == 'online':
mode = 'cloud'
run = swanlab.init(
project=logging_cfg.project,
experiment_name=logging_cfg.name,
group=logging_cfg.group,
tags=logging_cfg.tags,
id=logging_cfg.id,
resume=logging_cfg.resume,
mode=mode,
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
config=OmegaConf.to_container(cfg, resolve=True),
)
return _SwanLabLoggingBackend(run)
if backend not in (None, 'wandb'):
raise ValueError(f"Unknown logging backend: {backend}")
wandb = _load_wandb()
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
logging_kwargs.pop('backend', None)
run = wandb.init(
dir=str(output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**logging_kwargs
)
wandb.config.update(
{
"output_dir": str(output_dir),
}
)
return _WandbLoggingBackend(run)
@contextlib.contextmanager
def logging_backend_session(cfg: OmegaConf, output_dir):
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
primary_error = None
try:
yield logging_backend
except BaseException as exc:
primary_error = exc
raise
finally:
try:
logging_backend.finish()
except BaseException:
if primary_error is None:
raise
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
include_keys = ['global_step', 'epoch']
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
output_dir=self.output_dir)
assert isinstance(env_runner, BaseImageRunner)
# configure logging
wandb_run = wandb.init(
dir=str(self.output_dir),
config=OmegaConf.to_container(cfg, resolve=True),
**cfg.logging
)
wandb.config.update(
{
"output_dir": self.output_dir,
}
)
# configure checkpoint
topk_manager = TopKCheckpointManager(
save_dir=os.path.join(self.output_dir, 'checkpoints'),
@@ -148,140 +242,141 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
# training loop
log_path = os.path.join(self.output_dir, 'logs.json.txt')
with JsonLogger(log_path) as json_logger:
for local_epoch_idx in range(cfg.training.num_epochs):
step_log = dict()
# ========= train for this epoch ==========
train_losses = list()
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
# device transfer
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
if train_sampling_batch is None:
train_sampling_batch = batch
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
with JsonLogger(log_path) as json_logger:
for local_epoch_idx in range(cfg.training.num_epochs):
step_log = dict()
# ========= train for this epoch ==========
train_losses = list()
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
# device transfer
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
if train_sampling_batch is None:
train_sampling_batch = batch
# compute loss
raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward()
# compute loss
raw_loss = self.model.compute_loss(batch)
loss = raw_loss / cfg.training.gradient_accumulate_every
loss.backward()
# step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
lr_scheduler.step()
# update ema
if cfg.training.use_ema:
ema.step(self.model)
# step optimizer
if self.global_step % cfg.training.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
lr_scheduler.step()
# update ema
if cfg.training.use_ema:
ema.step(self.model)
# logging
raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
# logging
raw_loss_cpu = raw_loss.item()
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
train_losses.append(raw_loss_cpu)
step_log = {
'train_loss': raw_loss_cpu,
'global_step': self.global_step,
'epoch': self.epoch,
'lr': lr_scheduler.get_last_lr()[0]
}
is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch:
# log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
is_last_batch = (batch_idx == (len(train_dataloader)-1))
if not is_last_batch:
# log of last step is combined with validation and rollout
logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
if (cfg.training.max_train_steps is not None) \
and batch_idx >= (cfg.training.max_train_steps-1):
break
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# at the end of each epoch
# replace train_loss with epoch average
train_loss = np.mean(train_losses)
step_log['train_loss'] = train_loss
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# ========= eval for this epoch ==========
policy = self.model
if cfg.training.use_ema:
policy = self.ema_model
policy.eval()
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run validation
if (self.epoch % cfg.training.val_every) == 0:
with torch.no_grad():
val_losses = list()
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
for batch_idx, batch in enumerate(tepoch):
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
loss = self.model.compute_loss(batch)
val_losses.append(loss)
if (cfg.training.max_val_steps is not None) \
and batch_idx >= (cfg.training.max_val_steps-1):
break
if len(val_losses) > 0:
val_loss = torch.mean(torch.tensor(val_losses)).item()
# log epoch average validation loss
step_log['val_loss'] = val_loss
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# run diffusion sampling on a training batch
if (self.epoch % cfg.training.sample_every) == 0:
with torch.no_grad():
# sample trajectory from training set, and evaluate difference
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
obs_dict = batch['obs']
gt_action = batch['action']
result = policy.predict_action(obs_dict)
pred_action = result['action_pred']
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
step_log['train_action_mse_error'] = mse.item()
del batch
del obs_dict
del gt_action
del result
del pred_action
del mse
# We can't copy the last checkpoint here
# since save_checkpoint uses threads.
# therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
# checkpoint
if (self.epoch % cfg.training.checkpoint_every) == 0:
# checkpointing
if cfg.checkpoint.save_last_ckpt:
self.save_checkpoint()
if cfg.checkpoint.save_last_snapshot:
self.save_snapshot()
if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ==========
policy.train()
# sanitize metric names
metric_dict = dict()
for key, value in step_log.items():
new_key = key.replace('/', '_')
metric_dict[new_key] = value
# We can't copy the last checkpoint here
# since save_checkpoint uses threads.
# therefore at this point the file might have been empty!
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
# end of epoch
# log of last step is combined with validation and rollout
wandb_run.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
self.epoch += 1
if topk_ckpt_path is not None:
self.save_checkpoint(path=topk_ckpt_path)
# ========= eval end for this epoch ==========
policy.train()
# end of epoch
# log of last step is combined with validation and rollout
logging_backend.log(step_log, step=self.global_step)
json_logger.log(step_log)
self.global_step += 1
self.epoch += 1
@hydra.main(
version_base=None,