optuna
optuna copied to clipboard
Implement parallelism at user-level
Motivation
The n_jobs
feature is planned to be deprecated because of potential clashes between its internal threading and user-level parallelism. It's likely that user-level parallelism replacing it will want a more complete API for agents to check on and respond to the status of an ongoing study.
Description
This proposed code snippet (in additional context below) shows how to leverage Study.ask()
and Study.tell()
to run workers as asyncio tasks. These asyncio tasks are agents that can interact with the Study
in parallel. Using the aiowire.EventLoop
also allows them to spawn further agents.
Alternatives (optional)
Related issues and discussion include #1766 and #3328. An alternative using joblib has also been proposed.
As noted in #3328, serialization of the objective function needs to be tested.
Additional context (optional)
""" Method for executing user-managed trials within an aiowire event loop.
"""
import asyncio
import gc
import time
import threading
from functools import partial, wraps
from typing import Tuple
import optuna
from optuna.trial import TrialState
from aiowire import *
def as_nonblocking(func):
""" Create an async version of func that does not block
the main thread. Code from https://github.com/Tinche/aiofiles
"""
@wraps(func)
async def run(*args, loop=None, executor=None, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, pfunc)
return run
def run_trial(prog, study, func):
with prog.mutex:
trial = study.ask()
state, ans = func(trial)
with prog.mutex:
study.tell(trial, ans, state=state)
return ans
async_trial = as_nonblocking(run_trial)
class Progress:
""" This object holds progression counters
for the study. Here, it prevents creating
new trials after reaching num_trials.
It also holds a mutex to ensure accesses to Study
are serialized.
Note: timeouts are handled by EventLoop().
"""
def __init__(self, num_trials):
self.launched = 0
self.num_trials = num_trials
self.mutex = threading.Lock()
def incr(self):
if self.launched >= self.num_trials:
return True
self.launched += 1
return False
async def RunTrial(ev, prog, study, func):
""" Replacement for optuna.Study.optimize.
This Wire runs a trial, waits for it to complete,
and handles the result cleanup.
"""
if prog.incr():
return
try:
ans = await async_trial(prog, study, func)
except Exception:
raise
finally:
gc.collect()
#study._storage.remove_session()
return (RunTrial, [prog, study, func])
def test_func(trial) -> Tuple[TrialState, ...]:
def f(x):
return (x - 2) ** 2
def df(x):
return 2 * x - 4
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
# Iterative gradient descent objective function.
x = 3 # Initial value.
for step in range(128):
y = f(x)
trial.report(y, step=step)
if trial.should_prune():
print(f"Pruning trial {trial.number}")
# Finish the trial with the pruned state.
return TrialState.PRUNED, None
gy = df(x)
x -= gy * lr
time.sleep(1)
print(f"Completed trial {trial.number}")
return TrialState.COMPLETE, y
async def run_all(num_workers=1, num_trials=10):
study = optuna.create_study()
prog = Progress(num_trials)
async with EventLoop(4) as ev:
for worker in range(num_workers):
ev.start( (RunTrial, [prog, study, test_func]) )
print("Complete!")
with prog.mutex:
print(f"Best value = {study.best_value}")
print(f"Best params = {study.best_params}")
if __name__=="__main__":
asyncio.run( run_all(2) )
Hi,
Could this github.com/colesbury/nogil be related ? It is supposed to fix the current issue with the n_jobs
feature.
We have decided not to remove n_jobs
. See https://github.com/optuna/optuna/pull/3173 for more details.
This issue has not seen any recent activity.
This issue was closed automatically because it had not seen any recent activity. If you want to discuss it, you can reopen it freely.