jax-flows
jax-flows copied to clipboard
Jit error for NeuralSplineCoupling
Is there an example for using the NeuralSplineCoupling?
I tried starting with the example notebook intro.ipynb
, and just replaced the MAF with
bijection = flows.Serial(
*(flows.NeuralSplineCoupling(K=8, B=1, hidden_dim=128),
flows.Reverse())*2
)
init_fun = flows.Flow(bijection, flows.Normal())
But when I run the training loop, I get the error
FilteredStackTrace: IndexError: Array boolean indices must be concrete.
Removing @jit
in front of def step(...):
gets rid of that error, but hopefully there is still some way to jit a NeuralSplineCoupling?
Hi.
Encountered the same error today. Apparently, this is due to JAX not supporting boolean indexing, e.g., see here. I was curious to see if you have found a solution for this problem in this repository.
I never dug into fixing the problem in jax-flows
. I built my own normalizing flow package (pzflow) that has more functionality geared towards modeling tabular data and calculating posteriors. Briefly glancing at the jax-flows
code (specifically at unconstrained_RQS
), it looks like he masks out the out-of-bounds inputs as he hands the inputs to RQS
. In my code, I essentially calculate RQS
for all inputs, then simply replace the RQS
outputs for the out-of-bounds inputs.
But if I remember correctly, the neural splines in jax-flows
didn't work even when you don't jit them. They totally failed to learn the target distribution. I remember finding the reason why, but I can't remember it anymore. I ended up building the splines in (pzflow) based on the tensorflow probability implementation instead.
This is just after a brief glance at the jax-flows
code so I could be wrong! But I think the jax-flows
splines are a little broken at the moment.
Hey there everyone! You are correct that the spline code is not currently working due to the reasons described. I'll try to prioritize this and get a good implementation working. If anyone wants to take a shot at implementing this / submit a PR, happy to work with you as well.