Advanced indexing with scalar booleans is broken
Describe the issue:
Using pymc.draw() with truncated distributions throws warnings about shape inference failures in pytensor when no shape information is provided in pymc. The draws are still returned. Adding shape parameters to pymc code fixes the issue.
Reproducable code example:
import pymc as pm
with pm.Model() as m:
θ = pm.Bernoulli("θ", p=0.5)
days = pm.Truncated("days", pm.Binomial.dist(n=7, p=0.5), lower=1)
observed_days = θ * days
draws = pm.draw([θ, observed_days], draws=100)
Error message:
/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py:157: UserWarning: Failed to infer_shape from Op AdvancedSubtensor.
Input shapes: [(), ()]
Exception encountered during infer_shape: <class 'ValueError'>
Exception message: Nonzero only supports non-scalar arrays.
Traceback: Traceback (most recent call last):
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py", line 133, in get_node_infer_shape
o_shapes = shape_infer(
^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/subtensor.py", line 2628, in infer_shape
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 935, in nonzero
res = _nonzero(a)
^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/op.py", line 304, in __call__
node = self.make_node(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 884, in make_node
raise ValueError("Nonzero only supports non-scalar arrays.")
ValueError: Nonzero only supports non-scalar arrays.
PyTensor version information:
pytensor=='2.16.1' pymc=5.8.1.
Context for the issue:
No response
Apparently numpy behavior with scalar booleans is pretty odd, and PyTensor never really supported it (not just infer_shape). We should raise a NotImplementedError in make_node.
https://stackoverflow.com/questions/75828008/docs-about-boolean-scalars-in-indexing-of-numpy-array https://stackoverflow.com/questions/45493270/what-does-xfalse-do-in-numpy
JAX doesn't support boolean scalar indices either
The PyMC issue should be fixed by https://github.com/pymc-devs/pymc/pull/6923
Pure PyTensor example of the problem:
import pytensor
import numpy as np
import pytensor.tensor as pt
x = pt.arange(10)
y = x[np.array(True)]
y.type # TensorType(int64, shape=())
It think it should be a scalar output, but it's not, it's a vector. Numpy is doing something weird with scalar outputs. We either figure out the rule or raise NotImplementedError
If you then eval, there will be a lot of failures
y.eval()
/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/rewriting/shape.py:157: UserWarning: Failed to infer_shape from Op AdvancedSubtensor.
Input shapes: [(Cast{int64}.0,), ()]
Exception encountered during infer_shape: <class 'ValueError'>
Exception message: Nonzero only supports non-scalar arrays.
Traceback: Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/rewriting/shape.py", line 133, in get_node_infer_shape
o_shapes = shape_infer(
^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/subtensor.py", line 2626, in infer_shape
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/basic.py", line 935, in nonzero
res = _nonzero(a)
^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/basic.py", line 884, in make_node
raise ValueError("Nonzero only supports non-scalar arrays.")
ValueError: Nonzero only supports non-scalar arrays.
warn(msg)
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()). constant_folding
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor([0 1 2 3 4 ... 5 6 7 8 9], True)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1966, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 628, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 573, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 481, in replace
new_var = var.type.filter_variable(new_var, allow_convert=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/type.py", line 278, in filter_variable
raise TypeError(
TypeError: Cannot convert Type Matrix(int64, shape=(1, 10)) (of Variable [[0 1 2 3 ... 6 7 8 9]]) into Type Scalar(int64, shape=()). You can try to manually convert [[0 1 2 3 ... 6 7 8 9]] into a Scalar(int64, shape=()).