funsor
funsor copied to clipboard
Improve the speed of tests under JAX backend
Currently, jax
tests are pretty slow comparing to torch
backend. It would be nice if we can improve the speed of those tests.
Two places in the tests where this is most prominent: test_distribution_generic.py
and test_sum_product.py
.
I tried various ways in JAX to improve compiling time for samplers but no success. :(