Flux.jl
Flux.jl copied to clipboard
`count_params` function?
The show
methods are super nice, and especially the counts of parameters and arrays. Could functions getting the counts be pulled out as an API function? Looking at the source, I guess that would be something like
using Functors, Flux
using Functors: isleaf
_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
function count_params(m)
ps = params(m)
pars = sum(length, ps)
noncnt = _childarray_sum(_->1, m) - length(ps)
nonparam = _childarray_sum(length, m) - sum(length, ps)
return (; trainable_arrays=length(ps), trainable_params=pars, non_trainable_arrays=noncnt, non_trainable_params=nonparam)
end
Sounds like a good idea.