candle
candle copied to clipboard
Add `ReduceLROnPlateau` learning rate scheduler
I encountered a need for a learning rate scheduler analogous to the PyTorch ReduceLROnPlateau so I rolled my own. I figured others may find it useful as well so made a PR with it here.
There is an open thread about adding this feature from last year (https://github.com/huggingface/candle/discussions/2224). It looks like there was code written to add a core Scheduler trait to set up for learning rate schedulers which was written but never merged (https://github.com/huggingface/candle/pull/2225).
I think a core trait like that makes sense and I think adding this feature first is still compatible with that design, if it is eventually merged. One note is that Scheduler trait from the PR does not take additional arguments in the step() function, but ReduceLROnPlateau is an exception to the other PyTorch LR schedulers since it takes a metric as argument.
For now, ReduceLROnPlateau is standalone, but I think the Scheduler trait proposed can be generalized to support this, or the API of ReduceLROnPlateau could be modified to be consistent:
scheduler.set_metric(loss);
scheduler.step();