Make it possible to load siglip models from local files
Read img and patch size if supplied in model_config arg.
What is the context for the regex parsing of the repo name, the img/patch size isn't always correct in the config.json file I guess? Anyway this small change makes it possible to load a local model while being offline:
local_model_dir_path = "/Users/maxlund/mlx-models/mlx-siglip-large-384"
model, processor = load(
path_or_hf_repo=local_model_dir_path,
model_config={"image_size": 384, "patch_size": 16}
)
FWIW the image and patch size seemed to be correct for both mlx-community/siglip-large-patch16-384 and mlx-community/siglip-so400m-patch14-384 via downloaded config.json in the hf repos
Hey @maxlund
Thanks for the PR!
The context is that for certain models they don't supply the patch and img size on the config.json, I can only find it in the name.
Besides torch, which I will address today. Are you having trouble with any Siglip model in particular?
Hey no problem, messing around with it now and running into some issues. This seems to work fine and gives me embeddings for both text and images. But I want to extract them in separate steps of my pipeline.
from mlx_embeddings.utils import load, generate
import requests
from PIL import Image
# Load vision model and processor
model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})
# Load multiple images
image_urls = [
"./images/cats.jpg", # cats
"./images/desktop_setup.png" # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
image_urls]
# Text descriptions
texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"]
outputs = generate(model, processor, texts=texts, images=images)
This:
outputs = generate(model, processor, texts=texts, images=None)
gives me:
(<class 'AttributeError'>, AttributeError("'SiglipProcessor' object has no attribute 'batch_encode_plus'"), <traceback object at 0x136cc5f00>)
get_text_features and get_image_features using
inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)
but ran into other issues.. just about to have lunch back in a bit and I can give more details
Okay some progress..
import mlx.core as mx
from mlx_embeddings.utils import load, generate
import requests
from PIL import Image
model, processor = load("mlx-community/siglip-large-patch16-384", {"num_classes": 0})
image_urls = [
"./images/cats.jpg", # cats
"./images/desktop_setup.png" # desktop setup
]
images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in
image_urls]
texts = "a sentence"
inputs_text = processor(text=texts, images=None, padding="max_length", return_tensors="pt")
inputs_imgs = processor(text=None, images=images, return_tensors="pt")
input_ids = mx.array(inputs_text.input_ids)
pixel_values = mx.array(inputs_imgs.pixel_values)
print(f"{input_ids.shape=}")
print(f"{pixel_values.shape=}")
try:
text_embs = model.get_text_features(input_ids=input_ids)
print(f"{type(text_embs)}")
print(f"{type(text_embs.shape)}")
print(text_embs)
except Exception as e:
print(f"model.get_text_features(input_ids=input_ids) error: {e}")
try:
img_embs = model.get_image_features(pixel_values=pixel_values)
print(f"{type(img_embs)}")
print(f"{type(img_embs.shape)}")
except Exception as e:
print(f"model.get_image_features(pixel_values=pixel_values) error: {e}")
#input_ids.shape=(1, 64)
#pixel_values.shape=(2, 3, 384, 384)
#<class 'mlx.core.array'>
#<class 'tuple'>
#array([[-0.580078, -0.153076, -0.0585327, ..., 0.469727, 0.0390015, 0.192871]], dtype=float16)
#model.get_image_features(pixel_values=pixel_values) error: 'ModelArgs' object has no attribute 'use_return_dict'
```
img_embs = model.get_image_features(pixel_values=pixel_values, return_dict=False)
# model.get_image_features(pixel_values=pixel_values) error: [conv] Expect the input channels in the input and weight array to match but got shapes - input: (2,3,384,384) and weight: (1024,16,16,3)
Okay this did the trick I think
dtype = (
model.vision_model.vision_model.embeddings.patch_embedding.weight.dtype
)
img_embs = model.get_image_features(pixel_values=pixel_values.transpose(0, 2, 3, 1).astype(dtype), return_dict=False)
print(f"{type(img_embs)=}")
print(f"{img_embs.shape=}")
# type(img_embs)=<class 'mlx.core.array'>
# img_embs.shape=(2, 1024)
Might be able to get some benchmarks soon if no other road bumps