model-diagnostics
model-diagnostics copied to clipboard
Add Permutation Importance
Implements https://github.com/lorentzenchr/model-diagnostics/issues/201
This is the current basic call:
import numpy as np
import polars as pl
from sklearn.linear_model import LinearRegression
from model_diagnostics.xai import plot_permutation_importance
rng = np.random.default_rng(1)
n = 1000
X = pl.DataFrame(
{
"area": rng.uniform(30, 120, n),
"rooms": rng.choice([2.5, 3.5, 4.5], n),
"age": rng.uniform(0, 100, n),
}
)
y = X["area"] + 20 * X["rooms"] + rng.normal(0, 1, n)
model = LinearRegression()
model.fit(X, y)
_ = plot_permutation_importance(
predict_function=model.predict,
X=X,
y=y,
)
The extended feature API allows to permute groups like this:
_ = plot_permutation_importance(
predict_function=model.predict,
features={"size": ["area", "rooms"], "age": "age"},
X=X,
y=y,
)
This will be a great addition! Thanks @mayer79
The failing test is in the python 3.9 env with numpy 1.22.0 polars 1.0.0 scipy 1.10.0 pandas 1.5.3 pyarrow 11.0.0
Could you check if increasing one of the versions fixes the problem, e.g. polars version?
The failing test is in the python 3.9 env with numpy 1.22.0 polars 1.0.0 scipy 1.10.0 pandas 1.5.3 pyarrow 11.0.0
Could you check if increasing one of the versions fixes the problem, e.g. polars version?
The following changes in the 3.9 env would be necessary. I don't know how much it would hurt to abandon pandas 1
- pyarrow 11 -> 13
- pandas 1.5 -> 2.0
I have added some additional unit tests and moved safe_copy() and get_column_names() to array.py.
fyi, CI will fail due to new versions of polars and numpy. I am working on a fix.
Fix in #203, you need to sync (e.g. merge) with the main branch (and maybe hatch env prune on your local machine).
fyi: I am preparing to bump the minimum versions of python to 3.11 and numpy to 2. This implies polars 1.1.0, pandas >= 2.2.2 and pyarrow >= 16, see #206.
I have modified these aspects in the main functionality:
compute_permutation_importance()now returns both score differences and score ratios- Instead of standard deviations, the function returns standard errors
- The plot function has received an argument
which="difference"to select if score differences or ratios are to be plotted. - By default, the plot function shows approximate 95% CIs. The API as in your other functions, i.e., when
confidence_level=0, no error bars are plotted.