mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Add GPU implementation of QR factorization [wip]

Open nicolov opened this issue 2 months ago • 6 comments

Proposed changes

Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:

Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs Jan Priessnitz, GPU acceleration of matrix factorization

Here is the reference code in numpy for the algorithm.

Left todo

  • [x] clean up handling of batched inputs: slice the inputs/outputs and only pass the slice to the algorithm. Temporaries need only be sized for a single input matrix.
  • [ ] share some constants between the kernel and the driver function.
  • [x] consider merging the two kernels to compute W.
  • [ ] benchmark and optimize grid/block sizes.

Checklist

Put an x in the boxes that apply.

  • [ ] I have read the CONTRIBUTING document
  • [ ] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

nicolov avatar Apr 09 '24 14:04 nicolov