Extend verify_grad to complex gradient
Description
Extend verify_grad to complex gradient following the holomorphic gradient convention (as in JAX). The decision on which convention to follow (JAX-like or torch-like) has not been taken yet; see issue #1366.
Related Issue
- [ ] Closes #
- [x] Related to #1366
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1367.org.readthedocs.build/en/1367/
Codecov Report
:x: Patch coverage is 80.00000% with 4 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 82.05%. Comparing base (f1514eb) to head (806fad1).
:warning: Report is 293 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/gradient.py | 80.00% | 1 Missing and 3 partials :warning: |
:x: Your patch check has failed because the patch coverage (80.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## main #1367 +/- ##
=======================================
Coverage 82.05% 82.05%
=======================================
Files 203 203
Lines 48863 48875 +12
Branches 8695 8696 +1
=======================================
+ Hits 40093 40103 +10
- Misses 6619 6620 +1
- Partials 2151 2152 +1
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/gradient.py | 78.62% <80.00%> (+0.07%) |
:arrow_up: |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
The link in the original issues argues very much for the non-JAX approach. Why did JAX go this way? Do the arguments they make hold for higher order auto-diff?
Also want to modify one Op in this PR to show it working for complex gradients?
I'm afraid I've raised an issue which is beyond my knowledge... I'll try to summarize what I've understood so far though.
In the case of a real-valued function, jax and torch are equivalent (up to a conjugate). However when the function is complex-valued torch cannot compute the derivative. The gradient of torch assumes in its chain rule that $\frac{\partial L}{\partial z^\star} = \left(\frac{\partial L}{\partial z}\right)^\star$ which is true when $L$ is real-valued but not in general.
The internals of jax go way over my head, but the doc suggests that it computes internally the full derivative of $f(x+iy) = u(x,y) + i v(x,y)$ as if it was a $\mathbb R^2 \to \mathbb R^2$ function, that is to says the 2x2 $\mathbb R$-matrix of partial derivatives of $u$ and $v$, and returns $\frac{\partial u}{\partial x} - i\frac{\partial u}{\partial y}$ as the gradient. Whenever $f$ is real-valued ($v=0$), this expression is equal to $2\frac{\partial f}{\partial z}$ (ie the conjugate of what torch returns), and whenever $f$ is holomorphic it is equal to $\frac{\partial f}{\partial z}$. So clearly jax has a broader scope than torch.
In order to deal with higher order auto-diff I reckon we need both differentiation wrt $z$ and wrt $z^\star$. Even if the function is real-valued, its gradient is complex-valued and not necessarily holomorphic (I don't know how to do that in jax...)
Regarding pytensor I think if in L_op the grads arguments contained the derivative wrt to $z$ and $z^\star$ (and L_op returned them as well), then it would be possible to compute the chain rule in the general complex case.
I realize that it is much more complicated than I thought initially...
Just for reference some old theano discussions on the topic:
https://github.com/Theano/Theano/issues/3537
https://web.archive.org/web/20180714173600/http://deeplearning.net/software/theano/proposals/complex_gradient.html