jax-flows icon indicating copy to clipboard operation
jax-flows copied to clipboard

Jit error for NeuralSplineCoupling

Open jfcrenshaw opened this issue 4 years ago • 3 comments

Is there an example for using the NeuralSplineCoupling? I tried starting with the example notebook intro.ipynb, and just replaced the MAF with

bijection = flows.Serial(
    *(flows.NeuralSplineCoupling(K=8, B=1, hidden_dim=128),
      flows.Reverse())*2
)

init_fun = flows.Flow(bijection, flows.Normal())

But when I run the training loop, I get the error FilteredStackTrace: IndexError: Array boolean indices must be concrete. Removing @jit in front of def step(...): gets rid of that error, but hopefully there is still some way to jit a NeuralSplineCoupling?

jfcrenshaw avatar Jan 07 '21 07:01 jfcrenshaw

Hi.

Encountered the same error today. Apparently, this is due to JAX not supporting boolean indexing, e.g., see here. I was curious to see if you have found a solution for this problem in this repository.

hmdolatabadi avatar Mar 16 '21 05:03 hmdolatabadi

I never dug into fixing the problem in jax-flows. I built my own normalizing flow package (pzflow) that has more functionality geared towards modeling tabular data and calculating posteriors. Briefly glancing at the jax-flows code (specifically at unconstrained_RQS), it looks like he masks out the out-of-bounds inputs as he hands the inputs to RQS. In my code, I essentially calculate RQS for all inputs, then simply replace the RQS outputs for the out-of-bounds inputs.

But if I remember correctly, the neural splines in jax-flows didn't work even when you don't jit them. They totally failed to learn the target distribution. I remember finding the reason why, but I can't remember it anymore. I ended up building the splines in (pzflow) based on the tensorflow probability implementation instead.

This is just after a brief glance at the jax-flows code so I could be wrong! But I think the jax-flows splines are a little broken at the moment.

jfcrenshaw avatar Mar 17 '21 21:03 jfcrenshaw

Hey there everyone! You are correct that the spline code is not currently working due to the reasons described. I'll try to prioritize this and get a good implementation working. If anyone wants to take a shot at implementing this / submit a PR, happy to work with you as well.

ChrisWaites avatar Mar 19 '21 01:03 ChrisWaites