Functors.jl
Functors.jl copied to clipboard
Add more Base types
This should just work, right? https://discourse.julialang.org/t/data-science-lessons-making-10-neural-networks-run-on-gpu/74592
This is a bit of a drive-by comment, some of this might already be addressed. See some earlier discussion in #21. I think in general this would make sense to do, but we need to think carefully about what should be a lead (i.e. not functor'd) and the ergonomics of users overriding something Functors.jl says is not a leaf to be a leaf.
Thanks, I didn't see that, have not followed this package closely.
I'm not actually too sure why it requires types to be annotated, rather than the reverse, of recursing into anything unless told not to. That would seem more convenient at least for what Flux wants.
I guess you have the potential to run into things without default constructors... but at least that's a noisy failure not a silent don't-do-what-I-asked. I thought Setfield worked around this but it seems I was mistaken:
julia> using Setfield
julia> struct Foo{T,S} x::T; y::S end;
julia> foo = Foo(1,2); @set foo.x = "not an integer"
Foo{String, Int64}("not an integer", 2)
julia> struct Bar{T,S} x::T; y::S;
Bar(x::T) where T = new{T,T}(x)
end;
julia> bar = Bar(1); @set bar.x = "not an integer"
ERROR: MethodError: no method matching Bar(::String, ::Int64)
Setfield.jl relies on ConstructionBase.jl, which has the benefit of being a standardized interface but is not fancy (read: crazy) enough to make Bar
work automatically. We could try emulating what BSON.jl does to resuscitate types during deserialization, but the level of black magic required there has caused more than one headache in the past.
Ok. Without going so far, the desirable Base types to handle seem to be
- Iterators, like the above issue. These all have default constructors.
- Transpose, Adjoint, and maybe ReshapedArray, PermutedDimsArray. For the reason that shared weights will often be transposed.
Shared weights with Transpose seem somewhat tricky, since the gradient of a Transpose will often be a Matrix, so you need to know the inverse transformation. But not impossible, seems a smaller change than (say) running the cache
based on pointer
?
Is it a breaking change to add types to recurse into?
Transpose, Adjoint, and maybe ReshapedArray, PermutedDimsArray. For the reason that shared weights will often be transposed.
Depending on the shared weight solution, these types might be ones we don't want to recurse into. I'd suggest adding these to the list when we merge something complete w.r.t. shared weights.
I think such recursion is the only change needed in Functors for e.g. https://github.com/FluxML/Optimisers.jl/pull/42 (maybe with a few local changes...) to work with transposed shared.
Functors already tries to be aware of sharing, it just won't notice it if it's hidden by a wrapper.
A bigger change would be to teach it to notice despite the wrapper, but without just recursing inwards. That seems harder, e.g. it would have to keep track of the bijection which relates the two tied branches.
One headache with auto-detecting shared params is figuring out which is the "original". Currently we're assuming that any wrappers are derived arrays, but a user could construct a model encoder = Dense(W'); decoder = Dense(encoder.weight')
. I believe all the current rules are dimension-agnostic, but some do and tied weights will get funny that way. I guess we could just put up a warning that wrappers are considered tied by default?
The scheme in FluxML/Optimisers.jl#42 is to keep the first one it encounters as original. The unwrapped Array, not the Transpose if any. (Edit: an example here: https://github.com/FluxML/Optimisers.jl/pull/42/files#diff-3b9314a6f9f2d7eec1d0ef69fa76cfabafdbe6d0df923768f9ec32f27a249c63R205-R212)
Optimisers's rules all just treat the matrix as a list of numbers. If you were to write something else... in this scheme it's going to get the inner array. Presumably nobody is so perverse as to wrap the only copy of an unshared array, and if they are shared, there is one optimiser, and it can't be both ways around.