reach icon indicating copy to clipboard operation
reach copied to clipboard

Init Reach and call in spread functions are executing too slowly and consuming excessive resources. (nearest_neighbor_threshold)

Open Jester6136 opened this issue 1 year ago • 3 comments

Here is my code:

from difflib import ndiff
from time import perf_counter

from datasets import load_dataset
from datasketch import MinHash, MinHashLSH
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
from wordllama import WordLlama
# Load the model and dataset
model = StaticModel.from_pretrained("jester6136/multilingual-e5-large-m2v")
ds = load_dataset("jester6136/osint")["train"]
texts = ds['value'][:30000]
ids = ds['id'][:30000]

vectors = np.zeros((1, model.dim))
tmp_items = ['empty']
reach = Reach(vectors, tmp_items)
reach.delete(tmp_items)

def encode_texts(reach, model, texts: list[str], ids: list[str]) -> None:
    prompted_texts = [f"query: {text}" for text in texts]
    embedding_matrix = model.encode(prompted_texts, show_progressbar=True)
    reach.insert(ids, embedding_matrix)

def deduplicate(reach, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
    """
    Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.

    :param embedding_matrix: The embeddings to deduplicate.
    :param threshold: The similarity threshold to use for deduplication.
    :param batch_size: The batch size to use for similarity computation.
    :return: A tuple containing the deduplicated indices and a dictionary mapping removed indices to original indices.
    """
    embedding_matrix = reach.vectors
    items = reach.sorted_items
    # Use a set for deduplicated IDs and track duplicates
    deduplicated_ids = set(items)  # Start with all IDs as deduplicated
    duplicate_to_original_mapping = {}

    results = reach.nearest_neighbor_threshold(
        embedding_matrix,
        threshold=threshold,
        batch_size=batch_size,
        show_progressbar=True
    )
    # Process duplicate detection results
    for i, similar_items in enumerate(tqdm(results)):
        original_id = items[i]  # Get the original ID for the current embedding

        if original_id not in deduplicated_ids:
            continue  # Skip if already marked as a duplicate

        # Collect IDs of similar items (excluding the current ID)
        similar_ids = [item[0] for item in similar_items if item[0] != original_id]

        # Group similar items under the original ID
        if similar_ids:
            if original_id not in duplicate_to_original_mapping:
                duplicate_to_original_mapping[original_id] = []  # Initialize list for this original ID

            # Add all similar items to the group and remove from deduplicated set
            for sim_id in similar_ids:
                if sim_id in deduplicated_ids:
                    deduplicated_ids.remove(sim_id)  # Remove from deduplicated set
                    duplicate_to_original_mapping[original_id].append(sim_id)  # Group under original ID

    # Return deduplicated IDs and the mapping of duplicates
    return list(deduplicated_ids), duplicate_to_original_mapping

encode_texts(reach, model, texts, ids)
deduplicated_indices, duplicate_to_original_mapping = deduplicate(reach, 0.85, 1024)

It takes around 15 minutes and consumes a large amount of RAM to run. I have tested it on both my local machine and Google Colab, but the result is the same.

And here is the version without spread Reach runs smoothly, taking only 20 seconds to complete the nearest_neighbor_threshold.

from difflib import ndiff
from time import perf_counter

from datasets import load_dataset
from datasketch import MinHash, MinHashLSH
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
from wordllama import WordLlama
# Load the model and dataset
model = StaticModel.from_pretrained("jester6136/multilingual-e5-large-m2v")
ds = load_dataset("jester6136/osint")["train"]
texts = ds['value'][:30000]
ids = ds['id'][:30000]

# Encode texts into embeddings
embedding_matrix = model.encode(texts)
def deduplicate(embedding_matrix: np.ndarray, items: list[str], threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
    """
    Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.

    :param embedding_matrix: The embeddings to deduplicate.
    :param threshold: The similarity threshold to use for deduplication.
    :param batch_size: The batch size to use for similarity computation.
    :return: A tuple containing the deduplicated indices and a dictionary mapping removed indices to original indices.
    """
    vectors = np.zeros((1, model.dim))
    tmp_items = ['empty']
    reach = Reach(vectors, tmp_items)
    reach.delete(tmp_items)
    reach.insert(ids, embedding_matrix)

    # Use a set for deduplicated IDs and track duplicates
    deduplicated_ids = set(items)  # Start with all IDs as deduplicated
    duplicate_to_original_mapping = {}

    results = reach.nearest_neighbor_threshold(
        embedding_matrix,
        threshold=threshold,
        batch_size=batch_size,
        show_progressbar=True
    )

    # Process duplicate detection results
    for i, similar_items in enumerate(tqdm(results)):
        original_id = items[i]  # Get the original ID for the current embedding

        if original_id not in deduplicated_ids:
            continue  # Skip if already marked as a duplicate

        # Collect IDs of similar items (excluding the current ID)
        similar_ids = [item[0] for item in similar_items if item[0] != original_id]

        # Group similar items under the original ID
        if similar_ids:
            if original_id not in duplicate_to_original_mapping:
                duplicate_to_original_mapping[original_id] = []  # Initialize list for this original ID

            # Add all similar items to the group and remove from deduplicated set
            for sim_id in similar_ids:
                if sim_id in deduplicated_ids:
                    deduplicated_ids.remove(sim_id)  # Remove from deduplicated set
                    duplicate_to_original_mapping[original_id].append(sim_id)  # Group under original ID

    # Return deduplicated IDs and the mapping of duplicates
    return list(deduplicated_ids), duplicate_to_original_mapping

# Deduplicate (with a high threshold)
deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, ids, threshold=0.99)
print(f"Number of deduplicated docs: {len(deduplicated_indices)}")
print(f"Number of duplicate_to_original_mapping: {len(duplicate_to_original_mapping)}")

setting in colab:

!pip install datasets model2vec numpy wordllama tqdm datasketch datasets
!git clone https://github.com/stephantul/reach.git
%cd /content/reach

Jester6136 avatar Oct 23 '24 02:10 Jester6136

Could it occur because of the threshold?

Jester6136 avatar Oct 23 '24 03:10 Jester6136

Hey @Jester6136 ,

Yep, this is because of the high threshold. My apologies, this is rather inefficient. @Pringled contributed a fix, which is now in PR, see here #73 . I'm merging that soon, so using that should make it a lot faster.

I'll ping you when I release. Stéphan

stephantul avatar Oct 23 '24 08:10 stephantul

I just merged #73. Could you try the new function? It just returns indices, so that should be a lot faster.

stephantul avatar Oct 23 '24 08:10 stephantul