lecture-jax
lecture-jax copied to clipboard
[jax_intro] Update `%time` magic with `%timeit`
This PR fixes #206.
Deploy Preview for incomparable-parfait-2417f8 ready!
| Name | Link |
|---|---|
| Latest commit | f7bdbb302211676b838b8f6f6c748c665bef71c8 |
| Latest deploy log | https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/67f3daebe555670008250a96 |
| Deploy Preview | https://deploy-preview-207--incomparable-parfait-2417f8.netlify.app |
| Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
🚀 Deployed on https://67f3de93938c4d7c5b8d93e2--incomparable-parfait-2417f8.netlify.app
Hi @jstac,
Since %timeit runs the code multiple times, it's not useful for examples where we want to show the compilation time. So, I only used %timeit on the lines that were causing issues.
Interestingly, f_jit(x) is still slower than f(x), even under %timeit, in the preview here:
%timeit f(x).block_until_ready()
100 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
231 ms ± 313 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
but I cannot replicate it on Colab. On Colab, the jit version runs much faster:
%timeit f(x).block_until_ready()
127 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
67.4 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Colab uses a newer version of JAX (jax==0.5.2), while we are using jax==0.4.35. Nonetheless, I don't think the issue is related to the JAX version, given that it was working properly before.
One approach might be to replace f with a more computationally intensive function so that the jit-compiled version shows a more noticeable performance difference.
Please let me know your thoughts on how we should proceed.
Thanks @HumphreyYang , much appreciated.
@mmcky , do you have thoughts here?
I know you were considering having a separate runner that we control. Overall, it would be nice to control our environment and be a bit more up to date with JAX versions.
Thanks @HumphreyYang just doing some version checking in https://github.com/QuantEcon/lecture-jax/pull/208 as I'm not sure why the jax version would be that old.
After I get the MIT Solve application together I am going to setup custom GitHub runners. I have done some experiments and we should be able to get the GitHub runner going on the GPU server. It would be ideal to have a dedicated machine though as it will be running arbitrary code from GitHub.
@HumphreyYang the environment on GitHub actions should be using
jax 0.5.3 pypi_0 pypi
can you let me know where you found the old version of jax? Thanks
Many thanks @mmcky,
I found it under the preview building action:
Resolved by using JAX==0.6.0