probability
probability copied to clipboard
API error in jax substrate?
Hello,
If I install tensorflow probability and try to follow the instructions to create a realNVP flow, I get an import error. (This is replicated across several different installs and operating systems, including linux, mac, and google colab, and python 3.9 and 3.10. It happens even in a fresh conda environment where all that's been installed is tensorflow-probability)
The error can be reproduced as follows:
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfb.real_nvp_default_template(hidden_layers=[512, 512])
This is the smallest reproducible error of trying to run the example from here: https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/RealNVP
To implement the jax substrate, we basically reimplement the TF API surface using jax under the hood. This lets us leave most TFP code intact, just swapping out imports. This is not something we've done for the Keras portions of the TF API surface, so there is no underlying implementation for things like the real_nvp default MLP. You'd need to provide an alternate implementation using flax or something. I feel like there are examples floating around; I'll try to find one.
Thank you, that would be super helpful! (Honestly any example of how to use realNVP from JAX would be helpful—I don't mind writing my own alternative template, but I'm not sure what that would involve.)
Hi Justin, I put together a quick gist here: https://colab.research.google.com/gist/brianwa84/dfa3d56cded8e56038184fb17048afc6/rnvp-jax.ipynb Hopefully that's enough to get you going. LMK if you have questions.
Thanks for the colab @brianwa84 . I found it very helpful.
If I understand correctly, the shift_and_log_scale_fn expects two arguments: the input x , and the number of output variables. It's not clear from the documentation that this second variable is expected, is it?