Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Needless copies in `wrap_chainrules_input` and friends

Open mcabbott opened this issue 2 years ago • 0 comments

The conversion to and from ChainRules types should be essentially free for single structs, but for arrays of structs, it at present involves a copy in each direction. For example:

julia> zyg = [(1,2,nothing), (3,4,nothing)];

julia> Zygote.wrap_chainrules_input(ans)
2-element Vector{ChainRulesCore.Tangent{Any, Tuple{Int64, Int64, ChainRulesCore.ZeroTangent}}}:
 Tangent{Any}(1, 2, ChainRulesCore.ZeroTangent())
 Tangent{Any}(3, 4, ChainRulesCore.ZeroTangent())

julia> Zygote.wrap_chainrules_output(ans)
2-element Vector{Tuple{Int64, Int64, Nothing}}:
 (1, 2, nothing)
 (3, 4, nothing)

The arrays here have exactly the same contents, so ideally this would be done by reinterpreting the data. The code needed looks something like this:

@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
  if isbitstype(B)
    # B is the backing type. It still contains NoTangent etc, which need converting to Nothing
    reinterpret(wrap_chainrules_output(B), dxs)
  else
    map(wrap_chainrules_output, dxs)
  end
end
wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_output, T.parameters)
  :(Tuple{$(inner...)})
end

wrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
  inner = map(wrap_chainrules_input, T.parameters)
  :(Tuple{$(inner...)})
end

function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
  if isbitstype(S)
    T = wrap_chainrules_input(S)
    reinterpret(Tangent{P,T}, dx)
  else
    map(z2d, dx, primal)
  end
end

But at present, pasting this in causes some 2nd derivative tests to fail.

mcabbott avatar Oct 27 '21 16:10 mcabbott