sampling-book icon indicating copy to clipboard operation
sampling-book copied to clipboard

Use a Bayesian CNN on the MNIST dataset

Open rlouf opened this issue 3 years ago • 5 comments

Blackjax already has an example where we use SGLD to sample from a 3 layer MLP with a very decent accuracy when using the uncertainties to discard ambiguous predictions. We can use the CNN architecture in the Flax documentation:

from flax import linen as nn  

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)

    return x

And the logprob function as (not tested):

from jax.tree_utils import flatten_pytree
import distrax

def logpdf(params, images, categories, model):
    logits = model.apply(params, images).ravel()
    flat_params, _ = ravel_pytree(params)
    log_prior = distrax.Normal(0.0, 1.0).log_prob(flat_params).sum()
    log_likelihood = distrax.Bernoulli(logits=logits).log_prob(categories).sum()
 
    return log_prior + log_likelihood

We should look at:

  • Comparison between SgLD and SgHMC (#211)
  • Raw accuracy compared to a solution that uses SGD (with Optax)
  • Show the distribution of "confidence" in predictions
  • Accuracy once we've removed examples where model is not sure
  • Examples where the model is not sure / proportion of examples where it is not sure

rlouf avatar Jun 29 '22 08:06 rlouf

Hey @rlouf, love the example! Inside the logpdf function the y variable doesn't exist, I am guessing it should be categories instead?

cgarciae avatar Jun 30 '22 15:06 cgarciae

Yes, made the change, thank you! I have no guarantee that this will work though

rlouf avatar Jun 30 '22 15:06 rlouf

Hi @rlouf, I’ll work on this issue!

gerdm avatar Jun 30 '22 15:06 gerdm

Hey @gerdm do you still intend on working on this?

rlouf avatar Aug 29 '22 12:08 rlouf

Hey @rlouf. Yes, still planning to work on it. Expect updates in September.

gerdm avatar Aug 30 '22 07:08 gerdm