Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Directly call `ForwardDiff` to compute the hessian

Open YichengDWu opened this issue 2 years ago • 16 comments

Closes #1264

Zygote.forward_jacobian is not well maintained and the motivation for using it is not clear.

YichengDWu avatar Jul 23 '22 00:07 YichengDWu

Reverse mode over the Hessian

julia> function f1(x, ps)  # [edit: renamed not to clash]
              hess = Zygote.hessian(x->sum(x.^3), x)
                     return hess * x .+ ps.bias
                            end
f1 (generic function with 1 method)

julia> x = rand(3);

julia> ps = (;bias = rand(3));

julia> Zygote.gradient(p -> sum(f1(x,p)), ps)
((bias = Fill(1.0, 3),),)

Forward mode over the Hessian

julia> using ComponentArrays, ForwardDiff

julia> ForwardDiff.gradient(p -> sum(f1(x,p)), ComponentArray(ps))
ComponentVector{Float64}(bias = [1.0, 1.0, 1.0])

YichengDWu avatar Jul 23 '22 00:07 YichengDWu

Looks like it wasn't a complicated fix after all! Could you add one of the examples from #1264 or elsewhere as a regression test?

ToucheSir avatar Jul 23 '22 04:07 ToucheSir

using ForwardDiff, Zygote, BenchmarkTools
f(x,W) = sum((W.^3)*x)

x = randn(30);
W = randn(128,30);

@benchmark Zygote.hessian(W->f($x,W), $W)
BenchmarkTools.Trial: 13 samples with 1 evaluation.
 Range (min … max):  210.955 ms … 555.882 ms  ┊ GC (min … max):  6.64% … 41.61%
 Time  (median):     312.789 ms               ┊ GC (median):    13.30%
 Time  (mean ± σ):   386.835 ms ± 135.648 ms  ┊ GC (mean ± σ):  30.41% ± 16.92%

  ▁       ▁█  ▁▁    ▁                             ▁    ▁█     █  
  █▁▁▁▁▁▁▁██▁▁██▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁██▁▁▁▁▁█ ▁
  211 ms           Histogram: frequency by time          556 ms <

 Memory estimate: 851.09 MiB, allocs estimate: 8324.

hessian_dual(f, x::AbstractArray) = ForwardDiff.jacobian(x -> gradient(f, x)[1], x)

@benchmark hessian_dual(W->f($x,W),$W)
BenchmarkTools.Trial: 34 samples with 1 evaluation.
 Range (min … max):  107.892 ms … 426.180 ms  ┊ GC (min … max): 11.03% … 51.18%
 Time  (median):     133.778 ms               ┊ GC (median):    12.17%
 Time  (mean ± σ):   149.663 ms ±  62.519 ms  ┊ GC (mean ± σ):  19.00% ± 10.50%

     ▆██▄                                                        
  ▄▄▆████▄▁▄▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁
  108 ms           Histogram: frequency by time          426 ms <

 Memory estimate: 607.80 MiB, allocs estimate: 9604.

YichengDWu avatar Jul 23 '22 04:07 YichengDWu

Great acceleration!

YichengDWu avatar Jul 23 '22 21:07 YichengDWu

Sorry, I meant to add a test which fails on the current implementation but works on this PR. IIRC you had a couple, @MilkshakeForReal.

ToucheSir avatar Jul 24 '22 02:07 ToucheSir

It's also possible that we might lose out on some code paths that worked with Zygote via forward_jacobian since we could switch it out to using Zygote without breaking API.

On Sun, Jul 24, 2022, 08:01 Brian Chen @.***> wrote:

Sorry, I meant to add a test which fails on the current implementation but works on this PR. IIRC you had a couple, @MilkshakeForReal https://github.com/MilkshakeForReal.

— Reply to this email directly, view it on GitHub https://github.com/FluxML/Zygote.jl/pull/1270#issuecomment-1193232593, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJOZVVI3LHX4XXKBPBVF6XTVVSTHJANCNFSM54NECYWA . You are receiving this because you commented.Message ID: <FluxML/Zygote. @.***>

DhairyaLGandhi avatar Jul 24 '22 17:07 DhairyaLGandhi

Related to #1070

edit: don't close it since hessian_inverse is also an issue there.

julia> using CUDA

julia> CUDA.allowscalar(false)

