lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

JAX Tracer error when trying to plot_response_curves()

Open GilAdirim opened this issue 2 years ago • 4 comments

Hi,

Everything seems to be running smoothly - getting good results and am able to run plot_media_channel_posteriors(), plot_model_fit(), plot_out_of_sample_model_fit(), plot_media_baseline_contribution_area_plot() and plot_bars_media_metrics().

However, when I try to chart the plot_response_cruves: plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)

I'm hit with the following exception:

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-50-456180661661>](https://localhost:8080/#) in <module>
----> 1 plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)

45 frames
UnfilteredStackTrace: jax._src.errors.TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerIntegerConversionError              Traceback (most recent call last)
[/usr/local/lib/python3.9/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _chunk_iter(x, size)
   5219 
   5220   This object references a source array and a specific indexer into that array.
-> 5221   Methods on this object return copies of the source array that have been
   5222   modified at the positions specified by the indexer.
   5223   """

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

I tried downgrading both the matplotlib version as well as the jax/jaxlib as per another issue posted a month ago, but nothing seems to be working.

Any ideas?

GilAdirim avatar Mar 28 '23 13:03 GilAdirim

Have you tried what was discussed in #173 ? Downgrading jax and jaxlib worked perfectly fine for me, however I was able to plot reponse curves only after restarting my environment a couple times.

mpinheiro19 avatar Mar 30 '23 13:03 mpinheiro19

I've tried it, but didn't think of restarting the environment again - seems to be working now!

However, I'm getting barely usable plots: Screen Shot 2023-04-02 at 14 06 52

GilAdirim avatar Apr 02 '23 11:04 GilAdirim

as far as i know, these plots are OK. However, it seems to me your KPI column is scaled. Did you get any error/warning message from the lightweight_mmm library?

Also there is a missing graph in this picture, right? The last one should be a "multiline plot"

mpinheiro19 avatar Apr 03 '23 11:04 mpinheiro19

Hi everyone! Nice workaround. However, there is no installation option for jaxlib==0.4.2 and cuda12 (which I have installed). Do you have any other workaround for this to work for me?)

aapiskotin avatar Apr 13 '23 11:04 aapiskotin