narwhals icon indicating copy to clipboard operation
narwhals copied to clipboard

Research: how would Narwhals work in scikit-learn?

Open MarcoGorelli opened this issue 1 year ago • 23 comments

Scikit-learn mentioned Narwhals here https://github.com/scikit-learn/scikit-learn/pull/28513#issuecomment-2131226993

Regardless of whether they decide to use it or not, it would be beneficial to check whether Narwhals would at least work in scikit-learn, and whether it could in principle solve the linked issue

@EdAbati - you've contributed to scikit-learn, and you know Narwhals well - perhaps this issue might be of interest to you?

MarcoGorelli avatar Jun 29 '24 21:06 MarcoGorelli

Uuuh this can be fun! I contributed to features related to the Array API compatibility lately. But I'll gladly also try to figure out this part of scikit-learn

it would be beneficial to check whether Narwhals would at least work in scikit-learn, and whether it could in principle solve the linked issue

Agreed, let's start having a look at these:

  • [x] sklearn.inspection.permutation_importance
  • [x] sklearn.model_selection. cross_val_score
  • [x] sklearn.model_selection. cross_validate

~If anyone wants to help, please feel free to comment.~ This is done

EdAbati avatar Jul 01 '24 18:07 EdAbati

Some early thoughts:

  • what could be the cleanest way to select by row index? the pandas equivalent would be X.iloc[shuffling_idx, col_idx], I guess polars also works with X[shuffling_idx, col_idx], right? what do you think, is it a missing feature?
  • what about .copy()/.clone()? FYI I am looking at this

EdAbati avatar Jul 01 '24 20:07 EdAbati

what do you think, is it a missing feature?

I'd say so, yes! I think we need this one in Altair too. Does implementing it interest you? I think it just requires an extra branch DataFrame__getitem__

what about .copy()/.clone()?

Sure, doesn't hurt to add DataFrame.clone 👍

MarcoGorelli avatar Jul 01 '24 20:07 MarcoGorelli

Does implementing it interest you?

Most of the time my answer to the question is "yes". My only problem is time 😅 I will create an issue in case someone else is able to pick it up before me

EdAbati avatar Jul 02 '24 05:07 EdAbati

😄 I'll take the getitem one on then so we can propose it to Altair too

MarcoGorelli avatar Jul 02 '24 06:07 MarcoGorelli

Another useful feature could be something like polars' DataFrame.replace_column. Or do we have already the capability to overwrite a column without changing its position?

EdAbati avatar Jul 10 '24 21:07 EdAbati

Hey - I think this can be worked around with select? Or, possibly better, with_columns followed by rename?

I'm not totally sure about including inplace operations in Narwhals, I fear they'll bite us later..

MarcoGorelli avatar Jul 10 '24 22:07 MarcoGorelli

My daily struggle was again when the column name is an int 😅

If I'm not wrong, with_columns wants a string column name. rename seems to be happy with renaming to int. I'll clean up my messy code tonight and report back

EdAbati avatar Jul 11 '24 06:07 EdAbati

Does it work to do df.columns.index(name) to get the (int) position and then use that?

MarcoGorelli avatar Jul 11 '24 07:07 MarcoGorelli

ah right, I should have actually tried running it instead of commenting from my phone..

  File "/home/marcogorelli/scratch/.venv/lib/python3.11/site-packages/narwhals/_pandas_like/datafram
e.py", line 258, in with_columns
    df = self._native_dataframe.assign(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: keywords must be strings

MarcoGorelli avatar Jul 11 '24 08:07 MarcoGorelli

This seems to work even for non-string column names?

import narwhals as nw
import pandas as pd


@nw.narwhalify
def replace_column(df: nw.DataFrame, index: int, column: nw.Series):
    col_name = columns[index]
    out_columns = [col if col != col_name else column.alias(col_name) for col in df.columns]
    return df.select(out_columns).rename({col_name: column.name})


data = {
    0: [1, 2, 3],
    "bar": [6, 7, 8],
    "ham": ["a", "b", "c"],
}
print(replace_column(pd.DataFrame(data), 0, pd.Series([4,3,2], name='baz')))

MarcoGorelli avatar Jul 11 '24 09:07 MarcoGorelli

FYI my changes are here: https://github.com/scikit-learn/scikit-learn/compare/main...EdAbati:scikit-learn:RESEARCH-narwhals

Now the 3 functions linked above seem to work.

Need to check if they were solved in main too and if I broke any tests :) I'll keep you posted

EdAbati avatar Jul 22 '24 21:07 EdAbati

Update: cross_val_score and cross_validate were fixed in main. The example with LogisticRegression mentioned in the original ticket works. Some test that use RandomForest fail, I am looking at those now.

The next step I'd say would be to find all places where polars and pandas are mentioned/used in sklearn and see if those lines can be "narwhalified". 👀

EdAbati avatar Jul 23 '24 06:07 EdAbati

wow, amazing! 💪 well done

there's this example of theirs which uses Polars, would be good to check if it works (and also if using pyarrow instead just works "for free"?)

MarcoGorelli avatar Jul 23 '24 08:07 MarcoGorelli

Hi all, I'm working on moving some internal modules over to narwhals that depend on scikit-learn and am running into an issue that I feel falls under this discussion. It seems like the feature names do not get passed over if you fit with a narwhals dataframe:

import narwhals as nw
import pandas as pd
import polars as pl
from sklearn.preprocessing import StandardScaler

df_pd = pd.DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]})
df_pl = pl.DataFrame(df_pd)
df_nw = nw.from_native(df_pd)

