pymc-marketing
pymc-marketing copied to clipboard
`mmm.fit()` generates XlaRuntimeError unless `pytensor.config.floatX = "float32"`
As detailed here, version 0.11.0 seems to have introduced new behavior in which calling
mmm.fit(X=X, y=y, chains=4, target_accept=0.85, nuts_sampler="blackjax", random_seed=rng)
generates an XLA error:
XlaRuntimeError: INTERNAL: Compute error: Error dispatching computation: %sCpuCallback error: Traceback (most recent call last):
File "C:\Users\terier99\Anaconda3\envs\marketing_env\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback
RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32
The fix was to do the following:
import pytensor
pytensor.config.floatX = "float32"
What is the action here? Add this to the documentation? Or just have this as an issue to reference?
Unclear because the expected behavior is unclear. Are users supposed to change to floatX? Was that a decision somewhere along the way?
Using floatX just provides a temporary workaround of the problem, but many other functionalities offered by the library doesn't work with it. See this: Issue #1633
The maintainers should fix the root cause instead of depending on the workaround.
Hi @kb-open, this is an open source project. Would you like to make this contribution?