julia> hessian(x -> sum(tanh.(x)), cu([1,2,3.4]))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\Luffy\.julia\packages\GPUArraysCore\rSIl2\src\GPUArraysCore.jl:78
  [3] getindex
    @ C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:9 [inlined]
  [4] extract(xs::CuArray{ForwardDiff.Dual{Nothing, Float32, 3}, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:23
  [5] forward_jacobian(f::Zygote.var"#100#101"{var"#27#28"}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, #unused#::Val{3})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:29
  [6] forward_jacobian(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; chunk_threshold::Int64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:44
  [7] forward_jacobian
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:43 [inlined]
  [8] hessian_dual
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:76 [inlined]
  [9] hessian(f::Function, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\grad.jl:74
 [10] top-level scope
    @ REPL[135]:1
 [11] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

julia> hessian_dual(x -> sum(tanh.(x)), cu([1,2,3.4]))
3×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -0.6397   0.0        0.0
  0.0     -0.136219   0.0
  0.0      0.0       -0.00887072

YichengDWu avatar Jul 26 '22 01:07 YichengDWu

I'm digging into the issues and haven't found one case where forward_jacobian works but not ForwardDiff.jacobian. It's always the other way around

YichengDWu avatar Jul 26 '22 01:07 YichengDWu

This issue #305 is closed but I'm still getting an error

using Zygote
f(x) = [cos(x[1]) * sin(x[2]), sin(x[1]) * cos(x[2])]
jac(x) = last(Zygote.forward_jacobian(f, x))
g(x) = sum(jac(x))
x = [π/4, π/3]
Zygote.gradient(g, x)

julia> Zygote.gradient(g, x)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] _throw_mutation_error(f::Function, args::Matrix{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
  [3] (::Zygote.var"#444#445"{Matrix{Float64}})(#unused#::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:82
  [4] (::Zygote.var"#2496#back#446"{Zygote.var"#444#445"{Matrix{Float64}}})(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [5] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:31 [inlined]
  [6] (::typeof(∂(forward_jacobian)))(Δ::Tuple{Nothing, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [7] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:44 [inlined]
  [8] Pullback
    @ C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\forward.jl:43 [inlined]
  [9] (::typeof(∂(forward_jacobian)))(Δ::Tuple{Nothing, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] Pullback
    @ .\REPL[190]:1 [inlined]
 [11] (::typeof(∂(jac)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [12] Pullback
    @ .\REPL[191]:1 [inlined]
 [13] (::Zygote.var"#60#61"{typeof(∂(g))})(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [14] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [15] top-level scope
    @ REPL[193]:1
 [16] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

Now try ForwardDiff.jacobian


julia> jac(x) = last(ForwardDiff.jacobian(f, x))
jac (generic function with 1 method)

julia> Zygote.gradient(g, x)
([-0.6123724356957946, -0.3535533905932738],)

YichengDWu avatar Jul 26 '22 01:07 YichengDWu

is this still active?

vpuri3 avatar Oct 25 '23 18:10 vpuri3

@YichengDWu did this solve the GPU issue?

vpuri3 avatar Oct 25 '23 20:10 vpuri3

@vpuri3 I'm no longer active in the Julia community. However, the "fix" in this PR is simple, you can surely test it out.

YichengDWu avatar Oct 26 '23 09:10 YichengDWu

Thanks for letting me know @YichengDWu. The change works with Lux.jl, but I'm still getting scalar indexing errors on the GPU.

vpuri3 avatar Oct 26 '23 13:10 vpuri3

with the change in this PR, this code is working:

using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff

CUDA.allowscalar(false)

#==========================#
function testhessian(
    NN::Lux.AbstractExplicitLayer,
    data::Tuple;
    device = cpu_device(),
)
    p, st = Lux.setup(Random.default_rng(), NN)

    st = Lux.testmode(st)
    p = ComponentArray(p)

    xdata, ydata = data |> device
    p, st = (p, st)     |> device

    function loss(optx)
        ypred, _ = NN(xdata, optx, st)

        sum(abs2, ydata - ypred)
    end

    g(p) = Zygote.gradient(loss, p)[1]
    H(p) = ForwardDiff.jacobian(g, p)

    Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))

data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()

H = testhessian(NN, data; device)
julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
  0.236781  -0.075257    -1.20583    0.31846   -0.101217    -1.62179   -0.713834    0.503548  -1.14138     1.98508
 -0.075257   0.0239192    0.383253  -0.101217   0.0321702    0.515458   0.0296168  -0.780695   0.362769   -0.630924
 -1.20583    0.383253     6.1408    -1.62179    0.515458     8.2591     0.474545   -2.56436    5.19194   -10.1092
  0.318461  -0.101217    -1.62179    0.514738  -0.163601    -2.62135   -2.09317     0.677249  -1.53511     3.20854
 -0.101217   0.0321702    0.515458  -0.163601   0.0519977    0.833151   0.0398333  -2.18309    0.487909   -1.01978
 -1.62179    0.515458     8.2591    -2.62135    0.833151    13.3494     0.638242   -3.44895    5.84984   -16.3398
 -0.713834   0.0296168    0.474545  -2.09317    0.0398333    0.638242   0.0366717  -0.198167   0.449183   -0.781213
  0.503548  -0.780695    -2.56436    0.677249  -2.18309     -3.44895   -0.198167    1.07086   -2.4273      4.22154
 -1.14138    0.362769     5.19194   -1.53511    0.487909     5.84984    0.449183   -2.4273     5.50193    -9.56889
  1.98508   -0.630924   -10.1092     3.20854   -1.01978    -16.3398    -0.781213    4.22154   -9.56889    20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
  [052768ef] CUDA v5.0.0
  [b0b7db55] ComponentArrays v0.15.4
  [f6369f11] ForwardDiff v0.10.36
  [b2108857] Lux v0.5.8
  [d0bbae9a] LuxCUDA v0.3.1
  [e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote`

vpuri3 avatar Oct 26 '23 16:10 vpuri3

But Lux.Embedding is failing. This looks like an NNlib issue with ForwardDiff.Dual

using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff

CUDA.allowscalar(false)

#==========================#
function testhessian(
    NN::Lux.AbstractExplicitLayer,
    data::Tuple;
    device = cpu_device(),
)
    p, st = Lux.setup(Random.default_rng(), NN)

    st = Lux.testmode(st)
    p = ComponentArray(p)

    xdata, ydata = data |> device
    p, st = (p, st)     |> device

    function loss(optx)
        ypred, _ = NN(xdata, optx, st)

        sum(abs2, ydata - ypred)
    end

    g(p) = Zygote.gradient(loss, p)[1]
    H(p) = ForwardDiff.jacobian(g, p)

    Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))

E, K = 1, 10
NN = Chain(Embedding(E => 3), Dense(3, 1))
data = (ones(Int32, K), rand(1, K))

H = testhessian(NN, data; device)

julia> include("hess.jl")                                                                                                                                               
ERROR: LoadError: InvalidIRError: compiling MethodInstance for NNlibCUDAExt.scatter_kernel!(::typeof(+), ::CuDeviceMatrix{…}, ::CuDeviceMatrix{…}, ::CuDeviceVector{…}, 
::Int64, ::Int64, ::Tuple{…}) resulted in invalid LLVM IR                                                                                                               
Reason: unsupported call to an unknown function (call to julia.new_gc_frame)                                                                                            
Reason: unsupported call to an unknown function (call to julia.push_gc_frame)                                                                                           
Reason: unsupported call to an unknown function (call to julia.pop_gc_frame)                                                                                            
Reason: unsupported call to an unknown function (call to julia.get_gc_frame_slot)                                                                                       
Reason: unsupported dynamic function invocation (call to atomic_cas!)                                                                                                   
Stacktrace:                                                                                                                                                             
 [1] atomic_op!                                                                                                                                                         
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:228                                                                                                  
 [2] atomic_arrayset                                                                                                                                                    
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:468                                                                                                  
 [3] atomic_arrayset                                                                                                                                                    
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:440                                                                                                  
 [4] macro expansion                                                                                                                                                    
   @ ~/.julia/packages/CUDA/nbRJk/src/device/intrinsics/atomics.jl:435                                                                                                  
 [5] scatter_kernel!                                                                                                                                                    
   @ ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:28                                                                                                       
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl                                      
Stacktrace:                                                                         
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/validation.jl:147                                                                                             
  [2] macro expansion                                                               
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:440 [inlined]                                                                                       
  [3] macro expansion                                                                                                                                                   
    @ GPUCompiler ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]                                                                                 
  [4] macro expansion                                                               
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:439 [inlined]
  [5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:92                                                                                                   
  [6] emit_llvm                                                                                                                                                         
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/utils.jl:86 [inlined]                                                                                         
  [7]                                                                               
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:129                                                                                                 
  [8] codegen                                                                       
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:110 [inlined]                                                                                       
  [9] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Boo
l)                                                                                  
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:106                                                                                                 
 [10] compile                                                                       
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:98 [inlined]                                                                                        
 [11] #1042                                                                                                                                                             
    @ GPUCompiler ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:166 [inlined]
 [12] JuliaContext(f::CUDA.var"#1042#1045"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})                                            
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/driver.jl:47              
 [13] compile(job::GPUCompiler.CompilerJob)                          
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/compilation.jl:165             
 [14] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUD
A.link))                                                                                                                                                                    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:125                                                                                              
 [15] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)                            
    @ GPUCompiler ~/.julia/packages/GPUCompiler/2mJjc/src/execution.jl:103                                                                                              
 [16] macro expansion                                                                                                                                                   
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:323 [inlined]                                                                                         
 [17] macro expansion                                                                                                                                                   
    @ CUDA ./lock.jl:267 [inlined]                                                                                                                                      
 [18] cufunction(f::typeof(NNlibCUDAExt.scatter_kernel!), tt::Type{Tuple{…}}; kwargs::@Kwargs{})                                                                        
    @ CUDA ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:318                                                                                                   
 [19] cufunction                                                                    
    @ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:315 [inlined]                                                                                 
 [20] macro expansion                                                                                                                                                   
    @ NNlibCUDAExt ~/.julia/packages/CUDA/nbRJk/src/compiler/execution.jl:104 [inlined]                                                                                 
 [21] scatter!(op::typeof(+), dst::CuArray{…}, src::CuArray{…}, idx::CuArray{…})                                                                                        
    @ NNlibCUDAExt ~/.julia/packages/NNlib/5iRSB/ext/NNlibCUDAExt/scatter.jl:58     
 [22] ∇gather_src                                                                                                                                                           @ Zygote ~/.julia/packages/NNlib/5iRSB/src/gather.jl:131 [inlined]                                                                                                  
 [23] gather!_pullback                                                                                                                                                  
    @ Zygote ~/.julia/packages/NNlib/5iRSB/src/gather.jl:136 [inlined]                                                                                                  
 [24] ZBack                                                                                                                                                             
    @ Zygote ~/.julia/dev/Zygote/src/compiler/chainrules.jl:211 [inlined]                                                                                                [25] gather                                                                                                                                                            
    @ Zygote ~/.julia/packages/NNlib/5iRSB/src/gather.jl:46 [inlined]                                                                                                   
 [26] Embedding                                                                     
    @ Zygote ~/.julia/packages/Lux/Al3Ab/src/layers/basic.jl:490 [inlined]                                                                                              
 [27] apply                                                                                                                                                             
    @ Zygote ~/.julia/packages/LuxCore/aumFq/src/LuxCore.jl:115 [inlined]                                                                                                [28] macro expansion                                                                                                                                                   
    @ Zygote ~/.julia/packages/Lux/Al3Ab/src/layers/containers.jl:0 [inlined]                                                                                           
 [29] applychain                                                                                                                                                        
    @ Zygote ~/.julia/packages/Lux/Al3Ab/src/layers/containers.jl:480 [inlined]                                                                                         
 [30] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CuArray{…}, Nothing})                                                                                            
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0                                                                                                           
 [31] Chain                                                                                                                                                             
    @ Zygote ~/.julia/packages/Lux/Al3Ab/src/layers/containers.jl:478 [inlined]                                                                                         
 [32] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{CuArray{…}, Nothing})                                                                                                @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0                                                                                                           
 [33] loss                                                                                                                                                              
    @ Zygote ~/.julia/dev/GeometryLearning.jl/hess/hess.jl:23 [inlined]                                                                                                 
 [34] (::Zygote.Pullback{Tuple{var"#loss#34"{…}, ComponentVector{…}}, Any})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#125#126"{…}, Float32}, Float32, 7})         
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0                       

vpuri3 avatar Oct 26 '23 17:10 vpuri3

I narrowed the error in the above case to the MWE in https://github.com/FluxML/NNlib.jl/issues/547. That doesn't seem related to this PR.

vpuri3 avatar Oct 26 '23 17:10 vpuri3