BUG: C backend is failing to handle elemwise scalarloop code
Describe the issue:
Wrote a small example to get the Elemwise and ScalarLoop pattern. The code computes the (n + 1)th fibonacci number. When running it using linker=py, things work fine. When running it using the default cvm, I get Scalar check failed (numpy_int32).
Reproducable code example:
import numpy as np
import pytensor
from pytensor.compile.io import In
from pytensor.scalar import ScalarLoop, float32, int32
from pytensor.tensor.elemwise import Elemwise
n_steps = int32("n_steps")
f0 = float32("f0")
f1 = float32("f1")
end = float32("end")
i = float32("end")
op = ScalarLoop(init=[f0, f1, end, i],
update=[
pytensor.scalar.basic.identity(f1),
f0 + f1,
pytensor.scalar.basic.identity(end),
i + 1
], until=i >= end)
e = Elemwise(op)
_, p, _, _, done = e(n_steps, f0, f1, end, i)
fn = pytensor.function([n_steps,
end,
In(f0, value=0),
In(f1, value=1),
In(i, value=2)], [p, done])
print(fn(np.array([7.0]),
np.array([10.0, 5.0, 6.0])))
Error message:
(pytensor-dev) ✘ ch0ronomato@macbook-pro ~/dev/pytensor scalarloop PYTENSOR_FLAGS="optimizer=None" python example.py
Traceback (most recent call last):
File "/Users/ch0ronomato/dev/pytensor/pytensor/compile/function/types.py", line 959, in __call__
self.vm()
ValueError: Scalar check failed (npy_int32)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/ch0ronomato/dev/pytensor/example.py", line 30, in <module>
print(fn([7.0], [10.0, 5.0]))
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ch0ronomato/dev/pytensor/pytensor/compile/function/types.py", line 972, in __call__
raise_with_op(
File "/Users/ch0ronomato/dev/pytensor/pytensor/link/utils.py", line 524, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/Users/ch0ronomato/dev/pytensor/pytensor/compile/function/types.py", line 959, in __call__
self.vm()
ValueError: Scalar check failed (npy_int32)
Apply node that caused the error: TensorFromScalar(n_steps)
Toposort index: 4
Inputs types: [ScalarType(int32)]
Inputs shapes: [(1,)]
Inputs strides: [(4,)]
Inputs values: [array([7], dtype=int32)]
Outputs clients: [[Scalarloop(TensorFromScalar.0, TensorFromScalar.0, TensorFromScalar.0, TensorFromScalar.0, TensorFromScalar.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/ch0ronomato/dev/pytensor/example.py", line 22, in <module>
_, p, _, _, done = e(n_steps, f0, f1, end, i)
File "/Users/ch0ronomato/dev/pytensor/pytensor/graph/op.py", line 293, in __call__
node = self.make_node(*inputs, **kwargs)
File "/Users/ch0ronomato/dev/pytensor/pytensor/tensor/elemwise.py", line 491, in make_node
inputs = [as_tensor_variable(i) for i in inputs]
File "/Users/ch0ronomato/dev/pytensor/pytensor/tensor/elemwise.py", line 491, in <listcomp>
inputs = [as_tensor_variable(i) for i in inputs]
File "/Users/ch0ronomato/dev/pytensor/pytensor/tensor/__init__.py", line 50, in as_tensor_variable
return _as_tensor_variable(x, name, ndim, **kwargs)
File "/opt/anaconda3/envs/pytensor-dev/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
PyTensor version information:
Python 3.11.9; cloned from repo and ran. All config below.
Note: float16 support is experimental, use at your own risk. Value: float64
warn_float64 ({'ignore', 'raise', 'warn', 'pdb'}) Doc: Do an action when a tensor variable with float64 dtype is created. Value: ignore
pickle_test_value (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1020bee50>>) Doc: Dump test values while pickling model. If True, test values will be dumped with model. Value: True
cast_policy ({'custom', 'numpy+floatX'}) Doc: Rules for implicit type casting Value: custom
device (cpu) Doc: Default device for computations. only cpu is supported for now Value: cpu
conv__assert_shape (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1027ce350>>) Doc: If True, AbstractConv* ops will verify that user-provided shapes match the runtime shapes (debugging option, may slow down compilation) Value: False
print_global_stats (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570a250>>) Doc: Print some global statistics (time spent) at the end Value: False
unpickle_function (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105708e90>>) Doc: Replace unpickled PyTensor functions with None. This is useful to unpickle old graphs that pickled them when it shouldn't Value: True
<pytensor.configparser.ConfigParam object at 0x105708ed0> Doc: Default compilation mode Value: Mode
cxx (<class 'str'>) Doc: The C++ compiler to use. Currently only g++ is supported, but supporting additional compilers should not be too difficult. If it is empty, no C++ code is compiled. Value: /opt/anaconda3/envs/pytensor-dev/bin/clang++
linker ({'vm_nogc', 'c|py_nogc', 'vm', 'c', 'cvm', 'c|py', 'cvm_nogc', 'py'}) Doc: Default linker used if the pytensor flags mode is Mode Value: cvm
allow_gc (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105708090>>) Doc: Do we default to delete intermediate results during PyTensor function calls? Doing so lowers the memory requirement, but asks that we reallocate memory at the next function call. This is implemented for the default linker, but may not work for all linkers. Value: True
optimizer ({'o3', 'o4', 'fast_run', 'fast_compile', 'None', 'o1', 'o2', 'unsafe', 'merge'}) Doc: Default optimizer. If not None, will use this optimizer with the Mode Value: o4
optimizer_verbose (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x102086d10>>) Doc: If True, we print all optimization being applied Value: False
on_opt_error ({'ignore', 'raise', 'warn', 'pdb'}) Doc: What to do when an optimization crashes: warn and skip it, raise the exception, or fall into the pdb debugger. Value: warn
nocleanup (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1056f9090>>) Doc: Suppress the deletion of code files that did not compile cleanly Value: False
on_unused_input ({'ignore', 'raise', 'warn'}) Doc: What to do if a variable in the 'inputs' list of pytensor.function() is not used in the graph. Value: raise
gcc__cxxflags (<class 'str'>) Doc: Extra compiler flags for gcc Value: -Wno-c++11-narrowing
cmodule__warn_no_version (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105708c90>>) Doc: If True, will print a warning when compiling one or more Op with C code that can't be cached because there is no c_code_cache_version() function associated to at least one of those Ops. Value: False
cmodule__remove_gxx_opt (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570a710>>) Doc: If True, will remove the -O* parameter passed to g++.This is useful to debug in gdb modules compiled by PyTensor.The parameter -g is passed by default to g++ Value: False
cmodule__compilation_warning (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b110>>) Doc: If True, will print compilation warnings. Value: False
cmodule__preload_cache (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b090>>) Doc: If set to True, will preload the C module cache at import time Value: False
cmodule__age_thresh_use (<class 'int'>) Doc: In seconds. The time after which PyTensor won't reuse a compile c module. Value: 2073600
cmodule__debug (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x101bd1610>>) Doc: If True, define a DEBUG macro (if not exists) for any compiled C code. Value: False
compile__wait (<class 'int'>) Doc: Time to wait before retrying to acquire the compile lock. Value: 5
compile__timeout (<class 'int'>) Doc: In seconds, time that a process will wait before deciding to override an existing lock. An override only happens when the existing lock is held by the same owner and has not been 'refreshed' by this owner for more than this period. Refreshes are done every half timeout period for running processes. Value: 120
tensor__cmp_sloppy (<class 'int'>) Doc: Relax pytensor.tensor.math._allclose (0) not at all, (1) a bit, (2) more Value: 0
lib__amdlibm (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x102716390>>) Doc: Use amd's amdlibm numerical library Value: False
tensor__insert_inplace_optimizer_validate_nb (<class 'int'>) Doc: -1: auto, if graph have less then 500 nodes 1, else 10 Value: -1
traceback__limit (<class 'int'>) Doc: The number of stack to trace. -1 mean all. Value: 8
traceback__compile_limit (<class 'int'>) Doc: The number of stack to trace to keep during compilation. -1 mean all. If greater then 0, will also make us save PyTensor internal stack trace. Value: 0
warn__ignore_bug_before ({'0.5', '0.4', '0.6', '1.0.4', '0.9', '0.3', '1.0', '1.0.1', '1.0.3', '0.7', '0.8.2', '1.0.5', '0.10', '0.4.1', 'None', '0.8.1', 'all', '1.0.2', '0.8'}) Doc: If 'None', we warn about all PyTensor bugs found by default. If 'all', we don't warn about PyTensor bugs found by default. If a version, we print only the warnings relative to PyTensor bugs found after that version. Warning for specific bugs can be configured with specific [warn] flags. Value: 0.9
exception_verbosity ({'low', 'high'}) Doc: If 'low', the text of exceptions will generally refer to apply nodes with short names such as Elemwise{add_no_inplace}. If 'high', some exceptions will also refer to apply nodes with long descriptions like: A. Elemwise{add_no_inplace} B. log_likelihood_v_given_h C. log_likelihood_h Value: low
print_test_value (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1020de2d0>>)
Doc: If 'True', the eval of an PyTensor variable will return its test_value when this is available. This has the practical consequence that, e.g., in debugging my_var will print the same as my_var.tag.test_value when a test value is defined.
Value: False
compute_test_value ({'ignore', 'warn', 'pdb', 'raise', 'off'}) Doc: If 'True', PyTensor will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized. Value: off
compute_test_value_opt ({'ignore', 'warn', 'pdb', 'raise', 'off'}) Doc: For debugging PyTensor optimization only. Same as compute_test_value, but is used during PyTensor optimization Value: off
check_input (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b050>>) Doc: Specify if types should check their input in their C code. It can be used to speed up compilation, reduce overhead (particularly for scalars) and reduce the number of generated C files. Value: True
NanGuardMode__nan_is_error (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1024b6110>>) Doc: Default value for nan_is_error Value: True
NanGuardMode__inf_is_error (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b310>>) Doc: Default value for inf_is_error Value: True
NanGuardMode__big_is_error (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b2d0>>) Doc: Default value for big_is_error Value: True
NanGuardMode__action ({'raise', 'warn', 'pdb'}) Doc: What NanGuardMode does when it finds a problem Value: raise
DebugMode__patience (<class 'int'>) Doc: Optimize graph this many times to detect inconsistency Value: 10
DebugMode__check_c (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1025962d0>>) Doc: Run C implementations where possible Value: True
DebugMode__check_py (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b550>>) Doc: Run Python implementations where possible Value: True
DebugMode__check_finite (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x102717910>>) Doc: True -> complain about NaN/Inf results Value: True
DebugMode__check_strides (<class 'int'>) Doc: Check that Python- and C-produced ndarrays have same strides. On difference: (0) - ignore, (1) warn, or (2) raise error Value: 0
DebugMode__warn_input_not_reused (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b750>>) Doc: Generate a warning when destroy_map or view_map says that an op works inplace, but the op did not reuse the input for its output. Value: True
DebugMode__check_preallocated_output (<class 'str'>) Doc: Test thunks with pre-allocated memory as output storage. This is a list of strings separated by ":". Valid values are: "initial" (initial storage in storage map, happens with Scan),"previous" (previously-returned memory), "c_contiguous", "f_contiguous", "strided" (positive and negative strides), "wrong_size" (larger and smaller dimensions), and "ALL" (all of the above). Value:
DebugMode__check_preallocated_output_ndim (<class 'int'>) Doc: When testing with "strided" preallocated output memory, test all combinations of strides over that number of (inner-most) dimensions. You may want to reduce that number to reduce memory or time usage, but it is advised to keep a minimum of 2. Value: 4
profiling__time_thunks (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570b610>>) Doc: Time individual thunks when profiling Value: True
profiling__n_apply (<class 'int'>) Doc: Number of Apply instances to print by default Value: 20
profiling__n_ops (<class 'int'>) Doc: Number of Ops to print by default Value: 20
profiling__output_line_width (<class 'int'>) Doc: Max line width for the profiling output Value: 512
profiling__min_memory_size (<class 'int'>) Doc: For the memory profile, do not print Apply nodes if the size of their outputs (in bytes) is lower than this threshold Value: 1024
profiling__min_peak_memory (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x102716050>>) Doc: The min peak memory usage of the order Value: False
profiling__destination (<class 'str'>) Doc: File destination of the profiling output Value: stderr
profiling__debugprint (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x102717890>>) Doc: Do a debugprint of the profiled functions Value: False
profiling__ignore_first_call (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x10570bad0>>) Doc: Do we ignore the first call of an PyTensor function. Value: False
on_shape_error ({'raise', 'warn'}) Doc: warn: print a warning and use the default value. raise: raise an error Value: warn
openmp (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1056c6350>>) Doc: Allow (or not) parallel computation on the CPU with OpenMP. This is the default value used when creating an Op that supports OpenMP parallelization. It is preferable to define it via the PyTensor configuration file ~/.pytensorrc or with the environment variable PYTENSOR_FLAGS. Parallelization is only done for some operations that implement it, and even for operations that implement parallelism, each operation is free to respect this flag or not. You can control the number of threads used with the environment variable OMP_NUM_THREADS. If it is set to 1, we disable openmp in PyTensor by default. Value: False
openmp_elemwise_minsize (<class 'int'>) Doc: If OpenMP is enabled, this is the minimum size of vectors for which the openmp parallelization is enabled in element wise ops. Value: 200000
optimizer_excluding (<class 'str'>) Doc: When using the default mode, we will remove optimizer with these tags. Separate tags with ':'. Value:
optimizer_including (<class 'str'>) Doc: When using the default mode, we will add optimizer with these tags. Separate tags with ':'. Value:
optimizer_requiring (<class 'str'>) Doc: When using the default mode, we will require optimizer with these tags. Separate tags with ':'. Value:
optdb__position_cutoff (<class 'float'>) Doc: Where to stop earlier during optimization. It represent the position of the optimizer where to stop. Value: inf
optdb__max_use_ratio (<class 'float'>) Doc: A ratio that prevent infinite loop in EquilibriumGraphRewriter. Value: 8.0
cycle_detection ({'regular', 'fast'}) Doc: If cycle_detection is set to regular, most inplaces are allowed,but it is slower. If cycle_detection is set to faster, less inplacesare allowed, but it makes the compilation faster.The interaction of which one give the lower peak memory usage iscomplicated and not predictable, so if you are close to the peakmemory usage, triyng both could give you a small gain. Value: regular
check_stack_trace ({'raise', 'warn', 'log', 'off'}) Doc: A flag for checking the stack trace during the optimization process. default (off): does not check the stack trace of any optimization log: inserts a dummy stack trace that identifies the optimizationthat inserted the variable that had an empty stack trace.warn: prints a warning if a stack trace is missing and also a dummystack trace is inserted that indicates which optimization insertedthe variable that had an empty stack trace.raise: raises an exception if a stack trace is missing Value: off
metaopt__verbose (<class 'int'>) Doc: 0 for silent, 1 for only warnings, 2 for full output withtimings and selected implementation Value: 0
unittests__rseed (<class 'str'>) Doc: Seed to use for randomized unit tests. Special value 'random' means using a seed of None. Value: 666
warn__round (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105718310>>)
Doc: Warn when using tensor.round with the default mode. Round changed its default from half_away_from_zero to half_to_even to have the same default as NumPy.
Value: False
profile (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105718410>>) Doc: If VM should collect profile information Value: False
profile_optimizer (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1057184d0>>) Doc: If VM should collect optimizer profile information Value: False
profile_memory (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105718490>>) Doc: If VM should collect memory profile information and print it Value: False
<pytensor.configparser.ConfigParam object at 0x1056bbb10> Doc: Useful only for the VM Linkers. When lazy is None, auto detect if lazy evaluation is needed and use the appropriate version. If the C loop isn't being used and lazy is True, use the Stack VM; otherwise, use the Loop VM. Value: None
numba__vectorize_target ({'cpu', 'parallel', 'cuda'}) Doc: Default target for numba.vectorize. Value: cpu
numba__fastmath (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1057186d0>>) Doc: If True, use Numba's fastmath mode. Value: True
numba__cache (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x1026f9290>>) Doc: If True, use Numba's file based caching. Value: True
compiledir_format (<class 'str'>) Doc: Format string for platform-dependent compiled module subdirectory (relative to base_compiledir). Available keys: device, gxx_version, hostname, numpy_version, platform, processor, pytensor_version, python_bitwidth, python_int_bitwidth, python_version, short_platform. Defaults to compiledir_%(short_platform)s-%(processor)s- %(python_version)s-%(python_bitwidth)s. Value: compiledir_%(short_platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s
<pytensor.configparser.ConfigParam object at 0x101bd1110> Doc: platform-independent root directory for compiled modules Value: /Users/ch0ronomato/.pytensor
<pytensor.configparser.ConfigParam object at 0x1024c2690> Doc: platform-dependent cache directory for compiled modules Value: /Users/ch0ronomato/.pytensor/compiledir_macOS-14.5-x86_64-i386-64bit-i386-3.11.9-64
blas__ldflags (<class 'str'>) Doc: lib[s] to include for [Fortran] level-3 blas implementation Value: -L/opt/anaconda3/envs/pytensor-dev/lib -llapack -lblas -lcblas -lm -Wl,-rpath,/opt/anaconda3/envs/pytensor-dev/lib
blas__check_openmp (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x105ad1bd0>>) Doc: Check for openmp library conflict. WARNING: Setting this to False leaves you open to wrong results in blas-related operations. Value: True
scan__allow_gc (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x118628c10>>) Doc: Allow/disallow gc inside of Scan (default: False) Value: False
scan__allow_output_prealloc (<bound method BoolParam._apply of <pytensor.configparser.BoolParam object at 0x118379250>>) Doc: Allow/disallow memory preallocation for outputs inside of scan (default: True) Value: True
Context for the issue:
This doesn't particularly affect me. I want to make sure the torch scalarloop and elemwise work well, and this seemed like a good way to test that; also ran into this.
Which version of numpy?
Numpy is
numpy 1.26.4 py310hd45542a_0 conda-forge
You had some problems in how you specified the graph. The error is a bit obtuse, but what happened is that you defined a pytensor function from scalar inputs to scalar outputs (note that end is a float32, not a vector of float32). Instead you want to do something like this:
vector_end = pt.vector(dtype="float32")
_, p, _, _, done = e(n_steps, f0, f1, vector_end, i)
fn = pytensor.function([n_steps,
vector_end,
In(f0, value=0),
In(f1, value=1),
In(i, value=2)], [p, done])
print(fn(np.float32(7.0), np.array([10.0, 5.0, 6.0]).astype("float32")))
Note how n_steps is still left as a flot32, not a tensor of float32, so I define it as np.float32
In general your example is still a bit odd, in that we usually use scalar types only to define the inner ScalarLoop graph, but then use only elemwise types in the actual function. I also swapped the float32 for int32 in the variables that seem to behave as integers, not that it matters. The full example looks like:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.scalar import ScalarLoop, float32, int32
from pytensor.tensor.elemwise import Elemwise
n_steps = int32("n_steps")
f0 = float32("f0")
f1 = float32("f1")
end = int32("end")
i = int32("end")
op = ScalarLoop(
init=[f0, f1, end, i],
update=[
pytensor.scalar.basic.identity(f1),
f0 + f1,
pytensor.scalar.basic.identity(end),
i + 1
],
until=i >= end,
)
e = Elemwise(op)
# Elemwise takes as inputs tensor types. If you pass scalars you will see a `TensorFromScalar` Op in the graph
n_steps = pt.scalar("n_steps", dtype="int32")
f0 = pt.scalar("f0", dtype="float32")
f1 = pt.scalar("f1", dtype="float32")
end = pt.vector("end", dtype="int32")
i = pt.scalar("i", dtype="int32")
_, p, _, _, done = e(n_steps, f0, f1, end, i)
fn = pytensor.function([n_steps, f0, f1, end, i], [p, done])
print(
fn(
n_steps=np.array(7).astype("int32"),
f0=np.array(0.0).astype("float32"),
f1=np.array(1.0).astype("float32"),
end=np.array([10, 5, 6]).astype("int32"),
i=np.array(2).astype("int32"),
)
)
# [array([21., 5., 8.], dtype=float32), array([False, True, True])]
Sounds good - thanks for letting me know the example also looks strange. Will help me when actually testing the right thing.