lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

How can I calibrate my predicted ROI in the MMM with my geolift result ?

Open elisakrammerfiverr opened this issue 1 year ago • 4 comments

elisakrammerfiverr avatar Dec 04 '23 11:12 elisakrammerfiverr

@elisakrammerfiverr Depends how much you are willing to alter the underlying codebase.

One direct method we use is to treat the GEOLift test result as an additional observation, and compare it to the intermediate media contribution calculated. Where observations is the list of geotest results. You will need to scale your geolift test results "attributed KPI" using the target scaler to bring it to the right domain for the model, and this solution is for if you extrapolate the geolift result to the whole region e.g. market, you are training the MMM on. So if for a date period you know the total KPI you expect the channel to drive in a the MMM Market then this method will work. This is our format for geolift test results. You also need to map the time_periods to be index based as the model is date agnostic. So if our experiment lasted 5 days and happened to be at the very start of the model training period, it was for the 2nd channel in our media data array, and we thought over those 5 days the channel drove 1000 KPI.

  observations = [{
        'time_periods': jnp.array([0, 1, 2, 3, 4]),
        'channel': 1,
        'attributed_values': scalers['target'].transform(
           1000
        )
  }]

By passing the observations in transform_kwargs we can avoid changes to the most lightweightmmm functions api. This will need to be placed inside the media_mix_model function, inside models.py This might not work straight off the bat, as we're at this point working with a fairly altered codebase to google/lightweight_mmm, above line 411.

   media_contribution = jnp.einsum(media_einsum, media_transformed, coef_media)
    observations = transform_kwargs.get('observations', None)
    if observations is not None:
      for i, obs in enumerate(observations):
        _ = numpyro.sample(
          name=f'observation_{i}',
          fn=dist.Normal(
            loc=media_contribution[
              obs['time_periods'],
              obs['channel']
            ].sum(axis=1),
            scale=0.2
          ),
          obs=jnp.array(obs['attributed_values'])
        )

Using this method our MMM align with the geolift results and treat them as an additional observation/source of truth. We have seen this significantly reduce our uncertainty for the channel attribution while aligning our MMM with our geolift exps. I believe you will want to remove observations from transform_kwargs at predict time.

A simpler method would be to rescale your prior beliefs for media channel contribution based on geolift test results, but this may prove difficult to implement fairly, when you only have geolift results for some channels.

becksimpson avatar Dec 07 '23 09:12 becksimpson

@becksimpson, thanks for the idea.

the things which is not clear to me is, how the below part of the code is used for model calibration here. Would be great if you can share some details Thank you

media_contribution = jnp.einsum(media_einsum, media_transformed, coef_media)
   observations = transform_kwargs.get('observations', None)
   if observations is not None:
     for i, obs in enumerate(observations):
       _ = numpyro.sample(
         name=f'observation_{i}',
         fn=dist.Normal(
           loc=media_contribution[
             obs['time_periods'],
             obs['channel']
           ].sum(axis=1),
           scale=0.2
         ),
         obs=jnp.array(obs['attributed_values'])
       )

shubham223601 avatar Mar 12 '24 19:03 shubham223601

channel_media_einsum = 'tc, c -> tc'

becksimpson avatar Mar 14 '24 11:03 becksimpson

@becksimpson Thank you for this idea! In this implementation, is the following media_contribution, now used in the prediction calculation on line #411 to replace jnp.einsum(media_einsum, media_transformed, coef_media)? Here's how I'm thinking media_contribution is now used in the prediction calc:

prediction = ( intercept + coef_trend * trend ** expo_trend + seasonality * coef_seasonality + media_contribution)

nutterbrand avatar Jul 27 '24 17:07 nutterbrand