Zygote.jl
Zygote.jl copied to clipboard
Hessian of NN output wrt input (GPU)
Hello,
I need to compute the hessian of a neural network's output with respect to its inputs. I am able to do it on the CPU without any issues, but I'd like to do it on the GPU.
Neither Zygote.hessian
nor 'Zygote.hessian_reverse' work for me. Could you please help sort this out?
Attaching a simple example:
using CUDA, Flux, DiffEqFlux, DiffEqBase, NeuralPDE, Zygote
CUDA.allowscalar(false)
dim = 2; nn = 10;
chain = Chain(Dense(dim, nn, tanh), Dense(nn, nn, tanh), Dense(nn, 1));
initθ_g = DiffEqFlux.initial_params(chain) |> gpu;
parameterless_type_θ = DiffEqBase.parameterless_type(initθ_g);
phi_g = NeuralPDE.get_phi(chain, parameterless_type_θ);
initθ_c = DiffEqFlux.initial_params(chain);
parameterless_type_θc = DiffEqBase.parameterless_type(initθ_c);
phi_c = NeuralPDE.get_phi(chain, parameterless_type_θc);
ρFn_c(y) = sum(phi_c(y, initθ_c)); # phi_c returns the scalar NN output (cpu)
ρFn_g(y) = sum(phi_g(y, initθ_g)); # phi_g returns the scalar NN output (GPU)
xg = rand(2) |> gpu; # gpu array
xc = Array(xg); # cpu array
# @show ρFn_c(xc) # works
# @show ρFn_g(xg) # works
## gradient wrt input
dρ_g(x) = Zygote.gradient(ρFn_g, x)[1];
#@show dρ_g(xg) # works
dρ_c(x) = Zygote.gradient(ρFn_c, x)[1];
# @show dρ_c(xc) # works
## hessian wrt input
d2ρ_c(x) = Zygote.hessian(ρFn_c, x);
# d2ρ_c(xg) # works
d2ρ_g(x) = Zygote.hessian(ρFn_g, x);
# d2ρ_g(xg) # doesn't work (forward over reverse)
# Zygote.hessian_reverse(ρFn_g, xg) # doesn't work (reverse over reverse)
The error for Zygote.hessian
is:
ERROR: LoadError: MethodError: no method matching cudnnDataType(::Type{ForwardDiff.Dual{Nothing, Float32, 2}})
Closest candidates are:
cudnnDataType(::Type{Float16}) at /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/util.jl:7
cudnnDataType(::Type{Float32}) at /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/util.jl:8
cudnnDataType(::Type{Float64}) at /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/util.jl:9
...
Stacktrace:
[1] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}; format::CUDA.CUDNN.cudnnTensorFormat_t, dims::Vector{Int32})
@ CUDA.CUDNN /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/tensor.jl:9
[2] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ CUDA.CUDNN /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/tensor.jl:8
[3] cudnnActivationForward!(y::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, x::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}; o::Base.Iterators.Pairs{Symbol, CUDA.CUDNN.cudnnActivationMode_t, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{CUDA.CUDNN.cudnnActivationMode_t}}})
@ CUDA.CUDNN /scratch/user/vish0908/.julia/packages/CUDA/9T5Sq/lib/cudnn/activation.jl:22
[4] (::NNlibCUDA.var"#62#66")(src::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, dst::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ NNlibCUDA /scratch/user/vish0908/.julia/packages/NNlibCUDA/EENEy/src/cudnn/activations.jl:10
[5] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(tanh), Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}}})
@ NNlibCUDA /scratch/user/vish0908/.julia/packages/NNlibCUDA/EENEy/src/cudnn/activations.jl:30
[6] adjoint
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/broadcast.jl:104 [inlined]
[7] _pullback(__context__::Zygote.Context, 649::typeof(Base.Broadcast.broadcasted), 650::typeof(tanh), x::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57
[8] _pullback
@ /scratch/user/vish0908/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:148 [inlined]
[9] _pullback(ctx::Zygote.Context, f::Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[10] _pullback
@ /scratch/user/vish0908/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
[11] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[12] _pullback
@ /scratch/user/vish0908/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
[13] _pullback(ctx::Zygote.Context, f::Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[14] _pullback
@ /scratch/user/vish0908/.julia/packages/NeuralPDE/HVA0c/src/pinns_pde_solve.jl:807 [inlined]
[15] _pullback(::Zygote.Context, ::NeuralPDE.var"#271#273"{UnionAll, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[16] _pullback
@ /scratch/user/vish0908/Documents/neuripsCode/missile/test.jl:92 [inlined]
[17] _pullback(ctx::Zygote.Context, f::typeof(ρFn_g), args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[18] _pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:34
[19] pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:40
[20] gradient(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:75
[21] (::Zygote.var"#93#94"{typeof(ρFn_g)})(x::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:76
[22] forward_jacobian(f::Zygote.var"#93#94"{typeof(ρFn_g)}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, #unused#::Val{2})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/forward.jl:29
[23] forward_jacobian(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/forward.jl:44
[24] hessian_dual
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:76 [inlined]
[25] hessian
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:74 [inlined]
and for Zygote.hessian_reverse
is:
ERROR: LoadError: Compiling Tuple{Flux.var"#_restructure_pullback#56"{Int64}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}}}}}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/reverse.jl:121
[3] #Primal#20
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/reverse.jl:315
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/emit.jl:101
[6] #s3098#1243
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:28 [inlined]
[7] var"#s3098#1243"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] _pullback
@ /scratch/user/vish0908/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[10] _pullback(ctx::Zygote.Context, f::Flux.var"#210#back#57"{Flux.var"#_restructure_pullback#56"{Int64}}, args::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}}}})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[11] _pullback
@ /scratch/user/vish0908/.julia/packages/Flux/Zz9RI/src/utils.jl:656 [inlined]
[12] _pullback(ctx::Zygote.Context, f::typeof(∂(λ)), args::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}, NamedTuple{(:weight, :bias, :σ), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}}}}})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[13] _pullback
@ /scratch/user/vish0908/.julia/packages/NeuralPDE/HVA0c/src/pinns_pde_solve.jl:807 [inlined]
[14] _pullback(ctx::Zygote.Context, f::typeof(∂(λ)), args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[15] _pullback
@ /scratch/user/vish0908/Documents/neuripsCode/missile/test.jl:92 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(∂(ρFn_g)), args::Float32)
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[17] _pullback
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41 [inlined]
[18] _pullback(ctx::Zygote.Context, f::Zygote.var"#46#47"{typeof(∂(ρFn_g))}, args::Float32)
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[19] _pullback
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76 [inlined]
[20] _pullback(::Zygote.Context, ::typeof(gradient), ::typeof(ρFn_g), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[21] _pullback
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:87 [inlined]
[22] _pullback(ctx::Zygote.Context, f::Zygote.var"#97#98"{typeof(ρFn_g)}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[23] _apply
@ ./boot.jl:804 [inlined]
[24] adjoint
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/lib.jl:200 [inlined]
[25] _pullback
@ /scratch/user/vish0908/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[26] _pullback
@ ./operators.jl:938 [inlined]
[27] _pullback(ctx::Zygote.Context, f::ComposedFunction{typeof(Zygote._jvec), Zygote.var"#97#98"{typeof(ρFn_g)}}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[28] _pullback(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:34
[29] pullback(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:40
[30] withjacobian(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:153
[31] jacobian
@ /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:140 [inlined]
[32] hessian_reverse(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote /scratch/user/vish0908/.julia/packages/Zygote/TaBlo/src/lib/grad.jl:87
Please let me know if I'm doing something wrong, or if there's an alternative solution to obtaining the hessian on the GPU.
hmm, could we try with removing the warning at https://github.com/FluxML/Flux.jl/blob/c91867fd8cc49b19d90c88e0346db674cb757abd/src/utils.jl#L623
I would also try with an older Zygote version..
Additionally, the mwe might be able to be reduced further, removing much of NeuralPDE and DiffEqFlux and depend mostly on the restructure
pullback.
We should probably toss an @ignore
on that line, but to Dhairya's point this isn't a MWE because DiffEq(Flux) and NeuralPDE are doing so much behind the scenes.
Would this be a MWE?
using CUDA, Flux, Zygote
CUDA.allowscalar(false)
dim = 2; nn = 10;
chain = Chain(Dense(dim, nn, tanh), Dense(nn, nn, tanh), Dense(nn, 1));
θ, re = Flux.destructure(chain);
initθ_g = cu(θ); # parameter vector in gpu
phi_g = (x,θ) -> re(θ)(cu(x)) # nn function with gpu params
initθ_c = Array(θ); # parameter vector in cpu
phi_c = (x,θ) -> re(θ)(Array(x)) # function with cpu params
ρFn_c(y) = sum(phi_c(y, initθ_c)); # phi_c returns the NN output (cpu)
ρFn_g(y) = sum(phi_g(y, initθ_g)); # phi_g returns the NN output (GPU)
xg = rand(2) |> gpu;
xc = Array(xg);
# @show ρFn_c(xc) # works
# @show ρFn_g(xg) # works
## gradient wrt input
dρ_g(x) = Zygote.gradient(ρFn_g, x)[1];
# @show dρ_g(xg) # works
dρ_c(x) = Zygote.gradient(ρFn_c, x)[1];
# @show dρ_c(xc) # works
## hessian wrt input
d2ρ_c(x) = Zygote.hessian(ρFn_c, x);
# d2ρ_c(xc) # works
d2ρ_g(x) = Zygote.hessian(ρFn_g, x);
# d2ρ_g(xg) # doesn't work (forward over reverse)
# Zygote.hessian_reverse(ρFn_g, xg) # doesn't work (reverse over reverse)
The errors are the same as before.
The error for Zygote.hessian
looks like it's from broadcasting. Zygote's broadcasting for CuArrays has always used dual numbers... but adding a second layer of them via hessian
gives the error you had above:
julia> hessian(x -> sum(tanh.(x)), [1,2,3.4]) # CPU, ok
3×3 Matrix{Float64}:
-0.6397 0.0 0.0
0.0 -0.136219 0.0
0.0 0.0 -0.0088706
julia> hessian(x -> sum(tanh.(x)), cu([1,2,3.4])) # GPU, same error as above
ERROR: MethodError: no method matching cudnnDataType(::Type{ForwardDiff.Dual{Nothing, Float32, 3}})
julia> gradient(x -> sum(tanh.(x)), [1,2,3.4])
([0.41997434161402614, 0.07065082485316443, 0.004445193185743657],)
julia> gradient(x -> sum(tanh.(x)), cu([1,2,3.4])) # 1st derivative works fine
(Float32[0.4199743, 0.070650816, 0.004445251],)
(@v1.7) pkg> st
Status `~/.julia/environments/v1.7/Project.toml`
[052768ef] CUDA v3.4.2
[587475ba] Flux v0.12.6
[e88e6eb3] Zygote v0.6.21
Second derivatives Zygote over Zygote should cause it to use a (slower) generic method which I believe is usually 2nd differentiable, but won't work on the GPU. This is also where the two @info
statements in destructure might be seen by Zygote and cause the error seen. But after commenting them out, there are other problems. Zygote over Zygote is in general seldom a great idea.
julia> Zygote.hessian_reverse(x -> sum(tanh.(x)), [1,2,3.4]) # reverse over reverse, CPU
3×3 Matrix{Float64}:
-0.6397 -0.0 -0.0
-0.0 -0.136219 -0.0
-0.0 -0.0 -0.0088706
julia> Zygote.hessian_reverse(x -> sum(tanh.(x)), cu([1,2,3.4]))
ERROR: Mutating arrays is not supported -- called copyto!(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, _...)
julia> Zygote.hessian_reverse(ρFn_c, xc) # above MWE, reverse over reverse, CPU
ERROR: Can't differentiate foreigncall expression
julia> Zygote.hessian_reverse(ρFn_g, xg) # similar error on GPU
ERROR: Can't differentiate foreigncall expression
I don't understand from the error how broadcasting is going wrong. Maybe it's possible to mess around with overloading materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, typeof(tanh), Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}}})
(or broadcasted(::tanh, ::CuArray{<:Dual}))
) to somehow bypass the problem? Maybe it's something simple like a mess-up of tagging, Dual{Nothing, Float32, 2}
is suspicious as for 2nd derivatives normally you need to give each direction a "name" or tag.
Maybe it's possible to use a different 2 AD packages. I know some people to ReverseDiff over Zygote or the reverse.
julia> Zygote.hessian_reverse(ρFn_c, xc)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[10] gradtuple1
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:22 [inlined]
[11] #1640#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76 [inlined]
[16] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float32}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87 [inlined]
[18] (::typeof(∂(#107)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[19] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[20] (::Zygote.var"#1760#back#221"{Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))}})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[21] Pullback
@ ./operators.jl:1085 [inlined]
[22] (::typeof(∂(#_#83)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[23] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[24] #1760#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[25] Pullback
@ ./operators.jl:1085 [inlined]
[26] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_c)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_c)}(ρFn_c)))))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[27] (::Zygote.var"#52#53"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_c)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_c)}(ρFn_c))))})(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[28] withjacobian(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:162
[29] jacobian
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:140 [inlined]
[30] hessian_reverse(f::Function, x::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87
[31] top-level scope
@ REPL[42]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66
julia> Zygote.hessian_reverse(ρFn_g, xg)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[10] gradtuple1
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:22 [inlined]
[11] #1640#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[12] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41 [inlined]
[14] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76 [inlined]
[16] (::typeof(∂(gradient)))(Δ::Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87 [inlined]
[18] (::typeof(∂(#107)))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[19] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[20] (::Zygote.var"#1760#back#221"{Zygote.var"#219#220"{Tuple{Tuple{Nothing}}, typeof(∂(#107))}})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[21] Pullback
@ ./operators.jl:1085 [inlined]
[22] (::typeof(∂(#_#83)))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[23] (::Zygote.var"#219#220"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[24] #1760#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[25] Pullback
@ ./operators.jl:1085 [inlined]
[26] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_g)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_g)}(ρFn_g)))))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[27] (::Zygote.var"#52#53"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#107#108"{typeof(ρFn_g)}}(Zygote._jvec, Zygote.var"#107#108"{typeof(ρFn_g)}(ρFn_g))))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[28] withjacobian(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:162
[29] jacobian
@ ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:140 [inlined]
[30] hessian_reverse(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/grad.jl:87
[31] top-level scope
@ REPL[44]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66