[FR] Improve performance of Bool embeddings
domluna posted on generativeAI Slack a really nice gist using Bool embeddings (held in Int8) + StaticArrays (here).
It seems to provide huge performance benefits compared to my fairly trivial Bool retrieval implementation in RAGTools (here).
playing around with RAG and binary vectors https://huggingface.co/blog/embedding-quantization the idea of 64 bytes is from this post https://www.mixedbread.ai/blog/binary-mrl this is brute force parallel implementation for search that assumes the data is stored as binary in byte elements (int8, uint8). so 512 bits is 64 int8 elements, or a 64 element static vector using StaticArrays. For 100M vector dataset I get < 1s comparison time on my M1 macbook air. I’m wondering if there’s anything I can do to make it faster.
It would be excellent to:
- [ ] add an extension for StaticArrays.jl
- [ ] build a kernel for similarity on top of Bool embeddings in Int8 (see the gist)
- [ ] Add the corresponding type <:
AbstractSimilarityFinder - [ ] Add relevant tests (+ make sure that the dimension checks still work, since the
size(emb,1)would now be 64-times smaller) - [ ] Overload EmbeddingEltype trait (to provide a downstream signal of the embedding type/logic)
- [ ] Update any serialize/deserialize functions necessary to be able to store on disk (AIHelpMe uses HDF5, so just test it... eg, BitMatrix wasn't working with HDF5)
Ideally, also integrate into AIHelpMe so everyone downstream can benefit (these TODOs would be transferred to the other repo)
- [ ] Add SA.jl as a direct dep of AIHelpMe
- [ ] Change all bool knowledge packs to this new implementation
- [ ] update artifacts correspondingly
For reference, I did some benchmarking few weeks ago and Bool embeddings performed really well (especially assuming that we would use a reranker/cross-encoder downstream
With top_k=20, 1024dims in Bool still retains 97% recall and very competitive MRR:
nice!
StaticArrays makes it faster but it's not massive ~2x at most (on 1M vectors), so techinically not explitcitly required:
# no StaicArrays
julia> @b $k_closest_parallel(X1, q1, 10)
21.741 ms (52 allocs: 5.250 KiB)
# q2 is a StaticArray
julia> @b $k_closest_parallel(X1, q2, 10)
12.545 ms (52 allocs: 5.500 KiB)
# q2 is a StaticArray, X2 is a list of static arrays
julia> @b $k_closest_parallel(X2, q2, 10)
9.180 ms (52 allocs: 5.500 KiB)
the cool thing about the binary embeddings is that you can keep everything in memory and you don't need an enourmously powerful computer. for 1 billion rows you would need 64GB instead of 1TB, which greatly decreases costs. Furthermore you can potentially use this as part of a reranking pipeline where you keep a higher dimensional embedding version on disk and then seek the relevant rows from the binary embedding similarity.
function hamming_distance(x1::AbstractArray{T}, x2::AbstractArray{T})::Int where {T<:INT}
s = 0
for i in eachindex(x1, x2)
s += hamming_distance(x1[i], x2[i])
end
s
end
changing the sum calc to the above now produces these timings (adding simd or inbounds macros seems to have no effect)
julia> @b k_closest_parallel(X1, q1, 10)
4.710 ms (52 allocs: 5.250 KiB)
julia> @b k_closest_parallel(X1, q2, 10)
4.080 ms (52 allocs: 5.500 KiB)
julia> @b k_closest_parallel(X2, q2, 10)
4.070 ms (52 allocs: 5.500 KiB)
so using StaticArrays doesn't add much.
julia> versioninfo()
Julia Version 1.11.0-beta1
Commit 08e1fc0abb9 (2024-04-10 08:40 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin22.4.0)
CPU: 8 × Apple M1
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
JULIA_STACKTRACE_MINIMAL = true
DYLD_LIBRARY_PATH = /Users/lunaticd/.wasmedge/lib
JULIA_EDITOR = nvim
using StaticArrays
using Base.Threads
INT = Union{Int8,UInt8}
function hamming_distance(x1::T, x2::T)::Int where {T<:INT}
c = 0
for i = 0:7
c += ((x1 >> i) & 1) ⊻ ((x2 >> i) & 1)
end
return Int(c)
end
function hamming_distance(x1::AbstractArray{T}, x2::AbstractArray{T})::Int where {T<:INT}
s = 0
@inbounds @simd for i in eachindex(x1, x2)
s += hamming_distance(x1[i], x2[i])
end
s
end
function k_closest_parallel(
db::AbstractArray{V},
query::AbstractVector{T},
k::Int,
) where {T<:INT,V<:AbstractVector{T}}
n = length(db)
t = nthreads()
task_ranges = [(i:min(i + n ÷ t - 1, n)) for i = 1:n÷t:n]
tasks = map(task_ranges) do r
Threads.@spawn k_closest(view(db, r), query, k)
end
results = fetch.(tasks)
sort!(vcat(results...), by = x -> x[1])[1:k]
end
function k_closest(
db::AbstractVector{V},
query::AbstractVector{T},
k::Int,
) where {T<:INT,V<:AbstractVector{T}}
results = Vector{Pair{Int,Int}}(undef, k)
m = typemax(Int)
fill!(results, (m => -1))
@inbounds for i in eachindex(db)
d = hamming_distance(db[i], query)
for j = 1:k
if d < results[j][1]
old = results[j]
results[j] = d => i
for l = j+1:k-1
old, results[l] = results[l], old
end
break
end
end
end
return results
end
the core operation takes 20ns on a static array but when everything is combined we actually get even lower than that on average.
On a 1M vector of where each element is a 64 element vector of Int8
julia> @b k_closest_parallel(X, q, 1)
2.816 ms (50 allocs: 3.547 KiB)
julia> @b k_closest_parallel(X, q, 5)
3.142 ms (52 allocs: 4.516 KiB)
julia> @b k_closest_parallel(X, q, 10)
3.560 ms (52 allocs: 5.500 KiB)
julia> @b k_closest_parallel(X, q, 50)
7.449 ms (54 allocs: 13.938 KiB)
julia> @b k_closest_parallel(X, q, 100)
11.626 ms (54 allocs: 23.781 KiB)
naively looping and doing the distance op 1M times without any additional work would be 20ms but parallelized over 4 cores we're less than that.
I'm actually slightly hesitant to enforce threading under the hood, because:
- if we have composite indices (eg, 5 indices with different types of data), it might be more natural to thread over them
- we don't actually need the benefits (and associated future complications) - 99% of time spent on GenAI calls, so perhaps we can enjoy the luxury of letting LLVM handle threading and keep it single-threaded on Julia level?
Btw. I'm a bit lost in the references above -- what's your recommendation wrt StaticArrays? Do you think they are a valuable addition or should we keep it simple?
it doesn't seem that StaticArrays is absolutely necessary. The performance will be better but it's not drastically better. For 1M rows, 17.5% faster. Parallel isn't necessary, but it's just a situation where the problem is easily parallelizable (mapreduce pattern) so we do get very close to perfect scaling with cores, i.e., 4 cores makes it ~4x faster, 16 cores - 16x faster.
Linking a great writeup by @domluna here: https://github.com/domluna/tinyrag
In particular, this function looks the same as the inner function here.
It needs some benchmarking and potentially mini PR if someone is interested!
EDIT: I should have said the PR could be:
- improve the current binary emb. saved as Bool
- add an implementation for binary emb. saved as UInt8
Closed by https://github.com/svilupp/PromptingTools.jl/pull/152