pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Extend verify_grad to complex gradient

Open educhesne opened this issue 8 months ago • 4 comments

Description

Extend verify_grad to complex gradient following the holomorphic gradient convention (as in JAX). The decision on which convention to follow (JAX-like or torch-like) has not been taken yet; see issue #1366.

Related Issue

  • [ ] Closes #
  • [x] Related to #1366

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1367.org.readthedocs.build/en/1367/

educhesne avatar Apr 14 '25 17:04 educhesne

Codecov Report

:x: Patch coverage is 80.00000% with 4 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 82.05%. Comparing base (f1514eb) to head (806fad1). :warning: Report is 293 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/gradient.py 80.00% 1 Missing and 3 partials :warning:

:x: Your patch check has failed because the patch coverage (80.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1367   +/-   ##
=======================================
  Coverage   82.05%   82.05%           
=======================================
  Files         203      203           
  Lines       48863    48875   +12     
  Branches     8695     8696    +1     
=======================================
+ Hits        40093    40103   +10     
- Misses       6619     6620    +1     
- Partials     2151     2152    +1     
Files with missing lines Coverage Δ
pytensor/gradient.py 78.62% <80.00%> (+0.07%) :arrow_up:
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Apr 14 '25 18:04 codecov[bot]

The link in the original issues argues very much for the non-JAX approach. Why did JAX go this way? Do the arguments they make hold for higher order auto-diff?

Also want to modify one Op in this PR to show it working for complex gradients?

ricardoV94 avatar Apr 29 '25 19:04 ricardoV94

I'm afraid I've raised an issue which is beyond my knowledge... I'll try to summarize what I've understood so far though.

In the case of a real-valued function, jax and torch are equivalent (up to a conjugate). However when the function is complex-valued torch cannot compute the derivative. The gradient of torch assumes in its chain rule that $\frac{\partial L}{\partial z^\star} = \left(\frac{\partial L}{\partial z}\right)^\star$ which is true when $L$ is real-valued but not in general.

The internals of jax go way over my head, but the doc suggests that it computes internally the full derivative of $f(x+iy) = u(x,y) + i v(x,y)$ as if it was a $\mathbb R^2 \to \mathbb R^2$ function, that is to says the 2x2 $\mathbb R$-matrix of partial derivatives of $u$ and $v$, and returns $\frac{\partial u}{\partial x} - i\frac{\partial u}{\partial y}$ as the gradient. Whenever $f$ is real-valued ($v=0$), this expression is equal to $2\frac{\partial f}{\partial z}$ (ie the conjugate of what torch returns), and whenever $f$ is holomorphic it is equal to $\frac{\partial f}{\partial z}$. So clearly jax has a broader scope than torch.

In order to deal with higher order auto-diff I reckon we need both differentiation wrt $z$ and wrt $z^\star$. Even if the function is real-valued, its gradient is complex-valued and not necessarily holomorphic (I don't know how to do that in jax...)

Regarding pytensor I think if in L_op the grads arguments contained the derivative wrt to $z$ and $z^\star$ (and L_op returned them as well), then it would be possible to compute the chain rule in the general complex case. I realize that it is much more complicated than I thought initially...

educhesne avatar May 11 '25 20:05 educhesne

Just for reference some old theano discussions on the topic:

https://github.com/Theano/Theano/issues/3537

https://web.archive.org/web/20180714173600/http://deeplearning.net/software/theano/proposals/complex_gradient.html

ricardoV94 avatar May 11 '25 21:05 ricardoV94