DALEX
DALEX copied to clipboard
Error when examining aggregated SHAP values in R using tidymodels
I have trained and fitted a C5.0 model with adaptive boosting to predict a binary outcome. This was performed using tidymodels and the parsnip package.
I am now trying to evaluate the model in more detail using DALEX and DALEXtra.
The outcome variable is "yes/no" so I have converted these to binary "1/0". I have created an explainer using the following code with the DALEXtra package:
explain_c50 <-
explain_tidymodels(model = c5_final_fit,
data = final_data_test,
y = y_test,
verbose = F)
I have also created an explainer using only the DALEX package, and encounter the exact same issue:
custom_predict <- function(object, newdata) {
pred <-
predict(object, newdata, type = 'prob')[1] %>%
pull(.pred_F)
return(pred)
}
DALEX_explainerTest <- DALEX::explain(model = c5_final_fit,
data = final_data_test,
predict_function = custom_predict,
y = y_test,
label = "c50-train")
These run without error. However, when I try to run a command to interrogate additive SHAP values I get the following error:
DALEX::shap_aggregated(explainer = explain_c50, new_observations = final_data_test[1:10, ])
Error message is always the same, no matter what I try:
Error in `[<-.data.frame`(`*tmp*`, , candidate, value = list(pure_model_prediction = list( :
replacement element 1 is a matrix/data frame of 1 row, need 21
I have tried processing the data slightly differently, tried using the training data (rather than test), etc. The error message is always this way so I assume I am making a fairly fundamental error. I cannot understand why I am getting this error.
Does anyone reading this have any idea?
@hbaniecki I see there is an Inavalid! label added to my post. Have I posted incorrectly?
Hi, thanks for raising the issue. Invalid is meant to denote a bug / something not working as intended.
Oh I understand now. As you might guess, I am an enthusiastic amateur in this world. I am a medical doctor exploring machine learning in Critical Care. Thank you for clarifying!
@hbaniecki According to https://docs.github.com/en/issues/using-labels-and-milestones-to-track-work/managing-labels, "invalid" means that the issue/PR is no longer relevant. Maybe we can replace it by "Bug"?
Labels are described in https://github.com/ModelOriented/DALEX/labels and you can hover over them to read the description. I wouldn't change their names, even if only because it will override labels on all previous issues.
Hi @e05bf027
Thank you for raising the issue. Unfortunately, I'm afraid that without a reproducible example, I won't be able to help. I've tried to create a parsnip C5 model using some dummy dataset, and it worked without any issue both using DALEXtra and the custom function you've provided
library(parsnip)
library(tidymodels)
library(rules)
library(DALEXtra)
data <- iris
data$Species <- as.factor(ifelse(data$Species == "setosa", "yes", "no"))
model <- C5_rules(
trees = 1,
min_n = 1
) |>
set_engine("C5.0") |>
set_mode("classification") |>
fit(Species ~ ., data = data)
explain_c50 <-
explain_tidymodels(model = model,
data = data,
y = as.numeric(as.factor(data$Species)) - 1)
shap <- DALEX::shap_aggregated(explainer = explain_c50, new_observations = data[1:10, ])
custom_predict <- function(object, newdata) {
pred <-
predict(object, newdata, type = 'prob') %>%
pull(.pred_yes)
return(pred)
}
DALEX_explainerTest <- DALEX::explain(model = model,
data = data,
predict_function = custom_predict,
y = as.numeric(as.factor(data$Species)) - 1,
label = "c50-train")
DALEX::shap_aggregated(explainer = explain_c50, new_observations = data[1:10, ])
can you please provide some more details about your issue?