etna icon indicating copy to clipboard operation
etna copied to clipboard

`AbstractRunner`, `LocalRunner`, `ParallelLocalRunner`

Open martins0n opened this issue 1 year ago • 0 comments

🚀 Feature Request

Create runner objects for computing handling.

We just pass function func and args and kwargs via __call__.

  • etna.auto.runner.LocalRunner would just call passed function with given args and kwargs

  • etna.auto.runner.ParallelLocalRunner would call given function in via Joblib.parallel n_jobs times

Proposal

etna.auto.runner.AbstractRunner:

class AbstractRunner(ABC):
    """Abstract class for Runner."""

    @abstractmethod
    def __call__(self, func, *args, **kwargs):
        """Call given ``func`` in specified environment with ``*args`` and ``**kwargs``."""
        pass

etna.auto.runner.ParallelLocalRunner:

class ParallelLocalRunner(AbstractRunner):
    """ParallelLocalRunner for multiple parallel runs with joblib.

    Notes
    -----
    Global objects behavior could be different while parallel usage because platform dependent new process start.
    Be sure that new process is started with ``fork`` via ``multiprocessing.set_start_method``.
    If it's not possible you should try define all globals before ``if __name__ == "__main__"`` scope.
    """

    def __init__(
        self,
        n_jobs: int = 1,
        backend: str = "multiprocessing",
        mmap_mode: str = "c",
        joblib_params: Optional[dict] = None,
    ):
        """Init ParallelLocalRunner.

        Parameters
        ----------
        n_jobs:
            number of parallel jobs to use
        backend:
            joblib backend to use
        mmap_mode:
            joblib mmap mode
        joblib_params:
            joblib additional params
        """
        self.n_jobs = n_jobs
        self.backend = backend
        self.mmap_mode = mmap_mode
        self.joblib_params = {} if joblib_params is None else joblib_params

    def __call__(self, func, *args, **kwargs):
        """Call given ``func`` with Joblib and ``*args`` and ``**kwargs``."""
        payload = dill.dumps((func, args, kwargs))
        _ = Parallel(n_jobs=self.n_jobs, backend=self.backend, mmap_mode=self.mmap_mode, **self.joblib_params)(
            delayed(run_dill_encoded)(payload) for _ in range(self.n_jobs)
        )

Test cases

  • Test LocalRunner.__call__
  • Test ParallelLocalRunner.__call__ you can create multiple files for example and check that n_jobs files has been created

Additional context

No response

martins0n avatar Aug 12 '22 16:08 martins0n