lightweight_mmm
lightweight_mmm copied to clipboard
TypeError: add got incompatible shapes for broadcasting: (58,), (54,).
TypeError Traceback (most recent call last)
17 frames /usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in predict(self, media, extra_features, media_gap, target_scaler, seed) 518 if seed is None: 519 seed = utils.get_time_seed() --> 520 prediction = self._predict( 521 rng_key=jax.random.PRNGKey(seed=seed), 522 media_data=full_media,
[... skipping hidden 12 frame]
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in _predict(self, rng_key, media_data, extra_features, media_prior, degrees_seasonality, frequency, transform_function, weekday_seasonality, model, posterior_samples, custom_priors) 441 The predictions for the given data. 442 """ --> 443 return infer.Predictive( 444 model=model, posterior_samples=posterior_samples)( 445 rng_key=rng_key,
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in call(self, rng_key, *args, **kwargs) 1009 """ 1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None: -> 1011 return self._call_with_params(rng_key, self.params, args, kwargs) 1012 elif self.batch_ndims == 1: # batch over parameters 1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _call_with_params(self, rng_key, params, args, kwargs) 986 ) 987 model = substitute(self.model, self.params) --> 988 return _predictive( 989 rng_key, 990 model,
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs) 823 rng_key = rng_key.reshape(batch_shape + key_shape) 824 chunk_size = num_samples if parallel else 1 --> 825 return soft_vmap( 826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size 827 )
/usr/local/lib/python3.10/dist-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size) 417 fn = vmap(fn) 418 --> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) 420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1) 421 ys = tree_map(
[... skipping hidden 12 frame]
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in single_prediction(val) 796 ) 797 else: --> 798 model_trace = trace( 799 seed(substitute(masked_model, samples), rng_key) 800 ).get_trace(*model_args, **model_kwargs)
/usr/local/lib/python3.10/dist-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: OrderedDict containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs) 106 107
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features) 410 # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5]. 411 prediction = ( --> 412 intercept + coef_trend * trend ** expo_trend + 413 seasonality * coef_seasonality + 414 jnp.einsum(media_einsum, media_transformed, coef_media))
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in op(self, *args) 741 def forward_operator_to_aval(name): 742 def op(self, *args): --> 743 return getattr(self.aval, f"{name}")(self, *args) 744 return op 745
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in deferring_binary_op(self, other) 269 args = (other, self) if swap else (self, other) 270 if isinstance(other, _accepted_binop_types): --> 271 return binary_op(*args) 272 # Note: don't use isinstance here, because we don't want to raise for 273 # subclasses, e.g. NamedTuple objects that may override operators.
[... skipping hidden 12 frame]
/usr/local/lib/python3.10/dist-packages/jax/src/numpy/ufuncs.py in fn(x1, x2) 97 def fn(x1, x2, /): 98 x1, x2 = promote_args(numpy_fn.name, x1, x2) ---> 99 return lax_fn(x1, x2) if x1.dtype != np.bool else bool_lax_fn(x1, x2) 100 fn.qualname = f"jax.numpy.{numpy_fn.name}" 101 fn = jit(fn, inline=True)
[... skipping hidden 7 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py in broadcasting_shape_rule(name, *avals) 1597 result_shape.append(non_1s[0]) 1598 else: -> 1599 raise TypeError(f'{name} got incompatible shapes for broadcasting: ' 1600 f'{", ".join(map(str, map(tuple, shapes)))}.') 1601
TypeError: add got incompatible shapes for broadcasting: (58,), (54,).
Even i am getting the issue, looking for the solution for it
Installing an older version of numpyro resolved my issue !pip numpyro == 0.13.2
I had the same problem and 0.13.2 version of numpyro was not working for me so I used the following command to install numpyro while installing mmm, matplotlib etc:
!pip install numpyro==0.13.1
I am also facing the same problem. Appreciate if anyone has solution for this. Thanks
just install an older version of numpyro as stated in the comments above
When i install an older version of numpyro, I have following issues with import . Any idea how to solve this?
ModuleNotFoundError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_12544\2990551649.py in
~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\preprocessing.py in
~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\core\core_utils.py in
~\Anaconda3\envs\python3\lib\site-packages\numpyro_init_.py in
~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer_init_.py in
~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer\elbo.py in
~\Anaconda3\envs\python3\lib\site-packages\numpyro\ops\provenance.py in
ModuleNotFoundError: No module named 'jax.extend.linear_util'
install an older version of jax. 'jax.extend.linear_util' was removed in jax after 0.4.23 (currently in 0.4.25)
Sorry for the breakage. Could you try
pip install --upgrade git+https://github.com/google/lightweight_mmm.git