optuna-integration
optuna-integration copied to clipboard
Wandb bugfix: handle pruned trials correctly
Motivation
In the wandb callback, if someone does a raise optuna.exceptions.TrialPruned() to prune trials as is recommended, then it immediately interrupts program execution and thus no value is report-ed for the trial.
This will error out as further code applies functions such as len which are invalid for a Nonetype.
Description of the changes
This PR tries to be minimally invasive (since I'm not familiar with this codebase). Simply put, if trial.values is None then it wraps it in a list of the appropriate length so that it can be further processed and None values are logged to wandb where it can handle such datapoints appropriately.
Note: There are some other changes here that were automatically added by the code formatters. LMK if I should remove those and they'd be automatically handled internally.
Could you revert the changes to non-wandb related files?
Is that fine? @nzw0301
Thanks. In addition, could you share the minimal reproducible script to raise the error you report? These info would be helpful for reviewers.
Thanks. In addition, could you share the minimal reproducible script to raise the error you report? These info would be helpful for reviewers.
Here's a repro. Make sure to run wandb offline first:
import optuna
from optuna.integration.wandb import WeightsAndBiasesCallback
from optuna.study import MaxTrialsCallback
from optuna.trial import TrialState
def some_func(trial):
# Some computation
# ...
if trial is not None:
raise optuna.exceptions.TrialPruned()
return 0
def objective(trial):
x = trial.suggest_float("x", -3, 3)
output = some_func(trial)
return output
wandb_kwargs = {
"project": "example",
"anonymous": "allow",
"entity": "test",
"magic": True,
}
wandbc = WeightsAndBiasesCallback(
metric_name='Train/acc',
wandb_kwargs=wandb_kwargs,
as_multirun=True
)
study = optuna.create_study()
study.optimize(
objective,
callbacks=[wandbc],
)
I think this happens when a trial has no intermediate value (or not calling trial.report), which is unusual. So I'm not sure the shared script is the right usage of TrialPruned function. Could you elaborate on your use case?
I think this happens when a trial has no intermediate value (or not calling
trial.report), which is unusual. So I'm not sure the shared script is the right usage ofTrialPrunedfunction. Could you elaborate on your use case?
There are cases where one wishes to prune the trial before reporting a metric. Say if a metric has to be reported every n steps, and I detect huge gradient fluctuations that invalidate the run due to bad warmup, or if my loss is suddenly NaN. In this case, one could immediately terminate the run without having to wait for n more steps.
Also from the design perspective, it seems rather wonky to require reporting a metric atleast once - if say there's an error before a metric could be reported (such as triggering some misplaced assert) then one would just want that specific run to crash/pruned not the entire optuna study.
So I don't think its wise to have such a brittle system, especially in logging where robustness is expected otherwise a lot of $ can be quickly wasted.
Thank you for you clarification. I also realised a failed trial with or without intermediate value have the same issue.
alright @nzw0301 How does it look now?
As I said https://github.com/optuna/optuna-integration/pull/119#discussion_r1605935768, the current change does not store any data, which is not what we expect. Could you discard values only when values is None implemented in mlflow callback rather than returning None as in the current changes.
As I said #119 (comment), the current change does not store any data, which is not what we expect. Could you discard values only when
valuesis None implemented in mlflow callback rather than returning None as in the current changes.
Ah I see there's been a miscommunication. This was the exact problem I was trying to explain to you in https://github.com/optuna/optuna-integration/pull/119#discussion_r1604837213
So I should do something like:
if values is None:
run.log({**trial.params, **metrics}, step=step)
return
like https://github.com/optuna/optuna-integration/blob/15e6b0ec6d9a0d7f572ad387be8478c56257bef7/optuna_integration/wandb/wandb.py#L176
Yes, I suppose so. I think the minial change looks like
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if trial.values is not None:
if isinstance(self._metric_name, str):
if len(trial.values) > 1:
# Broadcast default name for multi-objective optimization.
names = ["{}_{}".format(self._metric_name, i) for i in range(len(trial.values))]
else:
names = [self._metric_name]
else:
if len(self._metric_name) != len(trial.values):
raise ValueError(
"Running multi-objective optimization "
"with {} objective values, but {} names specified. "
"Match objective values and names, or use default broadcasting.".format(
len(trial.values), len(self._metric_name)
)
)
else:
names = [*self._metric_name]
metrics = {name: value for name, value in zip(names, trial.values)}
else:
metrics = {}
Alternatively, define a function for trial.values's log to avoid this deep nest.
I changed the syntax slightly to be more pythonic and override metrics if values is not None and fixed the tests.
Should I add some more checks in the tests? right now, its just checking that n_trials calls have been made, i.e each and every "failed" run is still logged to wandb. But maybe I can do some more thorough checks here? 🤔
Thanks! I'll check your changes in a few days.
I'm not sure Callable is a good type annotation, but at least mypy looks happy.
Narrator: mypy was in fact not happy :wink:
@nzw0301 How does it look?