Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Reverse on reverse fails with a Flux nn

Open YichengDWu opened this issue 1 year ago • 10 comments

using Flux, Zygote

model = Dense(3,1)
grad_f(x) = gradient(x -> sum(model(x)),x)[1]
Zygote.jacobian(grad_f,rand(3))
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] Pullback
    @ .\iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] getindex
    @ .\tuple.jl:29 [inlined]
 [11] map
    @ .\tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:36 [inlined]
 [13] #1789#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\REPL[10]:1 [inlined]
 [20] (::typeof(∂(grad_f)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [21] #216
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined]
 [22] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(grad_f))}})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [23] Pullback
    @ .\operators.jl:1085 [inlined]
 [24] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [25] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [26] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [27] Pullback
    @ .\operators.jl:1085 [inlined]
 [28] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f))))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f)))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [30] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162
 [31] jacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140
 [32] top-level scope
    @ REPL[12]:1
 [33] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

YichengDWu avatar Jul 19 '22 06:07 YichengDWu

The first-order gradient is treating model as a global and trying to differentiate wrt it as well. I have no idea what the semantics of globals should be under reverse-over-reverse (we ought to make accum_global non-diff or ban it completely), so the easier path is to get rid of them entirely:

julia> grad_f_m(m) = x -> gradient(x -> sum(m(x)), x)[1]  # [edited to make name not clash]
grad_f_m (generic function with 1 method)

julia> grad_f_m(model)(rand(3))
3-element Vector{Float64}:
  0.28216660022735596
 -0.385393351316452
  0.08605238795280457

julia> Zygote.jacobian(grad_f_m(model), rand(3))[1]
3×3 Matrix{Float64}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

Assuming the all-zero jacobian is incorrect, it's perhaps the more "interesting" bug here.

ToucheSir avatar Jul 19 '22 14:07 ToucheSir

Globals are evil.

I think zero is correct, as the model is linear in x. If you change it to use tanh then these all agree:

julia> model
Dense(3 => 1, tanh)  # 4 parameters

julia> rr = rand(3);

julia> ForwardDiff.jacobian(grad_f,rr)
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537

julia> ForwardDiff.jacobian(grad_f_m(model), rr)
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537

julia> Zygote.jacobian(grad_f_m(model), rr)[1]
3×3 Matrix{Float64}:
 -0.018781    0.013774   -0.0526393
  0.013774   -0.0101018   0.0386056
 -0.0526393   0.0386056  -0.147537

mcabbott avatar Jul 19 '22 17:07 mcabbott

I expect it to behave exactly like a pure function

julia>  grad_f(x) = gradient(x -> sum(x.^3),x)[1]
grad_f (generic function with 1 method)

julia> Zygote.jacobian(grad_f,rand(3))[1]
3×3 Matrix{Float64}:
 0.0456902  0.0       0.0
 0.0        0.449621  0.0
 0.0        0.0       4.86983

YichengDWu avatar Jul 19 '22 20:07 YichengDWu

We have more than one bug here 🥲, see #1264.

YichengDWu avatar Jul 19 '22 21:07 YichengDWu

Just copy things from there

julia> function f(x, bias)
              jac = Zygote.jacobian(x->x.^3, x)[1]
              return jac * x .+ bias
              end
f (generic function with 1 method)

julia> x,bias = rand(3),rand(3)
([0.2279638899624825, 0.6476786632858718, 0.13745627655377346], [0.051516386842686224, 0.6842360463718182, 0.22031281411507742])

