pymc4
pymc4 copied to clipboard
Numerical instabilities with tf.float32 for cholesky decomposition
Carrying forward the discussions in #289 , it will be great to investigate cholesky decomposition errors further with tf.float32
. Here's a good starting point -
import tensorflow as tf
tf.random.set_seed(42)
def cov(shape, dtype):
sigma = tf.random.normal((shape, shape), dtype=dtype)
return tf.linalg.cholesky(tf.matmul(sigma, sigma, transpose_a=True))
cov(10**4, tf.float32) # decomposition errors
cov(10**4, tf.float64) # works all good
Should I open this issue in TF side?
Your normal distribution is centered at 0 and so the diagonal entries are susceptible to be near zero which presents challenges for float32 datatype. If you center the data differently, the results may change... I did it and got a success rate of about 70% on random data of shape 5000 x 5000
and 100 experiments...
Here's the code:
import numpy as np
import tensorflow as tf
def run(N, nrep, dtype):
sr = 0
for _ in range(nrep):
a = tf.random.normal((N,N), dtype=dtype) + tf.linalg.eye(N, dtype=dtype)
a = tf.linalg.matmul(a, a, transpose_a=True)
try:
__ = tf.linalg.cholesky(a)
sr += 1
print(".", end="")
except:
print("x", end="")
continue
print()
return sr/nrep
if __name__ == "__main__":
import sys
if len(sys.argv) != 4:
print("usage : python testing.py <N> <nrep> <dtype>")
else:
dtype = tf.float32 if sys.argv[3] == 'float32' else tf.float64
print("success rate :", run(int(sys.argv[1]), int(sys.argv[2]), dtype))
Though it is not the case with tensorflow only, I also tried to do the same experiment with numpy and got about 71% success... So, float32
datatype is generally not a good choice for large data...
GPy and other libraries just keep on adding noise until cholesky decomposition passes. So, I don't think there is a better solution to this problem. Though it is worth some discussion...
So, float32 datatype is generally not a good choice for large data...
That's a good observation. I think I should add these lines in notes section for Full Rank ADVI. Thanks @tirthasheshpatel for providing more insights.