coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Fix pytest-xdist compatibility in test_torch_ops

Open M-Quadra opened this issue 6 months ago • 0 comments

This PR replaces the direct torch operator of pytest.mark.parametrize with getattr in test_torch_ops to support parallel testing via pytest-xdist.


Executing all unit tests on a single device takes too long. To accelerate this process, I would like to leverage parallel execution with pytest-xdist. However, some test cases in test_torch_ops are incompatible.

When running:

pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestVarStd
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestTorchTensor::test_torch_rank0_tensor
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestSTFT

The following error occurs:

==================================== ERRORS ====================================
_____________________________ ERROR collecting gw1 _____________________________
Different tests were collected between gw0 and gw1. The difference is:
--- gw0

+++ gw1

Solution

Replace direct torch operator with getattr.

Before:

torch_op = torch.abs

After:

torch_op = getattr(torch, "abs")

This change allows the tests to run successfully with pytest-xdist:

pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

M-Quadra avatar Jun 12 '25 15:06 M-Quadra