AxisKeys.jl
AxisKeys.jl copied to clipboard
Support for NaNStatistics
It would be very handy if there was support for NaNStatistics. I couldn't find an easy way to rewrap the output array so I did this ugly workaround:
NaNStatistics.nansum(A::KeyedArray; dim=:, dims=:) = rewrap_keyedarray(nansum(A.data.data; dim, dims), A, dims isa Colon ? dim : dims)
function rewrap_keyedarray(new, original, dims)
axs = NamedTuple{propertynames(original)}(original.keys) |> pairs
if ndims(new) < ndims(original)
axs = collect(axs)[1:end .!= dims]
else
axs = Dict(axs)
for dim in dims
axs[propertynames(original)[dim]] = axes(new, dim)
end
axs = [k => axs[k] for k in propertynames(original)]
end
KeyedArray(new; axs...)
end
Thanks
This code is not so far from everything else in this package, and now that it has extensions (thanks to #151) it's fine to have one more for this package. Probably mostly copied from here:
https://github.com/mcabbott/AxisKeys.jl/blob/master/ext/StatisticsExt.jl
Of course, this support can explicitly be added as an extension here, you may want to make a PR for that.
But note that Julia already has a fully general way to apply reductions (like sum, nansum, ...) along array axes:
julia> A = KeyedArray(rand(2, 3), x=[:a,:b], y=10:10:30)
2-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ x ∈ 2-element Vector{Symbol}
→ y ∈ 3-element StepRange{Int64,...}
And data, 2×3 Matrix{Float64}:
(10) (20) (30)
(:a) 0.990372 0.0498909 0.869283
(:b) 0.994879 0.541963 0.279493
julia> nansum(A)
3.7258801629137928
julia> map(nansum, eachslice(A, dims=:x))
1-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ x ∈ 2-element Vector{Symbol}
And data, 2-element Vector{Float64}:
(:a) 1.9095453258981778
(:b) 1.8163348370156154
This way, there's no need for explicit support of each individual reduction function, everything just works.