himalaya
himalaya copied to clipboard
torch MPS backend support
Thanks for your great work making VM accessible !
I was looking into starting with himalaya, but it seems that you do not currently support pytorch's MPS backend for working on the M1 GPU. Is this correct ?
As the MPS backend has been officially released for almost a year, it would be great to take advantage of it to accelerate himalaya models ! Is this something that you would be interested in ?
Thanks again, Elizabeth
oh cool, thanks for letting us now! I didn't know that pytorch supports the M1. Since I have an M1 based machine, I will try to implement this in himalaya :)
Unfortunately, it looks like the MPS support in pytorch is far from being completed. Many linear algebra operators are not implemented yet (https://github.com/pytorch/pytorch/issues/77764). For example, using torch.linalg.eigh
returns the following error:
NotImplementedError: The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable
PYTORCH_ENABLE_MPS_FALLBACK=1
to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
I don't think we can support the MPS backend in himalaya until all the linalg operations are fully implemented in pytorch.
Thanks for getting to this so quickly, @mvdoc !!
Would it make since to try the PYTORCH_ENABLE_MPS_FALLBACK=1
flag as suggested in the linked thread ? I completely understand if you'd rather wait until the full M1 linalg support is available, but it might also be nice to take advantage of what is currently available.
Tom can correct me if I'm wrong, but I think that using that flag will
defeat the purpose of using GPU support, and it will be equivalent to
running himalaya with a CPU backend (backend = set_backend("numpy")
or
backend = set_backend("torch")
). Most of the GPU acceleration we get in
himalaya comes from the speed of those linear algebra operators on GPU. But
if these operators are not implemented yet in pytorch for the MPS backend,
I don't think there will be any noticeable speed-up in himalaya.
On Wed, Feb 15, 2023 at 10:45 AM Elizabeth DuPre @.***> wrote:
Thanks for getting to this so quickly, @mvdoc https://github.com/mvdoc !!
Would it make since to try the PYTORCH_ENABLE_MPS_FALLBACK=1 flag as suggested in the linked thread ? I completely understand if you'd rather wait until the full M1 linalg support is available, but it might also be nice to take advantage of what is currently available.
— Reply to this email directly, view it on GitHub https://github.com/gallantlab/himalaya/issues/42#issuecomment-1431843963, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABO5TGX6MDQIXDKAI7UCPK3WXUP6LANCNFSM6AAAAAAU4FMIKU . You are receiving this because you were mentioned.Message ID: @.***>
-- Matteo Visconti di Oleggio Castello, Ph.D. Postdoctoral Scholar Helen Wills Neuroscience Institute, UC Berkeley MatteoVisconti.com http://matteovisconti.com || github.com/mvdoc || linkedin.com/in/matteovisconti
Himalaya solvers are using GPUs to speed up two kinds of expensive computations:
- matrix inversions, through
torch.linalg.eigh
ortorch.linalg.svd
- matrix multiplications and other operations, through
torch.matmul
,torch.mean
, etc.
I think both improvements are important. So even though MPS does not support torch.linalg.eigh
, it could still be useful with PYTORCH_ENABLE_MPS_FALLBACK=1
to speed up matrix multiplications and other operations. In fact, some solvers are not using matrix inversions at all (e.g. KernelRidge(solver="conjugate_gradient")
,WeightedKernelRidge()
, or MultipleKernelRidgeCV(solver="hyper_gradient")
). For these solvers, an MPS backend would likely be beneficial.
Also, all solvers using torch.linalg.eigh
can also work with torch.linalg.svd
. Do you know if MPS supports torch.linalg.svd
?
Also, all solvers using torch.linalg.eigh can also work with torch.linalg.svd. Do you know if MPS supports torch.linalg.svd?
Locally I'm not able to confirm MPS support using either the stable or nightly build, getting
UserWarning: The operator 'aten::linalg_svd' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.
~~Looking at this older (solved) bug report, however, it seems that they were able to run torch.linalg.svd
with the MPS backend.... I'll try to track down the discrepancy.~~
EDIT : It looks like this is set here and indeed triggered as a warning in the thread I sent. So : no, torch.linalg.svd
is not currently supported on the MPS backend !
I will experiment a bit more and see what speedup we get even if we use PYTORCH_ENABLE_MPS_FALLBACK=1
:)
Well, I'm happy I was wrong (by a lot). I ran the voxelwise tutorial that fits the banded ridge model (it's actually one banded ridge model + one ridge model). We get a ~3x speed up by using the MPS backend. (The eigh
diagonalizer kept crashing with torch_mps
, so I had to switch to the svd
diagonalizer; I guess if we are comparing the svd
solver for the CPU as well, the speed up may be slightly larger.)
There are still things to check (some tests fail with the torch_mps
fail due to numerical precision), but I can see the torch_mps
backend implemented soon.
Backend numpy
, eigh
diag
python 06_plot_banded_ridge_model.py 15898.39s user 1106.48s system 243% cpu 1:56:27.62 total
Backend torch_mps
, svd
diag
python 06_plot_banded_ridge_model.py 1195.92s user 120.63s system 60% cpu 36:22.02 total
Another test: we don't get a noticeable speedup when running a simple ridge model.
Backend torch
, svd
diag
python 05_plot_motion_energy_model.py 236.70s user 50.63s system 177% cpu 2:41.66 total
Backend torch_mps
, svd
diag
python 05_plot_motion_energy_model.py 89.28s user 29.34s system 79% cpu 2:28.93 total
We get a ~3x speed up by using the MPS backend.
Nice! Thanks for working on this.
Yes, thank you @mvdoc for working on this and @TomDLT for your insight !
If there's anything I can provide to help, here, please let me know. Happy to help develop or review !