ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

`repeat` rrule ambiguity for Bool arrays

Open ToucheSir opened this issue 3 years ago • 2 comments

julia> rrule(repeat, falses(1), 1)
ERROR: MethodError: rrule(::typeof(repeat), ::BitVector, ::Int64) is ambiguous. Candidates:
  rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/array.jl:191
  rrule(::typeof(repeat), var"543"::AbstractArray{Bool}, var"544"...; repeat_pullback) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/nondiff.jl:65
Possible fix, define
  rrule(::typeof(repeat), ::AbstractArray{Bool}, ::Vararg{Integer})
Stacktrace:
 [1] top-level scope
   @ REPL[8]:1

Found while working on https://github.com/FluxML/Zygote.jl/issues/1234. Is there a way to constrain the vararg type in https://github.com/JuliaDiff/ChainRules.jl/blob/1770bb29ca42d4e07643284e0a4917ad6ea35b57/src/rulesets/Base/nondiff.jl#L65 to remove this ambiguity?

ToucheSir avatar Aug 01 '22 23:08 ToucheSir

Can we just do

@non_differentiable repeat(::AbstractArray{Bool}, ::Integer...)

It's just integers that are expected as the counts right?

mzgubic avatar Aug 02 '22 12:08 mzgubic

Might be worth trying to be systematic about these... at least to track progress. Not sure what to make of the frule ambiguities, but for rrule there appear to be 8, including this one:

julia> filter(Test.detect_ambiguities(ChainRules, ChainRulesCore)) do t
         (t[1].name in [:frule, Symbol("frule##kw")])
       end |> length
1242

julia> filter(Test.detect_ambiguities(ChainRules, ChainRulesCore)) do t
         (t[1].name in [:rrule, Symbol("rrule##kw")])
       end
8-element Vector{Tuple{Method, Method}}:
 (rrule(::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString; findfirst_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:175, rrule(::typeof(findfirst), var"780"::Function, var"781"; findfirst_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:176)
 ((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"780"::Function, var"781")
     @ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString)
     @ ChainRules none:0)
 ((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString)
     @ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"780"::Function, var"781")
     @ ChainRules none:0)
 ((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"791"::Function, var"792")
     @ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString)
     @ ChainRules none:0)
 (rrule(::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString; findlast_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:178, rrule(::typeof(findlast), var"791"::Function, var"792"; findlast_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:179)
 ((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString)
     @ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"791"::Function, var"792")
     @ ChainRules none:0)
 (rrule(::Type{Matrix}, var"427"::AbstractArray{Bool}; Matrix_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:57, rrule(TM::Type{<:Matrix}, A::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/LinearAlgebra/symmetric.jl:25)
 (rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/array.jl:191, rrule(::typeof(repeat), var"454"::AbstractArray{Bool}, var"455"...; repeat_pullback)
     @ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:65)

mcabbott avatar Aug 02 '22 13:08 mcabbott