ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Sparse vector to real power throws a pullback error

Open gdalle opened this issue 2 years ago • 0 comments

The following works with a dense vector or with an integer power, but not with both:

julia> using SparseArrays, Zygote

julia> f(x) = x .^ 2;

julia> g(x) = x .^ 2.0;

julia> x = sparse(rand(2))
2-element SparseVector{Float64, Int64} with 2 stored entries:
  [1]  =  0.908478
  [2]  =  0.342989

julia> Zygote.pullback(f, x)
(  [1]  =  0.825332
  [2]  =  0.117641, Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(f), SparseVector{Float64, Int64}}, Tuple{Zygote.var"#3878#back#1235"{Zygote.var"#1231#1234"{2, SparseVector{Float64, Int64}}}, Zygote.var"#1922#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), SparseVector{Float64, Int64}}, Tuple{}}}}}(∂(f)))

julia> y = rand(2)
2-element Vector{Float64}:
 0.5386564359257334
 0.8942806547203198

julia> Zygote.pullback(g, y)
([0.29015075596421375, 0.7997378894070039], Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(^), Vector{Float64}, Float64}, Tuple{Zygote.var"#2169#back#289"{Zygote.var"#287#288"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4119#back#1356"{Zygote.var"#∇broadcasted#1367"{Tuple{Vector{Float64}, Float64}, Vector{Tuple{Float64, Zygote.ZBack{ChainRules.var"#power_pullback#1338"{Float64, Float64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Float64}}}}, Val{3}}}}}, Zygote.var"#2013#back#200"{typeof(identity)}, Zygote.var"#2877#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Float64}, Tuple{}}, Zygote.var"#2169#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.var"#2013#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}}(∂(g)))

julia> Zygote.pullback(g, x)
ERROR: MethodError: no method matching zero(::Tuple{Float64, Zygote.ZBack{ChainRules.var"#power_pullback#1338"{Float64, Float64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Float64}}})

Closest candidates are:
  zero(::Type{Union{}}, Any...)
   @ Base number.jl:310
  zero(::Type{Pkg.Resolve.VersionWeight})
   @ Pkg ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/Pkg/src/Resolve/versionweights.jl:15
  zero(::Type{Dates.Date})
   @ Dates ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/Dates/src/types.jl:439
  ...

Stacktrace:
  [1] iszero(x::Tuple{Float64, Zygote.ZBack{ChainRules.var"#power_pullback#1338"{…}}})
    @ Base ./number.jl:42
  [2] _isnotzero
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/SparseArrays.jl:44 [inlined]
  [3] _map_zeropres!(f::Tf, C::Union{…}, A::Union{…}) where Tf
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:269 [inlined]
  [4] _noshapecheck_map(::SparseArrays.HigherOrderFns.var"#3#4"{…}, ::SparseVector{…})
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:189
  [5] _shapecheckbc
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:1061 [inlined]
  [6] _copy
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:1050 [inlined]
  [7] _copy
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:1056 [inlined]
  [8] copy
    @ SparseArrays.HigherOrderFns ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/higherorderfns.jl:1047 [inlined]
  [9] materialize
    @ Base.Broadcast ./broadcast.jl:903 [inlined]
 [10] _broadcast
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/broadcast.jl:189 [inlined]
 [11] _broadcast_generic
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/broadcast.jl:215 [inlined]
 [12] adjoint
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/broadcast.jl:205 [inlined]
 [13] _pullback
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [14] _apply
    @ Core ./boot.jl:836 [inlined]
 [15] adjoint
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
 [16] _pullback
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [17] broadcasted
    @ Base.Broadcast ./broadcast.jl:1347 [inlined]
 [18] g
    @ Zygote ./REPL[16]:1 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [20] pullback(f::Function, cx::Zygote.Context{false}, args::SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:44
 [21] pullback(f::Function, args::SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:42
 [22] top-level scope
    @ REPL[19]:1
 [23] top-level scope
    @ ~/.julia/packages/Infiltrator/LtFao/src/Infiltrator.jl:726
Some type information was truncated. Use `show(err)` to see complete types.

Version info:

julia> versioninfo()
Julia Version 1.10.0-beta1
Commit 6616549950e (2023-07-25 17:43 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × Intel(R) Core(TM) i7-8850H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 1 on 12 virtual cores
Environment:
  LD_LIBRARY_PATH = :/home/guillaume/Software/gurobi1002/linux64/lib

gdalle avatar Aug 10 '23 12:08 gdalle