MLJBase.jl
MLJBase.jl copied to clipboard
Control caching of composite models
Hi,
Problem description
I have just had the late realization that setting cache=false
for a composite model will not transfer to all sub machines built at fit time.
using MLJBase
using MLJLinearModels
n = 100
X = MLJBase.table(rand(n, 3))
y = rand(n)
stack = Stack(metalearner=LinearRegressor(),
model1 = LinearRegressor())
mach = machine(stack, X, y, cache=false)
fit!(mach, verbosity=0)
# Top level machine
@assert !isdefined(mach, :data) # not defined : ok
mach.cache # contains data: ?
# Any submachine
submachines = report(mach).machines
for submach in submachines
@assert isdefined(submach, :data)
end
I am currently working with very big Stacks
and I think this is the main reason for which I run out of memory.
Ideas:
-
Of course I guess there is the possibility of adding a hyperparameter to the composite model that can be transferred after to the machine definitions. This would probably solve my personnal issue (and would be a short term solution) however I don't think this is ideal in the long term because any user defined composite model will have to add this hyperparameter.
-
More generally I think this issue arises from the current impossibility(?) of communication between the machine and its submachines in the current design. If I am not mistaken we will have the same problem with computational resources as briefly raised here.
A vague idea would be to define the learning graph in a method like machine(m::MyComposite, args...)
instead of the current fit
.
What are your thoughts?
Of course I guess there is the possibility of adding a hyperparameter to the composite model that can be transferred after to the machine definitions. This would probably solve my personnal issue (and would be a short term solution) however I don't think this is ideal in the long term because any user defined composite model will have to add this hyperparameter.
This may not be the most convenient option, but I think it is by far the simplest, and what I imagined we would live with.
I think you are right that "component" machines cannot know anything about the Composite
machines that users interact with.
A vague idea would be to define the learning graph in a method like machine(m::MyComposite, args...) instead of the current fit.
I'm not sure I follow. Could you say a little more?
Another idea is to introduce a user-modifiable global constant to control the default value of C
in the machine constructor. In fact, there is already such a method for acceleration
and we are introducing such an option for check_scitype_level
at https://github.com/JuliaAI/MLJBase.jl/pull/753.
Yes I agree that the current design leaves no other choice that of hyperparameter or global variable, both of them are a bit unatural since the caching requirement has already been declared in the top level machine at construction time. It is also not very extensible, if other future information needs to be passed in the future.
I think I have matured the idea a bit and I think similar to here, instead of a fit(model::Composite, verbosity::Int, args...)
with Dagger.jl, a user could define the following:
function MLJBase.fit!(mach::Machine{Composite, C}; verbosity=0, kwargs...) where C
# RETRIEVE DATA, maybe there should be some initialization step?
X, y = (src() for src in mach.args)
# HERE DO WHATEVER YOU WANT FOR THE LEARNING NETWORK
# The goal is just to define a set of nodes that have a specific importance
# like OPERATIONS or additional reports
signature = mylearningnetwork(X, y)
# FINALIZER MECHANISM
return!(mach, signature)
end
I think this may have many benefits:
- The top level machine's internals are available for the fit procedure
- You can also tweak it by adding
kwargs...
- We may get rid of the internal surrogate machine
It does not respect the convention of defining a fit
for a model but I think that's ok because to use composability users will need MLJBase
anyway.
Of course for now this is rather hypothetical but if you like the idea I could give it a try and make a POC with the Stack as my favorite example. Since this is introducing a new method, I am hopping we might suceed into making this a new feature that could live alongside the current implementation without breaking anything.