lightweight_mmm
lightweight_mmm copied to clipboard
Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
When trying to do this: media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
media_data_train = media_scaler.fit_transform(media_data_train) target_train = target_scaler.fit_transform(target_train) costs2 = cost_scaler.fit_transform(costs)
I got the error. up until then everything went as in the tutorial.
have you transformed your dataframe or dataset to array, e.g. with np.array(media_data_train)?
I did like this, on the tutorial. it is an array but still doesn't let me use the Scalers:
it is an array, and is float32, it should be in good shape already. WHich dataset threw an error? target? media? or others?
As you can see below the target_train caused errors. The Scalers & fit_transofrm of media_data_train and costs worked.
And this is how target_train looks like:
Hmmmm.... I know you have checked it already, but would there be any chance that your target_train has nan, null, zero, or different data dtype?
Do you mind to provide your notebook, and data for further investigation? I am not sure if there is any pm function in github.