fix: `mps` device support in `entropy`
Reference Issues/PRs
As mentioned in #29300, we have some tests regarding the Array API that are failing in main.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-cupy-None-None] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-cupy.array_api-None-None] - TypeError: bool is only allowed on arrays with 0 dimensions
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-cuda-float64] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-cuda-float32] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-cupy-None-None] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-cupy.array_api-None-None] - TypeError: bool is only allowed on arrays with 0 dimensions
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-cuda-float64] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-cuda-float32] - ValueError: unrecognized csr_matrix constructor usage
and
FAILED sklearn/metrics/cluster/tests/test_supervised.py::test_entropy_array_api[torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-mps-float32] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-mps-float32] - ValueError: unrecognized csr_matrix constructor usage
What does this implement/fix? Explain your changes.
The first commit fixes the issue with mps.
I believe that the others were caused by #29269. If I remember correctly, the Array API does not support sparse matrices, and therefore should not work with those metrics in the multilabel case. It seems that the code introduced in that PR only works for numpy and torch in cpu, or am I missing something?
Should we revert this change?
cc @ogrisel
✔️ Linting Passed
All linting checks passed. Your pull request is in excellent shape! ☀️
Thanks for the fix. I see 3 possible solutions for problem with sparse matrices.
- Don't allow Array API for multilabel metrics.
- If we use
array_api_dispatch, convert sparse matrix to array here https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L217 (If I am not mistaken, exception raises because_averagegets a sparse matrix which is create here) - If use
array_api_dispatchdo not convert array to sparse matrix here https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L132
Also, I will try to check each PR in colab more carefully.
Yes, you are right! I think it could be nice to support multilabel. And given that the sparse matrices are created internally and not by the user, I'm more for option 3. But we decide for option 1, I think we should have some meaningful error message.
I've implemented something in the direction of option 3.
I will work on mypy later today.
Not sure if I should make 2 PRs, since the fixes for entropy and for accuracy are unrelated. 🤔
Not sure if I should make 2 PRs, since the fixes for entropy and for accuracy are unrelated. 🤔
That would be great thanks. That would ease the review in case one of the changes is controversial.
Created the 2nd PR #29336 for accuracy and zero_loss :)