pymc icon indicating copy to clipboard operation
pymc copied to clipboard

BUG: Models that include Dirichlet and another distribution do not respect floatX

Open tvwenger opened this issue 1 year ago • 0 comments

Describe the issue:

This may be related to https://github.com/pymc-devs/pymc/issues/6779

Description:

When creating a model with floatX="float32" that includes a Dirichlet distribution, the floatX assignment is respected. When creating a model with a Dirichlet distribution as well as another distribution, however, the floatX assignment is NOT respected, but only upon sampling. This is a weird bug.

Expected Behavior

The model should respect floatX in all cases.

Actual Behavior

When the model includes a Dirichlet distribution and then ANY other distribution, the graph includes float64 despite the request that floatX="float32".

Minimum Working Example

In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.

The first two models sample without issue, and floatX is respected.

The second and third models raise float64 errors during sampling. The error appears after model.point_logps(), which is what was all that was being checked in https://github.com/pymc-devs/pymc/issues/6779

The output (with truncated error messages) is appended below:

pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX =  float64

test_dirichlet
pytensor.config.floatX =  float32
foo float32
{'foo': -1.5}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_normal
pytensor.config.floatX =  float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]

test_dirichlet_normal
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
 ...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

test_dirichlet_halfcauchy
pytensor.config.floatX =  float32
foo float32
bar float32
{'foo': -1.5, 'bar': -1.14}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]

Note that the output of print(model.point_logps()) demonstrates that the error occurs after model.point_logps(). The error occurs during sampling.

Reproduceable code example:

import pytensor
import pytensor.tensor as pt
import pymc as pm

print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)

print("pytensor.config.floatX = ", pytensor.config.floatX)
print()


def test_dirichlet():
    print("test_dirichlet")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_normal():
    print("test_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Normal("foo", mu=0.0, sigma=1.0)
        print(foo, foo.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_normal():
    print("test_dirichlet_normal")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.Normal("bar", mu=0.0, sigma=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


def test_dirichlet_halfcauchy():
    print("test_dirichlet_halfcauchy")
    print("pytensor.config.floatX = ", pytensor.config.floatX)
    with pm.Model() as model:
        foo = pm.Dirichlet("foo", a=pt.ones(3))
        print(foo, foo.dtype)
        bar = pm.HalfCauchy("bar", beta=1.0)
        print(bar, bar.dtype)
    print(model.point_logps())
    with model:
        trace = pm.sample()
    print()


with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_normal()

with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet_halfcauchy()

Error message:

```shell ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0) ERROR (pytensor.graph.rewriting.basic): TRACEBACK: ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last): File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node replacements = node_rewriter.transform(fgraph, node) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform return self.fn(fgraph, node) ^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in local_sum_make_vector add(*[cast(value, acc_dtype) for value in elements]), out_dtype ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in add(*[cast(value, acc_dtype) for value in elements]), out_dtype ^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 763, in cast return _cast_mapping[dtype_name](x) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__ node = self.make_node(*inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 484, in make_node outputs = [ ^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 485, in TensorType(dtype=dtype, shape=shape)() File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 228, in __call__ return utils.add_tag_trace(self.make_variable(name)) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 200, in make_variable return self.variable_type(self, None, name=name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/variable.py", line 900, in __init__ raise Exception(msg) Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}. ```

PyMC version information:

pytensor version: 2.18.6 pymc version: 5.10.4

Context for the issue:

Models that include a Dirichlet distribution as well as any other distribution cannot use float32.

tvwenger avatar Feb 21 '24 20:02 tvwenger