numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Block Neural Autoregressive Flow density not properly normalised

Open danielward27 opened this issue 2 years ago • 1 comments

Hi all,

I believe there is an issue with BlockNeuralAutoregressiveTransforms not forming properly normalised densities when the inverse is used to transform a distribution (the same issue as we have here https://github.com/danielward27/flowjax/issues/102). See below

from functools import partial

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from numpyro.distributions import Normal, TransformedDistribution
from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN

if __name__ == "__main__":
    dim = (1,)
    init_fn, apply_fn = BlockNeuralAutoregressiveNN(dim[0])

    params = init_fn(jr.PRNGKey(0), (1,))[1]
    x = jnp.linspace(-30, 30, 10000)[:, None]
    arn = partial(apply_fn, params)
    bnaf = BlockNeuralAutoregressiveTransform(arn)

    # Plot showing bijection (1D) - note not real -> real!
    plt.plot(x, bnaf(x))
    plt.show()

    # Plot transformed normal
    dist = TransformedDistribution(Normal(jnp.zeros(dim)), bnaf.inv)
    probs = jnp.exp(dist.log_prob(x))
    probs, x = jnp.squeeze(probs), jnp.squeeze(x)

    plt.plot(x, probs)
    plt.show()

    # Rough integral
    print(jnp.trapz(probs, x))  # ~0.17

Note the codomain of BlockNeuralAutoregressiveTransform is set to real_vector, although the output is actually a linear transformation applied after a Tanh bijection, which won't map to the real line. I'm not sure what the best solution is? Maybe implement something like LeakyTanh (i.e. tanh but switch to linear outside some interval like [-3, 3]), and use that inside BNAFs instead?

danielward27 avatar Sep 29 '23 12:09 danielward27

Very interesting! Thanks for detailed explanation. Unfortunately, I'm not sure what's the best solution here. :(

fehiepsi avatar Oct 02 '23 14:10 fehiepsi