etna icon indicating copy to clipboard operation
etna copied to clipboard

`Optuna` wrapper in `etna.auto.optuna` and custom optuna sampler `ConfigSampler`

Open martins0n opened this issue 2 years ago • 0 comments

🚀 Feature Request

Create class etna.auto.optuna.Optuna which would handle logic around using optuna:

  • method tune runs study.optimize via given runner

Copy etna.auto.optuna.sampler.ConfigSampler as is.

Proposal

Optuna:

class OptunaDirection(str, Enum):

    maximize = "maximize"
    minimize = "minimize"


class Optuna:
    """Class for encapsulate work with Optuna."""

    def __init__(
        self,
        direction: OptunaDirection,
        sampler: optuna.samplers.BaseSampler,
        study_name: Optinal[str] = None,
        storage: Optional[optuna.storages.BaseStorage] = optuna.storages.RDBStorage(...),
    ) -> None:

        self.sampler = sampler
        self.storage = storage
        self.study_name = study_name
        self.direction = OptunaDirection(direction)
        self.study = optuna.create_study(
            storage=self.storage,
            study_name=self.study_name,
            direction=self.direction,
            sampler=self.sampler,
            load_if_exists=True,
        )

    def tune(self,
        objective: Callable,
        n_trials: Optional[int] = None,
        timeout: Optional[int] = None,
        runner = LocalRunner(...),
        **kwargs
    ) -> optuna.samplers.BaseSampler:
        runner(self.study.optimize, objective, n_trials=n_trials, timeout=timeout, **kwargs)
        return self.study

ConfigSampler: you could get implentation from inner gitlab in etna_utils.auto.optuna.confg_sampler

Test cases

  • Create toy objective with default TPESampler - it should work as expected
    • Try LocalRunner
    • Try ParallelLocalRunner
  • Create toy objective for ConfigSampler with pool of N pipelines:
    • In case of LocalRunner we should we would have exactly N runs
    • In case of ParallelLocalRunner`s we would get exactly N + n_jobs - 1 runs

Additional context

You could use code from etna_utils.auto.optuna

martins0n avatar Aug 12 '22 15:08 martins0n