ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

RFC: Rules for real-to-complex and complex-to-real functions

Open sethaxen opened this issue 5 years ago • 2 comments

Consider a case where we have a function f: ℝᵐ → ℂʳ → ℂˢ → ℝⁿ = ℝᵐ → ℝⁿ, which we can write as f = f₃ ∘ f₂ ∘ f₁. Typically f₁ will produce a complex output by adding, subtracting, multiplying or dividing the real by a complex number or by calling promote, complex, Complex or cis. Typically f₃ will produce a real output by calling a non-holomorphic function like real, imag, abs, abs2, hypot, or angle.

From https://github.com/JuliaDiff/ChainRulesCore.jl/pull/167, the fact that there are complex intermediates to f is just an implementation detail. We could have defined f: ℝᵐ → ℝ²ʳ → ℝ²ˢ → ℝⁿ, and the pushforwards and pullbacks of this new f should behave the same.

Since in general tangents are derivatives of a primal wrt a real, and co-tangents are derivatives of a real wrt a primal, the pushforward through f₁: ℝᵐ → ℂʳ should produce a complex tangent, while the pushforward through f₃: ℂˢ → ℝⁿ should produce a real tangent. Conversely, the pullback through f₃ should produce a complex cotangent, and the pullback through f₁ should produce a real cotangent.

The pushforward case is pretty easy to handle. We can 1) assume that a non-sensical tangent will not be passed and do nothing special (i.e. assume upstream AD did the right thing) or 2) define custom frules that ensure that the produced tangent of unary functions f₃(::Complex)::Real is real.

The pullback case is more complicated. Right now e.g. in Zygote, unless you create a complex number from reals by calling complex, you'll end up pulling back complex numbers through the initial real part of your program, which not only is wasteful but could break assumptions of the rrules of upstream functions. I propose for the binary functions f₁ adding custom rrules for f₁(::Real, ::Complex)::Complex and f₁(::Complex, ::Real)::Complex to ensure that the co-tangent pulled back to a real primal is actually real.

This came up a point of discussion in JuliaDiff/ChainRules.jl#196, and I would appreciate feedback so we can clarify our conventions here.

sethaxen avatar Jun 25 '20 21:06 sethaxen

See also this issue about why Zygote doesn't do this (tl/dr Zygote basically treats all reals as embedded in the complex numbers): https://github.com/FluxML/Zygote.jl/issues/342 and an update here: https://github.com/FluxML/Zygote.jl/issues/472

Also ccing @MikeInnes because this could change the behavior of Zygote.

sethaxen avatar Jun 26 '20 00:06 sethaxen

+1 for the pushforward / pullback of f(::Real)::Real) with real sensitivities / adjoints to be real, i.e. for ChainRules to stay real if all its input is real. This convention allows complex pushforwards / pullbacks to be obtained using

invoke(frule, Tuple{typeof(Δx), typeof(f), complex(typeof(x))}, Δx, f, x)
invoke(rrule, Tuple{typeof(f), complex(typeof(x))}, f, x)

Conversely, if we defined pushforwards / pullbacks to be complex for some functions f(::Real)::Real, then there would be no way to get the real version of the derivatives.


Edit: Actually, this does not work since !(Real <: Complex) :unamused:

Of course, a similar effect could be achieved using

frule(Δx, f, complex(x))
rrule(f, complex(x))

but this incurs some runtime penalty. In the case of Complex vs Real, this penalty would probably be acceptable in most circumstances, but I've been thinking that a similar approach could also be used for similar issues with AbstractArrays, see e.g. https://github.com/JuliaDiff/ChainRules.jl/issues/191 and https://github.com/JuliaDiff/ChainRules.jl/issues/52. The problem there is that it is not clear whether the adjoint of e.g. f(A,B) = A*B with respect to A::Diagonal should be a Diagonal or a Matrix. If the above invoke worked, then that would provide an interface for clarifying the intent: rrule(*, A::Diagonal, B)[2](ΔC)[1] would be a Diagonal, and if you wanted a Matrix instead then you could invoke the Matrix method of the rrule. And in this case, it would clearly not be acceptable to call rrule(*, Matrix(A::Diagonal), B).

ettersi avatar Jun 26 '20 03:06 ettersi