Zygote.jl
Zygote.jl copied to clipboard
Missing support for muladd in case of brodcasting with a complex argument
Hi all,
I noticed the following when I combine complex numbers muladd and forward mode (I think it is forward mode because I am broadcasting some function over a vector). I am using julia 1.9.3:
Julia Version 1.9.3
Commit bed2cd540a (2023-08-24 14:43 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 8 × Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
Threads: 1 on 8 virtual cores
And Zygote v0.6.64. The MWE is the following:
using Zygote
function f_no_muladd(x,a,b)
vec_res=@. real(a*exp(x+b*im))
return sum(vec_res)
end
function f_muladd(x,a,b)
vec_res=@. real(a*exp(muladd(b,im,x)))
return sum(vec_res)
end
x=ones(Float64,10);
a=1.0;
b=2.0;
Zygote.gradient(f_no_muladd,x,a,b) #completely fine
Zygote.gradient(f_muladd,x,a,b) # This call fails
The error is actually in ForwardDiff.jl:
ERROR: MethodError: no method matching calc_muladd_xyz(::ForwardDiff.Dual{Nothing, Bool, 6}, ::ForwardDiff.Dual{Nothing, Float64, 3}, ::ForwardDiff.Dual{Nothing, Float64, 3})
Closest candidates are:
calc_muladd_xyz(::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}, ::ForwardDiff.Dual{T, <:Any, N}) where {T, N}
@ ForwardDiff C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:637
Stacktrace:
[1] muladd
@ C:\Users\Nicola\.julia\packages\ForwardDiff\PcZ48\src\dual.jl:155 [inlined]
[2] muladd(z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, x::ForwardDiff.Dual{Nothing, Float64, 3}, y::ForwardDiff.Dual{Nothing, Float64, 3})
@ Base .\complex.jl:340
[3] muladd(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Bool, 6}}, y::ForwardDiff.Dual{Nothing, Float64, 3})
@ Base .\complex.jl:339
[4] (::Zygote.var"#1404#1405"{typeof(muladd)})(::Float64, ::Complex{Bool}, ::Float64)
@ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:276
[5] _broadcast_getindex_evalf
@ .\broadcast.jl:683 [inlined]
[6] _broadcast_getindex
@ .\broadcast.jl:656 [inlined]
[7] getindex
@ .\broadcast.jl:610 [inlined]
[8] copy
@ .\broadcast.jl:912 [inlined]
[9] materialize
@ .\broadcast.jl:873 [inlined]
[10] broadcast_forward
@ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:282 [inlined]
[11] _broadcast_generic
@ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:212 [inlined]
[12] adjoint
@ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\broadcast.jl:205 [inlined]
[13] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.DefaultArrayStyle{1}, ::typeof(muladd), ::Float64, ::Complex{Bool}, ::Vector{Float64})
@ Zygote C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66
[14] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:838
[15] adjoint
@ C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\lib\lib.jl:203 [inlined]
[16] _pullback
@ C:\Users\Nicola\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:66 [inlined]
[17] _pullback
@ .\broadcast.jl:1317 [inlined]
[18] _pullback
@ .\REPL[3]:2 [inlined]
[19] _pullback(::Zygote.Context{false}, ::typeof(f_muladd), ::Vector{Float64}, ::Float64, ::Float64)
@ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface2.jl:0
[20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
@ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:44
[21] pullback(::Function, ::Vector{Float64}, ::Float64, ::Vararg{Float64})
@ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:42
[22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
@ Zygote C:\Users\Nicola\.julia\packages\Zygote\4SSHS\src\compiler\interface.jl:96
[23] top-level scope
@ REPL[8]:1
Notice that the error shows Duals with 6 and 3 partials together, which doesn't make sense for ForwardDiff:
julia> using ForwardDiff: Dual
julia> muladd(Dual(1,2), Dual(3,4), Dual(5,6))
Dual{Nothing}(8,16)
julia> muladd(Dual(1,2,0), Dual(3,4), Dual(5,6))
ERROR: MethodError: no method matching calc_muladd_xyz(::Dual{Nothing, Int64, 2}, ::Dual{Nothing, Int64, 1}, ::Dual{Nothing, Int64, 1})
So the bug is here somehow?
The use of Dual for complex-number broadcasting was added in https://github.com/FluxML/Zygote.jl/pull/1324, would be worth checking whether https://github.com/FluxML/Zygote.jl/pull/1441 changes anything.