Flux.jl
Flux.jl copied to clipboard
Float32 parameters in structs unsupported?
The tutorial introduces an Affine
struct with 1-dimensional and 2-dimensional arrays as fields, and describes how to train it:
julia> import Flux
julia> struct Affine W; b end
julia> a = Affine(rand(3,3), rand(3))
Affine([0.03639048554521607 0.5361647971536908 0.6897849770910784; 0.5814757570729326 0.7659888335874412 0.9207516843236588; 0.05617376870992363 0.963698380772796 0.2242154150237281], [0.06022121006518244, 0.4059827905272061, 0.6536746552971661])
julia> Flux.params(a)
Params([])
julia> Flux.@functor Affine
julia> Flux.params(a)
Params([[0.03639048554521607 0.5361647971536908 0.6897849770910784; 0.5814757570729326 0.7659888335874412 0.9207516843236588; 0.05617376870992363 0.963698380772796 0.2242154150237281], [0.06022121006518244, 0.4059827905272061, 0.6536746552971661]])
I tried a similar struct holding a 0-dimensional Float32
:
julia> import Flux
julia> struct Foo foo::Float32 end
julia> foo = Foo(3.3)
Foo(3.3f0)
julia> Flux.params(foo)
Params([])
julia> Flux.@functor Foo
julia> Flux.params(foo)
Params([])
julia> Flux.trainable(foo)
(foo = 3.3f0,)
Here, I expect the call to Flux.params(foo)
after Flux.@functor Foo
to succeed, by analogy with the Affine
case. The parameter foo
is listed in the output of trainable
, so Flux.params
is doing some postprocessing that excludes it.
I can add a specialization for Flux.params!
to fix this example in the REPL:
julia> import Flux
julia> struct Foo foo::Float32 end
julia> foo = Foo(3.3)
julia> Flux.@functor Foo
julia> Flux.params!(p::Flux.Params, x::Number, seen = IdSet()) = push!(p, x)
julia> Flux.params(foo)
Params([3.3])
This example raises 3 possible improvements:
- Should a specialization for
Number
be added forFlux.params!
? - Should
Flux.params
throw an error if it encounters a type with noFlux.params!
specialization? - Should
Flux.params
throw an error if a type hasn't hadFlux.@functor
run on it yet(and there is no manual implementation of whatever the macro would normally output)?
Thanks for the issue, my points below:
- No, because
IdDicts
andIdSets
can only track mutable reference types like arrays. That's why https://github.com/FluxML/Optimisers.jl exists: we've realized that implicit params have some fundamental limitations and are trying to wire up the rest of the stack to use explicit ones instead. - It ought to, but at this point it would be a huge compat break because it hasn't done so thus far. Changing core Flux layers to not present scalar fields as params is one thing, but doing so for the whole ecosystem would be challenging.
- No, because it would be massively breaking for little benefit. Say you had a struct for a custom activation function with its own trainable params. You'd still want that to be picked up if it were used in say a Dense layer, but you wouldn't want, say, using
relu
instead to throw an error because of that.Functors.@functor
is explicitly opt-in for that reason. That said, Functors is already smart enough to pass through types it doesn't understand, see https://github.com/FluxML/Functors.jl/blob/master/src/functor.jl#L1-L2.
Edit: note that Functors.functor(foo)[1] == (3.3f0,)
.