model-diagnostics icon indicating copy to clipboard operation
model-diagnostics copied to clipboard

Add Permutation Importance

Open mayer79 opened this issue 6 months ago • 8 comments

Implements https://github.com/lorentzenchr/model-diagnostics/issues/201

mayer79 avatar May 10 '25 10:05 mayer79

This is the current basic call:

import numpy as np
import polars as pl
from sklearn.linear_model import LinearRegression

from model_diagnostics.xai import plot_permutation_importance

rng = np.random.default_rng(1)
n = 1000

X = pl.DataFrame(
    {
        "area": rng.uniform(30, 120, n),
        "rooms": rng.choice([2.5, 3.5, 4.5], n),
        "age": rng.uniform(0, 100, n),
    }
)

y = X["area"] + 20 * X["rooms"] + rng.normal(0, 1, n)

model = LinearRegression()
model.fit(X, y)

_ = plot_permutation_importance(
    predict_function=model.predict,
    X=X,
    y=y,
)

image

The extended feature API allows to permute groups like this:

_ = plot_permutation_importance(
    predict_function=model.predict,
    features={"size": ["area", "rooms"], "age": "age"},
    X=X,
    y=y,
)

image

mayer79 avatar May 16 '25 11:05 mayer79

This will be a great addition! Thanks @mayer79

lorentzenchr avatar May 19 '25 18:05 lorentzenchr

The failing test is in the python 3.9 env with numpy 1.22.0 polars 1.0.0 scipy 1.10.0 pandas 1.5.3 pyarrow 11.0.0

Could you check if increasing one of the versions fixes the problem, e.g. polars version?

lorentzenchr avatar May 23 '25 20:05 lorentzenchr

The failing test is in the python 3.9 env with numpy 1.22.0 polars 1.0.0 scipy 1.10.0 pandas 1.5.3 pyarrow 11.0.0

Could you check if increasing one of the versions fixes the problem, e.g. polars version?

The following changes in the 3.9 env would be necessary. I don't know how much it would hurt to abandon pandas 1

  • pyarrow 11 -> 13
  • pandas 1.5 -> 2.0

I have added some additional unit tests and moved safe_copy() and get_column_names() to array.py.

mayer79 avatar May 24 '25 13:05 mayer79

fyi, CI will fail due to new versions of polars and numpy. I am working on a fix.

lorentzenchr avatar Jun 25 '25 21:06 lorentzenchr

Fix in #203, you need to sync (e.g. merge) with the main branch (and maybe hatch env prune on your local machine).

lorentzenchr avatar Jun 26 '25 20:06 lorentzenchr

fyi: I am preparing to bump the minimum versions of python to 3.11 and numpy to 2. This implies polars 1.1.0, pandas >= 2.2.2 and pyarrow >= 16, see #206.

lorentzenchr avatar Jul 17 '25 14:07 lorentzenchr

I have modified these aspects in the main functionality:

  • compute_permutation_importance() now returns both score differences and score ratios
  • Instead of standard deviations, the function returns standard errors
  • The plot function has received an argument which="difference" to select if score differences or ratios are to be plotted.
  • By default, the plot function shows approximate 95% CIs. The API as in your other functions, i.e., when confidence_level=0, no error bars are plotted.

mayer79 avatar Jul 26 '25 18:07 mayer79