feat: switch pusht transformer logging to swanlab
This commit is contained in:
@@ -1,21 +1,37 @@
|
|||||||
import wandb
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import collections
|
import collections
|
||||||
import pathlib
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import dill
|
import dill
|
||||||
import math
|
import math
|
||||||
import wandb.sdk.data_types.video as wv
|
|
||||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||||
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
|
||||||
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
|
||||||
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
|
|
||||||
|
|
||||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||||
from diffusion_policy.common.pytorch_util import dict_apply
|
from diffusion_policy.common.pytorch_util import dict_apply
|
||||||
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_rollout_metrics(env_seeds, env_prefixs, all_rewards, all_video_paths=None):
|
||||||
|
del all_video_paths
|
||||||
|
|
||||||
|
max_rewards = collections.defaultdict(list)
|
||||||
|
log_data = dict()
|
||||||
|
for seed, prefix, rewards in zip(env_seeds, env_prefixs, all_rewards):
|
||||||
|
max_reward = np.max(rewards)
|
||||||
|
max_rewards[prefix].append(max_reward)
|
||||||
|
log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
|
||||||
|
|
||||||
|
aggregate_key_map = {
|
||||||
|
'train/': 'train_mean_score',
|
||||||
|
'test/': 'test_mean_score',
|
||||||
|
}
|
||||||
|
for prefix, value in max_rewards.items():
|
||||||
|
log_data[aggregate_key_map.get(prefix, prefix + 'mean_score')] = np.mean(value)
|
||||||
|
|
||||||
|
return log_data
|
||||||
|
|
||||||
class PushTImageRunner(BaseImageRunner):
|
class PushTImageRunner(BaseImageRunner):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
output_dir,
|
output_dir,
|
||||||
@@ -40,24 +56,11 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
if n_envs is None:
|
if n_envs is None:
|
||||||
n_envs = n_train + n_test
|
n_envs = n_train + n_test
|
||||||
|
|
||||||
steps_per_render = max(10 // fps, 1)
|
|
||||||
def env_fn():
|
def env_fn():
|
||||||
return MultiStepWrapper(
|
return MultiStepWrapper(
|
||||||
VideoRecordingWrapper(
|
PushTImageEnv(
|
||||||
PushTImageEnv(
|
legacy=legacy_test,
|
||||||
legacy=legacy_test,
|
render_size=render_size
|
||||||
render_size=render_size
|
|
||||||
),
|
|
||||||
video_recoder=VideoRecorder.create_h264(
|
|
||||||
fps=fps,
|
|
||||||
codec='h264',
|
|
||||||
input_pix_fmt='rgb24',
|
|
||||||
crf=crf,
|
|
||||||
thread_type='FRAME',
|
|
||||||
thread_count=1
|
|
||||||
),
|
|
||||||
file_path=None,
|
|
||||||
steps_per_render=steps_per_render
|
|
||||||
),
|
),
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
n_action_steps=n_action_steps,
|
n_action_steps=n_action_steps,
|
||||||
@@ -71,21 +74,8 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
# train
|
# train
|
||||||
for i in range(n_train):
|
for i in range(n_train):
|
||||||
seed = train_start_seed + i
|
seed = train_start_seed + i
|
||||||
enable_render = i < n_train_vis
|
|
||||||
|
|
||||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
|
||||||
# setup rendering
|
|
||||||
# video_wrapper
|
|
||||||
assert isinstance(env.env, VideoRecordingWrapper)
|
|
||||||
env.env.video_recoder.stop()
|
|
||||||
env.env.file_path = None
|
|
||||||
if enable_render:
|
|
||||||
filename = pathlib.Path(output_dir).joinpath(
|
|
||||||
'media', wv.util.generate_id() + ".mp4")
|
|
||||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
|
||||||
filename = str(filename)
|
|
||||||
env.env.file_path = filename
|
|
||||||
|
|
||||||
|
def init_fn(env, seed=seed):
|
||||||
# set seed
|
# set seed
|
||||||
assert isinstance(env, MultiStepWrapper)
|
assert isinstance(env, MultiStepWrapper)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -97,21 +87,8 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
# test
|
# test
|
||||||
for i in range(n_test):
|
for i in range(n_test):
|
||||||
seed = test_start_seed + i
|
seed = test_start_seed + i
|
||||||
enable_render = i < n_test_vis
|
|
||||||
|
|
||||||
def init_fn(env, seed=seed, enable_render=enable_render):
|
|
||||||
# setup rendering
|
|
||||||
# video_wrapper
|
|
||||||
assert isinstance(env.env, VideoRecordingWrapper)
|
|
||||||
env.env.video_recoder.stop()
|
|
||||||
env.env.file_path = None
|
|
||||||
if enable_render:
|
|
||||||
filename = pathlib.Path(output_dir).joinpath(
|
|
||||||
'media', wv.util.generate_id() + ".mp4")
|
|
||||||
filename.parent.mkdir(parents=False, exist_ok=True)
|
|
||||||
filename = str(filename)
|
|
||||||
env.env.file_path = filename
|
|
||||||
|
|
||||||
|
def init_fn(env, seed=seed):
|
||||||
# set seed
|
# set seed
|
||||||
assert isinstance(env, MultiStepWrapper)
|
assert isinstance(env, MultiStepWrapper)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -154,7 +131,6 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
n_chunks = math.ceil(n_inits / n_envs)
|
n_chunks = math.ceil(n_inits / n_envs)
|
||||||
|
|
||||||
# allocate data
|
# allocate data
|
||||||
all_video_paths = [None] * n_inits
|
|
||||||
all_rewards = [None] * n_inits
|
all_rewards = [None] * n_inits
|
||||||
|
|
||||||
for chunk_idx in range(n_chunks):
|
for chunk_idx in range(n_chunks):
|
||||||
@@ -214,39 +190,16 @@ class PushTImageRunner(BaseImageRunner):
|
|||||||
pbar.update(action.shape[1])
|
pbar.update(action.shape[1])
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
all_video_paths[this_global_slice] = env.render()[this_local_slice]
|
|
||||||
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
|
||||||
# clear out video buffer
|
# reset env state between evaluation calls
|
||||||
_ = env.reset()
|
_ = env.reset()
|
||||||
|
|
||||||
# log
|
# results reported in the paper are generated using the commented out
|
||||||
max_rewards = collections.defaultdict(list)
|
# line below, which would only report and average metrics from the
|
||||||
log_data = dict()
|
# first n_envs initial conditions and seeds. We keep the full n_inits
|
||||||
# results reported in the paper are generated using the commented out line below
|
# behavior here.
|
||||||
# which will only report and average metrics from first n_envs initial condition and seeds
|
return summarize_rollout_metrics(
|
||||||
# fortunately this won't invalidate our conclusion since
|
env_seeds=self.env_seeds[:n_inits],
|
||||||
# 1. This bug only affects the variance of metrics, not their mean
|
env_prefixs=self.env_prefixs[:n_inits],
|
||||||
# 2. All baseline methods are evaluated using the same code
|
all_rewards=all_rewards[:n_inits],
|
||||||
# to completely reproduce reported numbers, uncomment this line:
|
)
|
||||||
# for i in range(len(self.env_fns)):
|
|
||||||
# and comment out this line
|
|
||||||
for i in range(n_inits):
|
|
||||||
seed = self.env_seeds[i]
|
|
||||||
prefix = self.env_prefixs[i]
|
|
||||||
max_reward = np.max(all_rewards[i])
|
|
||||||
max_rewards[prefix].append(max_reward)
|
|
||||||
log_data[prefix+f'sim_max_reward_{seed}'] = max_reward
|
|
||||||
|
|
||||||
# visualize sim
|
|
||||||
video_path = all_video_paths[i]
|
|
||||||
if video_path is not None:
|
|
||||||
sim_video = wandb.Video(video_path)
|
|
||||||
log_data[prefix+f'sim_video_{seed}'] = sim_video
|
|
||||||
|
|
||||||
# log aggregate metrics
|
|
||||||
for prefix, value in max_rewards.items():
|
|
||||||
name = prefix+'mean_score'
|
|
||||||
value = np.mean(value)
|
|
||||||
log_data[name] = value
|
|
||||||
|
|
||||||
return log_data
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ if __name__ == "__main__":
|
|||||||
os.chdir(ROOT_DIR)
|
os.chdir(ROOT_DIR)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -15,7 +17,6 @@ import pathlib
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
import wandb
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import shutil
|
import shutil
|
||||||
@@ -31,6 +32,111 @@ from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
|||||||
|
|
||||||
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
OmegaConf.register_new_resolver("eval", eval, replace=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _LoggingBackend:
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _WandbLoggingBackend(_LoggingBackend):
|
||||||
|
def __init__(self, run):
|
||||||
|
self.run = run
|
||||||
|
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
self.run.log(payload, step=step)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
class _SwanLabLoggingBackend(_LoggingBackend):
|
||||||
|
def __init__(self, run):
|
||||||
|
self.run = run
|
||||||
|
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
self.run.log(payload, step=step)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.run.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_wandb():
|
||||||
|
try:
|
||||||
|
return importlib.import_module('wandb')
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"wandb is required when cfg.logging.backend == 'wandb' or missing"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _load_swanlab():
|
||||||
|
try:
|
||||||
|
return importlib.import_module('swanlab')
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def init_logging_backend(cfg: OmegaConf, output_dir):
|
||||||
|
backend = OmegaConf.select(cfg, 'logging.backend', default='wandb')
|
||||||
|
if backend == 'swanlab':
|
||||||
|
swanlab = _load_swanlab()
|
||||||
|
if swanlab is None:
|
||||||
|
raise ImportError("swanlab is required when cfg.logging.backend == 'swanlab'")
|
||||||
|
logging_cfg = cfg.logging
|
||||||
|
mode = logging_cfg.mode
|
||||||
|
if mode == 'online':
|
||||||
|
mode = 'cloud'
|
||||||
|
run = swanlab.init(
|
||||||
|
project=logging_cfg.project,
|
||||||
|
experiment_name=logging_cfg.name,
|
||||||
|
group=logging_cfg.group,
|
||||||
|
tags=logging_cfg.tags,
|
||||||
|
id=logging_cfg.id,
|
||||||
|
resume=logging_cfg.resume,
|
||||||
|
mode=mode,
|
||||||
|
logdir=str(pathlib.Path(output_dir) / 'swanlog'),
|
||||||
|
config=OmegaConf.to_container(cfg, resolve=True),
|
||||||
|
)
|
||||||
|
return _SwanLabLoggingBackend(run)
|
||||||
|
|
||||||
|
if backend not in (None, 'wandb'):
|
||||||
|
raise ValueError(f"Unknown logging backend: {backend}")
|
||||||
|
|
||||||
|
wandb = _load_wandb()
|
||||||
|
logging_kwargs = OmegaConf.to_container(cfg.logging, resolve=True)
|
||||||
|
logging_kwargs.pop('backend', None)
|
||||||
|
run = wandb.init(
|
||||||
|
dir=str(output_dir),
|
||||||
|
config=OmegaConf.to_container(cfg, resolve=True),
|
||||||
|
**logging_kwargs
|
||||||
|
)
|
||||||
|
wandb.config.update(
|
||||||
|
{
|
||||||
|
"output_dir": str(output_dir),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return _WandbLoggingBackend(run)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def logging_backend_session(cfg: OmegaConf, output_dir):
|
||||||
|
logging_backend = init_logging_backend(cfg=cfg, output_dir=output_dir)
|
||||||
|
primary_error = None
|
||||||
|
try:
|
||||||
|
yield logging_backend
|
||||||
|
except BaseException as exc:
|
||||||
|
primary_error = exc
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
logging_backend.finish()
|
||||||
|
except BaseException:
|
||||||
|
if primary_error is None:
|
||||||
|
raise
|
||||||
|
|
||||||
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
||||||
include_keys = ['global_step', 'epoch']
|
include_keys = ['global_step', 'epoch']
|
||||||
|
|
||||||
@@ -109,18 +215,6 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
output_dir=self.output_dir)
|
output_dir=self.output_dir)
|
||||||
assert isinstance(env_runner, BaseImageRunner)
|
assert isinstance(env_runner, BaseImageRunner)
|
||||||
|
|
||||||
# configure logging
|
|
||||||
wandb_run = wandb.init(
|
|
||||||
dir=str(self.output_dir),
|
|
||||||
config=OmegaConf.to_container(cfg, resolve=True),
|
|
||||||
**cfg.logging
|
|
||||||
)
|
|
||||||
wandb.config.update(
|
|
||||||
{
|
|
||||||
"output_dir": self.output_dir,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# configure checkpoint
|
# configure checkpoint
|
||||||
topk_manager = TopKCheckpointManager(
|
topk_manager = TopKCheckpointManager(
|
||||||
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
save_dir=os.path.join(self.output_dir, 'checkpoints'),
|
||||||
@@ -148,140 +242,141 @@ class TrainDiffusionTransformerHybridWorkspace(BaseWorkspace):
|
|||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
log_path = os.path.join(self.output_dir, 'logs.json.txt')
|
||||||
with JsonLogger(log_path) as json_logger:
|
with logging_backend_session(cfg=cfg, output_dir=self.output_dir) as logging_backend:
|
||||||
for local_epoch_idx in range(cfg.training.num_epochs):
|
with JsonLogger(log_path) as json_logger:
|
||||||
step_log = dict()
|
for local_epoch_idx in range(cfg.training.num_epochs):
|
||||||
# ========= train for this epoch ==========
|
step_log = dict()
|
||||||
train_losses = list()
|
# ========= train for this epoch ==========
|
||||||
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
train_losses = list()
|
||||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}",
|
||||||
for batch_idx, batch in enumerate(tepoch):
|
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||||
# device transfer
|
for batch_idx, batch in enumerate(tepoch):
|
||||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
# device transfer
|
||||||
if train_sampling_batch is None:
|
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||||
train_sampling_batch = batch
|
if train_sampling_batch is None:
|
||||||
|
train_sampling_batch = batch
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
raw_loss = self.model.compute_loss(batch)
|
raw_loss = self.model.compute_loss(batch)
|
||||||
loss = raw_loss / cfg.training.gradient_accumulate_every
|
loss = raw_loss / cfg.training.gradient_accumulate_every
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# step optimizer
|
# step optimizer
|
||||||
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
if self.global_step % cfg.training.gradient_accumulate_every == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# update ema
|
# update ema
|
||||||
if cfg.training.use_ema:
|
if cfg.training.use_ema:
|
||||||
ema.step(self.model)
|
ema.step(self.model)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
raw_loss_cpu = raw_loss.item()
|
raw_loss_cpu = raw_loss.item()
|
||||||
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
tepoch.set_postfix(loss=raw_loss_cpu, refresh=False)
|
||||||
train_losses.append(raw_loss_cpu)
|
train_losses.append(raw_loss_cpu)
|
||||||
step_log = {
|
step_log = {
|
||||||
'train_loss': raw_loss_cpu,
|
'train_loss': raw_loss_cpu,
|
||||||
'global_step': self.global_step,
|
'global_step': self.global_step,
|
||||||
'epoch': self.epoch,
|
'epoch': self.epoch,
|
||||||
'lr': lr_scheduler.get_last_lr()[0]
|
'lr': lr_scheduler.get_last_lr()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
is_last_batch = (batch_idx == (len(train_dataloader)-1))
|
||||||
if not is_last_batch:
|
if not is_last_batch:
|
||||||
# log of last step is combined with validation and rollout
|
# log of last step is combined with validation and rollout
|
||||||
wandb_run.log(step_log, step=self.global_step)
|
logging_backend.log(step_log, step=self.global_step)
|
||||||
json_logger.log(step_log)
|
json_logger.log(step_log)
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
if (cfg.training.max_train_steps is not None) \
|
if (cfg.training.max_train_steps is not None) \
|
||||||
and batch_idx >= (cfg.training.max_train_steps-1):
|
and batch_idx >= (cfg.training.max_train_steps-1):
|
||||||
break
|
break
|
||||||
|
|
||||||
# at the end of each epoch
|
# at the end of each epoch
|
||||||
# replace train_loss with epoch average
|
# replace train_loss with epoch average
|
||||||
train_loss = np.mean(train_losses)
|
train_loss = np.mean(train_losses)
|
||||||
step_log['train_loss'] = train_loss
|
step_log['train_loss'] = train_loss
|
||||||
|
|
||||||
# ========= eval for this epoch ==========
|
# ========= eval for this epoch ==========
|
||||||
policy = self.model
|
policy = self.model
|
||||||
if cfg.training.use_ema:
|
if cfg.training.use_ema:
|
||||||
policy = self.ema_model
|
policy = self.ema_model
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
# run rollout
|
# run rollout
|
||||||
if (self.epoch % cfg.training.rollout_every) == 0:
|
if (self.epoch % cfg.training.rollout_every) == 0:
|
||||||
runner_log = env_runner.run(policy)
|
runner_log = env_runner.run(policy)
|
||||||
# log all
|
# log all
|
||||||
step_log.update(runner_log)
|
step_log.update(runner_log)
|
||||||
|
|
||||||
# run validation
|
# run validation
|
||||||
if (self.epoch % cfg.training.val_every) == 0:
|
if (self.epoch % cfg.training.val_every) == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_losses = list()
|
val_losses = list()
|
||||||
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}",
|
||||||
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch:
|
||||||
for batch_idx, batch in enumerate(tepoch):
|
for batch_idx, batch in enumerate(tepoch):
|
||||||
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True))
|
||||||
loss = self.model.compute_loss(batch)
|
loss = self.model.compute_loss(batch)
|
||||||
val_losses.append(loss)
|
val_losses.append(loss)
|
||||||
if (cfg.training.max_val_steps is not None) \
|
if (cfg.training.max_val_steps is not None) \
|
||||||
and batch_idx >= (cfg.training.max_val_steps-1):
|
and batch_idx >= (cfg.training.max_val_steps-1):
|
||||||
break
|
break
|
||||||
if len(val_losses) > 0:
|
if len(val_losses) > 0:
|
||||||
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
||||||
# log epoch average validation loss
|
# log epoch average validation loss
|
||||||
step_log['val_loss'] = val_loss
|
step_log['val_loss'] = val_loss
|
||||||
|
|
||||||
# run diffusion sampling on a training batch
|
# run diffusion sampling on a training batch
|
||||||
if (self.epoch % cfg.training.sample_every) == 0:
|
if (self.epoch % cfg.training.sample_every) == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# sample trajectory from training set, and evaluate difference
|
# sample trajectory from training set, and evaluate difference
|
||||||
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
|
batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True))
|
||||||
obs_dict = batch['obs']
|
obs_dict = batch['obs']
|
||||||
gt_action = batch['action']
|
gt_action = batch['action']
|
||||||
|
|
||||||
result = policy.predict_action(obs_dict)
|
result = policy.predict_action(obs_dict)
|
||||||
pred_action = result['action_pred']
|
pred_action = result['action_pred']
|
||||||
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
mse = torch.nn.functional.mse_loss(pred_action, gt_action)
|
||||||
step_log['train_action_mse_error'] = mse.item()
|
step_log['train_action_mse_error'] = mse.item()
|
||||||
del batch
|
del batch
|
||||||
del obs_dict
|
del obs_dict
|
||||||
del gt_action
|
del gt_action
|
||||||
del result
|
del result
|
||||||
del pred_action
|
del pred_action
|
||||||
del mse
|
del mse
|
||||||
|
|
||||||
# checkpoint
|
|
||||||
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
|
||||||
# checkpointing
|
|
||||||
if cfg.checkpoint.save_last_ckpt:
|
|
||||||
self.save_checkpoint()
|
|
||||||
if cfg.checkpoint.save_last_snapshot:
|
|
||||||
self.save_snapshot()
|
|
||||||
|
|
||||||
# sanitize metric names
|
|
||||||
metric_dict = dict()
|
|
||||||
for key, value in step_log.items():
|
|
||||||
new_key = key.replace('/', '_')
|
|
||||||
metric_dict[new_key] = value
|
|
||||||
|
|
||||||
# We can't copy the last checkpoint here
|
# checkpoint
|
||||||
# since save_checkpoint uses threads.
|
if (self.epoch % cfg.training.checkpoint_every) == 0:
|
||||||
# therefore at this point the file might have been empty!
|
# checkpointing
|
||||||
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
if cfg.checkpoint.save_last_ckpt:
|
||||||
|
self.save_checkpoint()
|
||||||
|
if cfg.checkpoint.save_last_snapshot:
|
||||||
|
self.save_snapshot()
|
||||||
|
|
||||||
if topk_ckpt_path is not None:
|
# sanitize metric names
|
||||||
self.save_checkpoint(path=topk_ckpt_path)
|
metric_dict = dict()
|
||||||
# ========= eval end for this epoch ==========
|
for key, value in step_log.items():
|
||||||
policy.train()
|
new_key = key.replace('/', '_')
|
||||||
|
metric_dict[new_key] = value
|
||||||
|
|
||||||
|
# We can't copy the last checkpoint here
|
||||||
|
# since save_checkpoint uses threads.
|
||||||
|
# therefore at this point the file might have been empty!
|
||||||
|
topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict)
|
||||||
|
|
||||||
# end of epoch
|
if topk_ckpt_path is not None:
|
||||||
# log of last step is combined with validation and rollout
|
self.save_checkpoint(path=topk_ckpt_path)
|
||||||
wandb_run.log(step_log, step=self.global_step)
|
# ========= eval end for this epoch ==========
|
||||||
json_logger.log(step_log)
|
policy.train()
|
||||||
self.global_step += 1
|
|
||||||
self.epoch += 1
|
# end of epoch
|
||||||
|
# log of last step is combined with validation and rollout
|
||||||
|
logging_backend.log(step_log, step=self.global_step)
|
||||||
|
json_logger.log(step_log)
|
||||||
|
self.global_step += 1
|
||||||
|
self.epoch += 1
|
||||||
|
|
||||||
@hydra.main(
|
@hydra.main(
|
||||||
version_base=None,
|
version_base=None,
|
||||||
|
|||||||
28
image_pusht_diffusion_policy_dit.yaml
Normal file
28
image_pusht_diffusion_policy_dit.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
defaults:
|
||||||
|
- diffusion_policy/config/train_diffusion_transformer_hybrid_workspace@_here_
|
||||||
|
- override /diffusion_policy/config/task@task: pusht_image
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
exp_name: pusht_image_dit
|
||||||
|
|
||||||
|
policy:
|
||||||
|
_target_: diffusion_policy.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
|
||||||
|
|
||||||
|
logging:
|
||||||
|
backend: swanlab
|
||||||
|
mode: online
|
||||||
|
tags: ["${name}", "${task_name}", "${exp_name}", "swanlab"]
|
||||||
|
id: ${now:%Y%m%d%H%M%S}_${name}_${task_name}
|
||||||
|
group: ${exp_name}
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
val_dataloader:
|
||||||
|
num_workers: 0
|
||||||
|
|
||||||
|
task:
|
||||||
|
env_runner:
|
||||||
|
n_envs: 1
|
||||||
|
n_test_vis: 0
|
||||||
|
n_train_vis: 0
|
||||||
@@ -36,3 +36,4 @@ av==14.0.1
|
|||||||
pygame==2.5.2
|
pygame==2.5.2
|
||||||
robomimic==0.2.0
|
robomimic==0.2.0
|
||||||
opencv-python-headless==4.10.0.84
|
opencv-python-headless==4.10.0.84
|
||||||
|
swanlab
|
||||||
|
|||||||
110
tests/test_pusht_image_runner_metrics.py
Normal file
110
tests/test_pusht_image_runner_metrics.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym import spaces
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT_DIR) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT_DIR))
|
||||||
|
|
||||||
|
import diffusion_policy.env_runner.pusht_image_runner as runner_module
|
||||||
|
from diffusion_policy.env_runner.pusht_image_runner import summarize_rollout_metrics
|
||||||
|
|
||||||
|
|
||||||
|
class FakePushTImageEnv(gym.Env):
|
||||||
|
metadata = {'render.modes': ['rgb_array']}
|
||||||
|
|
||||||
|
def __init__(self, legacy=False, render_size=96):
|
||||||
|
del legacy, render_size
|
||||||
|
self.observation_space = spaces.Dict({
|
||||||
|
'image': spaces.Box(low=0, high=255, shape=(3, 4, 4), dtype=np.uint8),
|
||||||
|
})
|
||||||
|
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)
|
||||||
|
self.seed_value = 0
|
||||||
|
self.step_count = 0
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.seed_value = 0 if seed is None else seed
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.step_count = 0
|
||||||
|
return {'image': np.zeros((3, 4, 4), dtype=np.uint8)}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
del action
|
||||||
|
self.step_count += 1
|
||||||
|
reward = 0.1 if self.seed_value < 10000 else 0.9
|
||||||
|
done = self.step_count >= 1
|
||||||
|
obs = {'image': np.full((3, 4, 4), self.step_count, dtype=np.uint8)}
|
||||||
|
return obs, reward, done, {}
|
||||||
|
|
||||||
|
def render(self, *args, **kwargs):
|
||||||
|
raise AssertionError('render should not be called for scalar-only PushT image rollouts')
|
||||||
|
|
||||||
|
|
||||||
|
class FakePolicy:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def predict_action(self, obs_dict):
|
||||||
|
n_envs = next(iter(obs_dict.values())).shape[0]
|
||||||
|
return {
|
||||||
|
'action': torch.zeros((n_envs, 2, 2), dtype=torch.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarize_rollout_metrics_keeps_scalar_rewards_renames_means_and_omits_videos():
|
||||||
|
log_data = summarize_rollout_metrics(
|
||||||
|
env_seeds=[11, 12, 101],
|
||||||
|
env_prefixs=['train/', 'train/', 'test/'],
|
||||||
|
all_rewards=[
|
||||||
|
[0.2, 0.8],
|
||||||
|
[0.1, 0.4],
|
||||||
|
[0.5, 0.9],
|
||||||
|
],
|
||||||
|
all_video_paths=[
|
||||||
|
'/tmp/train-11.mp4',
|
||||||
|
'/tmp/train-12.mp4',
|
||||||
|
'/tmp/test-101.mp4',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert log_data['train/sim_max_reward_11'] == 0.8
|
||||||
|
assert log_data['train/sim_max_reward_12'] == 0.4
|
||||||
|
assert log_data['test/sim_max_reward_101'] == 0.9
|
||||||
|
assert log_data['train_mean_score'] == pytest.approx(0.6)
|
||||||
|
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||||
|
assert not any(key.startswith('train/sim_video_') for key in log_data)
|
||||||
|
assert not any(key.startswith('test/sim_video_') for key in log_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_ignores_vis_flags_and_never_emits_sim_videos(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr(runner_module, 'PushTImageEnv', FakePushTImageEnv)
|
||||||
|
|
||||||
|
runner = runner_module.PushTImageRunner(
|
||||||
|
output_dir=tmp_path,
|
||||||
|
n_train=1,
|
||||||
|
n_train_vis=1,
|
||||||
|
n_test=1,
|
||||||
|
n_test_vis=1,
|
||||||
|
n_envs=2,
|
||||||
|
max_steps=2,
|
||||||
|
n_obs_steps=2,
|
||||||
|
n_action_steps=2,
|
||||||
|
tqdm_interval_sec=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
log_data = runner.run(FakePolicy())
|
||||||
|
|
||||||
|
assert log_data['train/sim_max_reward_0'] == pytest.approx(0.1)
|
||||||
|
assert log_data['test/sim_max_reward_10000'] == pytest.approx(0.9)
|
||||||
|
assert log_data['train_mean_score'] == pytest.approx(0.1)
|
||||||
|
assert log_data['test_mean_score'] == pytest.approx(0.9)
|
||||||
|
assert not any('sim_video' in key for key in log_data)
|
||||||
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
198
tests/test_train_diffusion_transformer_workspace_logging.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import importlib
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
ROOT_DIR = pathlib.Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT_DIR) not in sys.path:
|
||||||
|
sys.path.append(str(ROOT_DIR))
|
||||||
|
|
||||||
|
MODULE_NAME = 'diffusion_policy.workspace.train_diffusion_transformer_hybrid_workspace'
|
||||||
|
|
||||||
|
|
||||||
|
def load_workspace_module(monkeypatch, *, wandb_missing=False):
|
||||||
|
sys.modules.pop(MODULE_NAME, None)
|
||||||
|
if wandb_missing:
|
||||||
|
monkeypatch.setitem(sys.modules, 'wandb', None)
|
||||||
|
return importlib.import_module(MODULE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_logger_uses_swanlab_backend_mapping_without_loading_wandb(tmp_path, monkeypatch):
|
||||||
|
workspace_module = load_workspace_module(monkeypatch, wandb_missing=True)
|
||||||
|
events = []
|
||||||
|
|
||||||
|
class FakeRun:
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
events.append(('log', payload, step))
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
events.append(('finish',))
|
||||||
|
|
||||||
|
class FakeSwanLab:
|
||||||
|
def init(self, **kwargs):
|
||||||
|
events.append(('init', kwargs))
|
||||||
|
return FakeRun()
|
||||||
|
|
||||||
|
monkeypatch.setattr(workspace_module, '_load_swanlab', lambda: FakeSwanLab())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
workspace_module,
|
||||||
|
'_load_wandb',
|
||||||
|
lambda: pytest.fail('wandb should not be loaded for the SwanLab backend'),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = OmegaConf.create({
|
||||||
|
'logging': {
|
||||||
|
'backend': 'swanlab',
|
||||||
|
'project': 'demo-project',
|
||||||
|
'name': 'demo-run',
|
||||||
|
'group': 'demo-group',
|
||||||
|
'tags': ['pusht', 'dit'],
|
||||||
|
'id': 'run-123',
|
||||||
|
'resume': True,
|
||||||
|
'mode': 'online',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||||
|
logger.log({'metric': 1.0}, step=7)
|
||||||
|
logger.finish()
|
||||||
|
|
||||||
|
assert events[0][0] == 'init'
|
||||||
|
init_kwargs = events[0][1]
|
||||||
|
assert init_kwargs['project'] == 'demo-project'
|
||||||
|
assert init_kwargs['experiment_name'] == 'demo-run'
|
||||||
|
assert init_kwargs['group'] == 'demo-group'
|
||||||
|
assert init_kwargs['tags'] == ['pusht', 'dit']
|
||||||
|
assert init_kwargs['id'] == 'run-123'
|
||||||
|
assert init_kwargs['resume'] is True
|
||||||
|
assert init_kwargs['mode'] == 'cloud'
|
||||||
|
assert init_kwargs['logdir'] == str(tmp_path / 'swanlog')
|
||||||
|
assert ('log', {'metric': 1.0}, 7) in events
|
||||||
|
assert events.count(('finish',)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_logger_defaults_to_legacy_wandb_path_when_backend_missing(tmp_path, monkeypatch):
|
||||||
|
workspace_module = load_workspace_module(monkeypatch)
|
||||||
|
events = []
|
||||||
|
|
||||||
|
class FakeRun:
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
events.append(('log', payload, step))
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
events.append(('finish',))
|
||||||
|
|
||||||
|
class FakeConfig:
|
||||||
|
def update(self, payload):
|
||||||
|
events.append(('config.update', payload))
|
||||||
|
|
||||||
|
class FakeWandb:
|
||||||
|
def __init__(self):
|
||||||
|
self.config = FakeConfig()
|
||||||
|
|
||||||
|
def init(self, **kwargs):
|
||||||
|
events.append(('init', kwargs))
|
||||||
|
return FakeRun()
|
||||||
|
|
||||||
|
monkeypatch.setattr(workspace_module, '_load_wandb', lambda: FakeWandb())
|
||||||
|
|
||||||
|
cfg = OmegaConf.create({
|
||||||
|
'logging': {
|
||||||
|
'project': 'demo-project',
|
||||||
|
'name': 'demo-run',
|
||||||
|
'group': None,
|
||||||
|
'tags': ['shared'],
|
||||||
|
'id': None,
|
||||||
|
'resume': True,
|
||||||
|
'mode': 'online',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger = workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||||
|
logger.log({'metric': 2.0}, step=3)
|
||||||
|
logger.finish()
|
||||||
|
|
||||||
|
assert events[0][0] == 'init'
|
||||||
|
init_kwargs = events[0][1]
|
||||||
|
assert init_kwargs['dir'] == str(tmp_path)
|
||||||
|
assert init_kwargs['project'] == 'demo-project'
|
||||||
|
assert init_kwargs['name'] == 'demo-run'
|
||||||
|
assert init_kwargs['mode'] == 'online'
|
||||||
|
assert ('config.update', {'output_dir': str(tmp_path)}) in events
|
||||||
|
assert ('log', {'metric': 2.0}, 3) in events
|
||||||
|
assert events.count(('finish',)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_logger_rejects_unknown_backends(tmp_path, monkeypatch):
|
||||||
|
workspace_module = load_workspace_module(monkeypatch)
|
||||||
|
cfg = OmegaConf.create({
|
||||||
|
'logging': {
|
||||||
|
'backend': 'tensorboard',
|
||||||
|
'project': 'demo-project',
|
||||||
|
'name': 'demo-run',
|
||||||
|
'mode': 'offline',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match='Unknown logging backend'):
|
||||||
|
workspace_module.init_logging_backend(cfg=cfg, output_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_logging_backend_session_preserves_primary_exception_when_finish_fails(tmp_path, monkeypatch):
|
||||||
|
workspace_module = load_workspace_module(monkeypatch)
|
||||||
|
events = []
|
||||||
|
|
||||||
|
class FakeBackend:
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
events.append(('log', payload, step))
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
events.append(('finish',))
|
||||||
|
raise RuntimeError('finish boom')
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
workspace_module,
|
||||||
|
'init_logging_backend',
|
||||||
|
lambda cfg, output_dir: FakeBackend(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match='primary boom'):
|
||||||
|
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||||
|
logger.log({'metric': 6.0}, step=12)
|
||||||
|
raise ValueError('primary boom')
|
||||||
|
|
||||||
|
assert ('log', {'metric': 6.0}, 12) in events
|
||||||
|
assert events.count(('finish',)) == 1
|
||||||
|
|
||||||
|
def test_logging_backend_session_finishes_on_exception(tmp_path, monkeypatch):
|
||||||
|
workspace_module = load_workspace_module(monkeypatch)
|
||||||
|
events = []
|
||||||
|
|
||||||
|
class FakeBackend:
|
||||||
|
def log(self, payload, step=None):
|
||||||
|
events.append(('log', payload, step))
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
events.append(('finish',))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
workspace_module,
|
||||||
|
'init_logging_backend',
|
||||||
|
lambda cfg, output_dir: FakeBackend(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = OmegaConf.create({'logging': {'mode': 'offline'}})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match='boom'):
|
||||||
|
with workspace_module.logging_backend_session(cfg=cfg, output_dir=tmp_path) as logger:
|
||||||
|
logger.log({'metric': 5.0}, step=11)
|
||||||
|
raise RuntimeError('boom')
|
||||||
|
|
||||||
|
assert ('log', {'metric': 5.0}, 11) in events
|
||||||
|
assert events.count(('finish',)) == 1
|
||||||
Reference in New Issue
Block a user