scikit-learn icon indicating copy to clipboard operation
scikit-learn copied to clipboard

ENH Introduces set_output API for pandas output

Open thomasjpfan opened this issue 3 years ago • 3 comments

Reference Issues/PRs

Closes https://github.com/scikit-learn/scikit-learn/issues/23001 Implements SLEP018: https://github.com/scikit-learn/enhancement_proposals/pull/68

What does this implement/fix? Explain your changes.

This PR introduces:

  1. set_output for Pipeline and transformers in preprocessing.
  2. Common test to check the behavior of set_output, which will be used for follow up PRs that adds set_output to all other transformers.
  3. OutputTypeMixin, where most transformers only need to subclass it to get the set_output API.
  4. Global configuration option set_config(output_transform="pandas") to set the output globally.

thomasjpfan avatar Jun 23 '22 00:06 thomasjpfan

Ok so the problem with inheriting from TransformerMixin is that it's hard to overwrite the behavior done in the mixin. With the implementation in this PR, transformers can decide not to inherit from SetOutputMixin and implement transform and fit_transform themselves to support set_output.

If we add SetOutputMixin to TransformerMixin everybody inheriting from TransformerMixin will automatically get SetOutputMixin but it's harder to customize. There would be two ways to allow customization, either via making _wrap_method_output a method that someone could overwrite, or by adding a **kwargs to __init_subclass__ which would allow someone to disable the automatic wrapping behavior.

Let the bikeshedding begin!

amueller avatar Jul 17 '22 19:07 amueller

This PR was updated:

  1. TransformerMixin now inherits from SetOutputMixin, which means many transformers will get set_output automatically, which is both a pro and a con. I am open to reverting this inheritance.
  2. set_output is only available only when get_feature_names_out is defined.
  3. SetOutputMixin includes an auto_wrap_output to disable the auto wrapping behavior.
  4. ColumnTransformer now outputs DataFrames with mixed types. I had to pull in the fix from #23993 for FunctionTransformer to get passthrough to always return a DataFrame.

thomasjpfan avatar Aug 20 '22 19:08 thomasjpfan

I did not look yet at some of the tests (apart from the common tests). I will have a go on them. I think that we miss documentation (for our users and developer).

Basically, the points that could raise some discussions are:

  • we are mutating the passed estimators in the pipeline and column transformer. Are we fine with this?
  • for the config, we have a flat dict. It could make it easier to work with nested dict?
  • in the config we store 2 keys: one transform and one for fit_transform. I would merge both stating that 1 key for *transform is more than enough. fit will always return self. The same would apply for fit_predict and predict. I don't see the use case that one wants a dataframe at transform but not at fit_transform.

glemaitre avatar Sep 12 '22 15:09 glemaitre

This PR was updated with:

  1. Moved avaliable_if to it's own file to prevent circular dependency
  2. SetOutputMixin is now private. Classes that inherit from TransformerMixin will get the feature right away.
  3. Add documentation for users and developers
  4. Moved all the utility functions as private.

Things I need to think about:

  1. FunctionTransformer that already returns a DataFrame: https://github.com/scikit-learn/scikit-learn/pull/23734#discussion_r978181639
  2. Using sklearn.experimental to be a feature flag. I can get this to work for set_output which is a user facing API. I'm deciding if we need to mark set_output as experimental at all. Currently, the API is fairly simple. The only change I see to it is adding predict="pandas" or enabling sparse support. Neither of these are backward breaking.

thomasjpfan avatar Sep 26 '22 19:09 thomasjpfan

How difficult/easy would it be, to extend the set_output options to, e.g., "arrow".

I see two possible paths: Support it directly in scikit-learn or an API to configure "arrow".

  1. Add another function _wrap_in_arrow_container and dispatch to it in _wrap_data_with_container when set_output(transform="arrow").
  2. Allow set_output(transform=callable), where a users defines a callable that constructs an arrow container. This will likely have the same signature of _wrap_in_pandas_container: (data, columns, index).

What happens with 3rd party transformers in a pipeline if global set_config(transform_output="pandas") has been set?

If they inherit from TransformerMixin, then they will auto opt-into set_output and transform will output DataFrames. They can opt-out by setting auto_wrap_output_keys=None.

If set_config(transform_output="pandas") has been set, and a model is saved via pickle and later loaded in a different python process, will it ouput pandas?

If it's a global set_config, then the global option needs to be set again to output pandas. If it's the local option, i.e. transform.set_output, then transform outputs pandas.

In the default setting, what happens with other Array API compliant data containers than numpy, e.g. cupy, see #22352?

The default setting will keep the behavior on main. The transformer decides what it wants to output, ndarray, cupy, dataframes, etc.

thomasjpfan avatar Oct 07 '22 22:10 thomasjpfan

What happens with 3rd party transformers in a pipeline if global set_config(transform_output="pandas") has been set?

If they inherit from TransformerMixin, then they will auto opt-into set_output and transform will output DataFrames. They can opt-out by setting auto_wrap_output_keys=None.

What if they do not inherit from TransformerMixin? Or stated otherwise: Are we clear enough on the expected API for a 3rd party transformer (that - for some reasons - does not want to depend on scikit-learn)?

lorentzenchr avatar Oct 09 '22 14:10 lorentzenchr

@thomasjpfan The only open point for me is the name of the default option of transform_output, see https://github.com/scikit-learn/scikit-learn/pull/23734#discussion_r985071192. Your suggestion to use None instead of "default" is good for me. As this would be a deviation from the SLEP, other opinions are highly appreciated.

Once that is settled, I will give my review approval.

lorentzenchr avatar Oct 09 '22 14:10 lorentzenchr

What if they do not inherit from TransformerMixin? Or stated otherwise: Are we clear enough on the expected API for a 3rd party transformer (that - for some reasons - does not want to depend on scikit-learn)?

For now, I prefer not to require the API from third party estimators. Our meta-estimators such as ColumnTransformer will complain when transformers set_output. As for the global option, I do not think we can require third parties to respect it.

Adopting the set_output API fully, requires at least a soft dependency on scikit-learn and pandas:

  • scikit-learn: To follow scikit-learn's global config, one needs to get the global config from scikit-learn.
  • pandas: Outputting DataFrames requires it.

thomasjpfan avatar Oct 09 '22 19:10 thomasjpfan

Looking through the code, we can not use None instead of "default". Currently, the signature is: set_output(self, transform=None), where None is a sentinel to mean do "not configure transform". This is required if we have set_output(predict="pandas"), which configures the container for predict but leaves transform alone. Secondly, est.set_output() with no input configures nothing. I can use another sentinel object, but that complicates the set_output API.

I think the best option is to find a string that best describes the behavior. I think "native" is okay. Another option is "any"

@lorentzenchr WDYT?

thomasjpfan avatar Oct 11 '22 21:10 thomasjpfan

Yes "native" is fine, even good! I'm happy with any other default value than "default", which I would consider bad design because it does not tell anything about the actual behaviour.

lorentzenchr avatar Oct 11 '22 21:10 lorentzenchr

I updated this PR to use "native". I also opened a PR to update SLEP: https://github.com/scikit-learn/enhancement_proposals/pull/78

thomasjpfan avatar Oct 11 '22 22:10 thomasjpfan

I merge. In case that #78 concludes to still change the default value, we can do that in a new PR.

lorentzenchr avatar Oct 12 '22 22:10 lorentzenchr

Yaaaaay!!

amueller avatar Oct 14 '22 17:10 amueller