numpyro
numpyro copied to clipboard
Flax BNN is several times slower in JAX 0.4.33 compared to JAX 0.4.31
Jax-0.4.31: Runtime: 27.06 seconds https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing
Jax-0.4.33: Runtime: 84.91 seconds https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing
I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic
Code:
import time
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.contrib.module import random_flax_module
import flax.linen as nn
# Set a random seed for reproducibility
rng_key = jax.random.PRNGKey(0)
# Generate some dummy data
def generate_data(n=100, noise_std=0.1):
X = jnp.linspace(-1, 1, n)
y = 3 * X + 2 + np.random.normal(0, noise_std, size=X.shape)
return X[:, None], y
# Define a simple neural network
class SimpleNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(10)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x.squeeze()
# Define the model
def model(X, y):
module = SimpleNN()
nn = random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))
with numpyro.plate("data", X.shape[0]):
mean = nn(X)
numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)
# Generate data
X, y = generate_data()
# Initialize the NUTS sampler
nuts_kernel = NUTS(model)
# Run inference
num_warmup, num_samples = 500, 1000
start_time = time.time()
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)
end_time = time.time()
# Print runtime
print(f"Runtime: {end_time - start_time:.2f} seconds")
# Print summary statistics
print(mcmc.print_summary())