Brian Patton

Results 56 comments of Brian Patton

You might need tfp.util.ParameterProperties(event_ndims=1) for all of your parameters. It's awkwardly named, but basically indicates how many final dimensions of each parameter get consumed to produce a single event.

Do you have tensorflow installed? Uninstalling that might speed up the import, if you're only going to use JAX or numpy. It would be nice to disentangle how much of...

We started the work to fix this, but GeneralSpace (the tangent space we might fall back to for spherical distributions) is not yet implemented. I think internally we have a...

There's now a GeneralSpace in 30c494e3b4acbae039e5c1e74c1e310fc378eafe The test there gives some sense as to how it could be applied, e.g. [HalfSphereSpaceTest](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/experimental/tangent_spaces/spaces_test.py#L159).

Better yet, @srvasude just committed https://github.com/tensorflow/probability/commit/c7f0ce4dfc128b63526079173207353b0496dc54 which should make the spherical distributions transform correctly in tfp-nightly!

Bijector code that uses keras from TF is not going to work well in JAX (other examples would be Glow and PixelCNN). You can look at masked_autoregressive_test.py to see which...

TFP just uses raw TF ops under the hood, so as long as you have a Tensor-in, Tensor-out model, I don't see why you couldn't deploy on mobile.

If stats on a large empirical sample would suffice, you could write `tfp.stats.variance(zero_infl.sample(100_000))`. We don't have variance implemented for mixtures afaik.

Maybe you can write your own loss function, something like ```python def zero_inflated_lognormal_loss(parameters, labels): # I forget the ordering p, mu, sigma = tf.unstack(parameters, axis=-1) zeroness_loss = -tfd.Bernoulli(probs=p).log_prob(tf.equal(labels, 0)) safe_labels...

We have taken a pass over the examples to assess what work might need to be done (which might be a bit out of date, as @davmre has already made...