fastembed icon indicating copy to clipboard operation
fastembed copied to clipboard

[Bug/Model Request]: Load model files from path, not from huggingface cach directory

Open satyaloka93 opened this issue 1 year ago • 4 comments

What happened?

Unable to test this in my organization, as we do not use hugging face cache folders for models, models are downloaded via git, scanned, then allowed for usage. I see some attempt to use local files via 'local_files_only' kwarg in this PR, but this won't work apparently as I do not have files in the snapshot format. Request loading models from a normal directory, like transformers/sentence-transformers and most other frameworks. Really would like to incorporate this technology in our information retrieval, but this is a show stopper.

What Python version are you on? e.g. python --version

Python 3.10

Version

0.2.7 (Latest)

What os are you seeing the problem on?

Linux

Relevant stack traces and/or logs

No response

satyaloka93 avatar Aug 10 '24 14:08 satyaloka93

Hi @satyaloka93

You want to put the same files as in HF hub to the cache directory and initialize from them. We are not talking about some custom models / files right now, right?

joein avatar Aug 11 '24 21:08 joein

Hi, they are the files from the Qdrant HF repo: https://huggingface.co/Qdrant/Splade_PP_en_v1/tree/main. Our organization pulls them via git, scans, and moves them where we can load them up. When I try to load from that directory, even using it as cache_dir and local_files_only=True, it does not work. I’m assuming because it’s expecting to have a cache structure, versus the normal HF files in your repo.

satyaloka93 avatar Aug 11 '24 23:08 satyaloka93

  • 1 Any update about enabling load models file from path ?

snassimr avatar Sep 15 '24 10:09 snassimr

  • 1 Any update about enabling load models file from path ?

I had claude help me modify fastembed/common/model_management.py to actually load the model from local files, with no HF callouts or requirements for a cache structure. It should still use HF repo if the flag isn't set, but I haven't tested it. Maybe someone could take a look.

import os
import time
import shutil
import tarfile
from pathlib import Path
from typing import Any, Dict, List, Optional

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from loguru import logger
from tqdm import tqdm


