himalaya icon indicating copy to clipboard operation
himalaya copied to clipboard

torch MPS backend support

Open emdupre opened this issue 2 years ago • 11 comments

Thanks for your great work making VM accessible !

I was looking into starting with himalaya, but it seems that you do not currently support pytorch's MPS backend for working on the M1 GPU. Is this correct ?

As the MPS backend has been officially released for almost a year, it would be great to take advantage of it to accelerate himalaya models ! Is this something that you would be interested in ?

Thanks again, Elizabeth

emdupre avatar Feb 14 '23 23:02 emdupre

oh cool, thanks for letting us now! I didn't know that pytorch supports the M1. Since I have an M1 based machine, I will try to implement this in himalaya :)

mvdoc avatar Feb 15 '23 00:02 mvdoc

Unfortunately, it looks like the MPS support in pytorch is far from being completed. Many linear algebra operators are not implemented yet (https://github.com/pytorch/pytorch/issues/77764). For example, using torch.linalg.eigh returns the following error:

NotImplementedError: The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

I don't think we can support the MPS backend in himalaya until all the linalg operations are fully implemented in pytorch.

mvdoc avatar Feb 15 '23 04:02 mvdoc

Thanks for getting to this so quickly, @mvdoc !!

Would it make since to try the PYTORCH_ENABLE_MPS_FALLBACK=1 flag as suggested in the linked thread ? I completely understand if you'd rather wait until the full M1 linalg support is available, but it might also be nice to take advantage of what is currently available.

emdupre avatar Feb 15 '23 18:02 emdupre

Tom can correct me if I'm wrong, but I think that using that flag will defeat the purpose of using GPU support, and it will be equivalent to running himalaya with a CPU backend (backend = set_backend("numpy") or backend = set_backend("torch")). Most of the GPU acceleration we get in himalaya comes from the speed of those linear algebra operators on GPU. But if these operators are not implemented yet in pytorch for the MPS backend, I don't think there will be any noticeable speed-up in himalaya.

On Wed, Feb 15, 2023 at 10:45 AM Elizabeth DuPre @.***> wrote:

Thanks for getting to this so quickly, @mvdoc https://github.com/mvdoc !!

Would it make since to try the PYTORCH_ENABLE_MPS_FALLBACK=1 flag as suggested in the linked thread ? I completely understand if you'd rather wait until the full M1 linalg support is available, but it might also be nice to take advantage of what is currently available.

— Reply to this email directly, view it on GitHub https://github.com/gallantlab/himalaya/issues/42#issuecomment-1431843963, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABO5TGX6MDQIXDKAI7UCPK3WXUP6LANCNFSM6AAAAAAU4FMIKU . You are receiving this because you were mentioned.Message ID: @.***>

-- Matteo Visconti di Oleggio Castello, Ph.D. Postdoctoral Scholar Helen Wills Neuroscience Institute, UC Berkeley MatteoVisconti.com http://matteovisconti.com || github.com/mvdoc || linkedin.com/in/matteovisconti

mvdoc avatar Feb 15 '23 18:02 mvdoc

Himalaya solvers are using GPUs to speed up two kinds of expensive computations:

  • matrix inversions, through torch.linalg.eigh or torch.linalg.svd
  • matrix multiplications and other operations, through torch.matmul, torch.mean, etc.

I think both improvements are important. So even though MPS does not support torch.linalg.eigh, it could still be useful with PYTORCH_ENABLE_MPS_FALLBACK=1 to speed up matrix multiplications and other operations. In fact, some solvers are not using matrix inversions at all (e.g. KernelRidge(solver="conjugate_gradient"),WeightedKernelRidge(), or MultipleKernelRidgeCV(solver="hyper_gradient")). For these solvers, an MPS backend would likely be beneficial.

Also, all solvers using torch.linalg.eigh can also work with torch.linalg.svd. Do you know if MPS supports torch.linalg.svd?

TomDLT avatar Feb 15 '23 19:02 TomDLT

Also, all solvers using torch.linalg.eigh can also work with torch.linalg.svd. Do you know if MPS supports torch.linalg.svd?

Locally I'm not able to confirm MPS support using either the stable or nightly build, getting

UserWarning: The operator 'aten::linalg_svd' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.

~~Looking at this older (solved) bug report, however, it seems that they were able to run torch.linalg.svd with the MPS backend.... I'll try to track down the discrepancy.~~

EDIT : It looks like this is set here and indeed triggered as a warning in the thread I sent. So : no, torch.linalg.svd is not currently supported on the MPS backend !

emdupre avatar Feb 15 '23 22:02 emdupre

I will experiment a bit more and see what speedup we get even if we use PYTORCH_ENABLE_MPS_FALLBACK=1 :)

mvdoc avatar Feb 15 '23 22:02 mvdoc

Well, I'm happy I was wrong (by a lot). I ran the voxelwise tutorial that fits the banded ridge model (it's actually one banded ridge model + one ridge model). We get a ~3x speed up by using the MPS backend. (The eigh diagonalizer kept crashing with torch_mps, so I had to switch to the svd diagonalizer; I guess if we are comparing the svd solver for the CPU as well, the speed up may be slightly larger.)

There are still things to check (some tests fail with the torch_mps fail due to numerical precision), but I can see the torch_mps backend implemented soon.

Backend numpy, eigh diag

python 06_plot_banded_ridge_model.py  15898.39s user 1106.48s system 243% cpu 1:56:27.62 total

Backend torch_mps, svd diag

python 06_plot_banded_ridge_model.py  1195.92s user 120.63s system 60% cpu 36:22.02 total

mvdoc avatar Feb 16 '23 14:02 mvdoc

Another test: we don't get a noticeable speedup when running a simple ridge model.

Backend torch, svd diag

python 05_plot_motion_energy_model.py  236.70s user 50.63s system 177% cpu 2:41.66 total

Backend torch_mps, svd diag

python 05_plot_motion_energy_model.py  89.28s user 29.34s system 79% cpu 2:28.93 total

mvdoc avatar Feb 16 '23 14:02 mvdoc

We get a ~3x speed up by using the MPS backend.

Nice! Thanks for working on this.

TomDLT avatar Feb 16 '23 17:02 TomDLT

Yes, thank you @mvdoc for working on this and @TomDLT for your insight !

If there's anything I can provide to help, here, please let me know. Happy to help develop or review !

emdupre avatar Feb 16 '23 19:02 emdupre