MLJModelInterface.jl
MLJModelInterface.jl copied to clipboard
Expected behavior for repeated `fit!` calls
Just wondering what the right strategy is for warm starts. What is the correct behavior for calling fit! on a machine twice? Here, I assume that the model has an $N$ parameter controlling the number of optimization steps:
model = Regressor(N=N)
mach = machine(model, X, y)
fit!(mach)
fit!(mach) # what happens here?
- Warm start; runs for another $N$ steps from where it left off.
- Warm start; runs for $0$ steps from where it left off. In other words, the user would need to increase $N$ if they want to replicate the behavior of (1)
fit!(mach) mach.model.N += N fit!(mach) - Cold start; resets the search state and runs for $N$ steps. Perhaps you need to use an
update!function for explicit warm starts (in which case 1 and 2 would require usingupdate!instead)
Right now, SymbolicRegression.jl is doing 2. But I'm not sure about this. Pinging @ablaom for tips.
See the fit_only! docstring for the detailed logic on whether fit orupdate is dispatched. The user typically interacts via fit! only (which calls fit_only! on the machine, and in the case of composite models, other machines on which the output of that machine depends.) There is no update! function. Also, see these sections of the docs:
- https://juliaai.github.io/MLJ.jl/dev/machines/#Warm-restarts
- https://juliaai.github.io/MLJModelInterface.jl/dev/iterative_models/
The behaviour the user (should) expect is that fit! tries to update the learned parameters to match current hyperparameters - if you don't change any parameter at all then the learned parameters won't update. If you want to add iterations, say, you do something like (2). This said, there are some exceptions to this rule: In MLJFlux models, it's possible to arrange that updating the optimiser does not trigger retraining from scratch (although it will by default).
If you change a hyperparameter, then how the update happens is a implementation detail: update will be dispatched, but whether training is cold or warm, depends on whether and how update has been overloaded. If update was not overloaded, then training is cold. You overload update to get the warm restart behaviour.
By the way, the same logic applies to any hyperparameter, not just iteration parameters. You could for example, overload update to have warm restarts when updating the regularisation parameter lambda in lasso regression.
In MLJ there isn't currently any concept of updating observations ,or features, only hyperparameters.
In LearnAPI.jl (where there are no machines) we have tried to be stricter about expectations around updating, if you are curious, take a look here.