Zygote.jl
Zygote.jl copied to clipboard
Needless copies in `wrap_chainrules_input` and friends
The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:
julia> zyg = [(1,2,nothing), (3,4,nothing)];
julia> Zygote.wrap_chainrules_input(ans)
2-element Vector{ChainRulesCore.Tangent{Any, Tuple{Int64, Int64, ChainRulesCore.ZeroTangent}}}:
Tangent{Any}(1, 2, ChainRulesCore.ZeroTangent())
Tangent{Any}(3, 4, ChainRulesCore.ZeroTangent())
julia> Zygote.wrap_chainrules_output(ans)
2-element Vector{Tuple{Int64, Int64, Nothing}}:
(1, 2, nothing)
(3, 4, nothing)
The arrays here have exactly the same contents, so ideally this would be done by reinterpret
ing the data. The code needed looks something like this:
@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
if isbitstype(B)
# B is the backing type. It still contains NoTangent etc, which need converting to Nothing
reinterpret(wrap_chainrules_output(B), dxs)
else
map(wrap_chainrules_output, dxs)
end
end
wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
inner = map(wrap_chainrules_output, T.parameters)
:(Tuple{$(inner...)})
end
wrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
inner = map(wrap_chainrules_input, T.parameters)
:(Tuple{$(inner...)})
end
function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
if isbitstype(S)
T = wrap_chainrules_input(S)
reinterpret(Tangent{P,T}, dx)
else
map(z2d, dx, primal)
end
end
But at present, pasting this in causes some 2nd derivative tests to fail.