Init Reach and call in spread functions are executing too slowly and consuming excessive resources. (nearest_neighbor_threshold)
Here is my code:
from difflib import ndiff
from time import perf_counter
from datasets import load_dataset
from datasketch import MinHash, MinHashLSH
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
from wordllama import WordLlama
# Load the model and dataset
model = StaticModel.from_pretrained("jester6136/multilingual-e5-large-m2v")
ds = load_dataset("jester6136/osint")["train"]
texts = ds['value'][:30000]
ids = ds['id'][:30000]
vectors = np.zeros((1, model.dim))
tmp_items = ['empty']
reach = Reach(vectors, tmp_items)
reach.delete(tmp_items)
def encode_texts(reach, model, texts: list[str], ids: list[str]) -> None:
prompted_texts = [f"query: {text}" for text in texts]
embedding_matrix = model.encode(prompted_texts, show_progressbar=True)
reach.insert(ids, embedding_matrix)
def deduplicate(reach, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
"""
Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
:param embedding_matrix: The embeddings to deduplicate.
:param threshold: The similarity threshold to use for deduplication.
:param batch_size: The batch size to use for similarity computation.
:return: A tuple containing the deduplicated indices and a dictionary mapping removed indices to original indices.
"""
embedding_matrix = reach.vectors
items = reach.sorted_items
# Use a set for deduplicated IDs and track duplicates
deduplicated_ids = set(items) # Start with all IDs as deduplicated
duplicate_to_original_mapping = {}
results = reach.nearest_neighbor_threshold(
embedding_matrix,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicate detection results
for i, similar_items in enumerate(tqdm(results)):
original_id = items[i] # Get the original ID for the current embedding
if original_id not in deduplicated_ids:
continue # Skip if already marked as a duplicate
# Collect IDs of similar items (excluding the current ID)
similar_ids = [item[0] for item in similar_items if item[0] != original_id]
# Group similar items under the original ID
if similar_ids:
if original_id not in duplicate_to_original_mapping:
duplicate_to_original_mapping[original_id] = [] # Initialize list for this original ID
# Add all similar items to the group and remove from deduplicated set
for sim_id in similar_ids:
if sim_id in deduplicated_ids:
deduplicated_ids.remove(sim_id) # Remove from deduplicated set
duplicate_to_original_mapping[original_id].append(sim_id) # Group under original ID
# Return deduplicated IDs and the mapping of duplicates
return list(deduplicated_ids), duplicate_to_original_mapping
encode_texts(reach, model, texts, ids)
deduplicated_indices, duplicate_to_original_mapping = deduplicate(reach, 0.85, 1024)
It takes around 15 minutes and consumes a large amount of RAM to run. I have tested it on both my local machine and Google Colab, but the result is the same.
And here is the version without spread Reach runs smoothly, taking only 20 seconds to complete the nearest_neighbor_threshold.
from difflib import ndiff
from time import perf_counter
from datasets import load_dataset
from datasketch import MinHash, MinHashLSH
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
from wordllama import WordLlama
# Load the model and dataset
model = StaticModel.from_pretrained("jester6136/multilingual-e5-large-m2v")
ds = load_dataset("jester6136/osint")["train"]
texts = ds['value'][:30000]
ids = ds['id'][:30000]
# Encode texts into embeddings
embedding_matrix = model.encode(texts)
def deduplicate(embedding_matrix: np.ndarray, items: list[str], threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
"""
Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
:param embedding_matrix: The embeddings to deduplicate.
:param threshold: The similarity threshold to use for deduplication.
:param batch_size: The batch size to use for similarity computation.
:return: A tuple containing the deduplicated indices and a dictionary mapping removed indices to original indices.
"""
vectors = np.zeros((1, model.dim))
tmp_items = ['empty']
reach = Reach(vectors, tmp_items)
reach.delete(tmp_items)
reach.insert(ids, embedding_matrix)
# Use a set for deduplicated IDs and track duplicates
deduplicated_ids = set(items) # Start with all IDs as deduplicated
duplicate_to_original_mapping = {}
results = reach.nearest_neighbor_threshold(
embedding_matrix,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicate detection results
for i, similar_items in enumerate(tqdm(results)):
original_id = items[i] # Get the original ID for the current embedding
if original_id not in deduplicated_ids:
continue # Skip if already marked as a duplicate
# Collect IDs of similar items (excluding the current ID)
similar_ids = [item[0] for item in similar_items if item[0] != original_id]
# Group similar items under the original ID
if similar_ids:
if original_id not in duplicate_to_original_mapping:
duplicate_to_original_mapping[original_id] = [] # Initialize list for this original ID
# Add all similar items to the group and remove from deduplicated set
for sim_id in similar_ids:
if sim_id in deduplicated_ids:
deduplicated_ids.remove(sim_id) # Remove from deduplicated set
duplicate_to_original_mapping[original_id].append(sim_id) # Group under original ID
# Return deduplicated IDs and the mapping of duplicates
return list(deduplicated_ids), duplicate_to_original_mapping
# Deduplicate (with a high threshold)
deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, ids, threshold=0.99)
print(f"Number of deduplicated docs: {len(deduplicated_indices)}")
print(f"Number of duplicate_to_original_mapping: {len(duplicate_to_original_mapping)}")
setting in colab:
!pip install datasets model2vec numpy wordllama tqdm datasketch datasets
!git clone https://github.com/stephantul/reach.git
%cd /content/reach
Could it occur because of the threshold?
Hey @Jester6136 ,
Yep, this is because of the high threshold. My apologies, this is rather inefficient. @Pringled contributed a fix, which is now in PR, see here #73 . I'm merging that soon, so using that should make it a lot faster.
I'll ping you when I release. Stéphan
I just merged #73. Could you try the new function? It just returns indices, so that should be a lot faster.