Brian Patton
Brian Patton
What is in constraining_bijector? Consider using tfp.experimental.mcmc.windowed_adaptive_nuts(..) instead. It's not clear how to further debug this without a stack trace or more code. Brian Patton | Software Engineer | ***@***.***...
Looks OK to me (apart from the raytune part, which I dropped). Suggests the problem is actually in tune. constraining_bijector was still undefined, so I defined it. https://colab.research.google.com/gist/brianwa84/3c0c6859b07607416380a1e83be5e430/untitled47.ipynb Brian Patton...
Do you get a stack trace with the exception?
Hi Justin, I put together a quick gist here: https://colab.research.google.com/gist/brianwa84/dfa3d56cded8e56038184fb17048afc6/rnvp-jax.ipynb Hopefully that's enough to get you going. LMK if you have questions.
You could use MVNDiagPlusLowRankCovariance: https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiagPlusLowRankCovariance with very small values for the diagonal. Brian Patton | Software Engineer | ***@***.*** On Fri, Jun 9, 2023 at 10:19 PM Qiuliang Ye ***@***.***>...
I reproduced against tfp-nightly. The intent is that this should work, but the base measure / tangent space corrections are relatively lightly exercised thus far, so I think you've found...
There are not currently such plans. Testing might be quite challenging, unless MLX provides an emulator? The best documentation about multi-backend support is in here: https://github.com/tensorflow/probability/blob/main/SUBSTRATES.md You would need to...
Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.
TensorList* is how TF tracks gradients for while loops. But the representation is different inside XLA from outside, so generally can't cross the jit_compile=True boundary. At least, that's mostly where...
Can you wrap the code in @tf.function(jit_compile=True) ?