Integrating JAX in py3nj?
Have you thought about using jax.numpy instead of numpy? I am using py3nj for my research and now I am porting the code to JAX. Do you think that the way the code is written would benefit from being jitted using jax.jit once all the instances of np in the code is converted to jnp?
Hi @srijaniiserprinceton
I've never thought about JAX and have no experience with it. But actually it is something I should study.
What is the benefit of using py3nj in JAX?
Use of GPU and autograd?
In this aspect, I'm not sure if actually benefits for py3nj.
autograd is hopeless as this is an integer problem.
Using this on GPU may be also difficult as this is originally a fortran implementation.
py3nj just wraps the original fortran implementation.
See this file
https://github.com/fujiisoup/py3nj/blob/master/fortran/_wigner.pyf
Does JAX disallow to use a native numpy function?
If it does, we maybe able to do the same thing on JAX.
I see. For some reason I thought you only use native numpy in the winger.py. If it is in FORTRAN then I guess it won't be as straightforward. Else, I was going to suggest that you might be able to change import numpy as np to import jax.numpy as np. From the brief experience that I have with JAX, I think it gives you near-C speedup because it converts the functions (which needs to be a pure function) into a compiled form after which it abandons the python behavior (which is usually slower than C).
The main advantage that I was trying to use is when calling the py3nj.wigner3j function. I was trying to do wig_jax = jax.jit(py3nj.wigner3j). But this was failing since the package used within py3nj is numpy and not jax.numpy.
In any case, just wanted to see what you think about the possibility. Maybe I would write a JAX version of py3nj if I see that getting the wigner-3j's are becoming a real bottleneck in my code.
Closing the issue :)
Hi, reopening this thread: We are using py3nj a lot in our code. But now, to speedup the code we are using JAX. To achieve our speedup, we need py3nj (and all the other functions) to be in jax.numpy (or NumPy compatible). Since py3nj uses fortran under-the-hood, it is becoming an obstacle in making the code just-in-time compatible using JAX. Could you please point me to the algorithm OR the part of the code in fortran where you are actually carrying out the computation? I am thinking of coding up the counterpart using jax.numpy so that it can be just-in-time compiled.
Hi @srijaniiserprinceton
We are using py3nj a lot in our code
It's really nice to know:) Thanks.
For example, this line https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc.f90#L33 is the high-level subroutine and this part https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc.f90#L93 computes actual 3j symbols.
I assume JAX also provides some mechanisms to call outside fortran codes, but I have no idea sorry.
Thanks! Also, to what angular degree \ell do you expect the Wigners to be accurate? We go upto around \ell=300.
Here it says https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc3jj.f#L85-L88 and thus I assume it is accurate even with large angular degree. Maybe you can take a look of the referenced information.
Thanks!
Let's keep this issue open. I would appreciate if you could post any updates here :+1: