jax
jax copied to clipboard
Bessel Functions to Match scipy
Dear Jax team,
There is an issue (https://github.com/google/jax/issues/12402) about problems with Bessel function support in core Jax. In particular, the Bessel Jv functions (and especially J0 and J1) are the basis functions of the Hankel transform (radially symmetric Fourier transform) and crop up everywhere in optics and computer vision, so it is very important to get these right.
As noted on that issue, the existing Jax implementation is not numerically stable. I have copied the C implementation used in scipy exactly and made this pull request. See demo: https://github.com/benjaminpope/sibylla/blob/main/notebooks/bessel_test.ipynb
In this PR are
- functions under
jax.scipy.special:[bessel_j0, bessel_j1, bessel_jn], with implementation under_srcas normal. These call helper functionsj0_small,j0_large,j1_small,j1_largethat are not exposed to the user. - a basic unit test under tests/third_party/test_bessel.py
Things that I would like help with:
- we use the Bessel recurrence relation to generate arbitrary higher order bessel functions Jv for v>2, bootstrapping from J0 and J1. But if you use
bessel_jnwith v<2, it will currently fail. I don't know how to fix this so that it defaults to just calling J0 or J1 as appropriate without some horriblewherecalls, but I bet jax ninjas on here will know the right syntax. - the unit test is basic - to be honest, I don't understand the syntax used in the core unit tests for
jax.scipy.special, and perhaps someone can help with this!
Caveats
- The higher order bessel_jn for large n start to diverge from scipy for x<<1, only up to about 1e-6 but still much worse than the machine precision that holds everywhere else. This is due to accumulating numerical error through the recurrence relation. It's still better than the existing implementation, but perhaps we can improve on this.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Hi,
I'm quite interested by this new implementation. It would wonderful if the new code could be differentiated with the use of jacrev at least wrt to z argument.
It would wonderful if the new code could be differentiated with the use of jacrev at least wrt to z argument.
For this, it might be worth defining custom JVPs using recurrance identities. We do something similar to this in this example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp
Hi, I have updated my Colab nb where I show an implementation of J0 and J1 that matches the sc.special.j0 & j1 and fulfils the use of jit/vmap/jacrev. This is if one wants to compare to the this PR.
Just bumping this @jakevdp - I think the recurrence identities for higher integer Jv are not necessarily numerically stable. I think float order Jv is probably too hard basket right now but @shashankdholakia might be interested in it for some of his applications.
Are you happy if we do a minimal PR on this of just J0, J1 this time?
Hi @jakevdp,
Are we happy to merge this?
Cheers,
Ben
Sorry, I'd been waiting for the outstanding issues here to be addressed: for example, there's still a notebook that needs to be removed, and I thought the plan was to replace the existing implementation with this one (and add tests to ensure it has the same API as the scipy.special versions).
What do you think? Does that sound reasonable?
Also, I don't think this is complete without usable autodiff rules – autodiff should not have to trace through the implementation. I think custom JVP rules with autodiff recurrance relations are probably the best answer to that.
Hi all, thanks for all the great work on this! I personally use Bessel functions quite often and would love to see them added to Jax. I've implemented a version ofbessel_jn that should hopefully match the scipy implementation for higher orders. It uses a combination of a power series expansion of the Bessel function at small z, the backwards recurrence relation currently in jax.scipy.special with some small modifications, and the forward recurrence relation method proposed here. The combination of the methods passes the np.allclose test with scipy for all v and z I've tested using float64 precision. It's also vectorized to match scipy.special.jn and added a custom JVP using the derivative recurrence relation. Here's the notebook I've been testing in. Not sure if it makes sense to add it here or make a separate PR. Thanks!
Very interested in using this for the reasons mentioned by @benjaminpope. @jakevdp Is there any working code for at least low order Bessel functions, e.g., J0, J1, Y0, Y1? Any insight is appreciated.
The existing Bessel function implementations in JAX are all in jax.scipy.special, namely i0, i0e, i1, i1e, and bessel_jn.
@jakevdp Thank you for your prompt reply. Unfortunately, I need BesselK (orders 0 and 1) evaluated at complex inputs. It seems the bessel_jn function you referenced does not support complex inputs, and even if it did, I would also need the BesselY function to construct BesselK. I believe the i0, i0e, i1, and i1e functions don't help with this either. I have found an implementation of BesselK in TensorFlow, but it also does not support complex inputs. Jax would likely find more immediate applications in physics if there was, for example, an equivalent to the scipy.special.kv function, which does support complex inputs. I am curious to know if @benjaminpope 's implementation resolves these issues and if so, will it be merged in the near future?
I don't know of any concrete plans to add more bessel implementations to the JAX core library.
I've also implemented spherical Bessel functions with my student @shashankdholakia, with custom vjp - if you don't want these in core jax, do we just want a separate repo with the Bessels?
TBH it's a bit weird to me that we don't want to support special functions in Jax when in a lot of fields of science and engineering these are as fundamental as sines and cosines.
Thanks for the comment. We've written up what's in-scope and out-of-scope for JAX, and the reasoning behind it, in this doc: https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html
Please let me know if you have questions!
I have to agree with @benjaminpope. I am deeply grateful to the developers for all of the time and effort they devote to maintaining this library, but it seems the cost of implementing the Bessel functions is dramatically outweighed by the benefits arising from their prospective applications in physics/engineering. While JAX is probably most often used for machine learning projects, it's also ideal for solving boundary value problems (BVPs) on a discretized numerical grid based on high-dimensional root finders like Newton's method. Such methods require the computation of high-dimensional Jacobians, which JAX is uniquely situated to tackle with auto-differentiation and efficient matrix operations. Non-JAX implementations rely on numerical approximations of the Jacobian and are consequently less accurate. It therefore seems a shame that solutions to the most common BVPs in physics/engineering rely on Green's functions (among them the Bessel functions) that are not available in JAX.
High-dimensional root finding for the solution of boundary value problems seems well-within the scope of JAX, and I would argue that implementations of the most common Green's functions required in these problems (e.g., the Bessel functions) therefore are as well. Compared to the immense amount of work done by the developers to create JAX as we know it today, this addition feels like the last mile of a marathon: difficult, but nearly complete. It seems @benjaminpope has already laid the foundation. @jakedvp forgive me if I am underestimating the difficulty of the last mile.
I agree that adding Bessel functions to JAX would be very advantageous, but most implementations rely on series expansions that are not well suited for implementation in JAX, because of several reasons (1) they often are designed to only work in float64, meaning the functions would return erroneous results in the default computation mode (2) the expansions are iterative in nature, and thus often not efficiently expressible in the computational tools available at the JAX level, and (3) have sometimes complicated convergence criteria that varies by input domain, making the computation even less suited to JAX-level implementation.
I'm all for adding Bessel function implementations for JAX, but only if the implementations are robust and efficient, which generally means implementation at a lower level.
So a quick update on what would be needed for good performance in float32:
J0, J1, Jv
The Jax translation of the implementation in CEPHES used in scipy (which is just piecewise polynomials, not iterative) works to machine precision for j0 and j1 in float32 in this notebook. So I am not at all worried about the float32 performance of these two functions. The trouble is that if you want to define custom exact derivatives, which we should, you need to be able to go up to arbitrary order jv. This also matches scipy at close to machine precision for most values of x, but breaks in float32 for very small values.
This is is probably fundamental: the derivatives of Jv(0) vanish up to order v, so you end up having quite high order polynomials close to the origin to express any curvature at all, and this will in general be a battle in float32. Realistically, most use cases are of J0, J1 and their first few derivatives (ie in our use cases, you want to calculate the grad of an element of the Fisher matrix ie Hessian of something with a J1 in it: so this gets you up to J4). So I think it's not unreasonable to truncate these at some order and meet most people's needs, but it's not ideal.
The cephes implementation (eg used in torch) for jv is quite complicated, and uses several different series expansions dynamically; I could have a crack at this but have a busy travel schedule and am unlikely to tackle this in the immediate future. Implementing this in Jax or XLA would allow us to go to arbitrary order in the custom_jvp, and would therefore be probably necessary for its inclusion in Jax.
One way forward is to use trigonometric approximations: https://arxiv.org/abs/2206.05334. It looks pretty good - even in float32 we're at machine precision for J2 and arguments less than about 3 or 4 (see last cells of this notebook), which gives us plenty of room to switch over to the recursion relation for larger values which are stable. What do you think?
Spherical Bessels
The implementation of spherical bessels by @shashankdholakia is good to machine precision for float64, but breaks for small arguments (returns nan) because of the while loops. I suggest this can easily be addressed with a branch doing a series expansion around the origin, though these also have the same property as their capital-letter cousins that the derivatives vanish up to order nu at the origin so you can get to quite high order. I wonder are there trig expansions analogous to the capital-Jv functions?
@jakevdp in terms of lower-level implementations, is there any enthusiasm in the team for doing either of these things in XLA? A pure-Jax implementation along the lines of scipy will I think be stable even in float32 except for small values, but it's going to be hard to get good float32 performance with polynomial expansions.