fenbux
fenbux copied to clipboard
Speed vs Distrax
I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.
It might be worth looking into why.
import timeit
def setup_code():
return '''
from jax import jit
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax
key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)
mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))
def febux_test(mean, sd, y_k):
return logpdf(Normal(mean=mean, sd=sd), y_k).sum()
def distrax_test(mean, sd, y_k):
return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()
jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
'''
febux_time = timeit.timeit('jit_febux_test(mean, sd, y_k).block_until_ready()',
setup=setup_code(), number=1000)
# Timing distrax_test
distrax_time = timeit.timeit('jit_distrax_test(mean, sd, y_k).block_until_ready()',
setup=setup_code(), number=1000)
print("Febux Test Time:", febux_time)
print("Distrax Test Time:", distrax_time)
Febux Test Time: 0.10123697502422146 Distrax Test Time: 0.08472020699991845