fastembed
fastembed copied to clipboard
[Bug/Model Request]: Load model files from path, not from huggingface cach directory
What happened?
Unable to test this in my organization, as we do not use hugging face cache folders for models, models are downloaded via git, scanned, then allowed for usage. I see some attempt to use local files via 'local_files_only' kwarg in this PR, but this won't work apparently as I do not have files in the snapshot format. Request loading models from a normal directory, like transformers/sentence-transformers and most other frameworks. Really would like to incorporate this technology in our information retrieval, but this is a show stopper.
What Python version are you on? e.g. python --version
Python 3.10
Version
0.2.7 (Latest)
What os are you seeing the problem on?
Linux
Relevant stack traces and/or logs
No response
Hi @satyaloka93
You want to put the same files as in HF hub to the cache directory and initialize from them. We are not talking about some custom models / files right now, right?
Hi, they are the files from the Qdrant HF repo: https://huggingface.co/Qdrant/Splade_PP_en_v1/tree/main. Our organization pulls them via git, scans, and moves them where we can load them up. When I try to load from that directory, even using it as cache_dir and local_files_only=True, it does not work. I’m assuming because it’s expecting to have a cache structure, versus the normal HF files in your repo.
- 1 Any update about enabling load models file from path ?
- 1 Any update about enabling load models file from path ?
I had claude help me modify fastembed/common/model_management.py to actually load the model from local files, with no HF callouts or requirements for a cache structure. It should still use HF repo if the flag isn't set, but I haven't tested it. Maybe someone could take a look.
import os
import time
import shutil
import tarfile
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from loguru import logger
from tqdm import tqdm
class ModelManagement:
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""Lists the supported models.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
raise NotImplementedError()
@classmethod
def _get_model_description(cls, model_name: str) -> Dict[str, Any]:
"""
Gets the model description from the model_name.
Args:
model_name (str): The name of the model.
raises:
ValueError: If the model_name is not supported.
Returns:
Dict[str, Any]: The model description.
"""
for model in cls.list_supported_models():
if model_name.lower() == model["model"].lower():
return model
raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
@classmethod
def load_from_local(cls, model_name: str, cache_dir: str) -> Path:
"""
Loads a model from a local directory.
Args:
model_name (str): The name of the model.
cache_dir (str): The path to the cache directory.
Returns:
Path: The path to the local model directory.
"""
#model_dir = Path(cache_dir) / model_name
model_dir = Path(cache_dir)
if not model_dir.exists():
raise FileNotFoundError(f"Model directory {model_dir} does not exist.")
required_files = ["config.json", "model.onnx"] # Add or modify as needed
for file in required_files:
if not (model_dir / file).exists():
raise FileNotFoundError(f"Required file {file} not found in {model_dir}")
return model_dir
@classmethod
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
"""
Downloads a file from Google Cloud Storage.
Args:
url (str): The URL to download the file from.
output_path (str): The path to save the downloaded file to.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
str: The path to the downloaded file.
"""
if os.path.exists(output_path):
return output_path
response = requests.get(url, stream=True)
# Handle HTTP errors
if response.status_code == 403:
raise PermissionError(
"Authentication Error: You do not have permission to access this resource. "
"Please check your credentials."
)
# Get the total size of the file
total_size_in_bytes = int(response.headers.get("content-length", 0))
# Warn if the total size is zero
if total_size_in_bytes == 0:
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
show_progress = total_size_in_bytes and show_progress
with tqdm(
total=total_size_in_bytes,
unit="iB",
unit_scale=True,
disable=not show_progress,
) as progress_bar:
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # Filter out keep-alive new chunks
progress_bar.update(len(chunk))
file.write(chunk)
return output_path
@classmethod
def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: Optional[str] = None,
extra_patterns: Optional[List[str]] = None,
**kwargs,
) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically
includes the required model files.
Returns:
Path: The path to the model directory.
"""
allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"preprocessor_config.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)
return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=kwargs.get("local_files_only", False),
)
@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Decompresses a .tar.gz file to a cache directory.
Args:
targz_path (str): Path to the .tar.gz file.
cache_dir (str): Path to the cache directory.
Returns:
cache_dir (str): Path to the cache directory.
"""
# Check if targz_path exists and is a file
if not os.path.isfile(targz_path):
raise ValueError(f"{targz_path} does not exist or is not a file.")
# Check if targz_path is a .tar.gz file
if not targz_path.endswith(".tar.gz"):
raise ValueError(f"{targz_path} is not a .tar.gz file.")
try:
# Open the tar.gz file
with tarfile.open(targz_path, "r:gz") as tar:
# Extract all files into the cache directory
tar.extractall(path=cache_dir)
except tarfile.TarError as e:
# If any error occurs while opening or extracting the tar.gz file,
# delete the cache directory (if it was created in this function)
# and raise the error again
if "tmp" in cache_dir:
shutil.rmtree(cache_dir)
raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")
return cache_dir
@classmethod
def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> Path:
#fast_model_name = f"fast-{model_name.split('/')[-1]}"
cache_tmp_dir = Path(cache_dir) / "tmp"
model_tmp_dir = cache_tmp_dir / fast_model_name
#model_dir = Path(cache_dir) / fast_model_name
model_dir = cache_dir
# check if the model_dir and the model files are both present for macOS
if model_dir.exists() and len(list(model_dir.glob("*"))) > 0:
return model_dir
if model_tmp_dir.exists():
shutil.rmtree(model_tmp_dir)
cache_tmp_dir.mkdir(parents=True, exist_ok=True)
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
if model_tar_gz.exists():
model_tar_gz.unlink()
cls.download_file_from_gcs(
source_url,
output_path=str(model_tar_gz),
)
cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
model_tar_gz.unlink()
# Rename from tmp to final name is atomic
model_tmp_dir.rename(model_dir)
return model_dir
@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path, retries=3, **kwargs) -> Path:
"""
Attempts to load a model from a local directory first, then falls back to online sources if necessary.
Args:
model (Dict[str, Any]): The model description.
cache_dir (str): The path to the cache directory.
retries: (int): The number of times to retry (including the first attempt)
**kwargs: Additional keyword arguments, including 'local_files_only'.
Returns:
Path: The path to the model directory.
"""
model_name = model["model"]
local_files_only = kwargs.get("local_files_only", False)
try:
logger.info(f"Loading from: {cache_dir}")
return cls.load_from_local(model_name, str(cache_dir))
except FileNotFoundError:
if local_files_only:
raise ValueError(f"Model {model_name} not found locally and local_files_only is set to True.")
# If local loading fails and online fetching is allowed, proceed with online sources
hf_source = model.get("sources", {}).get("hf")
url_source = model.get("sources", {}).get("url")
sleep = 3.0
while retries > 0:
retries -= 1
if hf_source:
# ... [rest of the Hugging Face download logic] ...
extra_patterns = [model["model_file"]]
extra_patterns.extend(model.get("additional_files", []))
try:
return Path(
cls.download_files_from_huggingface(
hf_source,
cache_dir=str(cache_dir),
extra_patterns=extra_patterns,
local_files_only=kwargs.get("local_files_only", False),
)
)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e} "
"Falling back to other sources."
)
if url_source:
# ... [rest of the GCS download logic] ...
try:
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
except Exception:
logger.error(f"Could not download model from url: {url_source}")
logger.error(
f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
)
time.sleep(sleep)
sleep *= 3
raise ValueError(f"Failed to load or download model {model_name} after all attempts.")
Okay so to load the model from a specific directory, you need to pass local_dir to the snapshot_download call in line 120 of fastembed/common/model_management.py and by default you can't. You need to modify some of the code to make this possible:
In line 194 of fastembed/text/onnx_embedding.py: replace local_files_only=self._local_files_only with **kwargs in the parameters of self.download_model.
In line 244 of fastembed/common/model_management.py: replace local_files_only=kwargs.get("local_files_only", False) with **kwargs in the parameters of cls.download_from_huggingface.
In line 124 of fastembed/common/model_management.py: replace local_files_only=kwargs.get("local_files_only", False) with **kwargs in the parameters of snapshot_download.
Then when you create the TextEmbedding instance, pass your local model directory as local_dir and it should work. It also works in my case, when calling QdrantClient.set_model.
I get that you guys want to explicitly pass local_files_only=False to snapshot_download if it's not specified by the user, but by default, local_files_only is False anyway so it just doesn't matter. Changing those kwargs.get("local_files_only", False) to **kwargs also enable the user to pass others arguments to the final call.
@joein can you ping your dev team about this? It should be an easy fix.
Also the if the sentence-transformers/all-MiniLM-L6-v2 model is downloaded from url (https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz) (line 18 in fastembed/text/pooled_normalize_embedding.py), the model file is model_optimized.onnx, not model.onnx as it should be.
Without changing any of the existing code you can use the following to do this (if you are using v0.3.x)
from typing import Iterable, List, Optional, Sequence, Union
from pathlib import Path
import numpy as np
from fastembed.common import OnnxProvider
from fastembed.text.onnx_embedding import OnnxTextEmbedding
class OnnxTextEmbeddingLocal(OnnxTextEmbedding):
def __init__(
self,
model_name: str,
model_dir: Path,
model_file : str,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
device_id: Optional[int] = None,
):
self.model_name=model_name
self.cache_dir=""
self._model_dir = model_dir
self._model_file = model_file
self.threads = threads
self.providers = providers
self.cuda = cuda
self.device_ids = device_ids
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None
self.load_onnx_model(
model_dir=self._model_dir,
model_file=self._model_file,
threads=self.threads,
providers=self.providers,
)
model = OnnxTextEmbeddingLocal(model_name=model_name, model_dir=Path(model_path), model_file=model_file)
This functionality has been added as of fastembed v0.6.0 and can be used with specific_model_path argument
An example of using it:
from fastembed import TextEmbedding
emb = TextEmbedding("sentence-transformers/all-MiniLM-l6-v2", specific_model_path="my_model")
print(list(emb.embed('single query')))
where my model structure is:
config.json
model.onnx
special_tokens_map.json
tokenizer.json
tokenizer_config.json
Closing this as completed, feel free to re-open if it does not work for you