djl
djl copied to clipboard
Mac M1 GPU support via Pytorch and MPS
Given that pytorch supports GPUs (via Metal/MPS) on Mac M1 machines (since version pytorch 1.12, should it not be possible to enable this in the pytorch engine?
How much work would this be now that the JNI bindings are now ARM (aarch64)
The speed up is significant for Mac M1 users.
References
https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
Currently we are already using the offical torch 1.12.1 MacOS Arm64 wheel to build the JNI. So it has capabitlity to extend to GPU, but we didn't implement the MPS device yet. We are welcome for any contribution to this device addition. You can try to get it work by adding MPS device in the code