fenbux icon indicating copy to clipboard operation
fenbux copied to clipboard

Speed vs Distrax

Open adam-hartshorne opened this issue 1 year ago • 2 comments

I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.

It might be worth looking into why.

import timeit

def setup_code():
    return '''
from jax import jit
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax

key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)

mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))

def febux_test(mean, sd, y_k):
    return logpdf(Normal(mean=mean, sd=sd), y_k).sum()

def distrax_test(mean, sd, y_k):
    return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()

jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
'''


febux_time = timeit.timeit('jit_febux_test(mean, sd, y_k).block_until_ready()',
                           setup=setup_code(), number=1000)

# Timing distrax_test
distrax_time = timeit.timeit('jit_distrax_test(mean, sd, y_k).block_until_ready()',
                             setup=setup_code(), number=1000)

print("Febux Test Time:", febux_time)
print("Distrax Test Time:", distrax_time)

Febux Test Time: 0.10123697502422146 Distrax Test Time: 0.08472020699991845

adam-hartshorne avatar Jan 16 '24 15:01 adam-hartshorne