s_pd, s_pl, s_nw = StandardScaler(), StandardScaler(), StandardScaler()
s_pd.fit(df_pd)
s_pl.fit(df_pl)
s_nw.fit(df_nw)

print(s_pd.feature_names_in_)
print(s_pl.feature_names_in_)
print(s_nw.feature_names_in_)

Expected output

['a' 'b']
['a' 'b']
['a' 'b']

Actual output

['a' 'b']
['a' 'b']
AttributeError: 'StandardScaler' object has no attribute 'feature_names_in_'

Everything else is working as expected except for the column names not passing through, and it persists when converting from both native types (pandas & polars). My apologies if this is not the right place to post this, I'd just like to hear @EdAbati's thoughts on this since it seems like he's well versed in both libraries.

ryansheabla avatar Mar 18 '25 17:03 ryansheabla

You may want to call to_native before passing to scikit-learn

If you'd like sklearn to support passing narwhals dataframe directly, could you open a feature request with scikit learn please?

MarcoGorelli avatar Mar 18 '25 17:03 MarcoGorelli

Yes, calling to_native does cause it to work as expected.

I'll put up an issue there as well, I was unsure if the problem was on the narwhals or scikit-learn end since everything works as expected except for the feature names. Thanks for the quick response!

ryansheabla avatar Mar 18 '25 17:03 ryansheabla

Adoption of narwhals in scikit-learn is now properly discussed in https://github.com/scikit-learn/scikit-learn/issues/31049.

I also opened a draft PR https://github.com/scikit-learn/scikit-learn/pull/31127 to experiment. Unfortunately, it reveals problems: I wanted to use narwhals in scikit-learn's internal _safe_indexing(X, indices, *, axis=0). For row-indexing (axis=0), we allow indices to be boolean and integer array-like, integer slice, and scalar integer. Narwhals does not allow for array-like of integers like a pandas series or tuples (but it works for numpy arrays and lists).

Example:

import narwhals as nw
import numpy as np
import pandas as pd

df = pd.DataFrame({"col1": ["a", "b", "c", "d"], "col2": np.arange(4)})
df_nw = nw.from_native(df)

# This works
df_nw[list(range(2))]
df_nw[np.arange(2)]

# This errors
df_nw[pd.Series(np.arange(2))]
df_nw[tuple(range(2))]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[17], line 1
----> 1 df_nw[pd.Series(np.arange(2))]

File python3.12/site-packages/narwhals/dataframe.py:916, in DataFrame.__getitem__(self, item)
    914 else:
    915     msg = f"Expected str or slice, got: {type(item)}"
--> 916     raise TypeError(msg)

TypeError: Expected str or slice, got: <class 'pandas.core.series.Series'>

lorentzenchr avatar Apr 13 '25 15:04 lorentzenchr

thanks @lorentzenchr for the report

~~i think we should allow that, it's a perfectly well-defined operation~~

MarcoGorelli avatar Apr 13 '25 15:04 MarcoGorelli

i think we should allow that, it's a perfectly well-defined operation

Sorry, I spoke too soon

I think it's correct that we disallow df_nw[tuple(range(2))], because that's indistinguishable from df_nw[0, 1]:

In [37]: class Foo:
    ...:     def __getitem__(self, item):
    ...:         print(item)
    ...:

In [38]: Foo()[0, 0]
(0, 0)

In [39]: Foo()[(0, 0)]
(0, 0)

And Polars in fact doesn't slice the dataframe with that operation, it returns a single element, equivalent to calling DataFrame.item

In [40]: dfpl = pl.DataFrame({'a':[1,2,3],'b':[4,5,6]})

In [41]: dfpl[tuple(range(2))]
Out[41]: 4

In [42]: dfpl.item(0, 1)
Out[42]: 4

So, we can't accept arbitrary array-like objects in DataFrame.__getitem__, just like how Polars doesn't

@lorentzenchr given that _safe_indexing is an internal-only function, I think you could call np.asarray on the input indices? pandas Series implement __array__ so it would work if indices is a pandas Series too

MarcoGorelli avatar Apr 13 '25 15:04 MarcoGorelli

What you can do, however, to pass a tuple is to use the ellipsis as the second argument. This works in both Narwhals and Polars:

In [55]: df_nw[tuple(range(2)), :]
Out[55]:
┌──────────────────┐
|Narwhals DataFrame|
|------------------|
|     col1  col2   |
|   0    a     0   |
|   1    b     1   |
└──────────────────┘

In [56]: df_nw.to_polars()[tuple(range(2)), :]
Out[56]:
shape: (2, 2)
┌──────┬──────┐
│ col1 ┆ col2 │
│ ---  ┆ ---  │
│ str  ┆ i64  │
╞══════╪══════╡
│ a    ┆ 0    │
│ b    ┆ 1    │
└──────┴──────┘

So long as this form of __getitem__ is used, I think we should be able to relax the input to accept array-like objects

MarcoGorelli avatar Apr 13 '25 16:04 MarcoGorelli

Not supporting tuples is fine. But

df_nw[:, [0, 1]]  # works
df_nw[:, np.array([0, 1])]  # raises error

For polars, both works fine, pandas also errors (in scikit-learn, this is circumvented by calling df.take(index, axis=1) instead).

lorentzenchr avatar Apr 13 '25 19:04 lorentzenchr

thanks for reporting - I think DataFrame.__getitem__ is a due a rewrite. This won't won't make it into tomorrow's release, but i'll try to get something ready for the following one

MarcoGorelli avatar Apr 13 '25 19:04 MarcoGorelli