julia> Zygote.gradient(b -> sum(f(x,b)), rand(3))
ERROR: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] _throw_mutation_error(f::Function, args::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
  [3] (::Zygote.var"#448#449"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:85
  [4] (::Zygote.var"#2506#back#450"{Zygote.var"#448#449"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [5] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:183 [inlined]
  [6] (::typeof(∂(_gradcopy!)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [7] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:165 [inlined]
  [8] (::typeof(∂(withjacobian)))(Δ::NamedTuple{(:val, :grad), Tuple{Nothing, Tuple{Matrix{Float64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [9] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(withjacobian))})(Δ::NamedTuple{(:val, :grad), Tuple{Nothing, Tuple{Matrix{Float64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [10] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [11] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140 [inlined]
 [12] (::typeof(∂(jacobian)))(Δ::Tuple{Matrix{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [13] Pullback
    @ .\REPL[20]:2 [inlined]
 [14] (::typeof(∂(f)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ .\REPL[22]:1 [inlined]
 [16] (::typeof(∂(#23)))(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof(∂(#23))})(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [18] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [19] top-level scope
    @ REPL[22]:1
 [20] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

YichengDWu avatar Jul 19 '22 21:07 YichengDWu

This might be a better example since layers in Lux should be treated exactly like pure functions

using Lux, Zygote, Random

model = Dense(3,1)
ps,st = Lux.setup(Random.default_rng(),model)
grad_f(x) = gradient(x -> sum(model(x,ps,st)[1]),x)[1]

Zygote.jacobian(grad_f,rand(3))
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] Pullback
    @ .\iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [6] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] getindex
    @ .\tuple.jl:29 [inlined]
 [11] map
    @ .\tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:36 [inlined]
 [13] #1789#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\REPL[11]:1 [inlined]
 [20] (::typeof(∂(grad_f)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [21] #216
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined]
 [22] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(grad_f))}})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [23] Pullback
    @ .\operators.jl:1085 [inlined]
 [24] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [25] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [26] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [27] Pullback
    @ .\operators.jl:1085 [inlined]
 [28] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f))))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(grad_f)}(Zygote._jvec, grad_f)))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [30] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162
 [31] jacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140
 [32] top-level scope
    @ REPL[13]:1
 [33] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

YichengDWu avatar Jul 23 '22 00:07 YichengDWu

I tried your method and it didn't work on Lux, how? @ToucheSir

using Lux, Zygote, Random

model = Dense(3,1)
ps,st = Lux.setup(Random.default_rng(),model)
grad_f(m,p,s) = x -> gradient(y -> sum(m(y,p,s)[1]),x)[1]

Zygote.jacobian(grad_f(model,ps,st),rand(3))
error
ERROR: Can't differentiate foreigncall expression.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/dev/limitations.html

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] Pullback
    @ .\essentials.jl:599 [inlined]
  [3] (::typeof(∂(getindex)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [4] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\tools\builtins.jl:12 [inlined]
  [5] (::typeof(∂(literal_getindex)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [6] Pullback
    @ .\reflection.jl:752 [inlined]
  [7] (::typeof(∂(fieldcount)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:220 [inlined]
  [9] (::typeof(∂(canonicalize)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:115 [inlined]
 [11] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:183 [inlined]
 [12] (::typeof(∂(_project)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [13] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:235 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ C:\Users\Luffy\.julia\packages\Lux\lEqCI\src\layers\basic.jl:631 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}, Nothing, Nothing})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] Pullback
    @ .\REPL[4]:1 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [21] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41 [inlined]
 [22] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [23] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76 [inlined]
 [24] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [25] Pullback
    @ .\REPL[4]:1 [inlined]
 [26] (::typeof(∂(λ)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [27] #216
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207 [inlined]
 [28] (::Zygote.var"#1909#back#218"{Zygote.var"#216#217"{Tuple{Tuple{Nothing}}, typeof(∂(λ))}})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [29] Pullback
    @ .\operators.jl:1085 [inlined]
 [30] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [31] (::Zygote.var"#216#217"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\lib.jl:207
 [32] #1909#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [33] Pullback
    @ .\operators.jl:1085 [inlined]
 [34] (::typeof(∂(λ)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [35] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [36] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:162
 [37] jacobian(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:140
 [38] top-level scope
    @ REPL[5]:1
 [39] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

YichengDWu avatar Jul 26 '22 06:07 YichengDWu

Why would it? It's not like Lux has any special magic here. It's probably calling much the same code under the hood.

ToucheSir avatar Jul 26 '22 13:07 ToucheSir

It's probably calling much the same code under the hood.

But one works and one doesn't? Maybe I don't understand how Zygote differentiates a functor.

YichengDWu avatar Jul 26 '22 19:07 YichengDWu

I don't think it has anything to do with functors, but rather that Lux has getindex calls in the forward pass where Flux does not. It doesn't look like the calls in question, but from within Zygote's internals: given that https://github.com/FluxML/Zygote.jl/blob/cb59b6c780635a24afdc39c06f4de92ce4f52a0e/src/lib/lib.jl#L207 shows up in the stacktrace, perhaps there is some splatting happening in Lux layers?

ToucheSir avatar Jul 30 '22 01:07 ToucheSir