SMAC3
SMAC3 copied to clipboard
[Refactor] Move cost validation to `tell`
Currently the runner
validates the costs. This is all well and good until having to use the ask
and tell
interface. Reporting a correct TrialValue
requires the user to re-implement all of the validation logic found within the runner here:
https://github.com/automl/SMAC3/blob/731854e8effe72d76cb07636458e783143f32dc5/smac/runner/target_function_runner.py#L156-L215
This validation logic makes sure that the costs are ordered correctly if reported as a dict
and also ensures the crash_cost
is extended to the length of the cost
list if using multi-objective and crash_cost
is a single float. Re-implementing this just to use ask
and tell
is a little verbose.
This issue is best illustrated by trying to create a more thorough example of using ask-and-tell
which also uses multi-objective and reports a TrialValue
with failed results, where Scenario
did not explicitly provide a Scenario(crash_cost=...)
for each objective.
https://automl.github.io/SMAC3/main/examples/1_basics/3_ask_and_tell.html#ask-and-tell
I'm not sure how best to approach this but one possibility might be to add the validation logic to runhistory.add
, but this is just a guess and may not be the best solution.
For reference, here's some sample code I have of something that wraps the tell
of SMAC.
def tell(self, report: TrialReport[SMACTrialInfo, Configuration]) -> None:
"""Tell the optimizer the result of the sampled config.
Args:
report: The report of the trial.
"""
# If we're successful, get the cost and times and report them
if isinstance(report, SuccessReport):
if "cost" not in report.results:
raise ValueError(
f"Report must have 'cost' if successful but got {report}."
" Use `trial.success(cost=...)` to set the results of the trial."
)
trial_value = SMACTrialValue(
time=report.time.duration,
starttime=report.time.start,
endtime=report.time.end,
cost=report.results["cost"],
status=StatusType.SUCCESS,
additional_info=report.results.get("additional_info", {}),
)
return
if isinstance(report, FailReport):
duration = report.time.duration
start = report.time.start
end = report.time.end
reported_cost = report.results.get("cost", None)
additional_info = report.results.get("additional_info", {})
elif isinstance(report, CrashReport):
duration = 0
start = 0
end = 0
reported_cost = None
additional_info = {}
# We got either a fail or a crash, time to deal with it
status_types: dict[type, StatusType] = {
MemoryLimitException: StatusType.MEMORYOUT,
TimeoutException: StatusType.TIMEOUT,
}
status_type = StatusType.CRASHED
if report.exception is not None:
status_type = status_types.get(type(report.exception), StatusType.CRASHED)
# If we have no reported costs, we need to ensure that we have a
# valid crash_cost based on the number of objectives
crash_cost = self.facade.scenario.crash_cost
objectives = self.facade.scenario.objectives
if reported_cost is not None:
cost = reported_cost
elif isinstance(crash_cost, float) and not isinstance(objectives, Sequence):
cost = crash_cost
elif isinstance(crash_cost, float) and isinstance(objectives, Sequence):
cost = [crash_cost for _ in range(len(objectives))]
elif isinstance(crash_cost, Sequence) and isinstance(objectives, Sequence):
cost = crash_cost
else:
raise ValueError(
f"Multiple crash cost reported ({crash_cost}) for only a single"
f" objective in `Scenario({objectives=}, ...)"
)
if (isinstance(cost, Sequence) and isinstance(objectives, Sequence)) and (
len(cost) != len(objectives)
):
raise ValueError(
f"Length of crash cost ({len(cost)}) and objectives "
f"({len(objectives)}) must be equal"
)
trial_value = SMACTrialValue(
time=duration,
starttime=start,
endtime=end,
cost=cost,
status=status_type,
additional_info=additional_info
)
self.facade.tell(info=report.info, value=trial_value, save=True)
I get what you mean and it probably makes sense. It basically would make the runner useless (basically only parsing), resulting in less code. However, much code has to be rearranged and get tested again. Maybe something for the next hackathon?
The questions is whether is should be placed inside the run history or the tell method.
Hmm, I think in the tell
method makes a lot more sense than the RunHistory
so that's a good idea. I think the core issue is that the validation should happen at the entry point to the optimizer so it's more integrateable without subscribing to SMAC's runners.
tldr; in tell
sounds amazing, just not sure if we allow the TrialValue
to be not correct when calling tell
and it gets fixed in there or what.