lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

previous_extra_features data in predict function

Open alberto-molinaro opened this issue 2 years ago • 0 comments
trafficstars

Hello,

In the predict function in lightweight-mmm module, there is this code:

if media_gap is not None:
      if media.ndim != media_gap.ndim:
        raise ValueError("Original media data and media gap must have the same "
                         "number of dimensions.")
      if media.ndim > 1 and media.shape[1] != media_gap.shape[1]:
        raise ValueError("Media gap must have the same numer of media channels"
                         "as the original media data.")
      previous_media = jnp.concatenate(arrays=[self.media, media_gap], axis=0)
      if extra_features is not None:
        previous_extra_features = jnp.concatenate(
            arrays=[
                self._extra_features,
                jnp.zeros((media_gap.shape[0], *self._extra_features.shape[1:]))
            ],
            axis=0)

When media_gap is not none, instead of calculate previous_extra_features as a 0 array, is it possible to add it as a function argument like media_gap? Could be something like extra_features_data_gap?

This way, we could simulate diverse scenarios for the extra features variables in predictions.

Thank you in advance!

alberto-molinaro avatar Jun 26 '23 09:06 alberto-molinaro