DirectML support
See #1169 for details
Adds functions to implement torch-directml. This library uses Directx12 to accelerate PyTorch, which is slower than CUDA but has the benefit of being compatible with any Directx12 device on Windows, including AMD GPUs.
Also includes a working example!
@Teranis This looks good! Can you also add some documentation in the docs/installation.rst file?
Hi, I modified the cellpose code so that it tries to use DirectML if CUDA and MPS are not available. I tried to follow the existing logic for it. I also included installation instructions (which are now a lot easier than before.
Codecov Report
Attention: Patch coverage is 40.00000% with 24 lines in your changes missing coverage. Please review.
Project coverage is 47.08%. Comparing base (
682a1de) to head (e08b47f). Report is 46 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| cellpose/core.py | 23.80% | 16 Missing :warning: |
| cellpose/models.py | 58.82% | 7 Missing :warning: |
| cellpose/dynamics.py | 50.00% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1181 +/- ##
==========================================
- Coverage 47.13% 47.08% -0.05%
==========================================
Files 16 16
Lines 3711 3761 +50
==========================================
+ Hits 1749 1771 +22
- Misses 1962 1990 +28
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Hi @mrariden ,
I have added a unit test for core.assign_device. I'm not sure how I would realistically cover the other functionality. I would run everything, but with a different fixture as the model:
@pytest.fixture()
def cellposemodel_fixture_24layer_directml():
""" This is functionally identical to CellposeModel but uses mock class """
gpu = True
with patch('torch.cuda.is_available', return_value=False), \
patch('torch.backends.mps.is_available', return_value=False):
model = MockCellposeModel(24, gpu=gpu)
yield model
Then, I would compare the model output to the ground truth. I'm confused about how to do these last steps or why you strictly stick to CLI.
Kind regards, Timon Stegmaier
Hi @mrariden , I was wondering if there is an update, or if I need to do anything else. Kind regards, Timon Stegmaier