coremltools
coremltools copied to clipboard
Fix pytest-xdist compatibility in test_torch_ops
This PR replaces the direct torch operator of pytest.mark.parametrize with getattr in test_torch_ops to support parallel testing via pytest-xdist.
Executing all unit tests on a single device takes too long. To accelerate this process, I would like to leverage parallel execution with pytest-xdist. However, some test cases in test_torch_ops are incompatible.
When running:
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestVarStd
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestTorchTensor::test_torch_rank0_tensor
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestSTFT
The following error occurs:
==================================== ERRORS ====================================
_____________________________ ERROR collecting gw1 _____________________________
Different tests were collected between gw0 and gw1. The difference is:
--- gw0
+++ gw1
Solution
Replace direct torch operator with getattr.
Before:
torch_op = torch.abs
After:
torch_op = getattr(torch, "abs")
This change allows the tests to run successfully with pytest-xdist:
pytest -n 4 --disable-warnings -x coremltools/converters/mil/frontend/torch/test/test_torch_ops.py