lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

Next lecture for JAX conversion

Open jstac opened this issue 2 years ago • 6 comments

EDIT transferred from intermediate lectures. The discussion is still relevant.

Perhaps this lecture is a good candidate for attempting to port to JAX: https://python.quantecon.org/ifp.html

I think it will be challenging because there is both linear interpolation and root finding.

I'm not sure how well these can be done with JAX, and it's possible that we cannot beat the Numba versions. But it would be interesting to find out.

CC @Smit-create @HumphreyYang

jstac avatar Jan 18 '23 09:01 jstac

Hi @jstac and @Smit-create,

I can have a first try and send it to @Smit-create for review to see if it needs further improvement :)

HumphreyYang avatar Jan 18 '23 11:01 HumphreyYang

Thanks @HumphreyYang . @Smit-create , you might like to try at the same time or closely coordinate with @HumphreyYang . I'm sure there will be plenty of challenges.

jstac avatar Jan 18 '23 21:01 jstac

Thanks @jstac, @HumphreyYang. I'll also have a look into it at the same time.

Smit-create avatar Jan 19 '23 06:01 Smit-create

I think it will be challenging because there is both linear interpolation and root finding.

Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using brentq. Since we make all the functions jax.jit, numba fails as it couldn't detect the types.

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'>

Smit-create avatar Jan 19 '23 12:01 Smit-create

I think it will be challenging because there is both linear interpolation and root finding.

Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using brentq. Since we make all the functions jax.jit, numba fails as it couldn't detect the types.

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'>

@Smit-create Same here. brentq needs a function as input, and it is at the center of this implementation. Nonetheless, any JAX jitted function will not be able to be passed into it unless it is not JAX compiled. Without jax.jit, there is little room to improve -- there may be some work we can do to eliminate the for loops, but the trade-off will be losing the speed-up of the compiled function unless we have a JAX version of the function.

I also attempted to convert the last exercise, but IFP class defined using the extended pytree blocked the implementation. nametuple may get around with it till some point, but interp function limits the possibility to parallel the computation in JAX.

HumphreyYang avatar Jan 19 '23 12:01 HumphreyYang

Thanks for making a start @Smit-create @HumphreyYang

Yes, the first challenge is that you need to find JAX equivalents for root finding and interpolation.

Regarding root finding, you could try https://jaxopt.github.io/stable/root_finding.html or the methods in https://python.quantecon.org/newton_method.html

For linear interpolation in one dimension there is https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html

You will need to be careful about what happens outside the grid points used for interpolation.

For some lectures we will need 2D interpolation. This might be harder. It is mentioned in https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html

For 2D I noticed https://github.com/adam-coogan/jaxinterp2d/

jstac avatar Jan 19 '23 21:01 jstac