mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Add GPU implementation of QR factorization [wip]

Open nicolov opened this issue 1 year ago • 6 comments

Proposed changes

Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:

Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs Jan Priessnitz, GPU acceleration of matrix factorization

Here is the reference code in numpy for the algorithm.

Left todo

  • [x] clean up handling of batched inputs: slice the inputs/outputs and only pass the slice to the algorithm. Temporaries need only be sized for a single input matrix.
  • [ ] share some constants between the kernel and the driver function.
  • [x] consider merging the two kernels to compute W.
  • [ ] benchmark and optimize grid/block sizes.

Checklist

Put an x in the boxes that apply.

  • [ ] I have read the CONTRIBUTING document
  • [ ] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

nicolov avatar Apr 09 '24 14:04 nicolov

@awni I tried to apply your comments and pushed https://github.com/ml-explore/mlx/pull/975/commits/501b889de904adf3b60f4b19d16b3ee46bdc5db2 to avoid creating a new command buffer for each kernel, but I get:

-[AGXG13XFamilyCommandBuffer tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1015: failed assertion `A command encoder is already encoding to this command buffer'

nicolov avatar Apr 15 '24 11:04 nicolov

Did you manually make a command encoder from the command buffer? MLX manages an active command encoder so you should not make it directly. Rather call the device.get_command_encoder() to get the active encoder.

awni avatar Apr 15 '24 13:04 awni

Rather call the device.get_command_encoder() to get the active encoder.

I also tried doing that in https://github.com/ml-explore/mlx/pull/975/commits/b979ccf3e4051b99320e5d69abe138359c9f0660 which just produces the wrong result.

nicolov avatar Apr 15 '24 13:04 nicolov

I also tried tracing and XCode complains about redundant bindings. Should I somehow refactor how I bind buffers to the encoder?

Screenshot 2024-04-15 at 3 09 58 PM

nicolov avatar Apr 15 '24 14:04 nicolov

I fixed the code (needed to introduce one more kernel to ensure the atomics were synchronized properly across different threadgroups). It's a bit slow, so I'll try to improve it now:

  device     n  time_ms
0    cpu  2000    99.39
1    gpu  2000   283.36

nicolov avatar Apr 17 '24 14:04 nicolov

@nicolov are you planning to come back to this?

awni avatar Apr 25 '24 03:04 awni

I compared with Pytorch, which can use either magma or cusolver. I used a cloud machine with A100 and a desktop with a 1050 (which should be in the ballpark of my M1 Max):

magma cusolver cpu
1050 8.8 11.7 14.1
A100 18.2 6.3 21

The magma backend (see here) ends up calling magma_sgeqrf2_gpu(m, n, dA, ldda, tau, info) (here) which is an hybrid CPU/GPU algorithm using queues to overlap CPU compute (factorization of small submatrices) and GPU compute (transformation on the remainder of the matrix).

I'm not sure if it makes sense to implement such a hybrid in mlx, as it would require a lot of synchronization points. @awni what do you think?

Benchmark code:

import torch
a = torch.randn((1000, 1000))
a_cuda = a.to('cuda')

%timeit torch.linalg.qr(a, 'complete')

torch.backends.cuda.preferred_linalg_library('magma')
%timeit torch.linalg.qr(a_cuda, 'complete')

torch.backends.cuda.preferred_linalg_library('cusolver')
%timeit torch.linalg.qr(a_cuda, 'complete')

nicolov avatar Jun 24 '24 20:06 nicolov

I'm not sure if it makes sense to implement such a hybrid in mlx, as it would require a lot of synchronization points. @awni what do you think?

Indeed that sounds tricky to make fast.

awni avatar Jun 25 '24 14:06 awni

I'll close this PR, will re-open if I find an algorithm that's fast enough.

nicolov avatar Jul 04 '24 07:07 nicolov