mlx
mlx copied to clipboard
Add GPU implementation of QR factorization [wip]
Proposed changes
Add a GPU implementation of QR factorization using the blocked Householder reflection algorithm, see:
Andrew Kerr, Dan Campbell, Mark Richards, QR Decomposition on GPUs Jan Priessnitz, GPU acceleration of matrix factorization
Here is the reference code in numpy for the algorithm.
Left todo
- [x] clean up handling of batched inputs: slice the inputs/outputs and only pass the slice to the algorithm. Temporaries need only be sized for a single input matrix.
- [ ] share some constants between the kernel and the driver function.
- [x] consider merging the two kernels to compute W.
- [ ] benchmark and optimize grid/block sizes.
Checklist
Put an x
in the boxes that apply.
- [ ] I have read the CONTRIBUTING document
- [ ] I have run
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)