BUG: np.where cast Int64 to object
Pandas version checks
-
[X] I have checked that this issue has not already been reported.
-
[X] I have confirmed this bug exists on the latest version of pandas.
-
[X] I have confirmed this bug exists on the main branch of pandas.
Reproducible Example
import pandas as pd
import numpy as np
df = pd.DataFrame({
'A' : pd.array([1,2,3,4], dtype='Int64'),
'B' : pd.array([2,1,7,8], dtype='Int64'),
})
df['C'] = np.where(df.A > df.B, df.A, df.B)
df.dtypes
# A Int64
# B Int64
# C object
# dtype: object
Issue Description
the dtype of the C column is object
Expected Behavior
I would expect the dtype to be Int64 since A and B are Int64 array:
A Int64 B Int64 C Int64 dtype: object
I understand that this issue is more related to how the numpy where function works. However Int64 is a pandas object so I don't know if numpy can fix this issue.
So is there a way to apply a "np.where" logic with pandas object ?
Thanks,
Installed Versions
INSTALLED VERSIONS
commit : 06d230151e6f18fdb8139d09abf539867a8cd481 python : 3.8.8.final.0 python-bits : 64 OS : Linux OS-release : 3.10.0-862.el7.x86_64 Version : #1 SMP Wed Mar 21 18:14:51 EDT 2018 machine : x86_64 processor : x86_64 byteorder : little LC_ALL : None LANG : None LOCALE : None.None
pandas : 1.4.1 numpy : 1.22.2 pytz : 2021.3 dateutil : 2.8.2 pip : 22.0.3 setuptools : 59.8.0 Cython : 0.29.28 pytest : 7.0.1 hypothesis : None sphinx : None blosc : None feather : None xlsxwriter : 3.0.2 lxml.etree : 4.8.0 html5lib : None pymysql : None psycopg2 : 2.9.2 jinja2 : 2.11.3 IPython : 8.0.1 pandas_datareader: None bs4 : None bottleneck : 1.3.2 fastparquet : None fsspec : 2022.01.0 gcsfs : None matplotlib : 3.5.1 numba : 0.53.1 numexpr : 2.8.0 odfpy : None openpyxl : 3.0.9 pandas_gbq : None pyarrow : 7.0.0 pyreadstat : None pyxlsb : None s3fs : None scipy : 1.8.0 sqlalchemy : 1.4.31 tables : None tabulate : 0.8.9 xarray : None xlrd : None xlwt : None zstandard : None
Numpy does not recognise Int64, and as such converts it to object. I would suggest using pandas where or mask instead. You could also have a look at case_when from pyjanitor which preserves pandas dtypes and does something similar to np.where/np.select.
Thanks, I didn't know about pyjanitor. The where function do the job for many use cases as well
reopening as I think this could be fixed on the pandas side, and without further investigation, suspect that it is an bug in pandas with the implementation of the numpy array protocol for extension arrays since np.where(df.A._values > df.B._values, df.A._values, df.B._values) also returns a numpy array with object dtype and not an nullable integer array.
I think it is reasonable to expect np.where(df.A > df.B, df.A, df.B) to return the same as df.A.where(df.A > df.B, df.B)
take
IIRC np.where uses NEP 18, which means to get this working would require implementing array_function. Ive tried that a couple times with little luck. PR would be welcome.