chroma icon indicating copy to clipboard operation
chroma copied to clipboard

Adds device (cuda) option for TransformerEmbeddingFunction

Open spartanhaden opened this issue 2 years ago • 2 comments

Description of changes

Summarize the changes made by this PR.

  • Improvements & Bug fixes
    • N/A
  • New functionality
    • Adds option to pass compute device (e.g., "cpu", "cuda", "cuda:1") to the SentenceTransformerEmbeddingFunction class in chromadb/utils/embedding_functions.py.

Test plan

How are these changes tested? I ran with no device specified and it used the cpu as before. I also ran with cuda and again with cuda:1 and it ran on my first and then second GPU respecitvely

Documentation Changes

Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the docs repository?

I think all the docstrings should be fine but it might be worth noting you can use the gpu here like it mentions for the instructor models further down on the page

spartanhaden avatar Apr 12 '23 03:04 spartanhaden

@spartanhaden this is a nice PR, thanks!

Do you think we should proactively add any other inputs?

https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer

jeffchuber avatar Apr 12 '23 03:04 jeffchuber

I think that should be fine. I ended up using the InstructorEmbeddingFunction instead anyway though since the model seems to be better according to some benchmarks you refer to here. Could auto-detect GPU by default actually with something like this that still lets the user pick if they want. torch should be loaded as well within this function as it's imported by SentenceTransformer here so i don't think you'd need to add it as a dependency either

from typing import Optional
def __init__(self, model_name: str = "all-MiniLM-L6-v2", device: Optional[str] = None):
        try:
            from sentence_transformers import SentenceTransformer
        except ImportError:
            raise ValueError(
                "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
            )
        # Automatically use CUDA if it's available
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self._model = SentenceTransformer(model_name, device=device)

spartanhaden avatar Apr 12 '23 04:04 spartanhaden

Note to self: resolve merge conflict and then merge

jeffchuber avatar May 11 '23 17:05 jeffchuber