Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Add docs note about saving/loading models with anonymous functions

Open darsnack opened this issue 2 years ago • 9 comments

The new save/load docs promote JLD2.jl which does not support saving/loading anonymous functions reliably. This most commonly occurs for activation functions. The solution is to use Flux.state + Flux.loadmodel! and set the desired anonymous function in the destination of loadmodel!. This avoids needing the serialization library to correctly handle the anonymous function.

This will be problematic when the anonymous function contains data (state) that actually should be restored. A possible solution here is to make the closure an explicit struct. Maybe there are better solutions.

Regardless, new users are unlikely to realize these edge cases. We should expand the saving/loading documentation to explain how to handle these cases with code examples.

darsnack avatar Jun 09 '23 15:06 darsnack

Given that activation functions have always been handled weirdly, despite the fact that it is a little against the Flux style, maybe it might not be a bad idea to have an Activation layer that just does this explicitly. We've had this discussion before, though (https://github.com/FluxML/NNlib.jl/pull/423#issuecomment-1164558904 and https://github.com/FluxML/NNlib.jl/pull/423#issuecomment-1164670060), so I thought it may just be time to do this.

theabhirath avatar Jun 10 '23 05:06 theabhirath

An Activation layer won't help if it wraps an anonymous function. It's a wrapper so it just pushes the issue one node deeper in the tree.

This kind of solution is both cleaner and correct by just naming the function (e.g. myact(x) = ...). If you are closing over some data that needs to be serialized, then define a callable.

darsnack avatar Jun 10 '23 12:06 darsnack

I wonder if we could create a helper function which searches the model for these closures and warns the user if it finds them?

ToucheSir avatar Jun 10 '23 14:06 ToucheSir

Might my issue at https://github.com/FluxML/Flux.jl/issues/2339 be related to this? It contains anonymous functions that slice the input arrays, like x->x[begin:inputpoints, 1, :] for example. How would one go around correctly saving a model like this according to your advice?

tom-plaa avatar Sep 20 '23 13:09 tom-plaa

Just as mentioned up top: extract the parameters with Flux.state and only save those. I suspect we'll be scrubbing any examples that use BSON to save the whole model from the docs soon because it's just too error-prone.

ToucheSir avatar Sep 20 '23 13:09 ToucheSir

After checking the docs, this implies that the model definition must be available in the session, right? Is it necessary to create a custom struct and apply the Flux.@functor macro to it before saving (like in the docs)? Must we also repeat these steps before loading it (creating the same struct and applying the macro)? I'm saying this because of this line in the docs: model = MyModel(); # MyModel definition must be available

tom-plaa avatar Sep 20 '23 13:09 tom-plaa

state strips out any custom container types and gives you a tree of plain old Julia objects (tuples, namedtuples, arrays) which should be easier to save and mostly immune to type-related breakages down the line. It is necessary to have any layer types with parameters/non-trainable state support functor for it to work, but you'll need those declarations anyhow because loadmodel! takes in an already constructed model to stuff the aforementioned tree of plain old Julia objects back into.

ToucheSir avatar Sep 20 '23 14:09 ToucheSir

Thank you, I managed to make it work with the loadmodel! function. I will update my other issue accordingly. It might be clearer to expand this on the documentation to mention that you need to rebuild the custom struct all over again and apply the functor macro when loading as well.

tom-plaa avatar Sep 20 '23 15:09 tom-plaa

The reason we don't mention that in the docs is the same reason PyTorch doesn't mention that you need to define all the layer types for a model before calling model.load_state_dict(...): if you have model already to load into, that means all of the custom layer structs, @functor definitions, etc must already be present! That said, this issue exists in the first place because the saving and loading docs could use some work, so any suggestions (ideally in the form of PRs) is appreciated :)

ToucheSir avatar Sep 20 '23 19:09 ToucheSir