yocto-gl
yocto-gl copied to clipboard
[BUG] Computing SHAP explanations fails during mlflow.evaluate()
Willingness to contribute
Yes. I would be willing to contribute a fix for this bug with guidance from the MLflow community.
System information
- Running MlFlow on Databricks runtime 10.4 ML LTS:
- MLflow version is '1.24.0':
- Python version is '3.8.10':
Describe the problem
While executing mlflow.evaluate on Databricks runtime 10.4 get an error :
Exception: Additivity check failed in TreeExplainer!
The error is coming from SHAP package, the check_additivity=False should be specified. Can this be updated in options or set to default or should there be an downsampling then?
The related SHAP issue is here: https://github.com/slundberg/shap/issues/941
Tracking information
No response
Code to reproduce issue
model = RandomForestClassifier(**params_rf)
# construct an evaluation dataset from the test set
eval_data = X_test
eval_data["target"] = y_test
with mlflow.start_run(run_name=f'untuned_random_forest_{user_name}'):
model.fit(X_train, y_train)
# predict_proba returns [prob_negative, prob_positive], so slice the output with [:, 1]
predictions_test = model.predict_proba(X_test)[:,1]
auc_score = roc_auc_score(y_test, predictions_test)
# Logging parameters used to train our model
mlflow.log_params(params_rf)
# Use the area under the ROC curve as a metric.
mlflow.log_metric('auc', auc_score)
model_info = mlflow.sklearn.log_model(model, "sklearn_rf_model")
result = mlflow.evaluate(
model_info.model_uri,
eval_data,
targets='target',
model_type="classifier",
dataset_name="wine_dataset",
evaluators=["default"],
)
Other info / logs
Exception: Additivity check failed in TreeExplainer! Please ensure the data matrix you passed to the explainer is the same shape that the model was trained on. If your data shape is correct then please report this on GitHub. This check failed because for one of the samples the sum of the SHAP values was 0.547230, while the model output was 0.527094. If this difference is acceptable you can set check_additivity=False to disable this check.
Partial log:
/databricks/python/lib/python3.8/site-packages/shap/explainers/_tree.py in shap_values(self, X, y, tree_limit, approximate, check_additivity, from_call)
406 out = self._get_shap_output(phi, flat_output)
407 if check_additivity and self.model.model_output == "raw":
--> 408 self.assert_additivity(out, self.model.predict(X))
409
410 return out
/databricks/python/lib/python3.8/site-packages/shap/explainers/_tree.py in assert_additivity(self, phi, model_output)
537 if type(phi) is list:
538 for i in range(len(phi)):
--> 539 check_sum(self.expected_value[i] + phi[i].sum(-1), model_output[:,i])
540 else:
541 check_sum(self.expected_value + phi.sum(-1), model_output)
/databricks/python/lib/python3.8/site-packages/shap/explainers/_tree.py in check_sum(sum_val, model_output)
533 " was %f, while the model output was %f. If this difference is acceptable" \
534 " you can set check_additivity=False to disable this check." % (sum_val[ind], model_output[ind])
--> 535 raise Exception(err_msg)
536
537 if type(phi) is list:
What component(s) does this bug affect?
- [ ]
area/artifacts
: Artifact stores and artifact logging - [ ]
area/build
: Build and test infrastructure for MLflow - [ ]
area/docs
: MLflow documentation pages - [ ]
area/examples
: Example code - [ ]
area/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registry - [ ]
area/models
: MLmodel format, model serialization/deserialization, flavors - [ ]
area/projects
: MLproject format, project running backends - [ ]
area/scoring
: MLflow Model server, model deployment tools, Spark UDFs - [ ]
area/server-infra
: MLflow Tracking server backend - [ ]
area/tracking
: Tracking Service, tracking client APIs, autologging
What interface(s) does this bug affect?
- [ ]
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev server - [ ]
area/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Models - [ ]
area/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registry - [ ]
area/windows
: Windows support
What language(s) does this bug affect?
- [ ]
language/r
: R APIs and clients - [ ]
language/java
: Java APIs and clients - [ ]
language/new
: Proposals for new client languages
What integration(s) does this bug affect?
- [ ]
integrations/azure
: Azure and Azure ML integrations - [ ]
integrations/sagemaker
: SageMaker integrations - [ ]
integrations/databricks
: Databricks integrations
Hi @AnastasiaProkaieva, thank you for raising this! Can you try setting the explainability_nsamples
value in the evaluator_config
argument of mlflow.evaluate()
to a smaller value? If that doesn't work, we can absolutely implement support for disabling the additivity check, though setting this flag to False
doesn't appear to be recommended.
Finally, if you aren't interested in SHAP outputs, you can also disable model explanations by passing the evaluator_config={"log_model_explainability": False}
as an argument to mlflow.evaluate()
. Thank you for using MLflow and Databricks!
Hi @dbczumar, I tried to setexplainability_nsamples
to 12,50,100,500 (instead of default 2000) and the error stays the same. Then it seems that this TreeExplainer is not working anymore without additivity check=False
, I recently had the same issuer with my old Shap script.
Meanwhile, when I am setting the value of explainability_nsamples < num_feautes
- there is another error that is raised
ValueError: The beeswarm plot does not support plotting explanations with instances that have more than one dimension!
Would be also great to maybe add or raise a print that explains this or add to the docs thatexplainability_nsamples
cannot be less then the amount of features in the dataset?
{"log_model_explainability": False}
-works, but I feel the whole goal of this evaluate was also to have this Shap explainer under the hood (at least I've been testing this for this as well).
@AnastasiaProkaieva Thanks for trying this out! I'll go ahead and file a PR to support additivity_check=False
. Apologies for the inconvenience!
@dbczumar This is still being observed. Do you know if this will be fixed anytime soon? I could try to log artifacts seperately related to shap along with mlflow.evaluate having {"log_model_explainability": False}. Would this be the right approach as a quick handle of the problem?
You can also try seting explainability_algorithm
to be kernel
. The default shap tree explainer might be buggy.
This worked @WeichenXu123 . Tried using kernel,permutation and partition.
@khojarohan4 For the shap tree explainer, you can file ticket in shap repo, I think it is shap bug. :)
@WeichenXu123 I will file one