lecture-jax
lecture-jax copied to clipboard
Next lecture for JAX conversion
EDIT transferred from intermediate lectures. The discussion is still relevant.
Perhaps this lecture is a good candidate for attempting to port to JAX: https://python.quantecon.org/ifp.html
I think it will be challenging because there is both linear interpolation and root finding.
I'm not sure how well these can be done with JAX, and it's possible that we cannot beat the Numba versions. But it would be interesting to find out.
CC @Smit-create @HumphreyYang
Hi @jstac and @Smit-create,
I can have a first try and send it to @Smit-create for review to see if it needs further improvement :)
Thanks @HumphreyYang . @Smit-create , you might like to try at the same time or closely coordinate with @HumphreyYang . I'm sure there will be plenty of challenges.
Thanks @jstac, @HumphreyYang. I'll also have a look into it at the same time.
I think it will be challenging because there is both linear interpolation and root finding.
Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using brentq. Since we make all the functions jax.jit, numba fails as it couldn't detect the types.
This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'>
I think it will be challenging because there is both linear interpolation and root finding.
Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using
brentq. Since we make all the functionsjax.jit,numbafails as it couldn't detect the types.This error may have been caused by the following argument(s): - argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'> - argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> - argument 3: Cannot determine Numba type of <class 'tuple'>
@Smit-create Same here. brentq needs a function as input, and it is at the center of this implementation. Nonetheless, any JAX jitted function will not be able to be passed into it unless it is not JAX compiled. Without jax.jit, there is little room to improve -- there may be some work we can do to eliminate the for loops, but the trade-off will be losing the speed-up of the compiled function unless we have a JAX version of the function.
I also attempted to convert the last exercise, but IFP class defined using the extended pytree blocked the implementation. nametuple may get around with it till some point, but interp function limits the possibility to parallel the computation in JAX.
Thanks for making a start @Smit-create @HumphreyYang
Yes, the first challenge is that you need to find JAX equivalents for root finding and interpolation.
Regarding root finding, you could try https://jaxopt.github.io/stable/root_finding.html or the methods in https://python.quantecon.org/newton_method.html
For linear interpolation in one dimension there is https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html
You will need to be careful about what happens outside the grid points used for interpolation.
For some lectures we will need 2D interpolation. This might be harder. It is mentioned in https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html
For 2D I noticed https://github.com/adam-coogan/jaxinterp2d/