lightweight_mmm
lightweight_mmm copied to clipboard
previous_extra_features data in predict function
trafficstars
Hello,
In the predict function in lightweight-mmm module, there is this code:
if media_gap is not None:
if media.ndim != media_gap.ndim:
raise ValueError("Original media data and media gap must have the same "
"number of dimensions.")
if media.ndim > 1 and media.shape[1] != media_gap.shape[1]:
raise ValueError("Media gap must have the same numer of media channels"
"as the original media data.")
previous_media = jnp.concatenate(arrays=[self.media, media_gap], axis=0)
if extra_features is not None:
previous_extra_features = jnp.concatenate(
arrays=[
self._extra_features,
jnp.zeros((media_gap.shape[0], *self._extra_features.shape[1:]))
],
axis=0)
When media_gap is not none, instead of calculate previous_extra_features as a 0 array, is it possible to add it as a function argument like media_gap? Could be something like extra_features_data_gap?
This way, we could simulate diverse scenarios for the extra features variables in predictions.
Thank you in advance!