numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Flax BNN is several times slower in JAX 0.4.33 compared to JAX 0.4.31

Open ziatdinovmax opened this issue 1 year ago • 4 comments

Jax-0.4.31: Runtime: 27.06 seconds https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing

Jax-0.4.33: Runtime: 84.91 seconds https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing

I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic

Code:

import time
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.contrib.module import random_flax_module
import flax.linen as nn


# Set a random seed for reproducibility
rng_key = jax.random.PRNGKey(0)

# Generate some dummy data
def generate_data(n=100, noise_std=0.1):
    X = jnp.linspace(-1, 1, n)
    y = 3 * X + 2 + np.random.normal(0, noise_std, size=X.shape)
    return X[:, None], y

# Define a simple neural network
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x.squeeze()

# Define the model
def model(X, y):
    module = SimpleNN()
    nn = random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))

    with numpyro.plate("data", X.shape[0]):
        mean = nn(X)
        numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

# Generate data
X, y = generate_data()

# Initialize the NUTS sampler
nuts_kernel = NUTS(model)

# Run inference
num_warmup, num_samples = 500, 1000

start_time = time.time()

mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)

end_time = time.time()

# Print runtime
print(f"Runtime: {end_time - start_time:.2f} seconds")

# Print summary statistics
print(mcmc.print_summary())

ziatdinovmax avatar Sep 25 '24 06:09 ziatdinovmax