added eval script and documentation
This commit is contained in:
64
eval.py
Normal file
64
eval.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Usage:
|
||||
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
|
||||
"""
|
||||
|
||||
import sys
|
||||
# use line-buffering for both stdout and stderr
|
||||
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
|
||||
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import click
|
||||
import hydra
|
||||
import torch
|
||||
import dill
|
||||
import wandb
|
||||
import json
|
||||
from diffusion_policy.workspace.base_workspace import BaseWorkspace
|
||||
|
||||
@click.command()
|
||||
@click.option('-c', '--checkpoint', required=True)
|
||||
@click.option('-o', '--output_dir', required=True)
|
||||
@click.option('-d', '--device', default='cuda:0')
|
||||
def main(checkpoint, output_dir, device):
|
||||
if os.path.exists(output_dir):
|
||||
click.confirm(f"Output path {output_dir} already exists! Overwrite?", abort=True)
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load checkpoint
|
||||
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
|
||||
cfg = payload['cfg']
|
||||
cls = hydra.utils.get_class(cfg._target_)
|
||||
workspace = cls(cfg, output_dir=output_dir)
|
||||
workspace: BaseWorkspace
|
||||
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
||||
|
||||
# get policy from workspace
|
||||
policy = workspace.model
|
||||
if cfg.training.use_ema:
|
||||
policy = workspace.ema_model
|
||||
|
||||
device = torch.device(device)
|
||||
policy.to(device)
|
||||
policy.eval()
|
||||
|
||||
# run eval
|
||||
env_runner = hydra.utils.instantiate(
|
||||
cfg.task.env_runner,
|
||||
output_dir=output_dir)
|
||||
runner_log = env_runner.run(policy)
|
||||
|
||||
# dump log to json
|
||||
json_log = dict()
|
||||
for key, value in runner_log.items():
|
||||
if isinstance(value, wandb.sdk.data_types.video.Video):
|
||||
json_log[key] = value._path
|
||||
else:
|
||||
json_log[key] = value
|
||||
out_path = os.path.join(output_dir, 'eval_log.json')
|
||||
json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user