lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

[jax_intro] Update `%time` magic with `%timeit`

Open HumphreyYang opened this issue 7 months ago • 7 comments
trafficstars

This PR fixes #206.

HumphreyYang avatar Apr 07 '25 12:04 HumphreyYang

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
Latest commit f7bdbb302211676b838b8f6f6c748c665bef71c8
Latest deploy log https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/67f3daebe555670008250a96
Deploy Preview https://deploy-preview-207--incomparable-parfait-2417f8.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

netlify[bot] avatar Apr 07 '25 12:04 netlify[bot]

🚀 Deployed on https://67f3de93938c4d7c5b8d93e2--incomparable-parfait-2417f8.netlify.app

github-actions[bot] avatar Apr 07 '25 14:04 github-actions[bot]

Hi @jstac,

Since %timeit runs the code multiple times, it's not useful for examples where we want to show the compilation time. So, I only used %timeit on the lines that were causing issues.

Interestingly, f_jit(x) is still slower than f(x), even under %timeit, in the preview here:

%timeit f(x).block_until_ready()
100 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
231 ms ± 313 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

but I cannot replicate it on Colab. On Colab, the jit version runs much faster:

%timeit f(x).block_until_ready()
127 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
67.4 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Colab uses a newer version of JAX (jax==0.5.2), while we are using jax==0.4.35. Nonetheless, I don't think the issue is related to the JAX version, given that it was working properly before.

One approach might be to replace f with a more computationally intensive function so that the jit-compiled version shows a more noticeable performance difference.

Please let me know your thoughts on how we should proceed.

HumphreyYang avatar Apr 08 '25 01:04 HumphreyYang

Thanks @HumphreyYang , much appreciated.

@mmcky , do you have thoughts here?

I know you were considering having a separate runner that we control. Overall, it would be nice to control our environment and be a bit more up to date with JAX versions.

jstac avatar Apr 08 '25 01:04 jstac

Thanks @HumphreyYang just doing some version checking in https://github.com/QuantEcon/lecture-jax/pull/208 as I'm not sure why the jax version would be that old.

After I get the MIT Solve application together I am going to setup custom GitHub runners. I have done some experiments and we should be able to get the GitHub runner going on the GPU server. It would be ideal to have a dedicated machine though as it will be running arbitrary code from GitHub.

mmcky avatar Apr 08 '25 05:04 mmcky

@HumphreyYang the environment on GitHub actions should be using

jax                       0.5.3                    pypi_0    pypi

can you let me know where you found the old version of jax? Thanks

mmcky avatar Apr 08 '25 06:04 mmcky

Many thanks @mmcky,

I found it under the preview building action:

IMG_5636

HumphreyYang avatar Apr 08 '25 06:04 HumphreyYang

Resolved by using JAX==0.6.0

HumphreyYang avatar Apr 30 '25 01:04 HumphreyYang