ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
WIP: Wirtinger support
As discussed in #40, Wirtinger support is going to be moved out of master for now. I'm going to start working on it in this branch, but might eventually decide to move this into a package.
My current plans for this so far are:
- [x] remove type constraints for Wirtinger (since pretty much anything can be a differential, these don't make much sense to me anymore)
- [x]
makeadd a functionwirtinger_[primal|conjugate]
recursive, to work better for things likeThunk
sunthunk
for this instead - [x] introduce a type
ComplexGradient
, which works like Zygote's complex derivatives to address #23 and make porting Zygote to ChainRules easier (Still needs docstrings) - [x] introduce an abstract type
AbstractWirtinger
, whichWirtinger
as well asComplexGradient
are a subtype of (Still needs docstrings) - [x] a function
chain
, which works mostly like*
, but respects which function is the derivative of the outer/inner function, which is important forAbstractWirtinger
differentials (Still needs docstrings) - [x] simplify
@scalar_rule
with thischain
function - [x] disallow multiplication of complex numbers with
AbstractWirtinger
and require users to usechain
, since the order of chaining matters here, too (Still needs a better error) - [ ] tests
- [ ] docs
I always appreciate any feedback.
Thinking about making ComplexGradient
two real parameters instead of one complex one, so we could make use of Zero
in some places. Would complicate some arithmetic, though.
I've written down some thoughts about this in FluxML/Zygote.jl#328 also. I still need to finish that though and will then try to adapt some of that in the docs here.
I have not forgotten about this and started reading it again yesterday.
disallow multiplication of complex numbers with AbstractWirtinger and require users to use chain, since the order of chaining matters here, too (Still needs a better error)
Why disallowing ::Wirtinger * ::Complex
? How would a user interact with the Wirtinger object if there is no generic function defined?
Most of the time, users shouldn't interact with Wirtinger objects at all. One of the main use case I see for them is as intermediary representations in mixed-mode AD implementations. If a library wants to expose that functionality to users, the authors can add their own abstraction on top to best fit their particular needs.
To understand, why we don't want to multiply complex numbers and Wirtinger objects, we can think of Wirtinger objects as 2x2 real Jacobians in a different basis, where two real numbers just happen to be represented by one complex number. In this context then, it is quite clear, that these form only a real vector space and why multiplication with a complex number is not well defined. We can still define an injective homomorphism from the complex numbers to Wirtinger objects by mapping z
to Wirtinger(z, Zero())
, but these don't commute with all Wirtinger objects anymore, which would be required for a complex vector space. JuliaDiff/ChainRules.jl#133 and JuliaDiff/ChainRules.jl#135 make it quite clear that subtypes of AbstractDifferential
should form a vector space, so I believe disallowing ::Wirtinger * ::Complex
is the only reasonable thing to do here.
@simeonschaub what's the status of this? Has it been completely superceded by the other complex numbers stuff?
What has been superseded is ComplexGradient
, now that we just use Adjoint
for that. It seems like people still expressed interest in Wirtinger derivatives, so this might be something to consider for v2. I currently don't have much use for this anymore, but if someone wanted to push this forward, I would certainly certainly offer to help wherever I can.