optuna icon indicating copy to clipboard operation
optuna copied to clipboard

Implement parallelism at user-level

Open frobnitzem opened this issue 2 years ago • 3 comments

Motivation

The n_jobs feature is planned to be deprecated because of potential clashes between its internal threading and user-level parallelism. It's likely that user-level parallelism replacing it will want a more complete API for agents to check on and respond to the status of an ongoing study.

Description

This proposed code snippet (in additional context below) shows how to leverage Study.ask() and Study.tell() to run workers as asyncio tasks. These asyncio tasks are agents that can interact with the Study in parallel. Using the aiowire.EventLoop also allows them to spawn further agents.

Alternatives (optional)

Related issues and discussion include #1766 and #3328. An alternative using joblib has also been proposed.

As noted in #3328, serialization of the objective function needs to be tested.

Additional context (optional)

""" Method for executing user-managed trials within an aiowire event loop.
"""
import asyncio
import gc
import time
import threading
from functools import partial, wraps

from typing import Tuple

import optuna
from optuna.trial import TrialState

from aiowire import *

def as_nonblocking(func):
    """ Create an async version of func that does not block
        the main thread.  Code from https://github.com/Tinche/aiofiles
    """
    @wraps(func)
    async def run(*args, loop=None, executor=None, **kwargs):
        if loop is None:
            loop = asyncio.get_event_loop()
        pfunc = partial(func, *args, **kwargs)
        return await loop.run_in_executor(executor, pfunc)

    return run

def run_trial(prog, study, func):
    with prog.mutex:
        trial = study.ask()

    state, ans = func(trial)

    with prog.mutex:
        study.tell(trial, ans, state=state)
    return ans

async_trial = as_nonblocking(run_trial)

class Progress:
    """ This object holds progression counters
        for the study.  Here, it prevents creating
        new trials after reaching num_trials.

        It also holds a mutex to ensure accesses to Study
        are serialized.

        Note: timeouts are handled by EventLoop().
    """
    def __init__(self, num_trials):
        self.launched = 0
        self.num_trials = num_trials
        self.mutex = threading.Lock()
    def incr(self):
        if self.launched >= self.num_trials:
            return True
        self.launched += 1
        return False

async def RunTrial(ev, prog, study, func):
    """ Replacement for optuna.Study.optimize.
        This Wire runs a trial, waits for it to complete,
        and handles the result cleanup.
    """
    if prog.incr():
        return
    try:
        ans = await async_trial(prog, study, func)
    except Exception:
        raise
    finally:
        gc.collect()
    #study._storage.remove_session()
    return (RunTrial, [prog, study, func])

def test_func(trial) -> Tuple[TrialState, ...]:
    def f(x):
        return (x - 2) ** 2
    def df(x):
        return 2 * x - 4

    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    # Iterative gradient descent objective function.
    x = 3  # Initial value.
    for step in range(128):
        y = f(x)
        trial.report(y, step=step)
        if trial.should_prune():
            print(f"Pruning trial {trial.number}")
            # Finish the trial with the pruned state.
            return TrialState.PRUNED, None
        gy = df(x)
        x -= gy * lr
    time.sleep(1)
    print(f"Completed trial {trial.number}")
    return TrialState.COMPLETE, y

async def run_all(num_workers=1, num_trials=10):
    study = optuna.create_study()
    prog = Progress(num_trials)
    async with EventLoop(4) as ev:
        for worker in range(num_workers):
            ev.start( (RunTrial, [prog, study, test_func]) )
    print("Complete!")
    with prog.mutex:
        print(f"Best value = {study.best_value}")
        print(f"Best params = {study.best_params}")

if __name__=="__main__":
    asyncio.run( run_all(2) )

frobnitzem avatar May 03 '22 19:05 frobnitzem

Hi, Could this github.com/colesbury/nogil be related ? It is supposed to fix the current issue with the n_jobs feature.

Finebouche avatar May 09 '22 11:05 Finebouche

We have decided not to remove n_jobs. See https://github.com/optuna/optuna/pull/3173 for more details.

HideakiImamura avatar May 20 '22 05:05 HideakiImamura

This issue has not seen any recent activity.

github-actions[bot] avatar Jun 05 '22 23:06 github-actions[bot]

This issue was closed automatically because it had not seen any recent activity. If you want to discuss it, you can reopen it freely.

github-actions[bot] avatar Sep 13 '22 23:09 github-actions[bot]