code release
This commit is contained in:
121
ray_exec.py
Normal file
121
ray_exec.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Usage:
|
||||
Training:
|
||||
python train.py --config-name=train_diffusion_lowdim_workspace -- logger.mode=online
|
||||
"""
|
||||
import os
|
||||
import ray
|
||||
import click
|
||||
|
||||
def worker_fn(command_args, data_src=None, unbuffer_python=False, use_shell=False):
|
||||
import os
|
||||
import subprocess
|
||||
import signal
|
||||
import time
|
||||
|
||||
# setup data symlink
|
||||
if data_src is not None:
|
||||
cwd = os.getcwd()
|
||||
src = os.path.expanduser(data_src)
|
||||
dst = os.path.join(cwd, 'data')
|
||||
try:
|
||||
os.symlink(src=src, dst=dst)
|
||||
except FileExistsError:
|
||||
# it's fine if it already exists
|
||||
pass
|
||||
|
||||
# run command
|
||||
process_env = os.environ.copy()
|
||||
if unbuffer_python:
|
||||
# disable stdout/stderr buffering for subprocess (if python)
|
||||
# to remove latency between print statement and receiving printed result
|
||||
process_env['PYTHONUNBUFFERED'] = 'TRUE'
|
||||
|
||||
# ray worker masks out Ctrl-C signal (ie SIGINT)
|
||||
# here we unblock this signal for the child process
|
||||
def preexec_function():
|
||||
import signal
|
||||
signal.pthread_sigmask(signal.SIG_UNBLOCK, {signal.SIGINT})
|
||||
|
||||
if use_shell:
|
||||
command_args = ' '.join(command_args)
|
||||
|
||||
# stdout passtrough to ray worker, which is then passed to ray driver
|
||||
process = subprocess.Popen(
|
||||
args=command_args,
|
||||
env=process_env,
|
||||
preexec_fn=preexec_function,
|
||||
shell=use_shell)
|
||||
|
||||
while process.poll() is None:
|
||||
try:
|
||||
# sleep to ensure that monitor thread can acquire gil
|
||||
# and raise KeyboardInterrupt here.
|
||||
time.sleep(0.01)
|
||||
except KeyboardInterrupt:
|
||||
process.send_signal(signal.SIGINT)
|
||||
print('SIGINT sent to subprocess')
|
||||
except Exception as e:
|
||||
process.terminate()
|
||||
raise e
|
||||
|
||||
if process.returncode not in (0, -2):
|
||||
print("Failed execution!")
|
||||
raise RuntimeError("Failed execution.")
|
||||
return process.returncode
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--ray_address', '-ra', default='auto')
|
||||
@click.option('--num_cpus', '-nc', default=7, type=float)
|
||||
@click.option('--num_gpus', '-ng', default=1, type=float)
|
||||
@click.option('--max_retries', '-mr', default=0, type=int)
|
||||
@click.option('--data_src', '-d', default='./data', type=str)
|
||||
@click.option('--unbuffer_python', '-u', is_flag=True, default=False)
|
||||
@click.argument('command_args', nargs=-1, type=str)
|
||||
def main(ray_address,
|
||||
num_cpus, num_gpus, max_retries,
|
||||
data_src, unbuffer_python,
|
||||
command_args):
|
||||
# expand path
|
||||
if data_src is not None:
|
||||
data_src = os.path.abspath(os.path.expanduser(data_src))
|
||||
|
||||
# init ray
|
||||
root_dir = os.path.dirname(__file__)
|
||||
runtime_env = {
|
||||
'working_dir': root_dir,
|
||||
'excludes': ['.git']
|
||||
}
|
||||
ray.init(
|
||||
address=ray_address,
|
||||
runtime_env=runtime_env
|
||||
)
|
||||
# remote worker func
|
||||
worker_ray = ray.remote(worker_fn).options(
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
max_retries=max_retries,
|
||||
# resources=resources,
|
||||
retry_exceptions=True
|
||||
)
|
||||
# run
|
||||
task_ref = worker_ray.remote(command_args, data_src, unbuffer_python)
|
||||
|
||||
try:
|
||||
# normal case
|
||||
result = ray.get(task_ref)
|
||||
print('Return code: ', result)
|
||||
except KeyboardInterrupt:
|
||||
# a KeyboardInterrupt will be raised in worker
|
||||
ray.cancel(task_ref, force=False)
|
||||
result = ray.get(task_ref)
|
||||
print('Return code: ', result)
|
||||
except Exception as e:
|
||||
# worker will be terminated
|
||||
ray.cancel(task_ref, force=True)
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user