Kilosort
Kilosort copied to clipboard
Improve memory management in clustering_qr.kmeans_plusplus
This modification avoids the creation or immediately deletes unnecessary tensors in clustering_qr.kmeans_plusplus. It helps with OOM errors (#746 ) happening at https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L202
Xg can at sometimes be quite big (5GB in the case I get OOM), in both of these lines a copy of Xg was created unnecessarily on the GPU.
https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L166-L168 & https://github.com/MouseLand/Kilosort/blob/b2f5ded41aaa7e13ed44bf43ecf488770ef54753/kilosort/clustering_qr.py#L202
The solution to line 202 does not impact speed. Solution to line 167 might impact speed but not in any noticeable fashion on my tests, for this reason I didn´t extend the reach of the clear_cache arg to the kmeans_plusplus func.
Tested on pytorch 2.1.2 and 2.4.1.