pymc
pymc copied to clipboard
BUG: Models that include Dirichlet and another distribution do not respect floatX
Describe the issue:
This may be related to https://github.com/pymc-devs/pymc/issues/6779
Description:
When creating a model with floatX="float32" that includes a Dirichlet distribution, the floatX assignment is respected. When creating a model with a Dirichlet distribution as well as another distribution, however, the floatX assignment is NOT respected, but only upon sampling. This is a weird bug.
Expected Behavior
The model should respect floatX in all cases.
Actual Behavior
When the model includes a Dirichlet distribution and then ANY other distribution, the graph includes float64 despite the request that floatX="float32".
Minimum Working Example
In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.
The first two models sample without issue, and floatX is respected.
The second and third models raise float64 errors during sampling. The error appears after model.point_logps(), which is what was all that was being checked in https://github.com/pymc-devs/pymc/issues/6779
The output (with truncated error messages) is appended below:
pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX = float64
test_dirichlet
pytensor.config.floatX = float32
foo float32
{'foo': -1.5}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]
test_normal
pytensor.config.floatX = float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]
test_dirichlet_normal
pytensor.config.floatX = float32
foo float32
bar float32
{'foo': -1.5, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]
test_dirichlet_halfcauchy
pytensor.config.floatX = float32
foo float32
bar float32
{'foo': -1.5, 'bar': -1.14}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]
Note that the output of print(model.point_logps()) demonstrates that the error occurs after model.point_logps(). The error occurs during sampling.
Reproduceable code example:
import pytensor
import pytensor.tensor as pt
import pymc as pm
print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)
print("pytensor.config.floatX = ", pytensor.config.floatX)
print()
def test_dirichlet():
print("test_dirichlet")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_normal():
print("test_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Normal("foo", mu=0.0, sigma=1.0)
print(foo, foo.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_dirichlet_normal():
print("test_dirichlet_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
bar = pm.Normal("bar", mu=0.0, sigma=1.0)
print(bar, bar.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_dirichlet_halfcauchy():
print("test_dirichlet_halfcauchy")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
bar = pm.HalfCauchy("bar", beta=1.0)
print(bar, bar.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_normal()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet_normal()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet_halfcauchy()
Error message:
PyMC version information:
Context for the issue:
Models that include a Dirichlet distribution as well as any other distribution cannot use float32.