djl icon indicating copy to clipboard operation
djl copied to clipboard

Mac M1 GPU support via Pytorch and MPS

Open ajrnz opened this issue 1 year ago • 1 comments

Given that pytorch supports GPUs (via Metal/MPS) on Mac M1 machines (since version pytorch 1.12, should it not be possible to enable this in the pytorch engine?

How much work would this be now that the JNI bindings are now ARM (aarch64)

The speed up is significant for Mac M1 users.

References

https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/

ajrnz avatar Sep 15 '22 09:09 ajrnz

Currently we are already using the offical torch 1.12.1 MacOS Arm64 wheel to build the JNI. So it has capabitlity to extend to GPU, but we didn't implement the MPS device yet. We are welcome for any contribution to this device addition. You can try to get it work by adding MPS device in the code

lanking520 avatar Sep 22 '22 04:09 lanking520