Use a Bayesian CNN on the MNIST dataset
Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:
from flax import linen as nn
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
And the logprob function as (not tested):
from jax.tree_utils import flatten_pytree
import distrax
def logpdf(params, images, categories, model):
logits = model.apply(params, images).ravel()
flat_params, _ = ravel_pytree(params)
log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum()
log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum()
return log_prior + log_likelihood
We should look at:
- Comparison between SgLD and SgHMC (#211)
- Raw accuracy compared to a solution that uses SGD (with Optax)
- Show the distribution of "confidence" in predictions
- Accuracy once we've removed examples where model is not sure
- Examples where the model is not sure / proportion of examples where it is not sure
Hey @rlouf, love the example! Inside the logpdf function the y variable doesn't exist, I am guessing it should be categories instead?
Yes, made the change, thank you! I have no guarantee that this will work though
Hi @rlouf, I’ll work on this issue!
Hey @gerdm do you still intend on working on this?
Hey @rlouf. Yes, still planning to work on it. Expect updates in September.