Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Missing support for muladd in case of brodcasting with a complex argument

Open rcalxrc08 opened this issue 8 months ago • 1 comments

Hi all,

I noticed the following when I combine complex numbers muladd and forward mode (I think it is forward mode because I am broadcasting some function over a vector). I am using julia 1.9.3:

Julia Version 1.9.3
Commit bed2cd540a (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 8 × Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

And Zygote v0.6.64. The MWE is the following:

using Zygote
function f_no_muladd(x,a,b)
	vec_res=@. real(a*exp(x+b*im))
	return sum(vec_res)
end

function f_muladd(x,a,b)
	vec_res=@. real(a*exp(muladd(b,im,x)))
	return sum(vec_res)
end
x=ones(Float64,10);
a=1.0;
b=2.0;
Zygote.gradient(f_no_muladd,x,a,b) #completely fine
Zygote.gradient(f_muladd,x,a,b) # This call fails

The error is actually in ForwardDiff.jl:

ERROR: MethodError: no method matching calc_muladd_xyz(::ForwardDiff.Dual{Nothing, Bool, 6}, ::ForwardDiff.Dual{Nothing, Float64, 3}, ::ForwardDiff.Dual{Nothing, Float64, 3})

Closest candidates are:
  calc_muladd_xyz(::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}) where {T, N}
   @ ForwardDiff C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:637

Stacktrace:
  [1] muladd
    @ C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:155 [inlined]
  [2] muladd(z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, x::ForwardDiff.Dual{Nothing, Float64, 3}, y::ForwardDiff.Dual{Nothing, Float64, 3})
    @ Base .\complex.jl:340
  [3] muladd(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, y::ForwardDiff.Dual{Nothing, Float64, 3})
    @ Base .\complex.jl:339
  [4] (::Zygote.var"#1404#1405"{typeof(muladd)})(::Float64, ::Complex{Bool}, ::Float64)
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:276
  [5] _broadcast_getindex_evalf
    @ .\broadcast.jl:683 [inlined]
  [6] _broadcast_getindex
    @ .\broadcast.jl:656 [inlined]
  [7] getindex
    @ .\broadcast.jl:610 [inlined]
  [8] copy
    @ .\broadcast.jl:912 [inlined]
  [9] materialize
    @ .\broadcast.jl:873 [inlined]
 [10] broadcast_forward
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:282 [inlined]
 [11] _broadcast_generic
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:212 [inlined]
 [12] adjoint
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:205 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.DefaultArrayStyle{1}, ::typeof(muladd), ::Float64, ::Complex{Bool}, ::Vector{Float64})
    @ Zygote C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66
 [14] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:838
 [15] adjoint
    @ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\lib.jl:203 [inlined]
 [16] _pullback
    @ C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66 [inlined]
 [17] _pullback
    @ .\broadcast.jl:1317 [inlined]
 [18] _pullback
    @ .\REPL[3]:2 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(f_muladd), ::Vector{Float64}, ::Float64, ::Float64)
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface2.jl:0
 [20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:44
 [21] pullback(::Function, ::Vector{Float64}, ::Float64, ::Vararg{Float64})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:42
 [22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:96
 [23] top-level scope
    @ REPL[8]:1

rcalxrc08 avatar Oct 07 '23 13:10 rcalxrc08

Notice that the error shows Duals with 6 and 3 partials together, which doesn't make sense for ForwardDiff:

julia> using ForwardDiff: Dual

julia> muladd(Dual(1,2), Dual(3,4), Dual(5,6))
Dual{Nothing}(8,16)

julia> muladd(Dual(1,2,0), Dual(3,4), Dual(5,6))
ERROR: MethodError: no method matching calc_muladd_xyz(::Dual{Nothing, Int64, 2}, ::Dual{Nothing, Int64, 1}, ::Dual{Nothing, Int64, 1})

So the bug is here somehow?

The use of Dual for complex-number broadcasting was added in https://github.com/FluxML/Zygote.jl/pull/1324, would be worth checking whether https://github.com/FluxML/Zygote.jl/pull/1441 changes anything.

mcabbott avatar Oct 07 '23 23:10 mcabbott