optuna-integration icon indicating copy to clipboard operation
optuna-integration copied to clipboard

Wandb bugfix: handle pruned trials correctly

Open neel04 opened this issue 1 year ago • 13 comments
trafficstars

Motivation

In the wandb callback, if someone does a raise optuna.exceptions.TrialPruned() to prune trials as is recommended, then it immediately interrupts program execution and thus no value is report-ed for the trial.

This will error out as further code applies functions such as len which are invalid for a Nonetype.

Description of the changes

This PR tries to be minimally invasive (since I'm not familiar with this codebase). Simply put, if trial.values is None then it wraps it in a list of the appropriate length so that it can be further processed and None values are logged to wandb where it can handle such datapoints appropriately.

Note: There are some other changes here that were automatically added by the code formatters. LMK if I should remove those and they'd be automatically handled internally.

neel04 avatar May 13 '24 17:05 neel04

Could you revert the changes to non-wandb related files?

nzw0301 avatar May 14 '24 01:05 nzw0301

Is that fine? @nzw0301

neel04 avatar May 14 '24 12:05 neel04

Thanks. In addition, could you share the minimal reproducible script to raise the error you report? These info would be helpful for reviewers.

nzw0301 avatar May 14 '24 16:05 nzw0301

Thanks. In addition, could you share the minimal reproducible script to raise the error you report? These info would be helpful for reviewers.

Here's a repro. Make sure to run wandb offline first:

import optuna
from optuna.integration.wandb import WeightsAndBiasesCallback
from optuna.study import MaxTrialsCallback
from optuna.trial import TrialState

def some_func(trial):
    # Some computation
    # ...

    if trial is not None:
        raise optuna.exceptions.TrialPruned()

    return 0

def objective(trial):
    x = trial.suggest_float("x", -3, 3)

    output = some_func(trial)

    return output

wandb_kwargs = {
    "project": "example",
    "anonymous": "allow",
    "entity": "test",
    "magic": True,
}

wandbc = WeightsAndBiasesCallback(
    metric_name='Train/acc',
    wandb_kwargs=wandb_kwargs,
    as_multirun=True
)

study = optuna.create_study()
study.optimize(
    objective,
    callbacks=[wandbc],
)

neel04 avatar May 14 '24 22:05 neel04

I think this happens when a trial has no intermediate value (or not calling trial.report), which is unusual. So I'm not sure the shared script is the right usage of TrialPruned function. Could you elaborate on your use case?

nzw0301 avatar May 15 '24 13:05 nzw0301

I think this happens when a trial has no intermediate value (or not calling trial.report), which is unusual. So I'm not sure the shared script is the right usage of TrialPruned function. Could you elaborate on your use case?

There are cases where one wishes to prune the trial before reporting a metric. Say if a metric has to be reported every n steps, and I detect huge gradient fluctuations that invalidate the run due to bad warmup, or if my loss is suddenly NaN. In this case, one could immediately terminate the run without having to wait for n more steps.

Also from the design perspective, it seems rather wonky to require reporting a metric atleast once - if say there's an error before a metric could be reported (such as triggering some misplaced assert) then one would just want that specific run to crash/pruned not the entire optuna study.

So I don't think its wise to have such a brittle system, especially in logging where robustness is expected otherwise a lot of $ can be quickly wasted.

neel04 avatar May 15 '24 22:05 neel04

Thank you for you clarification. I also realised a failed trial with or without intermediate value have the same issue.

nzw0301 avatar May 16 '24 02:05 nzw0301

alright @nzw0301 How does it look now?

neel04 avatar May 19 '24 11:05 neel04

As I said https://github.com/optuna/optuna-integration/pull/119#discussion_r1605935768, the current change does not store any data, which is not what we expect. Could you discard values only when values is None implemented in mlflow callback rather than returning None as in the current changes.

nzw0301 avatar May 19 '24 14:05 nzw0301

As I said #119 (comment), the current change does not store any data, which is not what we expect. Could you discard values only when values is None implemented in mlflow callback rather than returning None as in the current changes.

Ah I see there's been a miscommunication. This was the exact problem I was trying to explain to you in https://github.com/optuna/optuna-integration/pull/119#discussion_r1604837213

So I should do something like:

if values is None:
    run.log({**trial.params, **metrics}, step=step)
    return

like https://github.com/optuna/optuna-integration/blob/15e6b0ec6d9a0d7f572ad387be8478c56257bef7/optuna_integration/wandb/wandb.py#L176

neel04 avatar May 19 '24 15:05 neel04

Yes, I suppose so. I think the minial change looks like

    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
        if trial.values is not None:
            if isinstance(self._metric_name, str):
                if len(trial.values) > 1:
                    # Broadcast default name for multi-objective optimization.
                    names = ["{}_{}".format(self._metric_name, i) for i in range(len(trial.values))]

                else:
                    names = [self._metric_name]

            else:
                if len(self._metric_name) != len(trial.values):
                    raise ValueError(
                        "Running multi-objective optimization "
                        "with {} objective values, but {} names specified. "
                        "Match objective values and names, or use default broadcasting.".format(
                            len(trial.values), len(self._metric_name)
                        )
                    )

                else:
                    names = [*self._metric_name]

            metrics = {name: value for name, value in zip(names, trial.values)}
        else:
            metrics = {}

Alternatively, define a function for trial.values's log to avoid this deep nest.

nzw0301 avatar May 20 '24 01:05 nzw0301

I changed the syntax slightly to be more pythonic and override metrics if values is not None and fixed the tests.

Should I add some more checks in the tests? right now, its just checking that n_trials calls have been made, i.e each and every "failed" run is still logged to wandb. But maybe I can do some more thorough checks here? 🤔

neel04 avatar May 20 '24 11:05 neel04

Thanks! I'll check your changes in a few days.

nzw0301 avatar May 22 '24 02:05 nzw0301

I'm not sure Callable is a good type annotation, but at least mypy looks happy.

Narrator: mypy was in fact not happy :wink:

neel04 avatar May 24 '24 12:05 neel04

@nzw0301 How does it look?

neel04 avatar May 24 '24 12:05 neel04