s2fft
s2fft copied to clipboard
Questions related to HEALPix support
Hello, I have been working on interfacing s2fft with https://github.com/CMBSciPol/jax-healpy (currently only the transform for spin 0 is included, I am working on the interface towards spin=-2 and 2 and the adaptation of other healpy routines). Our goal is to have jittable and differentiable JAX adaptations of the healpy routines.
For this interface, I had some questions related to the methods jax and jax_healpy.
First, is I am not sure if the best practice would be to use only the jax method for both CPU and GPU, or if on CPU it would be better to switch to jax_healpy (for instance using jax.lax.platform)
Following this point, the forward function with the jax_healpy method is currently not jitable as the healpy_forward expects a jnp.array, but it transforms it right away into a np.array which is not tracable, here: https://github.com/astro-informatics/s2fft/blob/5d1d13fc47f4a4f8ed1ca0cf952819ff21ac141a/s2fft/transforms/c_backend_spherical.py#L331
Would it be possible to have it jittable? This can probably be solved with a pure_callback to make it jittable, and on which you have already the custom_jvp in place to make it auto-differentiable, but I am not sure if it is the best way to proceed.
Otherwise, I guess the only way to proceed is to use the jax method instead?
I also saw that you were using jnp.repeat here: https://github.com/astro-informatics/s2fft/blob/5d1d13fc47f4a4f8ed1ca0cf952819ff21ac141a/s2fft/transforms/spherical.py#L733 in the forward_jax. On my side, I found jnp.broadcast_to to be in general more performant for the same use. In the specific line I highlighted, you may be able to replace this line with indices = jnp.broadcast_to(jnp.arange(L), (2 * L - 1, L,)).T to obtain the same result with a gain of a factor 2 in average (for L = 255 in my benchmark) for this specific execution.
Again, thanks a lot for all of your hard work!