AxisKeys.jl icon indicating copy to clipboard operation
AxisKeys.jl copied to clipboard

unsupported keyword argument "time" when taking a gradient with Zygote

Open lazarusA opened this issue 2 years ago • 1 comments

Doing the following outputs a reasonable error? Any hints? Doing a similar function without AxisKeys works, however the advantage of using the time dimension is lost, which I would like to keep.

I suppose that I will need a new rule, however not sure how to actually start doing it.

using AxisKeys
using AxisKeys: KeyedArray as KA
using Zygote

ab = (;
    a =KA([5.0f0, 10.0f0];  time=1:2),
    b = KA([-2.0f0, 0.1f0];  time=1:2),
    )

function getVals(ab::NamedTuple, ts::Int)
    map(ab) do v
        in(:time, AxisKeys.dimnames(v)) ? v[time=ts][1] : v
    end
end

gradient(x -> x^2 + x*sum(getVals(ab,2)), 5)
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::KeyedArray{Float32, 1, NamedDimsArray{(:time,), Float32, 1, Vector{Float32}}, Base.RefValue{UnitRange{Int64}}}; time=2)
Closest candidates are:
  adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) at none:0 got unsupported keyword argument "time"
  adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) at none:0 got unsupported keyword argument "time"
  adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) at none:0 got unsupported keyword argument "time"
  ...
Stacktrace:
  [1] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:75 [inlined]
  [2] _pullback
.
.
.

lazarusA avatar Apr 29 '23 13:04 lazarusA

MWE is this:

julia> using Zygote, NamedDims

julia> gradient(x -> x[1], NamedDimsArray(rand(3), :a))  # easy case
(NamedDimsArray([1.0, 0.0, 0.0], :a),)

julia> gradient(x -> x[a=1], NamedDimsArray(rand(3), :a))
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}; a::Int64)
Closest candidates are:
  adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) got unsupported keyword argument "a"

Ideally Zygote would treat this call as not having a rule, and keep going, to see later calls to getindex without keywords. On some level keywords don't participate in dispatch, but e.g. something like this nothing is the desired outcome:

julia> using ChainRulesCore

julia> rrule(getindex, NamedDimsArray(rand(3), :a), 1)  # easy case
(0.48728253527621856, ChainRules.var"#getindex_pullback#1601"{NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}, Tuple{Int64}, Tuple{NoTangent}}(NamedDimsArray([0.48728253527621856, 0.20976131397698006, 0.6193363603857295], :a), (1,), (NoTangent(),)))

julia> rrule(getindex, NamedDimsArray(rand(3), :a); a=1)  # nothing, as if no rule

julia> rrule(getindex, NamedDimsArray(rand(3), :a))  # same positional arguments, without keyword
ERROR: BoundsError: attempt to access 3-element Vector{Float64} at index []

That's not how Zygote works, because (1) it's using its own @adjoint rule for getindex, and (2) even when using rrule it doesn't call it & check nothing, instead it asks the compiler which method would hypothetically be used.

mcabbott avatar Apr 30 '23 14:04 mcabbott