Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Correct counting of shared parameters in `Base.show`.

Open codetalker7 opened this issue 1 year ago • 12 comments

This PR fixes the show function by correctly counting non-trainable parameters. The earlier counting code duplicated shared parameters in it's count (https://github.com/FluxML/Flux.jl/blob/ebcbbe495e45c68c84382b5ee0282fe9edf441b8/src/layers/show.jl#L96), and hence some shared trainable parameters were being counted as being non-trainable. The change is in the _big_finale function, where, instead of duplicating the counts, we use an IdSet to keep track of which parameters have been counted (and don't count a parameter twice).

As an example, now the following code shows the correct output:

julia> using Flux;

julia> d = Dense(10 => 10);

julia> shared_layer = Chain(Embedding(10, 10), d, d)
Chain(
  Embedding(10 => 10),                  # 100 parameters
  Dense(10 => 10),                      # 110 parameters
  Dense(10 => 10),                      # 110 parameters
)                   # Total: 3 arrays, 210 parameters, 1.055 KiB.

julia> normal_layer = Chain(Embedding(10, 10), Dense(10 => 10), Dense(10 => 10))
Chain(
  Embedding(10 => 10),                  # 100 parameters
  Dense(10 => 10),                      # 110 parameters
  Dense(10 => 10),                      # 110 parameters
)                   # Total: 5 arrays, 320 parameters, 1.562 KiB.

TODO:

  • [ ] Add tests.
  • [ ] Add an example in the docs for shared parameters?

Closes #2321.

codetalker7 avatar Sep 14 '23 09:09 codetalker7

If this looks good, I'll go ahead and add some tests and add an example in the documentation as well.

codetalker7 avatar Sep 14 '23 09:09 codetalker7

The documentation CI is failing for an RNN. I assume even there the output is incorrect? It's probably showing the state parameter as non-trainable (which is trainable, right?)

codetalker7 avatar Sep 14 '23 14:09 codetalker7

It's probably showing the state parameter as non-trainable (which is trainable, right?)

That's Recur.state, which should be non-trainable. Note how only cell is included below:

julia> Flux.trainable(RNN(2 => 5))
(cell = RNNCell(2 => 5, tanh),)

ToucheSir avatar Sep 15 '23 04:09 ToucheSir

It's probably showing the state parameter as non-trainable (which is trainable, right?)

That's Recur.state, which should be non-trainable. Note how only cell is included below:

julia> Flux.trainable(RNN(2 => 5))
(cell = RNNCell(2 => 5, tanh),)

I see, yes, that makes sense. I think I understand now why it's not showing any non-trainable parameters for RNN(2 => 5): this is because, both the initial state (state0 of the cell) and Recur.state are initialized to the zero matrix (and hence pushing both these matrices to the IdSet just pushes one matrix instead of two). ~~Instead of this, we'll have to push names of parameters to the IdSet as well (to distinguish between two distinct parameters having the same value).~~ Even pushing names of parameters might not work, since two layers can share the same parameter name and the same parameter values and still be different.

~~Just to confirm: is it true that all parameters in Flux (i.e, Functors.children(m), where m is some layer) have unique names associated to them? If not, I don't immediately see a way of counting the total number of distinct parameters.~~

codetalker7 avatar Sep 15 '23 13:09 codetalker7

Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like Adjoint and Transpose as aliases? Which wrappers in particular? What about reshapes of an Array, which share the same data but have different objectids and thus aren't caught by using an IdSet? It's not an easy problem, but this PR is a good start.

ToucheSir avatar Sep 20 '23 20:09 ToucheSir

Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like Adjoint and Transpose as aliases? Which wrappers in particular? What about reshapes of an Array, which share the same data but have different objectids and thus aren't caught by using an IdSet? It's not an easy problem, but this PR is a good start.

Taking inspiration from Flux.params!, I tried to push the whole layer to the IdSet instead of just AbstractArrays, and that seems to be giving correct results. How does it look now?

codetalker7 avatar Sep 22 '23 09:09 codetalker7

I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?

ToucheSir avatar Sep 27 '23 00:09 ToucheSir

Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.

(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)

mcabbott avatar Sep 29 '23 03:09 mcabbott

I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?

Hi @ToucheSir, could you explain the "nominally different layers" part? I didn't quite follow it. Maybe an example?

Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.

(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)

Sure; I'll take a look at both Functors and Optimisers more closely.

codetalker7 avatar Sep 30 '23 15:09 codetalker7

Something like this:

d1 = Dense(3 => 4)
d2 = Dense(d1.weight)
d1.weight === d2.weight # tied
d1 !== d2 # but pushing the whole layer won't capture that

ToucheSir avatar Sep 30 '23 16:09 ToucheSir

d1 !== d2

I see, yes, that makes sense. I think, any solution which counts distinct (or shared) parameters in a model must use some form of unique ID associated to that parameter (I can't think of other ways atm, maybe there are more clever ways). Can we somehow associate such an ID to every parameter in a Flux model? Or more generally, associate some metadata to each leaf of a struct?

codetalker7 avatar Sep 30 '23 17:09 codetalker7

Functors uses a cache which should detect such sharing. It's a little smarter than just using objectid, so as not to catch immutable objects which are accidentally ===.

julia> using Functors, Flux

julia> let mat = rand(2,2)
         model = Chain(Dense(mat), Dense(mat')) # separate bias vectors
         cnt = Ref(0)
         fmapstructure(model; exclude=x->x isa Array) do x
           cnt[] += 1
         end
       end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = (parent = 1,), bias = 3, σ = ())),)

julia> using StaticArrays

julia> [1,2] === [1,2]  # different arrays with same value, not shared
false

julia> SA[1,2] === SA[1,2]  # here the value is the only identity... still not shared.
true

julia> let mat = @SMatrix rand(2,2)
         model = Chain(Dense(mat), Dense(mat'))  # still has mat === mat
         cnt = Ref(0)
         fmapstructure(model; exclude=x->x isa AbstractArray) do x
           cnt[] += 1
         end
       end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = 3, bias = 4, σ = ())),)

I think fmap like this ought to be equivalent to Flux.params. But the trainable count needs a modified walk to exclude some children.

mcabbott avatar Sep 30 '23 18:09 mcabbott