Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

errors with gradients of mutating NNlib functions when return type is marked as `Const`

Open CarloLucibello opened this issue 1 month ago • 2 comments

See this comment https://github.com/FluxML/NNlib.jl/pull/665#issuecomment-3616287926 and the PR diff that excludes the failing cases from CI

CarloLucibello avatar Dec 05 '25 16:12 CarloLucibello

can you add an isolated mwe of the failure?

wsmoses avatar Dec 05 '25 16:12 wsmoses

using NNlib: gather, gather!
using Enzyme: Const, Duplicated
using EnzymeTestUtils

src = Float64[3, 4, 5, 6, 7]
idx = [
    1 2 3 4;
    4 2 1 3;
    3 5 5 3]
dst = gather(src, idx)

EnzymeTestUtils.test_reverse(gather!, Duplicated, (dst, Duplicated), (src, Duplicated), (idx, Const)) # this is OK

EnzymeTestUtils.test_reverse(gather!, Const, (dst, Duplicated), (src, Duplicated), (idx, Const)) # ERROR
test_reverse: gather! with return activity Const on (::Matrix{Float64}, Duplicated), (::Vector{Float64}, Duplicated), (::Matrix{Int64}, Const): Error During Test at /home/lucibello/.julia/packages/EnzymeTestUtils/yGBt1/src/test_reverse.jl:84
  Got exception outside of a @test
  BoundsError: attempt to access 17×17 transpose(::Matrix{Float64}) with eltype Float64 at index [1:17, 18]
  Stacktrace:
    [1] throw_boundserror(A::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, I::Tuple{Base.Slice{Base.OneTo{Int64}}, Int64})
      @ Base ./abstractarray.jl:737
    [2] checkbounds
      @ ./abstractarray.jl:702 [inlined]
    [3] _getindex
      @ ./multidimensional.jl:888 [inlined]
    [4] getindex(::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, ::Function, ::Int64)
      @ Base ./abstractarray.jl:1290
    [5] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f_vec::ComposedFunction{ComposedFunction{Base.Fix1{typeof(EnzymeTestUtils.multi_tovec), Bool}, Base.Splat{EnzymeTestUtils.var"#fnew#45"{Bool, EnzymeTestUtils.CallWithKWargs{@NamedTuple{}}, Tuple{typeof(gather!), Matrix{Float64}, Vector{Float64}, Matrix{Int64}}, NTuple{4, Bool}}}}, Base.Fix1{typeof(EnzymeTestUtils.from_vec), EnzymeTestUtils.var"#Tuple_from_vec#8"{Tuple{Matrix{Float64}, Vector{Float64}}, Bool}}}, ȳ::Vector{Float64}, x::Vector{Float64})
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/yGBt1/src/finite_difference_calls.jl:79
    [6] _fd_reverse(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::EnzymeTestUtils.CallWithKWargs{@NamedTuple{}}, ȳ::Matrix{Float64}, activities::Tuple{Const{typeof(gather!)}, Duplicated{Matrix{Float64}}, Duplicated{Vector{Float64}}, Const{Matrix{Int64}}}, active_return::Bool)
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/yGBt1/src/finite_difference_calls.jl:124
    [7] macro expansion
      @ ~/.julia/packages/EnzymeTestUtils/yGBt1/src/test_reverse.jl:103 [inlined]
    [8] macro expansion
      @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
    [9] test_reverse(::typeof(gather!), ::Type, ::Tuple{Matrix{Float64}, UnionAll}, ::Vararg{Any}; rng::Random.TaskLocalRNG, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, fkwargs::@NamedTuple{}, rtol::Float64, atol::Float64, testset_name::Nothing, runtime_activity::Bool, output_tangent::Nothing)
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/yGBt1/src/test_reverse.jl:86
   [10] test_reverse(::Function, ::Type, ::Tuple{Matrix{Float64}, UnionAll}, ::Vararg{Any})
      @ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/yGBt1/src/test_reverse.jl:68
   [11] top-level scope
      @ REPL[13]:1
   [12] eval
      @ ./boot.jl:385 [inlined]
   [13] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
   [14] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
   [15] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
   [16] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
   [17] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
   [18] (::Base.var"#1016#1018"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:437
   [19] #invokelatest#2
      @ ./essentials.jl:892 [inlined]
   [20] invokelatest
      @ ./essentials.jl:889 [inlined]
   [21] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:421
   [22] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:338
   [23] _start()
      @ Base ./client.jl:557
Test Summary:                                                                                                                                  | Error  Total  Time
test_reverse: gather! with return activity Const on (::Matrix{Float64}, Duplicated), (::Vector{Float64}, Duplicated), (::Matrix{Int64}, Const) |     1      1  0.7s
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.

CarloLucibello avatar Dec 08 '25 06:12 CarloLucibello