numpyro
numpyro copied to clipboard
Block Neural Autoregressive Flow density not properly normalised
Hi all,
I believe there is an issue with BlockNeuralAutoregressiveTransforms not forming properly normalised densities when the inverse is used to transform a distribution (the same issue as we have here https://github.com/danielward27/flowjax/issues/102). See below
from functools import partial
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from numpyro.distributions import Normal, TransformedDistribution
from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
if __name__ == "__main__":
dim = (1,)
init_fn, apply_fn = BlockNeuralAutoregressiveNN(dim[0])
params = init_fn(jr.PRNGKey(0), (1,))[1]
x = jnp.linspace(-30, 30, 10000)[:, None]
arn = partial(apply_fn, params)
bnaf = BlockNeuralAutoregressiveTransform(arn)
# Plot showing bijection (1D) - note not real -> real!
plt.plot(x, bnaf(x))
plt.show()
# Plot transformed normal
dist = TransformedDistribution(Normal(jnp.zeros(dim)), bnaf.inv)
probs = jnp.exp(dist.log_prob(x))
probs, x = jnp.squeeze(probs), jnp.squeeze(x)
plt.plot(x, probs)
plt.show()
# Rough integral
print(jnp.trapz(probs, x)) # ~0.17
Note the codomain of BlockNeuralAutoregressiveTransform is set to real_vector, although the output is actually a linear transformation applied after a Tanh bijection, which won't map to the real line. I'm not sure what the best solution is? Maybe implement something like LeakyTanh (i.e. tanh but switch to linear outside some interval like [-3, 3]), and use that inside BNAFs instead?
Very interesting! Thanks for detailed explanation. Unfortunately, I'm not sure what's the best solution here. :(