probability
probability copied to clipboard
Probabilistic reasoning and statistical analysis in TensorFlow
Hello! I am currently trying to use `JointDistributionSequential` to predict multiple distributions using a Mixture Density Network. Minimal example: ```python import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability...
## Code ``` from jax.config import config config.update("jax_enable_x64", True) import os #os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform' import jax import jax.numpy as jnp from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions Y=jnp.ones((4032,258),dtype=jnp.float64) distribution =...
TFP 0.19.0 is using a deprecated TFP kwarg `scale_identity_multiplier` which results in the following log spam: ``` From site-packages/tensorflow_probability/python/distributions/distribution.py:342: calling MultivariateNormalDiag.__init__ (from tensorflow_probability.python.distributions.mvn_diag) with scale_identity_multiplier is deprecated and will be...
I'm currently running a distributed HMC on 4 Tesla V100 cards, and my codes are like: `import functools` `import collections` `import contextlib` `from jax.config import config` `config.update("jax_enable_x64", True)` `import jax`...
Hi, I was trying out the RelaxedOneHotCategorical function on tensorfloa_probability version 0.19.0. The following code gives me the incorrect distribution. ``` from tensorflow_probability.substrates import jax as tfp temperature = 0.5...
Hi all, I have a simple BNN that I just tried to change to have a negative binomial distribution as output: ``` def get_model(input_shape, loss, optimizer, metrics, kl_weight, output_shape): inputs...
There appears to be a breaking change in the way MultivariateNormalTriL works together with tf.keras in tf 2.16.1 and tfp 0.24.0, tf_keras version 2.16.0 I'm using python 3.11.8 on a...
Given this brief [code](https://colab.research.google.com/drive/18JkC4UjMLz5YYDZj31IZLXF_7TuRKrKd#scrollTo=DbfH5lcL4EnI): ```py import pandas as pd import tensorflow_probability as tfp import tensorflow as tf data = pd.read_csv('https://raw.githubusercontent.com/WillianFuks/tfcausalimpact/master/tests/fixtures/arma_data.csv')[['y']].astype('float32') data.index = pd.date_range(start='2024-01-01', periods=len(data), freq='D') obs = data.iloc[:70] model =...
I am trying to use TensorFlow probability as a metric in Keras. With respect to kendalls_tau, I get the following error: ``` import tensorflow_probability as tfp import tensorflow as tf...
Dear all, I am currently learning Bayesian analysis and utilizing `tensorflow_probability.substrates.jax`, but I've encountered some issues. While using `jax` with `jit` for NUTS alone, the performance is quite fast. However,...