pymc-marketing icon indicating copy to clipboard operation
pymc-marketing copied to clipboard

`mmm.fit()` generates XlaRuntimeError unless `pytensor.config.floatX = "float32"`

Open cluhmann opened this issue 10 months ago • 3 comments

As detailed here, version 0.11.0 seems to have introduced new behavior in which calling

mmm.fit(X=X, y=y, chains=4, target_accept=0.85, nuts_sampler="blackjax", random_seed=rng)

generates an XLA error:

XlaRuntimeError: INTERNAL: Compute error: Error dispatching computation: %sCpuCallback error: Traceback (most recent call last):
  File "C:\Users\terier99\Anaconda3\envs\marketing_env\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback
RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32

The fix was to do the following:

import pytensor  
pytensor.config.floatX = "float32"

cluhmann avatar Jan 24 '25 13:01 cluhmann

What is the action here? Add this to the documentation? Or just have this as an issue to reference?

williambdean avatar Mar 21 '25 08:03 williambdean

Unclear because the expected behavior is unclear. Are users supposed to change to floatX? Was that a decision somewhere along the way?

cluhmann avatar Mar 21 '25 12:03 cluhmann

Using floatX just provides a temporary workaround of the problem, but many other functionalities offered by the library doesn't work with it. See this: Issue #1633

The maintainers should fix the root cause instead of depending on the workaround.

kb-open avatar Apr 21 '25 10:04 kb-open

Hi @kb-open, this is an open source project. Would you like to make this contribution?

williambdean avatar May 21 '25 01:05 williambdean