SliceMap.jl
SliceMap.jl copied to clipboard
mapcols return CPU arrays with GPU input
julia> using SliceMap, Flux
julia> mapcols(norm, gpu(Flux.param(randn(2,2))))
Tracked 1×2 Array{Float32,2}:
1.98362 0.443408
julia> mapcols(norm, gpu(randn(2,2)))
1×2 Array{Float32,2}:
0.891582 3.15292
Possibly because norm returns a scalar?
Yes, for scalars I called surevec(x) = [x] before reducing, which made an Array always.
I tried two things on master to improve things. The first made a small CuArray from each scalar instead, and then reduce(hcat, …) made a CuArray. But then the gradient wants to take first() of each of them, which gives a warning. The second was just to return map(f, eachcol(M)) pretty much. But then I realised that’s also an ordinary array. Surely there is a good solution to this.
Related thread: https://discourse.julialang.org/t/map-performance-with-cuarrays/33497/10
Are there functions for which some variant of f.(eachcol(cu(x))) does make sense? i.e. for which mapping over slices of a CuArray is useful & fast?