MLJModelInterface.jl icon indicating copy to clipboard operation
MLJModelInterface.jl copied to clipboard

Expected behavior for repeated `fit!` calls

Open MilesCranmer opened this issue 9 months ago • 1 comments

Just wondering what the right strategy is for warm starts. What is the correct behavior for calling fit! on a machine twice? Here, I assume that the model has an $N$ parameter controlling the number of optimization steps:

model = Regressor(N=N)
mach = machine(model, X, y)
fit!(mach)
fit!(mach)  # what happens here?
  1. Warm start; runs for another $N$ steps from where it left off.
  2. Warm start; runs for $0$ steps from where it left off. In other words, the user would need to increase $N$ if they want to replicate the behavior of (1)
    fit!(mach)
    mach.model.N += N
    fit!(mach)
    
  3. Cold start; resets the search state and runs for $N$ steps. Perhaps you need to use an update! function for explicit warm starts (in which case 1 and 2 would require using update! instead)

Right now, SymbolicRegression.jl is doing 2. But I'm not sure about this. Pinging @ablaom for tips.

MilesCranmer avatar Mar 06 '25 15:03 MilesCranmer

See the fit_only! docstring for the detailed logic on whether fit orupdate is dispatched. The user typically interacts via fit! only (which calls fit_only! on the machine, and in the case of composite models, other machines on which the output of that machine depends.) There is no update! function. Also, see these sections of the docs:

  • https://juliaai.github.io/MLJ.jl/dev/machines/#Warm-restarts
  • https://juliaai.github.io/MLJModelInterface.jl/dev/iterative_models/

The behaviour the user (should) expect is that fit! tries to update the learned parameters to match current hyperparameters - if you don't change any parameter at all then the learned parameters won't update. If you want to add iterations, say, you do something like (2). This said, there are some exceptions to this rule: In MLJFlux models, it's possible to arrange that updating the optimiser does not trigger retraining from scratch (although it will by default).

If you change a hyperparameter, then how the update happens is a implementation detail: update will be dispatched, but whether training is cold or warm, depends on whether and how update has been overloaded. If update was not overloaded, then training is cold. You overload update to get the warm restart behaviour.

By the way, the same logic applies to any hyperparameter, not just iteration parameters. You could for example, overload update to have warm restarts when updating the regularisation parameter lambda in lasso regression.

In MLJ there isn't currently any concept of updating observations ,or features, only hyperparameters.

In LearnAPI.jl (where there are no machines) we have tried to be stricter about expectations around updating, if you are curious, take a look here.

ablaom avatar Mar 09 '25 21:03 ablaom