Potential Performance due to Jax version
Describe the issue as clearly as possible:
Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34
Reproduced locally: On JAX 0.4.30
-------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------
Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts 4.2739 (1.0) 5.0262 (1.0) 4.7685 (1.0) 0.3398 (1.0) 4.9671 (1.0) 0.5408 (1.0) 1;0 0.2097 (1.0) 5 1
test_regression_hmc 7.2055 (1.69) 8.1514 (1.62) 7.6479 (1.60) 0.4128 (1.22) 7.5257 (1.52) 0.7291 (1.35) 2;0 0.1308 (0.62) 5 1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
On JAX 0.4.34
---------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------
Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts 9.2754 (1.0) 10.2643 (1.0) 9.6681 (1.0) 0.3660 (1.0) 9.6078 (1.0) 0.3647 (1.0) 2;0 0.1034 (1.0) 5 1
test_regression_hmc 19.7752 (2.13) 21.4303 (2.09) 20.6382 (2.13) 0.7185 (1.96) 20.4793 (2.13) 1.2633 (3.46) 2;0 0.0485 (0.47) 5 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Steps/code to reproduce the bug:
Ping to a jax version and run
pytest --benchmark-only
### Expected result:
```shell
n.a
Error message:
n.a
Blackjax/JAX/jaxlib/Python version information:
n.a
Context for the issue:
No response
related: https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: https://github.com/jax-ml/jax/discussions/23822
Also https://github.com/jax-ml/jax/discussions/24501
On Mon, Oct 14, 2024, 8:23 AM Junpeng Lao @.***> wrote:
related: pyro-ppl/numpyro#1867 https://github.com/pyro-ppl/numpyro/issues/1867 likely rootcause and workaround see: jax-ml/jax#23822 https://github.com/jax-ml/jax/discussions/23822
— Reply to this email directly, view it on GitHub https://github.com/blackjax-devs/blackjax/issues/746#issuecomment-2411077970, or unsubscribe https://github.com/notifications/unsubscribe-auth/AARQOEGWUIUCYRFOP7EMFCTZ3OZSNAVCNFSM6AAAAABPN7CCX2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMJRGA3TOOJXGA . You are receiving this because you are subscribed to this thread.Message ID: @.***>