Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Float32 parameters in structs unsupported?

Open MichaelBurge opened this issue 3 years ago • 1 comments

The tutorial introduces an Affine struct with 1-dimensional and 2-dimensional arrays as fields, and describes how to train it:

julia> import Flux

julia> struct Affine W; b end

julia> a = Affine(rand(3,3), rand(3))
Affine([0.03639048554521607 0.5361647971536908 0.6897849770910784; 0.5814757570729326 0.7659888335874412 0.9207516843236588; 0.05617376870992363 0.963698380772796 0.2242154150237281], [0.06022121006518244, 0.4059827905272061, 0.6536746552971661])

julia> Flux.params(a)
Params([])

julia> Flux.@functor Affine

julia> Flux.params(a)
Params([[0.03639048554521607 0.5361647971536908 0.6897849770910784; 0.5814757570729326 0.7659888335874412 0.9207516843236588; 0.05617376870992363 0.963698380772796 0.2242154150237281], [0.06022121006518244, 0.4059827905272061, 0.6536746552971661]])

I tried a similar struct holding a 0-dimensional Float32:

julia> import Flux
julia> struct Foo foo::Float32 end

julia> foo = Foo(3.3)
Foo(3.3f0)

julia> Flux.params(foo)
Params([])

julia> Flux.@functor Foo

julia> Flux.params(foo)
Params([])

julia> Flux.trainable(foo)
(foo = 3.3f0,)

Here, I expect the call to Flux.params(foo) after Flux.@functor Foo to succeed, by analogy with the Affine case. The parameter foo is listed in the output of trainable, so Flux.params is doing some postprocessing that excludes it.

I can add a specialization for Flux.params! to fix this example in the REPL:

julia> import Flux
julia> struct Foo foo::Float32 end
julia> foo = Foo(3.3)
julia> Flux.@functor Foo
julia> Flux.params!(p::Flux.Params, x::Number, seen = IdSet()) = push!(p, x)

julia> Flux.params(foo)
Params([3.3])

This example raises 3 possible improvements:

  1. Should a specialization for Number be added for Flux.params!?
  2. Should Flux.params throw an error if it encounters a type with no Flux.params! specialization?
  3. Should Flux.params throw an error if a type hasn't had Flux.@functor run on it yet(and there is no manual implementation of whatever the macro would normally output)?

MichaelBurge avatar Dec 26 '21 17:12 MichaelBurge

Thanks for the issue, my points below:

  1. No, because IdDicts and IdSets can only track mutable reference types like arrays. That's why https://github.com/FluxML/Optimisers.jl exists: we've realized that implicit params have some fundamental limitations and are trying to wire up the rest of the stack to use explicit ones instead.
  2. It ought to, but at this point it would be a huge compat break because it hasn't done so thus far. Changing core Flux layers to not present scalar fields as params is one thing, but doing so for the whole ecosystem would be challenging.
  3. No, because it would be massively breaking for little benefit. Say you had a struct for a custom activation function with its own trainable params. You'd still want that to be picked up if it were used in say a Dense layer, but you wouldn't want, say, using relu instead to throw an error because of that. Functors.@functor is explicitly opt-in for that reason. That said, Functors is already smart enough to pass through types it doesn't understand, see https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L1-L2.

Edit: note that Functors.functor(foo)[1] == (3.3f0,).

ToucheSir avatar Dec 26 '21 20:12 ToucheSir