Adding `summarise_array` to show.jl
I'm not sure if you'll want this, and maybe there is a better way to do what this does, but:
This PR adds two lines that, if I haven't missed anything, should make no difference to the current behavior of Base.show. But it allows the user some customization of the model/layer display.
The particular problem this solves for me is "being able to easily inspect summary stats of the model weights at the individual tensor level". With the change in this PR, I can define eg. this:
Flux.summarise_array(a) = "; L=$(length(a)):σ=$(round(std(a), digits = 3))"
and then when I display my model in the REPL, I can see the details I wanted for each tensor in each layer, in the context of the model structure:
PR Checklist
- [ ] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
Something like this might be useful, thanks for digging into the somewhat messy show code...
At present _nan_show(io, trainables(layer)) (a few lines down) aims to warn you about Inf/NaN, and all-zero. But maybe just always printing something about the values would be better? σ = NaN would convey almost the same information. Is std the best one number, or would say norm(W) be better?
I'm a little reluctant to add more functions we encourage people to overload.
I wonder if it should only do the first (say) 5 parameter arrays, so that you will never get 10 lines of noise.
If I were to pick one number it would be std, but std is NaN when taken over a single value, and if you aren't aware of this you might go looking for bugs that aren't there. So maybe the "population" std (which divides by N instead of N-1)? norm is also a defensible choice.
And yes, I think standardly reporting something like this is good. But if you do it, please do it in a way that is easy to overload, (even if we don't encourage people to do so)? Sometimes I care about the std, but sometimes also the mean, the extrema, the number of params < 0, etc.
One option would be to support this with a more explicit call where the user passes in the array function directly, like summarise(model, array_printing_function = f) and have show call that with the default? I'm not sure I'd be able to figure that out though...