Flux.jl
Flux.jl copied to clipboard
Add docs note about saving/loading models with anonymous functions
The new save/load docs promote JLD2.jl which does not support saving/loading anonymous functions reliably. This most commonly occurs for activation functions. The solution is to use Flux.state + Flux.loadmodel! and set the desired anonymous function in the destination of loadmodel!. This avoids needing the serialization library to correctly handle the anonymous function.
This will be problematic when the anonymous function contains data (state) that actually should be restored. A possible solution here is to make the closure an explicit struct. Maybe there are better solutions.
Regardless, new users are unlikely to realize these edge cases. We should expand the saving/loading documentation to explain how to handle these cases with code examples.
Given that activation functions have always been handled weirdly, despite the fact that it is a little against the Flux style, maybe it might not be a bad idea to have an Activation layer that just does this explicitly. We've had this discussion before, though (https://github.com/FluxML/NNlib.jl/pull/423#issuecomment-1164558904 and https://github.com/FluxML/NNlib.jl/pull/423#issuecomment-1164670060), so I thought it may just be time to do this.
An Activation layer won't help if it wraps an anonymous function. It's a wrapper so it just pushes the issue one node deeper in the tree.
This kind of solution is both cleaner and correct by just naming the function (e.g. myact(x) = ...). If you are closing over some data that needs to be serialized, then define a callable.
I wonder if we could create a helper function which searches the model for these closures and warns the user if it finds them?
Might my issue at https://github.com/FluxML/Flux.jl/issues/2339 be related to this? It contains anonymous functions that slice the input arrays, like x->x[begin:inputpoints, 1, :] for example. How would one go around correctly saving a model like this according to your advice?
Just as mentioned up top: extract the parameters with Flux.state and only save those. I suspect we'll be scrubbing any examples that use BSON to save the whole model from the docs soon because it's just too error-prone.
After checking the docs, this implies that the model definition must be available in the session, right? Is it necessary to create a custom struct and apply the Flux.@functor macro to it before saving (like in the docs)? Must we also repeat these steps before loading it (creating the same struct and applying the macro)? I'm saying this because of this line in the docs:
model = MyModel(); # MyModel definition must be available
state strips out any custom container types and gives you a tree of plain old Julia objects (tuples, namedtuples, arrays) which should be easier to save and mostly immune to type-related breakages down the line. It is necessary to have any layer types with parameters/non-trainable state support functor for it to work, but you'll need those declarations anyhow because loadmodel! takes in an already constructed model to stuff the aforementioned tree of plain old Julia objects back into.
Thank you, I managed to make it work with the loadmodel! function. I will update my other issue accordingly. It might be clearer to expand this on the documentation to mention that you need to rebuild the custom struct all over again and apply the functor macro when loading as well.
The reason we don't mention that in the docs is the same reason PyTorch doesn't mention that you need to define all the layer types for a model before calling model.load_state_dict(...): if you have model already to load into, that means all of the custom layer structs, @functor definitions, etc must already be present! That said, this issue exists in the first place because the saving and loading docs could use some work, so any suggestions (ideally in the form of PRs) is appreciated :)