lightweight_mmm
lightweight_mmm copied to clipboard
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected
Hello everyone, Â I am trying to find the optimal budget allocation for two given budgets (for example 2000 and 3000).
Here is my code:
########### MODEL INPUT PATH ############
#######################################
model_path = 'path_to_model.pkl'
media_scaler_path = 'path_to_media_scaler.pkl'
target_scaler_path = 'path_to_target_scaler.pkl'
Â
Â
########### LOAD THE MODEL ##############
########################################
mmm_model, media_scaler, target_scaler, prices = sim.read_model_and_scalers(model_path = model_path,
                                                                           media_scaler_path = media_scaler_path,
                                                                           revenue_scaler_path = target_scaler_path)
Â
########### FIND ALLOCATION ##############
########################################
budget = 2000
  Â
solution, kpi_without_optim, previous_budget_allocation = optimize_media.find_optimal_budgets(
               n_time_periods = 30,
               media_mix_model = mmm_model,
               budget = jnp.array(budget),
               prices = prices,
               media_scaler = media_scaler,
               target_scaler = target_scaler,
               seed = 52,
               bounds_lower_pct=0.2,
               bounds_upper_pct=0.2
   )
Â
print(solution)
The first time I run the code with budget=2000, it runs properly. But If I change the budget parameter to budget=3000 and then run it again, I get a ConcretizationTypeError: Â
"name": "ConcretizationTypeError", "message": "Abstract tracer value encountered where concrete value is expected:
Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>\nThe problem arose with the `bool` function. \nThe error
occurred while tracing the function _objective_function at
.venv/lib/python3.9/site-packages/lightweight_mmm/optimize_media.py:27 for jit. This value became a tracer due to JAX
operations on these lines:\n\n operation a:f32[2] = copy b\n   from line
.venv/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:523 (predict)\n\n operation a:f32[] b:f32[] = pjit[\nÂ
jaxpr={ lambda ; c:f32[2]. let\n     d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] c\n     e:f32[] =
squeeze[dimensions=(0,)] d\n     f:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] c\n     g:f32[] =
squeeze[dimensions=(0,)] f\n   in (e, g) }\n name=_unstack\n] h\n   from line
.venv/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:100 (<genexpr>)\n\n operation a:bool[2] = eq b c\n Â
from line .venv/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:104 (_compare_equality_for_lmmm)\n\nÂ
operation a:bool[2] = pjit[\n jaxpr={ lambda ; b:f32[2]. let c:bool[2] = ne b b in (c,) }\n name=isnan\n] d\n   from line
.venv/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:104 (_compare_equality_for_lmmm)\n\n operation
a:bool[2] = pjit[\n jaxpr={ lambda ; b:f32[2]. let c:bool[2] = ne b b in (c,) }\n name=isnan\n] d\n   from line
.venv/lib/python3.9/site-packages/lightweight_mmm/lightweight_mmm.py:104 (_compare_equality_for_lmmm)\n\nSee
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError"
I read the jax documentation trying to modify the package function but I was not able to fix this.   Can someone explain how I can handle this error ?  Thanks in advance !
Hello everyone,
I found a fix: Downgrading jax from 0.4.8 to 0.4.2 fixed the error on mac M1.
Thank you!