Functors.jl icon indicating copy to clipboard operation
Functors.jl copied to clipboard

Add more Base types

Open mcabbott opened this issue 3 years ago • 8 comments

This should just work, right? https://discourse.julialang.org/t/data-science-lessons-making-10-neural-networks-run-on-gpu/74592

mcabbott avatar Jan 14 '22 16:01 mcabbott

This is a bit of a drive-by comment, some of this might already be addressed. See some earlier discussion in #21. I think in general this would make sense to do, but we need to think carefully about what should be a lead (i.e. not functor'd) and the ergonomics of users overriding something Functors.jl says is not a leaf to be a leaf.

darsnack avatar Jan 14 '22 17:01 darsnack

Thanks, I didn't see that, have not followed this package closely.

I'm not actually too sure why it requires types to be annotated, rather than the reverse, of recursing into anything unless told not to. That would seem more convenient at least for what Flux wants.

I guess you have the potential to run into things without default constructors... but at least that's a noisy failure not a silent don't-do-what-I-asked. I thought Setfield worked around this but it seems I was mistaken:

julia> using Setfield

julia> struct Foo{T,S} x::T; y::S end;

julia> foo = Foo(1,2); @set foo.x = "not an integer"
Foo{String, Int64}("not an integer", 2)

julia> struct Bar{T,S} x::T; y::S;
       Bar(x::T) where T = new{T,T}(x)
       end;

julia> bar = Bar(1); @set bar.x = "not an integer"
ERROR: MethodError: no method matching Bar(::String, ::Int64)

mcabbott avatar Jan 14 '22 18:01 mcabbott

Setfield.jl relies on ConstructionBase.jl, which has the benefit of being a standardized interface but is not fancy (read: crazy) enough to make Bar work automatically. We could try emulating what BSON.jl does to resuscitate types during deserialization, but the level of black magic required there has caused more than one headache in the past.

ToucheSir avatar Jan 30 '22 04:01 ToucheSir

Ok. Without going so far, the desirable Base types to handle seem to be

  • Iterators, like the above issue. These all have default constructors.
  • Transpose, Adjoint, and maybe ReshapedArray, PermutedDimsArray. For the reason that shared weights will often be transposed.

Shared weights with Transpose seem somewhat tricky, since the gradient of a Transpose will often be a Matrix, so you need to know the inverse transformation. But not impossible, seems a smaller change than (say) running the cache based on pointer?

Is it a breaking change to add types to recurse into?

mcabbott avatar Jan 31 '22 14:01 mcabbott

Transpose, Adjoint, and maybe ReshapedArray, PermutedDimsArray. For the reason that shared weights will often be transposed.

Depending on the shared weight solution, these types might be ones we don't want to recurse into. I'd suggest adding these to the list when we merge something complete w.r.t. shared weights.

darsnack avatar Jan 31 '22 15:01 darsnack

I think such recursion is the only change needed in Functors for e.g. https://github.com/FluxML/Optimisers.jl/pull/42 (maybe with a few local changes...) to work with transposed shared.

Functors already tries to be aware of sharing, it just won't notice it if it's hidden by a wrapper.

A bigger change would be to teach it to notice despite the wrapper, but without just recursing inwards. That seems harder, e.g. it would have to keep track of the bijection which relates the two tied branches.

mcabbott avatar Jan 31 '22 15:01 mcabbott

One headache with auto-detecting shared params is figuring out which is the "original". Currently we're assuming that any wrappers are derived arrays, but a user could construct a model encoder = Dense(W'); decoder = Dense(encoder.weight'). I believe all the current rules are dimension-agnostic, but some do and tied weights will get funny that way. I guess we could just put up a warning that wrappers are considered tied by default?

ToucheSir avatar Jan 31 '22 18:01 ToucheSir

The scheme in FluxML/Optimisers.jl#42 is to keep the first one it encounters as original. The unwrapped Array, not the Transpose if any. (Edit: an example here: https://github.com/FluxML/Optimisers.jl/pull/42/files#diff-3b9314a6f9f2d7eec1d0ef69fa76cfabafdbe6d0df923768f9ec32f27a249c63R205-R212)

Optimisers's rules all just treat the matrix as a list of numbers. If you were to write something else... in this scheme it's going to get the inner array. Presumably nobody is so perverse as to wrap the only copy of an unshared array, and if they are shared, there is one optimiser, and it can't be both ways around.

mcabbott avatar Jan 31 '22 19:01 mcabbott