heat
heat copied to clipboard
Support Apple MPS acceleration
Description
RESEARCH FEATURE This is a first attempt at extending PyTorch's support of Apple MPS acceleration to Heat users.
NOTE: distributed operations between MPS have not been tested! This feature is meant to enable acceleration of non-distributed numpy operations on Apple MPS.
Example:
import numpy as np
import heat as ht
a = np.random.randn(10000,10000).astype(np.float32)
b = np.random.randn(10000,10000).astype(np.float32)
c = a@b
a = ht.array(a, device="gpu")
b = ht.array(b, device="gpu")
d = a@b
Warnings
- MPS does not support float64. Double precision floats may be cast to single precision
- many MPS operations do not support int64 (i.e. cumulative operations,
cumsum
,cumprod
...). Double precision integers may be cast to single precision internally - memory-distributed calculations have not been tested.
- Current status of PyTorch's MPS operations coverage
Issue/s resolved: #1053
Changes proposed:
Type of change
Memory requirements
Performance
Due Diligence
- [ ] All split configurations tested
- [ ] Multiple dtypes tested in relevant functions
- [ ] Documentation updated (if needed)
- [ ] Title of PR is suitable for corresponding CHANGELOG entry
Does this change modify the behaviour of other functions? If so, which?
yes / no
Thank you for the PR!
Codecov Report
Attention: Patch coverage is 0%
with 6 lines
in your changes missing coverage. Please review.
Project coverage is 91.76%. Comparing base (
9296f05
) to head (f3a5ad8
). Report is 2 commits behind head on main.
:exclamation: Current head f3a5ad8 differs from pull request most recent head 84e4942
Please upload reports for the commit 84e4942 to get more accurate results.
Files | Patch % | Lines |
---|---|---|
heat/core/devices.py | 0.00% | 6 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1129 +/- ##
==========================================
+ Coverage 91.74% 91.76% +0.02%
==========================================
Files 80 73 -7
Lines 11683 10524 -1159
==========================================
- Hits 10718 9657 -1061
+ Misses 965 867 -98
Flag | Coverage Δ | |
---|---|---|
unit | 91.76% <0.00%> (+0.02%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
Thank you for the PR!
@mrfh92 this is too unstable to release now, esp. with me being the only one who can test things and being away for the next few weeks. Let's leave this for later
This pull request is stale because it has been open for 60 days with no activity.