lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

Open EldadWinter opened this issue 1 year ago • 5 comments

When trying to do this: media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean) cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train) target_train = target_scaler.fit_transform(target_train) costs2 = cost_scaler.fit_transform(costs)

I got the error. up until then everything went as in the tutorial.

EldadWinter avatar Jan 06 '24 12:01 EldadWinter

have you transformed your dataframe or dataset to array, e.g. with np.array(media_data_train)?

entzyeung avatar Jan 07 '24 23:01 entzyeung

I did like this, on the tutorial. it is an array but still doesn't let me use the Scalers:

Screenshot 2024-01-08 at 16 39 27

EldadWinter avatar Jan 08 '24 14:01 EldadWinter

it is an array, and is float32, it should be in good shape already. WHich dataset threw an error? target? media? or others?

entzyeung avatar Jan 08 '24 17:01 entzyeung

As you can see below the target_train caused errors. The Scalers & fit_transofrm of media_data_train and costs worked.

Screenshot 2024-01-08 at 19 12 25

And this is how target_train looks like:

Screenshot 2024-01-08 at 19 15 21

EldadWinter avatar Jan 08 '24 17:01 EldadWinter

Hmmmm.... I know you have checked it already, but would there be any chance that your target_train has nan, null, zero, or different data dtype?

Do you mind to provide your notebook, and data for further investigation? I am not sure if there is any pm function in github.

entzyeung avatar Jan 09 '24 23:01 entzyeung