mlx-embeddings icon indicating copy to clipboard operation
mlx-embeddings copied to clipboard

Make it possible to load siglip models from local files

Open maxlund opened this issue 8 months ago • 5 comments

Read img and patch size if supplied in model_config arg.

What is the context for the regex parsing of the repo name, the img/patch size isn't always correct in the config.json file I guess? Anyway this small change makes it possible to load a local model while being offline:

local_model_dir_path = "/Users/maxlund/mlx-models/mlx-siglip-large-384"
model, processor = load(
    path_or_hf_repo=local_model_dir_path,
    model_config={"image_size": 384, "patch_size": 16}
)

FWIW the image and patch size seemed to be correct for both mlx-community/siglip-large-patch16-384 and mlx-community/siglip-so400m-patch14-384 via downloaded config.json in the hf repos

maxlund avatar Apr 05 '25 08:04 maxlund

Hey @maxlund

Thanks for the PR!

The context is that for certain models they don't supply the patch and img size on the config.json, I can only find it in the name.

Besides torch, which I will address today. Are you having trouble with any Siglip model in particular?

Blaizzy avatar Apr 05 '25 08:04 Blaizzy

Hey no problem, messing around with it now and running into some issues. This seems to work fine and gives me embeddings for both text and images. But I want to extract them in separate steps of my pipeline.

from mlx_embeddings.utils import load, generate
import requests
from PIL import Image

# Load vision model and processor
model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})

# Load multiple images
image_urls = [
    "./images/cats.jpg",  # cats
    "./images/desktop_setup.png"  # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
          image_urls]

# Text descriptions
texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"]

outputs = generate(model, processor, texts=texts, images=images)

This:

outputs = generate(model, processor, texts=texts, images=None)

gives me:

(<class 'AttributeError'>, AttributeError("'SiglipProcessor' object has no attribute 'batch_encode_plus'"), <traceback object at 0x136cc5f00>)

get_text_features and get_image_features using

inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)

but ran into other issues.. just about to have lunch back in a bit and I can give more details

maxlund avatar Apr 05 '25 11:04 maxlund

Okay some progress..

import mlx.core as mx
from mlx_embeddings.utils import load, generate
import requests
from PIL import Image

model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})
image_urls = [
    "./images/cats.jpg",  # cats
    "./images/desktop_setup.png"  # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
          image_urls]

texts = "a sentence"
inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)
print(f"{input_ids.shape=}")
print(f"{pixel_values.shape=}")
try:
    text_embs = model.get_text_features(input_ids=input_ids)
    print(f"{type(text_embs)}")
    print(f"{type(text_embs.shape)}")
    print(text_embs)
except Exception as e:
    print(f"model.get_text_features(input_ids=input_ids) error: {e}")

try:
    img_embs = model.get_image_features(pixel_values=pixel_values)
    print(f"{type(img_embs)}")
    print(f"{type(img_embs.shape)}")
except Exception as e:
    print(f"model.get_image_features(pixel_values=pixel_values) error: {e}")
    
#input_ids.shape=(1, 64)
#pixel_values.shape=(2, 3, 384, 384)
#<class 'mlx.core.array'>
#<class 'tuple'>
#array([[-0.580078, -0.153076, -0.0585327, ..., 0.469727, 0.0390015, 0.192871]], dtype=float16)
#model.get_image_features(pixel_values=pixel_values) error: 'ModelArgs' object has no attribute 'use_return_dict'
    ```
    

maxlund avatar Apr 05 '25 12:04 maxlund

 img_embs = model.get_image_features(pixel_values=pixel_values, return_dict=False)

# model.get_image_features(pixel_values=pixel_values) error: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,3,384,384) and weight: (1024,16,16,3)

maxlund avatar Apr 05 '25 12:04 maxlund

Okay this did the trick I think

    dtype = (
        model.vision_model.vision_model.embeddings.patch_embedding.weight.dtype
    )
    img_embs = model.get_image_features(pixel_values=pixel_values.transpose(0, 2, 3, 1).astype(dtype), return_dict=False)
    print(f"{type(img_embs)=}")
    print(f"{img_embs.shape=}")

# type(img_embs)=<class 'mlx.core.array'>
# img_embs.shape=(2, 1024)

Might be able to get some benchmarks soon if no other road bumps

maxlund avatar Apr 05 '25 12:04 maxlund