lightweight_mmm
lightweight_mmm copied to clipboard
Adstock Normalization
(Correct me if I am wrong..)
Currently adstock transformation normalization is performed in media_transforms/adstock via:
lambda adstock_values: adstock_values / (1. / (1 - lag_weight)
However I believe this is the infinite power series sum 1 / (1 - lw) where lw is the lag_weight, I noticed in my transformations this creates a window effect towards the start of the data, where adstock transformed values are divided by very high numbers, despite having few contributors, especially in cases of high carryover effect.
This issue can be resolved by instead using the partial infinite power series sum: ( 1 - lw**n) / (1 - lw), where n is the (1 based) index of current datapoint. index 1 --> 1, index 2 --> 1 + lw etc
For my local purposes, I implemented using the equivocal: (1 / (1 - lw)) / (1 / (1 - lw**n)) but obviously can be done cleaner.
(1. / (1 - lag_weight.reshape(1, -1))) / (1. / (1 - lag_weight.reshape(1, -1) ** (1 + jnp.arange(0, data.shape[0])).reshape(-1, 1)))