AxisKeys.jl
AxisKeys.jl copied to clipboard
unsupported keyword argument "time" when taking a gradient with Zygote
Doing the following outputs a reasonable error? Any hints? Doing a similar function without AxisKeys works, however the advantage of using the time dimension is lost, which I would like to keep.
I suppose that I will need a new rule, however not sure how to actually start doing it.
using AxisKeys
using AxisKeys: KeyedArray as KA
using Zygote
ab = (;
a =KA([5.0f0, 10.0f0]; time=1:2),
b = KA([-2.0f0, 0.1f0]; time=1:2),
)
function getVals(ab::NamedTuple, ts::Int)
map(ab) do v
in(:time, AxisKeys.dimnames(v)) ? v[time=ts][1] : v
end
end
gradient(x -> x^2 + x*sum(getVals(ab,2)), 5)
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::KeyedArray{Float32, 1, NamedDimsArray{(:time,), Float32, 1, Vector{Float32}}, Base.RefValue{UnitRange{Int64}}}; time=2)
Closest candidates are:
adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) at none:0 got unsupported keyword argument "time"
adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) at none:0 got unsupported keyword argument "time"
adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) at none:0 got unsupported keyword argument "time"
...
Stacktrace:
[1] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:75 [inlined]
[2] _pullback
.
.
.
MWE is this:
julia> using Zygote, NamedDims
julia> gradient(x -> x[1], NamedDimsArray(rand(3), :a)) # easy case
(NamedDimsArray([1.0, 0.0, 0.0], :a),)
julia> gradient(x -> x[a=1], NamedDimsArray(rand(3), :a))
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(getindex), ::NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}; a::Int64)
Closest candidates are:
adjoint(::ZygoteRules.AContext, ::typeof(getindex), ::AbstractArray, ::Any...) got unsupported keyword argument "a"
Ideally Zygote would treat this call as not having a rule, and keep going, to see later calls to getindex without keywords. On some level keywords don't participate in dispatch, but e.g. something like this nothing is the desired outcome:
julia> using ChainRulesCore
julia> rrule(getindex, NamedDimsArray(rand(3), :a), 1) # easy case
(0.48728253527621856, ChainRules.var"#getindex_pullback#1601"{NamedDimsArray{(:a,), Float64, 1, Vector{Float64}}, Tuple{Int64}, Tuple{NoTangent}}(NamedDimsArray([0.48728253527621856, 0.20976131397698006, 0.6193363603857295], :a), (1,), (NoTangent(),)))
julia> rrule(getindex, NamedDimsArray(rand(3), :a); a=1) # nothing, as if no rule
julia> rrule(getindex, NamedDimsArray(rand(3), :a)) # same positional arguments, without keyword
ERROR: BoundsError: attempt to access 3-element Vector{Float64} at index []
That's not how Zygote works, because (1) it's using its own @adjoint rule for getindex, and (2) even when using rrule it doesn't call it & check nothing, instead it asks the compiler which method would hypothetically be used.