flowjax
flowjax copied to clipboard
Autodiff problem with block_neural_autoregressive_flow
Very nice package :-)
While playing around with different optimization objectives I ran into an autodiff issue. The following always returns exactly zeros, which I think isn't correct. This might be related to the bisection search (I think that's what's used here?) but if the while_loop in there is a problem I would have expected an error, not an incorrect result.
import flowjax.flows
import jax.numpy as jnp
import jax
import numpy as np
flow_key = jax.random.PRNGKey(0)
point = np.random.randn(5)
cotan = np.random.randn(5)
base_dist = flowjax.distributions.Normal(jnp.zeros(5))
flow = flowjax.flows.block_neural_autoregressive_flow(flow_key, base_dist=base_dist, invert=True)
out, pull_grad_fn = jax.vjp(lambda x: flow.bijection.transform_and_log_det(x), point)
pullback = pull_grad_fn((cotan, 1.))
pullback
# (Array([0., 0., 0., 0., 0.], dtype=float32),)