faiss icon indicating copy to clipboard operation
faiss copied to clipboard

GpuIndexIVFScalarQuantizer with quantizers that require shared memory on the GPU don't work for k >= 1024.

Open gabuzi opened this issue 7 months ago • 3 comments

Summary

GpuIndexIVFScalarQuantizer with scalar quantizers that require shared memory on the GPU don't seem to work for k >= 1024 in Faiss 1.7.4. See the small reproduction script at the bottom.

I need such high k in order to make up for quantization errors via refinement on CPU.

My investigation points towards exhausting the shared memory available on GPU without opt-in (see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#fn33).

GPUs by default can only address 48kiB shared memory per block without special opt-in (see link above), but devices since compute capability 7 can potentially address more, but do require dynamic allocation and the opt-in.

For k = 1024, faiss uses 128 threads (faiss/gpu/impl/scan/IVFInterleaved1024.cu). Since the move to 64-bit indexing, this configuration fully exhausts the 48kiB shared memory already as shown in the following two lines (kNumWarps = 4 for 128 threads, and NumWarpQ is 1024) https://github.com/facebookresearch/faiss/blob/0fc8456e1db0846759bd8d76a03ebe497c511a9f/faiss/gpu/impl/IVFInterleaved.cuh#L83-L84

This leaves no shared memory to be allocated dynamically in the kernel launch resulting in the cuda error.

For k = 2048, the problem is identical, as this gets launched with half the number of threads, but double the NumWarpQ, again filling the 48k of statically allocatable shared memory.

Quick workarounds that I have successfully tested (equal results vs CPU indexes, no performance tests, and both seem suboptimal / hacky):

  1. reduce number of threads for k >= 1024, such that less smem is allocated statically, leaving some for dynamic allocation.
  2. or call the opt-in before kernel launch (cudaFuncSetAttribute()) see below. This requires a GPU with CC >= 7.0.

code for option 2: add to IVFInterleaved.cuh in the IVFINT_RUN macro for both the useResidual cases. Edited to add: This is based on 1.7.4. I now saw that this area has seen some recent changes, but apparently only a refactor to reorganize code files, and the solution could still hold.

+        int smem_dyn_bytes = codec.getSmemSize(dim);                         \
+        if (useResidual) {                                                   \
+            if (smem_dyn_bytes > 0)                                          \
+                cudaFuncSetAttribute(                                        \
+                        ivfInterleavedScan<                                  \
+                                CODEC_TYPE,                                  \
+                                METRIC_TYPE,                                 \
+                                THREADS,                                     \
+                                NUM_WARP_Q,                                  \
+                                NUM_THREAD_Q,                                \
+                                true>,                                       \
+                        cudaFuncAttributeMaxDynamicSharedMemorySize,         \
+                        smem_dyn_bytes);                                     \
             ivfInterleavedScan<                                              \
                    CODEC_TYPE,                                              \
                    METRIC_TYPE,                                             \
                    THREADS,                                                 \
                    NUM_WARP_Q,                                              \
                    NUM_THREAD_Q,                                            \
                    true>                                                    \
                    <<<grid, THREADS, codec.getSmemSize(dim), stream>>>(     \

Related: The check in https://github.com/facebookresearch/faiss/blob/7dd06dd188643523d1a0faab511ee8fc0c5b0e3b/faiss/gpu/GpuIndexIVFScalarQuantizer.cu#L92-L110 doesn't seem to consider that the maxK selection also requires shared memory.

Platform

x86-64, intel RTX 4090

OS: Ubuntu 20.04 LTS

Faiss version: 1.7.4 Installed from: anaconda in clean docker image from nvidia/cuda:11.4.3-runtime-ubi8, but also built myself install command: conda create -y -n faiss-issue -c pytorch -c nvidia python=3.10 faiss-gpu=1.7.4 mkl=2021 blas=1.0=mkl

Faiss compilation options:

Running on:

  • [ ] CPU
  • [x] GPU

Interface:

  • [x] C++
  • [x] Python

Reproduction instructions

import faiss
import numpy as np
res = faiss.StandardGpuResources()
d = 10
nlist = 10
idx = faiss.GpuIndexIVFScalarQuantizer(res, d, nlist, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2, False)
numvecs = 40000
vecs = np.random.rand(numvecs, d)
idx.train(vecs)
idx.add(vecs)

idx.search(vecs, 512)  # works
print("512 works")
idx.search(vecs, 1024)  # fails
print("I never get to here :(")

Output:

512 works
Faiss assertion 'err__ == cudaSuccess' failed in void faiss::gpu::ivfInterleavedScanImpl_1024_(faiss::gpu::Tensor<float, 2, true>&, faiss::gpu::Tensor<long int, 2, true>&, faiss::gpu::DeviceVector<void*>&, faiss::gpu::DeviceVector<void*>&, faiss::gpu::IndicesOptions, faiss::gpu::DeviceVector<long int>&, int, faiss::MetricType, bool, faiss::gpu::Tensor<float, 3, true>&, faiss::gpu::GpuScalarQuantizer*, faiss::gpu::Tensor<float, 2, true>&, faiss::gpu::Tensor<long int, 2, true>&, faiss::gpu::GpuResources*) at /home/circleci/miniconda/conda-bld/faiss-pkg_1681998300314/work/faiss/gpu/impl/scan/IVFInterleaved1024.cu:13; details: CUDA error 1 invalid argument
Aborted (core dumped)

Edited to remove a leftover from a modification that I made at some point and a comment.

gabuzi avatar Jan 16 '24 15:01 gabuzi

Thanks for the precise report. I don't know if we can do something about it @wickedfoo ?

mdouze avatar Jan 29 '24 11:01 mdouze

Would there be an alternative approach that is slower but more parsimonious in memory?

mdouze avatar Jan 29 '24 11:01 mdouze

Thanks for getting back to me and sorry for the delay.

I may have been a bit hasty in saying "we need such high k". The reality was more like that we were looking into it, but we have abandoned this this high-k approach due to increased temporary memory requirements during search. This is no longer an issue for us, but the problem as described in the initial post still stands.

Our main rationale for using such quantization was to save GPU memory. Naturally, the temporary memory on the GPU grows with k (pre-refinement), which we have naively ignored. Memory savings by storing the index data compressed can thus easily be annihilated by the temporary allocations required for storing the distances and indices to the temporary results on the GPU before refinement. The solution is then to reduce the query batchsize. This has been exacerbated a bit by the recent move to 64-bit indexing on the GPU.

gabuzi avatar Feb 13 '24 12:02 gabuzi