Add GPU implementation of QR factorization [wip]
Proposed changes
Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:
Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs Jan Priessnitz, GPU acceleration of matrix factorization
Here is the reference code in numpy for the algorithm.
Left todo
- [x] clean up handling of batched inputs: slice the inputs/outputs and only pass the slice to the algorithm. Temporaries need only be sized for a single input matrix.
- [ ] share some constants between the kernel and the driver function.
- [x] consider merging the two kernels to compute W.
- [ ] benchmark and optimize grid/block sizes.
Checklist
Put an x in the boxes that apply.
- [ ] I have read the CONTRIBUTING document
- [ ] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
@awni I tried to apply your comments and pushed https://github.com/ml-explore/mlx/pull/975/commits/501b889de904adf3b60f4b19d16b3ee46bdc5db2 to avoid creating a new command buffer for each kernel, but I get:
-[AGXG13XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1015: failed assertion `A command encoder is already encoding to this command buffer'
Did you manually make a command encoder from the command buffer? MLX manages an active command encoder so you should not make it directly. Rather call the device.get_command_encoder() to get the active encoder.
Rather call the
device.get_command_encoder()to get the active encoder.
I also tried doing that in https://github.com/ml-explore/mlx/pull/975/commits/b979ccf3e4051b99320e5d69abe138359c9f0660 which just produces the wrong result.
I also tried tracing and XCode complains about redundant bindings. Should I somehow refactor how I bind buffers to the encoder?
I fixed the code (needed to introduce one more kernel to ensure the atomics were synchronized properly across different threadgroups). It's a bit slow, so I'll try to improve it now:
device n time_ms
0 cpu 2000 99.39
1 gpu 2000 283.36
@nicolov are you planning to come back to this?
I compared with Pytorch, which can use either magma or cusolver. I used a cloud machine with A100 and a desktop with a 1050 (which should be in the ballpark of my M1 Max):
| magma | cusolver | cpu | |
|---|---|---|---|
| 1050 | 8.8 | 11.7 | 14.1 |
| A100 | 18.2 | 6.3 | 21 |
The magma backend (see here) ends up calling magma_sgeqrf2_gpu(m, n, dA, ldda, tau, info) (here) which is an hybrid CPU/GPU algorithm using queues to overlap CPU compute (factorization of small submatrices) and GPU compute (transformation on the remainder of the matrix).
I'm not sure if it makes sense to implement such a hybrid in mlx, as it would require a lot of synchronization points. @awni what do you think?
Benchmark code:
import torch
a = torch.randn((1000, 1000))
a_cuda = a.to('cuda')
%timeit torch.linalg.qr(a, 'complete')
torch.backends.cuda.preferred_linalg_library('magma')
%timeit torch.linalg.qr(a_cuda, 'complete')
torch.backends.cuda.preferred_linalg_library('cusolver')
%timeit torch.linalg.qr(a_cuda, 'complete')
I'm not sure if it makes sense to implement such a hybrid in mlx, as it would require a lot of synchronization points. @awni what do you think?
Indeed that sounds tricky to make fast.
I'll close this PR, will re-open if I find an algorithm that's fast enough.