lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

[cake_eating_numerical] Numba time less than JAX

Open mmcky opened this issue 1 year ago • 4 comments

The lecture cake_eating_numerical was removed in #175 as the jax time was longer than numba. We need to review the implementations in this lecture as a sense check.

A copy of the lecture is here:

cake_eating_numerical.md

and the timings were at the bottom of this preview

https://6661340ed985e83faa0cb785--incomparable-parfait-2417f8.netlify.app/cake_eating_numerical

We would expect jax to outperform numba unless there is a good reason that we should explain.

@kp992 do you have time to look into this lecture?

TODO:

  1. review implementations and confirm why numba is less than jax for execution time
  2. submit a PR updating and re-enabling this lecture

mmcky avatar Jun 06 '24 04:06 mmcky

Sure, will take a look.

kp992 avatar Jun 06 '24 14:06 kp992

Hi @mmcky, I checked the difference in timings and the main reason is the that the difference in the algorithms used. JAX is surely optimized to the fullest but the algorithm used by JAX to find the maximum is a brute force approach where as numba uses brent_max function. Its currently unavailable in JAX implementation and so JAX is just using a brute force approach over the grid.

kp992 avatar Jun 08 '24 12:06 kp992

If the brent_max part is available in JAX, we could beat numba in timings.

kp992 avatar Jun 08 '24 12:06 kp992

thanks @kp992 that is really helpful. Algorithms matter :-).

@jstac (Smit) has identified the issue in timings here and there is a good explanation as to why the numba execution is faster than jax. It is a good example of how algorithms matter (just as much as technology). What do you think about making this point in the lecture?

mmcky avatar Jun 10 '24 00:06 mmcky