class ModelManagement:
    @classmethod
    def list_supported_models(cls) -> List[Dict[str, Any]]:
        """Lists the supported models.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries containing the model information.
        """
        raise NotImplementedError()

    @classmethod
    def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
        """
        Gets the model description from the model_name.

        Args:
            model_name (str): The name of the model.

        raises:
            ValueError: If the model_name is not supported.

        Returns:
            Dict[str, Any]: The model description.
        """
        for model in cls.list_supported_models():
            if model_name.lower() == model["model"].lower():
                return model

        raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")

    @classmethod
    def load_from_local(cls, model_name: str, cache_dir: str) -> Path:
        """
        Loads a model from a local directory.

        Args:
            model_name (str): The name of the model.
            cache_dir (str): The path to the cache directory.

        Returns:
            Path: The path to the local model directory.
        """
        #model_dir = Path(cache_dir) / model_name
        model_dir = Path(cache_dir)
        if not model_dir.exists():
            raise FileNotFoundError(f"Model directory {model_dir} does not exist.")

        required_files = ["config.json", "model.onnx"]  # Add or modify as needed
        for file in required_files:
            if not (model_dir / file).exists():
                raise FileNotFoundError(f"Required file {file} not found in {model_dir}")

        return model_dir

    @classmethod
    def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
        """
        Downloads a file from Google Cloud Storage.

        Args:
            url (str): The URL to download the file from.
            output_path (str): The path to save the downloaded file to.
            show_progress (bool, optional): Whether to show a progress bar. Defaults to True.

        Returns:
            str: The path to the downloaded file.
        """

        if os.path.exists(output_path):
            return output_path
        response = requests.get(url, stream=True)

        # Handle HTTP errors
        if response.status_code == 403:
            raise PermissionError(
                "Authentication Error: You do not have permission to access this resource. "
                "Please check your credentials."
            )

        # Get the total size of the file
        total_size_in_bytes = int(response.headers.get("content-length", 0))

        # Warn if the total size is zero
        if total_size_in_bytes == 0:
            print(f"Warning: Content-length header is missing or zero in the response from {url}.")

        show_progress = total_size_in_bytes and show_progress

        with tqdm(
            total=total_size_in_bytes,
            unit="iB",
            unit_scale=True,
            disable=not show_progress,
        ) as progress_bar:
            with open(output_path, "wb") as file:
                for chunk in response.iter_content(chunk_size=1024):
                    if chunk:  # Filter out keep-alive new chunks
                        progress_bar.update(len(chunk))
                        file.write(chunk)
        return output_path

    @classmethod
    def download_files_from_huggingface(
        cls,
        hf_source_repo: str,
        cache_dir: Optional[str] = None,
        extra_patterns: Optional[List[str]] = None,
        **kwargs,
    ) -> str:
        """
        Downloads a model from HuggingFace Hub.
        Args:
            hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
            cache_dir (Optional[str]): The path to the cache directory.
            extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically
                includes the required model files.
        Returns:
            Path: The path to the model directory.
        """
        allow_patterns = [
            "config.json",
            "tokenizer.json",
            "tokenizer_config.json",
            "special_tokens_map.json",
            "preprocessor_config.json",
        ]
        if extra_patterns is not None:
            allow_patterns.extend(extra_patterns)

        return snapshot_download(
            repo_id=hf_source_repo,
            allow_patterns=allow_patterns,
            cache_dir=cache_dir,
            local_files_only=kwargs.get("local_files_only", False),
        )

    @classmethod
    def decompress_to_cache(cls, targz_path: str, cache_dir: str):
        """
        Decompresses a .tar.gz file to a cache directory.

        Args:
            targz_path (str): Path to the .tar.gz file.
            cache_dir (str): Path to the cache directory.

        Returns:
            cache_dir (str): Path to the cache directory.
        """
        # Check if targz_path exists and is a file
        if not os.path.isfile(targz_path):
            raise ValueError(f"{targz_path} does not exist or is not a file.")

        # Check if targz_path is a .tar.gz file
        if not targz_path.endswith(".tar.gz"):
            raise ValueError(f"{targz_path} is not a .tar.gz file.")

        try:
            # Open the tar.gz file
            with tarfile.open(targz_path, "r:gz") as tar:
                # Extract all files into the cache directory
                tar.extractall(path=cache_dir)
        except tarfile.TarError as e:
            # If any error occurs while opening or extracting the tar.gz file,
            # delete the cache directory (if it was created in this function)
            # and raise the error again
            if "tmp" in cache_dir:
                shutil.rmtree(cache_dir)
            raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")

        return cache_dir

    @classmethod
    def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
        #fast_model_name = f"fast-{model_name.split('/')[-1]}"

        cache_tmp_dir = Path(cache_dir) / "tmp"
        model_tmp_dir = cache_tmp_dir / fast_model_name
        #model_dir = Path(cache_dir) / fast_model_name
        model_dir = cache_dir

        # check if the model_dir and the model files are both present for macOS
        if model_dir.exists() and len(list(model_dir.glob("*"))) > 0:
            return model_dir

        if model_tmp_dir.exists():
            shutil.rmtree(model_tmp_dir)

        cache_tmp_dir.mkdir(parents=True, exist_ok=True)

        model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"

        if model_tar_gz.exists():
            model_tar_gz.unlink()

        cls.download_file_from_gcs(
            source_url,
            output_path=str(model_tar_gz),
        )

        cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
        assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"

        model_tar_gz.unlink()
        # Rename from tmp to final name is atomic
        model_tmp_dir.rename(model_dir)

        return model_dir

    @classmethod
    def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path:
        """
        Attempts to load a model from a local directory first, then falls back to online sources if necessary.

        Args:
            model (Dict[str, Any]): The model description.
            cache_dir (str): The path to the cache directory.
            retries: (int): The number of times to retry (including the first attempt)
            **kwargs: Additional keyword arguments, including 'local_files_only'.
    
        Returns:
            Path: The path to the model directory.
        """
        model_name = model["model"]
        local_files_only = kwargs.get("local_files_only", False)


        try:
            logger.info(f"Loading from: {cache_dir}")
            return cls.load_from_local(model_name, str(cache_dir))
        except FileNotFoundError:
            if local_files_only:
                raise ValueError(f"Model {model_name} not found locally and local_files_only is set to True.")

        # If local loading fails and online fetching is allowed, proceed with online sources
        hf_source = model.get("sources", {}).get("hf")
        url_source = model.get("sources", {}).get("url")

        sleep = 3.0
        while retries > 0:
            retries -= 1

            if hf_source:
                # ... [rest of the Hugging Face download logic] ...
                extra_patterns = [model["model_file"]]
                extra_patterns.extend(model.get("additional_files", []))

                try:
                    return Path(
                        cls.download_files_from_huggingface(
                            hf_source,
                            cache_dir=str(cache_dir),
                            extra_patterns=extra_patterns,
                            local_files_only=kwargs.get("local_files_only", False),
                        )
                    )    
                except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
                    logger.error(
                        f"Could not download model from HuggingFace: {e} "
                        "Falling back to other sources."
                    )
            if url_source:
                # ... [rest of the GCS download logic] ...
                try:
                    return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
                except Exception:
                    logger.error(f"Could not download model from url: {url_source}")
            logger.error(
                f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
            ) 
            time.sleep(sleep)
            sleep *= 3

        raise ValueError(f"Failed to load or download model {model_name} after all attempts.")

