lightweight_mmm
lightweight_mmm copied to clipboard
JAX Tracer error when trying to plot_response_curves()
Hi,
Everything seems to be running smoothly - getting good results and am able to run plot_media_channel_posteriors(), plot_model_fit(), plot_out_of_sample_model_fit(), plot_media_baseline_contribution_area_plot() and plot_bars_media_metrics().
However, when I try to chart the plot_response_cruves: plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)
I'm hit with the following exception:
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-50-456180661661>](https://localhost:8080/#) in <module>
----> 1 plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)
45 frames
UnfilteredStackTrace: jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerIntegerConversionError Traceback (most recent call last)
[/usr/local/lib/python3.9/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _chunk_iter(x, size)
5219
5220 This object references a source array and a specific indexer into that array.
-> 5221 Methods on this object return copies of the source array that have been
5222 modified at the positions specified by the indexer.
5223 """
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
I tried downgrading both the matplotlib version as well as the jax/jaxlib as per another issue posted a month ago, but nothing seems to be working.
Any ideas?
Have you tried what was discussed in #173 ? Downgrading jax and jaxlib worked perfectly fine for me, however I was able to plot reponse curves only after restarting my environment a couple times.
I've tried it, but didn't think of restarting the environment again - seems to be working now!
However, I'm getting barely usable plots:

as far as i know, these plots are OK. However, it seems to me your KPI column is scaled. Did you get any error/warning message from the lightweight_mmm library?
Also there is a missing graph in this picture, right? The last one should be a "multiline plot"
Hi everyone! Nice workaround. However, there is no installation option for jaxlib==0.4.2 and cuda12 (which I have installed). Do you have any other workaround for this to work for me?)