scikit-learn icon indicating copy to clipboard operation
scikit-learn copied to clipboard

fix: `mps` device support in `entropy`

Open EdAbati opened this issue 1 year ago • 6 comments

Reference Issues/PRs

As mentioned in #29300, we have some tests regarding the Array API that are failing in main.

FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-cupy-None-None] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-cupy.array_api-None-None] - TypeError: bool is only allowed on arrays with 0 dimensions
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-cuda-float64] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-cuda-float32] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-cupy-None-None] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-cupy.array_api-None-None] - TypeError: bool is only allowed on arrays with 0 dimensions
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-cuda-float64] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-cuda-float32] - ValueError: unrecognized csr_matrix constructor usage

and

FAILED sklearn/metrics/cluster/tests/test_supervised.py::test_entropy_array_api[torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multilabel_classification_metric-torch-mps-float32] - ValueError: unrecognized csr_matrix constructor usage
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multilabel_classification_metric-torch-mps-float32] - ValueError: unrecognized csr_matrix constructor usage

What does this implement/fix? Explain your changes.

The first commit fixes the issue with mps.

I believe that the others were caused by #29269. If I remember correctly, the Array API does not support sparse matrices, and therefore should not work with those metrics in the multilabel case. It seems that the code introduced in that PR only works for numpy and torch in cpu, or am I missing something?

Should we revert this change?

cc @ogrisel

EdAbati avatar Jun 20 '24 22:06 EdAbati

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 31b0477. Link to the linter CI: here

github-actions[bot] avatar Jun 20 '24 22:06 github-actions[bot]

Thanks for the fix. I see 3 possible solutions for problem with sparse matrices.

  1. Don't allow Array API for multilabel metrics.
  2. If we use array_api_dispatch, convert sparse matrix to array here https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L217 (If I am not mistaken, exception raises because _average gets a sparse matrix which is create here)
  3. If use array_api_dispatch do not convert array to sparse matrix here https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L132

Also, I will try to check each PR in colab more carefully.

Tialo avatar Jun 20 '24 22:06 Tialo

Yes, you are right! I think it could be nice to support multilabel. And given that the sparse matrices are created internally and not by the user, I'm more for option 3. But we decide for option 1, I think we should have some meaningful error message.

EdAbati avatar Jun 21 '24 06:06 EdAbati

I've implemented something in the direction of option 3.

I will work on mypy later today.

Not sure if I should make 2 PRs, since the fixes for entropy and for accuracy are unrelated. 🤔

EdAbati avatar Jun 21 '24 07:06 EdAbati

Not sure if I should make 2 PRs, since the fixes for entropy and for accuracy are unrelated. 🤔

That would be great thanks. That would ease the review in case one of the changes is controversial.

ogrisel avatar Jun 21 '24 07:06 ogrisel

Created the 2nd PR #29336 for accuracy and zero_loss :)

EdAbati avatar Jun 22 '24 13:06 EdAbati