Flux.jl
Flux.jl copied to clipboard
Correct counting of shared parameters in `Base.show`.
This PR fixes the show
function by correctly counting non-trainable parameters. The earlier counting code duplicated shared parameters in it's count (https://github.com/FluxML/Flux.jl/blob/ebcbbe495e45c68c84382b5ee0282fe9edf441b8/src/layers/show.jl#L96), and hence some shared trainable parameters were being counted as being non-trainable. The change is in the _big_finale
function, where, instead of duplicating the counts, we use an IdSet
to keep track of which parameters have been counted (and don't count a parameter twice).
As an example, now the following code shows the correct output:
julia> using Flux;
julia> d = Dense(10 => 10);
julia> shared_layer = Chain(Embedding(10, 10), d, d)
Chain(
Embedding(10 => 10), # 100 parameters
Dense(10 => 10), # 110 parameters
Dense(10 => 10), # 110 parameters
) # Total: 3 arrays, 210 parameters, 1.055 KiB.
julia> normal_layer = Chain(Embedding(10, 10), Dense(10 => 10), Dense(10 => 10))
Chain(
Embedding(10 => 10), # 100 parameters
Dense(10 => 10), # 110 parameters
Dense(10 => 10), # 110 parameters
) # Total: 5 arrays, 320 parameters, 1.562 KiB.
TODO:
- [ ] Add tests.
- [ ] Add an example in the docs for shared parameters?
Closes #2321.
If this looks good, I'll go ahead and add some tests and add an example in the documentation as well.
The documentation CI is failing for an RNN
. I assume even there the output is incorrect? It's probably showing the state
parameter as non-trainable (which is trainable, right?)
It's probably showing the state parameter as non-trainable (which is trainable, right?)
That's Recur.state
, which should be non-trainable. Note how only cell
is included below:
julia> Flux.trainable(RNN(2 => 5))
(cell = RNNCell(2 => 5, tanh),)
It's probably showing the state parameter as non-trainable (which is trainable, right?)
That's
Recur.state
, which should be non-trainable. Note how onlycell
is included below:julia> Flux.trainable(RNN(2 => 5)) (cell = RNNCell(2 => 5, tanh),)
I see, yes, that makes sense. I think I understand now why it's not showing any non-trainable parameters for RNN(2 => 5)
: this is because, both the initial state (state0
of the cell
) and Recur.state
are initialized to the zero matrix (and hence pushing both these matrices to the IdSet
just pushes one matrix instead of two). ~~Instead of this, we'll have to push names of parameters to the IdSet
as well (to distinguish between two distinct parameters having the same value).~~ Even pushing names of parameters might not work, since two layers can share the same parameter name and the same parameter values and still be different.
~~Just to confirm: is it true that all parameters in Flux (i.e, Functors.children(m)
, where m
is some layer) have unique names associated to them? If not, I don't immediately see a way of counting the total number of distinct parameters.~~
Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like Adjoint
and Transpose
as aliases? Which wrappers in particular? What about reshape
s of an Array
, which share the same data but have different objectid
s and thus aren't caught by using an IdSet
? It's not an easy problem, but this PR is a good start.
Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like
Adjoint
andTranspose
as aliases? Which wrappers in particular? What aboutreshape
s of anArray
, which share the same data but have differentobjectid
s and thus aren't caught by using anIdSet
? It's not an easy problem, but this PR is a good start.
Taking inspiration from Flux.params!
, I tried to push the whole layer to the IdSet
instead of just AbstractArray
s, and that seems to be giving correct results. How does it look now?
I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?
Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.
(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)
I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?
Hi @ToucheSir, could you explain the "nominally different layers" part? I didn't quite follow it. Maybe an example?
Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.
(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)
Sure; I'll take a look at both Functors and Optimisers more closely.
Something like this:
d1 = Dense(3 => 4)
d2 = Dense(d1.weight)
d1.weight === d2.weight # tied
d1 !== d2 # but pushing the whole layer won't capture that
d1 !== d2
I see, yes, that makes sense. I think, any solution which counts distinct (or shared) parameters in a model must use some form of unique ID associated to that parameter (I can't think of other ways atm, maybe there are more clever ways). Can we somehow associate such an ID to every parameter in a Flux model? Or more generally, associate some metadata to each leaf of a struct?
Functors uses a cache which should detect such sharing. It's a little smarter than just using objectid
, so as not to catch immutable objects which are accidentally ===
.
julia> using Functors, Flux
julia> let mat = rand(2,2)
model = Chain(Dense(mat), Dense(mat')) # separate bias vectors
cnt = Ref(0)
fmapstructure(model; exclude=x->x isa Array) do x
cnt[] += 1
end
end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = (parent = 1,), bias = 3, σ = ())),)
julia> using StaticArrays
julia> [1,2] === [1,2] # different arrays with same value, not shared
false
julia> SA[1,2] === SA[1,2] # here the value is the only identity... still not shared.
true
julia> let mat = @SMatrix rand(2,2)
model = Chain(Dense(mat), Dense(mat')) # still has mat === mat
cnt = Ref(0)
fmapstructure(model; exclude=x->x isa AbstractArray) do x
cnt[] += 1
end
end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = 3, bias = 4, σ = ())),)
I think fmap
like this ought to be equivalent to Flux.params
. But the trainable count needs a modified walk to exclude some children.