satyaloka93 avatar Sep 15 '24 15:09 satyaloka93

Okay so to load the model from a specific directory, you need to pass local_dir to the snapshot_download call in line 120 of fastembed/common/model_management.py and by default you can't. You need to modify some of the code to make this possible:

In line 194 of fastembed/text/onnx_embedding.py: replace local_files_only=self._local_files_only with **kwargs in the parameters of self.download_model.

In line 244 of fastembed/common/model_management.py: replace local_files_only=kwargs.get("local_files_only", False) with **kwargs in the parameters of cls.download_from_huggingface.

In line 124 of fastembed/common/model_management.py: replace local_files_only=kwargs.get("local_files_only", False) with **kwargs in the parameters of snapshot_download.

Then when you create the TextEmbedding instance, pass your local model directory as local_dir and it should work. It also works in my case, when calling QdrantClient.set_model.

I get that you guys want to explicitly pass local_files_only=False to snapshot_download if it's not specified by the user, but by default, local_files_only is False anyway so it just doesn't matter. Changing those kwargs.get("local_files_only", False) to **kwargs also enable the user to pass others arguments to the final call.

@joein can you ping your dev team about this? It should be an easy fix.

Also the if the sentence-transformers/all-MiniLM-L6-v2 model is downloaded from url (https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz) (line 18 in fastembed/text/pooled_normalize_embedding.py), the model file is model_optimized.onnx, not model.onnx as it should be.

TrickyWhiteCat avatar Oct 16 '24 08:10 TrickyWhiteCat

Without changing any of the existing code you can use the following to do this (if you are using v0.3.x)

from typing import Iterable, List, Optional, Sequence, Union
from pathlib import Path

import numpy as np

from fastembed.common import OnnxProvider
from fastembed.text.onnx_embedding import OnnxTextEmbedding

class OnnxTextEmbeddingLocal(OnnxTextEmbedding):
    def __init__(
        self,
        model_name: str,
        model_dir: Path,
        model_file : str,
        threads: Optional[int] = None,
        providers: Optional[Sequence[OnnxProvider]] = None,
        cuda: bool = False,
        device_ids: Optional[List[int]] = None,
        device_id: Optional[int] = None,
    ):
        self.model_name=model_name
        self.cache_dir=""
        self._model_dir = model_dir
        self._model_file = model_file
        self.threads = threads
        self.providers = providers
        self.cuda = cuda
        self.device_ids = device_ids

        if device_id is not None:
            self.device_id = device_id
        elif self.device_ids is not None:
            self.device_id = self.device_ids[0]
        else:
            self.device_id = None

        self.load_onnx_model(
            model_dir=self._model_dir,
            model_file=self._model_file,
            threads=self.threads,
            providers=self.providers,
        )
model = OnnxTextEmbeddingLocal(model_name=model_name, model_dir=Path(model_path), model_file=model_file)

LaurentGoderre avatar Nov 13 '24 15:11 LaurentGoderre

This functionality has been added as of fastembed v0.6.0 and can be used with specific_model_path argument

An example of using it:

from fastembed import TextEmbedding

emb = TextEmbedding("sentence-transformers/all-MiniLM-l6-v2", specific_model_path="my_model")
print(list(emb.embed('single query')))

where my model structure is:

config.json
model.onnx
special_tokens_map.json
tokenizer.json
tokenizer_config.json

Closing this as completed, feel free to re-open if it does not work for you

joein avatar Mar 02 '25 22:03 joein