ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Need adjoint for reinterpret SVector

Open jenkspt opened this issue 2 years ago • 4 comments

Minimum working example:

f(x) = sum(sum(reinterpret(SVector{size(x, 1), eltype(x)}, x)))
Zygote.gradient(f, rand(3, 10))
ERROR: Need an adjoint for constructor Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}. Gradient is of type FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] (::Zygote.Jnew{Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}, Nothing, false})(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:323
 [3] (::Zygote.var"#1811#back#235"{Zygote.Jnew{Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}, Nothing, false}})(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [4] Pullback
   @ ./reinterpretarray.jl:47 [inlined]
 [5] (::typeof(∂(reinterpret)))(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [6] Pullback
   @ ./REPL[95]:1 [inlined]
 [7] (::Zygote.var"#52#53"{typeof(∂(f))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
 [8] gradient(::Function, ::Matrix{Float64}, ::Vararg{Any})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
 [9] top-level scope
   @ REPL[98]:1

jenkspt avatar Apr 09 '22 23:04 jenkspt

I realize that a specific rule for SVector probably shouldn't be added to chain rules core -- but a more general solution for any reinterpreted composite types is the goal.

jenkspt avatar Apr 09 '22 23:04 jenkspt

Would be nice. Instead of depending on StaticArrays, you could probably just specify abstract types, something like this:

function ChainRules.rrule(::typeof(reinterpret), ::Type{T}, x::AbstractArray{S}) where {T<:AbstractArray{S},S}
    unreinterpret(dy) = (NoTangent(), NoTangent(), reinterpret(S, dy))
    reinterpret(T, x), unreinterpret
end

and another signature for the other way? This would prevent the rule from acting on things like reinterpret(Float32, [1.0, 2.0]), although also things like reinterpreting to remove units.

mcabbott avatar Apr 10 '22 02:04 mcabbott

Does it make sense to add your suggestion to ChainRules?

jenkspt avatar Apr 10 '22 23:04 jenkspt

Yes, if it works... probably someone just has to tidy it up & figure out how to get tests working, etc.

mcabbott avatar Apr 10 '22 23:04 mcabbott