ChainRules.jl
ChainRules.jl copied to clipboard
`rrule` for `mean(f, x)` is not vectorized?
Hi, it seems that the rrule for mean(f, x) is not vectorized and thus does not place nicely with CUDA:
using Zygote, CUDA, Statistics
julia> gradient(y -> mean(x -> x.^2, y), CUDA.randn(10))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
[3] getindex
@ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
[4] iterate
@ ./abstractarray.jl:1220 [inlined]
[5] iterate
@ ./abstractarray.jl:1218 [inlined]
[6] iterate
@ ./generator.jl:44 [inlined]
[7] collect(itr::Base.Generator{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ChainRules.var"#1655#1660"{Zygote.ZygoteRuleConfig{Zygote.Context{false}}, var"#24#26"}})
@ Base ./array.jl:782
[8] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::typeof(sum), f::var"#24#26", xs::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:102
[9] rrule
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:76 [inlined]
[10] #rrule#1808
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:28 [inlined]
[11] rrule
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:21 [inlined]
[12] chain_rrule
@ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:223 [inlined]
[13] macro expansion
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
[14] _pullback
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
[15] _pullback
@ ./REPL[14]:1 [inlined]
[16] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[17] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
[18] pullback
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
[19] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
[20] top-level scope
@ REPL[14]:1
[21] top-level scope
@ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
The problem seems to be that this line does not use map or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?
By the way, sum(f, x) for the same f works perfectly. So I'm quite curious why the result is different. Both hit the same rrule right?
julia> gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10))
(Float32[-0.03543221, -0.002124702, 0.068868384, -0.21756743, 0.234217, -0.16418666, -0.033367466, -0.26496077, 0.095435165, -0.044487894],)
This appears to be more complicated. It seems that gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10)) does not hit the sum(f, x) rrule, while mean(f, x) does. This is super weird. I have no idea which rrule is being hit for sum(f, x).
Zygote has this rule for sum(f, xs::CuArray), which takes precedence over the one here:
https://github.com/FluxML/Zygote.jl/blob/d4562e330d588cb986604bb4f1942bf9fca8ecc5/src/lib/broadcast.jl#L372-L377
Note also that sum(x -> x^2, xs) is equivalent to sum(abs2, xs) which has a special rule. I think that mean(abs2, xs) goes here and should call that.
(One example above has x -> x.^2 with an extra broadcast, some chance that changes what path is taken in the sum(f, xs) rule.)