Qwen2-VL-7B-Instruct-4bit allocation crashes on larger dimension images
Qwen2-VL-7B-Instruct-4bit crashes on memory allocation errors on images with larger dimensions.
My machine: Apple M3 Pro, 36 GB RAM
Error production below is with an image of dimensions: 1978 × 2806. With an image of dimension 1278 × 816, I do not see the crash.
Ideally I'm wondering if there is a way to do some combination of:
- Compare size of allocation that would be attempted to
maximum allowed buffer size, and if it is larger raise an exception instead of attempting and crashing the process - Shrink the image so that it doesn't try to allocate over the maximum allowed buffer size (even if there is some loss in quality)
Steps to reproduce
- Clone mlx-vlm at commit https://github.com/Blaizzy/mlx-vlm/commit/ae66c0b518e7851337c6ec2f76c637b9c4f3b11c
- Create a virtual environment
python -m venv myenv
source myenv/bin/activate
- Install dependencies
pip install -r requirements.txt
- Download https://huggingface.co/mlx-community/Qwen2-VL-7B-Instruct-4bit
- Run command:
-> % python -m mlx_vlm.generate --image '/Users/matt/Downloads/math-proof.jpg' --temp 0.0 --prompt "what is this" --model "/Users/matt/.cache/lm-studio/models/mlx-community/Qwen2-VL-7B-Instruct-4bit"
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
==========
Image: ['/Users/matt/Downloads/math-proof.jpg']
Prompt: <|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
what is this<|vision_start|><|image_pad|><|vision_end|><|im_end|>
<|im_start|>assistant
libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 51619840000 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes.
[1] 23291 abort python -m mlx_vlm.generate --image '/Users/matt/Downloads/math-proof.jpg' 0.
Image used for this test
Dimensions: 1978 × 2806
Thanks @mattjcly !
I have streamline a way to resize images here #83.
Now, regarding your buffer size. Do you have a suggested default you would like to use?
Or just limitting the size of the image to at most X based on the spec would help?
🙌🙌 Thanks for the quick response @Blaizzy
Or just limitting the size of the image to at most X based on the spec would help?
Yes, something to this effect.
I think the following example client flow could be desirable (if possible of course):
-
I call
generateorgenerate_stepwith image of size1080x1080(full size of the image) -
This raises an exception that lets me know the attempted allocation is too large for the machine, instead of process crashing
Then either:
- 3a. I can then pass in a smaller resize shape of
512x512through the api you've provided. If it succeeds, great. If not, I get another exception and can continue to resize smaller until it succeeds
OR
- 3b. I can pass in a memory size that is my upper bound (i.e., 27GB), and mlx_vlm internally resizes the image so that it will fit
What do you think? I'm not exactly sure how easy it is to get step 2 working. I'd also imagine 3a to be an easier implementation than 3b.
I'm not sure about step 2 either.
Let me check.
Could you try to run the model in a Python file like this:
import sys
def main():
try:
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
# Load the model
model_path = "mlx-community/Qwen2-VL-7B-Instruct-4bit"
model, processor = load(model_path)
config = load_config(model_path)
# Prepare input
image = ["http://images.cocodataset.org/val2017/000000039769.jpg"]
prompt = "Describe this image."
# Apply chat template
formatted_prompt = apply_chat_template(
processor, config, prompt, num_images=len(image)
)
# Generate output
output = generate(model, processor, image, formatted_prompt, verbose=False)
print(output)
# This is a placeholder to simulate the error:
except MemoryError as e:
print(f"Memory allocation error: {e}", file=sys.stderr)
print("The program attempted to allocate more memory than available or allowed.", file=sys.stderr)
print("Consider reducing the size of your input or using a machine with more memory.", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"An unexpected error occurred: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
Unfortunately the above still results in the libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 51619840000 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes. crash :(
From limited research, it seems like the crash comes from a C++ error that isn't caught and propagated to python properly, so generate never even returns.
Interestingly enough though, I was able to avoid the attempted large allocation at https://github.com/Blaizzy/mlx-vlm/blob/d4b562f221f510063bd9add2334d8ee2e3b2192f/mlx_vlm/models/qwen2_vl/qwen2_vl.py#L70-L78 with the following (not polished) batched implementation of _merge_input_ids_with_image_features:
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids, batch_size=100
):
image_token_index = self.config.image_token_index
total_length = input_ids.shape[1]
num_batches = (total_length + batch_size - 1) // batch_size
image_features_start_idx = 0
image_features_end_idx = 0
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, total_length)
# Slicing the current batch
batch_input_ids = input_ids[:, start_idx:end_idx]
batch_image_positions = (batch_input_ids == image_token_index)
# Find indices where image tokens are present
indices_to_replace = np.where(batch_image_positions.flatten())[0]
image_features_end_idx = image_features_start_idx+indices_to_replace.size
# Get corresponding image features for the batch
batch_image_features = image_features[:, image_features_start_idx:image_features_end_idx, :]
batch_inputs_embeds = inputs_embeds[:, start_idx:end_idx, :]
# Ensure the assignment is only done for matching shapes
image_features_idx = 0
for idx in indices_to_replace:
idx=int(idx)
batch_inputs_embeds[:, idx, :] = batch_image_features[:, image_features_idx, :]
image_features_idx += 1
image_features_start_idx = image_features_end_idx
# Update the main embeddings array
inputs_embeds[:, start_idx:end_idx, :] = batch_inputs_embeds
return inputs_embeds
to effectively break down whatever large allocation numpy attempts in lines:
inputs_embeds = np.array(inputs_embeds.astype(mx.float32))
inputs_embeds[image_positions] = image_features
But then a similar allocation attempt occurs in the first mx.async_eval(y) of generate_step and causes the same crash: libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 51619840000 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes.
I think my personal takeaways are:
- Seems like catching this error goes deeper than the code in this repo
- Maybe there's a further possibility for batched
mxevaluation that could reduce the required memory footprint to make this inference attempt possible on my machine, given that there was a way to batch and get around the original crashing line?
Seems totally reasonable to me to leave this scope out of https://github.com/Blaizzy/mlx-vlm/pull/83 though
Interesting!
Thanks for the update.
I have one question. Are you making requests in batch? if so what is the use case?
Are you making requests in batch? if so what is the use case?
Not currently making requests in batch!
Sorry, I could have expressed my thoughts around the "batching" thing more clearly. Its more about the potential for breaking down large allocations into smaller batches, so that a single large request can succeed. My chain of thought:
- On my 36GB RAM M3 Pro,
mlx_vlm.generate(as-is) crashes withQwen2-VL-7B-Instruct-4bitand the provided image because it tries to allocate too much memory within the_merge_input_ids_with_image_featuresfunction - My (rough) alternative batched implementation of
_merge_input_ids_with_image_featuresin https://github.com/Blaizzy/mlx-vlm/issues/79#issuecomment-2411147795 is able to get around this large allocation, so that_merge_input_ids_with_image_featuresno longer crashes. - However, even with my alternative batched implementation in (2), the first mx.async_eval(y) of generate_step crashes because it tries to allocate too much memory for my machine.
- If there was an alternative implementation of
_merge_input_ids_with_image_featuresthat performed the computations in a way that requires less memory, avoiding the crash, I'm wondering if this approach could be extended to the first mx.async_eval(y) of generate_step (and any other places that may attempt large all-at-once allocations) in some way, allowingmlx_vlm.generateto use less memory to the point wheremlx_vlm.generatewithQwen2-VL-7B-Instruct-4bitand the provided image can successfully inference on my 36GB RAM M3 Pro
Does that clarify? Maybe it's not possible for some reason that I do not yet know - just an ideation
Could you share the method / reproducible example you used to indetifiy _merge_input_ids_with_image_features and async_eval as the culprits?
Certainly!
If I add:
import faulthandler
faulthandler.enable()
at the top of generate.py, and run:
python -m mlx_vlm.generate --image '/Users/matt/Downloads/math-proof.jpg' --temp 0.0 --prompt "what is this" --model "/Users/matt/.cache/lm-studio/models/mlx-community/Qwen2-VL-7B-Instruct-4bit"
I see: With mlx-vlm https://github.com/Blaizzy/mlx-vlm/commit/ae66c0b518e7851337c6ec2f76c637b9c4f3b11c as-is
libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 51619840000 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes.
Fatal Python error: Aborted
Thread 0x00000001e136dc40 (most recent call first):
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/models/qwen2_vl/qwen2_vl.py", line 81 in _merge_input_ids_with_image_features
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/models/qwen2_vl/qwen2_vl.py", line 68 in get_input_embeddings
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/models/qwen2_vl/qwen2_vl.py", line 97 in __call__
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/utils.py", line 888 in generate_step
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/utils.py", line 1021 in generate
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/generate.py", line 84 in main
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/generate.py", line 99 in <module>
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86 in _run_code
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196 in _run_module_as_main
With the edition of my rough batched implementation of _merge_input_ids_with_image_features AND changing https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/qwen2_vl/language.py#L185 to use mx.tile instead of np.tile (or else it crashes there too)
libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 51619840000 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes.
Fatal Python error: Aborted
Thread 0x00000001e136dc40 (most recent call first):
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/utils.py", line 891 in generate_step
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/utils.py", line 1021 in generate
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/generate.py", line 85 in main
File "/Users/matt/Workspace/mlx-vlm/mlx_vlm/generate.py", line 100 in <module>
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86 in _run_code
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196 in _run_module_as_main
(Please feel free to let me know if you know of better ways to trace)
Seems like there is some relationship between large allocations and mlx->np translations, but then once it goes into mx.async_eval(y) territory I'm afraid I'm no longer sure how to make modifications (at the moment)
I have the same issue. When I use a picture with the resolution as 4032*3024, it gave me the following error:
libc++abi: terminating due to uncaught exception of type std::runtime_error: Attempting to allocate 247669456896 bytes which is greater than the maximum allowed buffer size of 21743271936 bytes. zsh: abort python qwen2-vl.py
21743271936 bytes is 247GB, I don't think processing this single image will consume so much memory. There must be something wrong with the calculation.
I am using the mlx-vlm version 0.1.0.
Ohh yeah, I found the issue here #84.
Upgrade your MLX version to the latest and let me know if it solves it.
pip install -U mlx
And yeah, that large of an image is not a good idea.
Try passing --resize-shape to generate.py and use a smaller image
It will be faster and less resource intensive
And yeah, that large of an image is not a good idea.
Try passing
--resize-shapeto generate.py and use a smaller imageIt will be faster and less resource intensive
The mlx version is 0.19.0, and it is the latest one as of this writing. The --resize-shape works perfectly, thank you, @Blaizzy
My pleasure!