support dims
support sum(arr, dims =1) like in standard julia sum to apply on a given axis.
FWIW, I find that in this case something like
sum(x -> !isnan(x) * x, arr, dims=1)
(i.e., without using NaNMath) works just as well. (And is maybe a bit faster than a NaNMath implementation since it relies on the built-in sum?)
nice solution for sum, but unfortunately doesn't work for, e.g., mean.
riffing on julia Base and daneel's code:
using Statistics, Test
_nanfunc(f, A, ::Colon) = f(filter(!isnan, A))
_nanfunc(f, A, dims) = mapslices(a->_nanfunc(f,a,:), A, dims=dims)
nanfunc(f, A; dims=:) = _nanfunc(f, A, dims)
A = [1 2 3; 4 5 6; 7 8 9; NaN 11 12]
@test isapprox(nanfunc(mean, A), mean(filter(!isnan, A)))
@test nanfunc(mean, A, dims=1) == [4.0 6.5 7.5]
@test nanfunc(mean, A, dims=2) == transpose([2.0 5.0 8.0 11.5])
@test isapprox(nanfunc(var, A), var(filter(!isnan, A)))
@test nanfunc(var, A, dims=1) == [9.0 15.0 15.0]
@test nanfunc(var, A, dims=2) == transpose([1.0 1.0 1.0 0.5])
can we actually make this a PR? one issue I see is that mapslices doesn't play with @view nicely so at the moment if you actually use dims you would slow down significantly and have huge allocations:
julia> a = rand([NaN, 1,2,3,4,5], 100,100,100);
julia> @btime nanfunc(mean, a);
1.188 ms (4 allocations: 7.63 MiB)
julia> @btime NaNMath.mean(a);
2.035 ms (1 allocation: 16 bytes)
julia> @btime nanfunc(mean, a; dims=2);
10.382 ms (120039 allocations: 11.37 MiB)
riffing on julia Base and daneel's code:
using Statistics, Test _nanfunc(f, A, ::Colon) = f(filter(!isnan, A)) _nanfunc(f, A, dims) = mapslices(a->_nanfunc(f,a,:), A, dims=dims) nanfunc(f, A; dims=:) = _nanfunc(f, A, dims) A = [1 2 3; 4 5 6; 7 8 9; NaN 11 12] @test isapprox(nanfunc(mean, A), mean(filter(!isnan, A))) @test nanfunc(mean, A, dims=1) == [4.0 6.5 7.5] @test nanfunc(mean, A, dims=2) == transpose([2.0 5.0 8.0 11.5]) @test isapprox(nanfunc(var, A), var(filter(!isnan, A))) @test nanfunc(var, A, dims=1) == [9.0 15.0 15.0] @test nanfunc(var, A, dims=2) == transpose([1.0 1.0 1.0 0.5])
Hi @bjarthur! I do have a question about the way to apply a specific function's argument within the nanfunc(). E.g. if we wanted to calculate std() which might be corrected or not, or any other function that needs an extra one or more arguments.
Thanks!