Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

gradient of `cat` which introduce new dims do not match the dims of input

Open chengchingwen opened this issue 4 years ago • 5 comments

MWE:

julia> Zygote.gradient(randn(3,3)) do x
           sum(sin.(cat(x; dims=4)))
       end[1]
3×3×1×1 Array{Float64, 4}: # should be 3 x 3 but get 3 x 3 x 1 x 1
[:, :, 1, 1] =
 0.930559   0.810081  0.894403
 0.951607  -0.659616  0.310079
 0.950346   0.774937  0.910482

chengchingwen avatar Sep 04 '21 12:09 chengchingwen

Fixed in ChainRules, but I think Zygote still uses its own older versions:

julia> Diffractor.gradient(randn(3,3)) do x
               sum(sin.(cat(x; dims=4)))
           end[1]
3×3 Matrix{Float64}:
 0.954619  0.903961  0.297481
 0.174126  0.99136   0.972581
 0.53144   0.439785  0.877705

mcabbott avatar Sep 04 '21 21:09 mcabbott

Seems to be fixed already. Tested with Zygote 0.6.34:

julia> Zygote.gradient(randn(3,3)) do x
                  sum(sin.(cat(x; dims=4)))
              end[1]         
3×3 Matrix{Float64}:        
 0.598398   0.988603  0.999602
 0.84835   -0.101217  0.391286
 0.9785    -0.717415  0.87662

chengchingwen avatar Feb 12 '22 08:02 chengchingwen

FWIW, this is because gradient calls ProjectTo on the final answer. The rule itself is unchanged, and thus intermediate results may show this.

julia> Zygote.pullback(randn(3,3)) do x
           sum(sin.(cat(x; dims=4)))
       end[2](1.0)[1]
3×3×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 0.392695  -0.145669  0.450509
 0.334361   0.316647  0.980656
 0.987435   0.843904  0.985057
 
 julia> gradient(x -> sum(abs2, cat(x * x'; dims=4)), [1 2; 3 4])
ERROR: MethodError: no method matching *(::Array{Int64, 4}, ::Matrix{Int64}) 

mcabbott avatar Feb 12 '22 13:02 mcabbott

IIRC the hurdle to simply deleting all of these is https://github.com/JuliaGPU/GPUArrays.jl/issues/362 . vcat of a mix of numbers and CuArrays mostly works, and its gradient should not use scalar indexing. ChainRules doesn't depend on GPU stuff but it calls sum(view(x,i)) as a work-around... which doesn't work yet. Maybe there's a smarter way.

mcabbott avatar Feb 13 '22 01:02 mcabbott

This is solved now right?

mzgubic avatar Aug 01 '22 21:08 mzgubic