Zygote.jl
Zygote.jl copied to clipboard
gradient of `cat` which introduce new dims do not match the dims of input
MWE:
julia> Zygote.gradient(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[1]
3×3×1×1 Array{Float64, 4}: # should be 3 x 3 but get 3 x 3 x 1 x 1
[:, :, 1, 1] =
0.930559 0.810081 0.894403
0.951607 -0.659616 0.310079
0.950346 0.774937 0.910482
Fixed in ChainRules, but I think Zygote still uses its own older versions:
julia> Diffractor.gradient(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[1]
3×3 Matrix{Float64}:
0.954619 0.903961 0.297481
0.174126 0.99136 0.972581
0.53144 0.439785 0.877705
Seems to be fixed already. Tested with Zygote 0.6.34:
julia> Zygote.gradient(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[1]
3×3 Matrix{Float64}:
0.598398 0.988603 0.999602
0.84835 -0.101217 0.391286
0.9785 -0.717415 0.87662
FWIW, this is because gradient calls ProjectTo on the final answer. The rule itself is unchanged, and thus intermediate results may show this.
julia> Zygote.pullback(randn(3,3)) do x
sum(sin.(cat(x; dims=4)))
end[2](1.0)[1]
3×3×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.392695 -0.145669 0.450509
0.334361 0.316647 0.980656
0.987435 0.843904 0.985057
julia> gradient(x -> sum(abs2, cat(x * x'; dims=4)), [1 2; 3 4])
ERROR: MethodError: no method matching *(::Array{Int64, 4}, ::Matrix{Int64})
IIRC the hurdle to simply deleting all of these is https://github.com/JuliaGPU/GPUArrays.jl/issues/362 . vcat of a mix of numbers and CuArrays mostly works, and its gradient should not use scalar indexing. ChainRules doesn't depend on GPU stuff but it calls sum(view(x,i)) as a work-around... which doesn't work yet. Maybe there's a smarter way.
This is solved now right?