ChainRules.jl
ChainRules.jl copied to clipboard
`repeat` rrule ambiguity for Bool arrays
julia> rrule(repeat, falses(1), 1)
ERROR: MethodError: rrule(::typeof(repeat), ::BitVector, ::Int64) is ambiguous. Candidates:
rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/array.jl:191
rrule(::typeof(repeat), var"543"::AbstractArray{Bool}, var"544"...; repeat_pullback) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/nondiff.jl:65
Possible fix, define
rrule(::typeof(repeat), ::AbstractArray{Bool}, ::Vararg{Integer})
Stacktrace:
[1] top-level scope
@ REPL[8]:1
Found while working on https://github.com/FluxML/Zygote.jl/issues/1234. Is there a way to constrain the vararg type in https://github.com/JuliaDiff/ChainRules.jl/blob/1770bb29ca42d4e07643284e0a4917ad6ea35b57/src/rulesets/Base/nondiff.jl#L65 to remove this ambiguity?
Can we just do
@non_differentiable repeat(::AbstractArray{Bool}, ::Integer...)
It's just integers that are expected as the counts right?
Might be worth trying to be systematic about these... at least to track progress. Not sure what to make of the frule ambiguities, but for rrule there appear to be 8, including this one:
julia> filter(Test.detect_ambiguities(ChainRules, ChainRulesCore)) do t
(t[1].name in [:frule, Symbol("frule##kw")])
end |> length
1242
julia> filter(Test.detect_ambiguities(ChainRules, ChainRulesCore)) do t
(t[1].name in [:rrule, Symbol("rrule##kw")])
end
8-element Vector{Tuple{Method, Method}}:
(rrule(::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString; findfirst_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:175, rrule(::typeof(findfirst), var"780"::Function, var"781"; findfirst_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:176)
((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"780"::Function, var"781")
@ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString)
@ ChainRules none:0)
((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"776"::Union{Regex, AbstractChar, AbstractString, Function}, var"777"::AbstractString)
@ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findfirst), var"780"::Function, var"781")
@ ChainRules none:0)
((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"791"::Function, var"792")
@ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString)
@ ChainRules none:0)
(rrule(::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString; findlast_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:178, rrule(::typeof(findlast), var"791"::Function, var"792"; findlast_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:179)
((::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"787"::Union{AbstractChar, AbstractString, Function}, var"788"::AbstractString)
@ ChainRules none:0, (::ChainRulesCore.var"#rrule##kw")(kwargs, ::typeof(rrule), ::typeof(findlast), var"791"::Function, var"792")
@ ChainRules none:0)
(rrule(::Type{Matrix}, var"427"::AbstractArray{Bool}; Matrix_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:57, rrule(TM::Type{<:Matrix}, A::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S})
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/LinearAlgebra/symmetric.jl:25)
(rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/array.jl:191, rrule(::typeof(repeat), var"454"::AbstractArray{Bool}, var"455"...; repeat_pullback)
@ ChainRules ~/.julia/packages/ChainRules/BbzFc/src/rulesets/Base/nondiff.jl:65)