dask-ml icon indicating copy to clipboard operation
dask-ml copied to clipboard

ColumnTransformer _hstack incompatible with scikit's version

Open avalanche-pwn opened this issue 7 months ago • 1 comments

Describe the issue: The current dask_ml's transformer's _hstack method has different signature than the method from scikit - it lacks the n_samples argument.

Minimal Complete Verifiable Example:

from dask_ml.wrappers import Incremental
from dask_ml.feature_extraction.text import HashingVectorizer
import dask.dataframe as dd
import pandas as pd
from dask_ml.compose import ColumnTransformer

data = {
    "test1": ["example", "text"],
    "test2": ["lorem", "ipsum"]
}

df = pd.DataFrame(data)
df = dd.from_pandas(df).astype(str)

pipeline = ColumnTransformer([
    ("test1", HashingVectorizer(), "test1"),
    ("test2", HashingVectorizer(), "test2"),
    ])

pipeline.fit(df)

Anything else we need to know?: This causes a crash:

Traceback (most recent call last):
  File "/home/antoni/Documents/projects/dask/reproducers/1/main.py", line 20, in <module>
    pipeline.fit(df)
    ~~~~~~~~~~~~^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 947, in fit
    self.fit_transform(X, y=y, **params)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/utils/_set_output.py", line 319, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/base.py", line 1389, in wrapper
    return fit_method(estimator, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 1031, in fit_transform
    return self._hstack(list(Xs), n_samples=n_samples)
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ColumnTransformer._hstack() got an unexpected keyword argument 'n_samples'

The fix seems simple enough it would be just adding a check similar to the one in scikit's version before returning. I can implement this just please let me know if this kind of fix seems like enough.

Environment:

  • Dask version: 2025.5.0
  • Dask-ml version: 2025.1.0
  • scikit-learn: 1.6.1
  • Python version: 3.13.3
  • Operating System: Linux
  • Install method (conda, pip, source): pip

avalanche-pwn avatar May 19 '25 14:05 avalanche-pwn

We could add that keyword. We'll need to figure out how to interpret that though, since for dask dataframe and some array inputs it'll not be available.

TomAugspurger avatar May 25 '25 13:05 TomAugspurger