mlx-vlm icon indicating copy to clipboard operation
mlx-vlm copied to clipboard

Models should not need to be re-loaded between back-to-back prompts

Open neilmehta24 opened this issue 1 year ago • 2 comments

When using mlx-vlm through the python API, we need to call mlx_vlm.utils.load for before every request to stream_generate. We need to do this because we see Exceptions being raised when we try to call stream_generate without reloading the model. We are seeing Exceptions across multiple VLM architectures when not re-loaded. The exceptions are different from one another between architectures, there is usually some state that is not being reset

neilmehta24 avatar Feb 21 '25 17:02 neilmehta24

Hey @neilmehta24

Thanks for reporting this!

Could you share a reproducible example?

Blaizzy avatar Feb 24 '25 11:02 Blaizzy

The exceptions are different from one another between architectures, there is usually some state that is not being reset

I usually use stream_generate in the manner you refer to in dev.

Image

So I suspect the KV cache or the increase in input size (i.e., number of images). The former is easy to fix, the latter has limitations because not all models support multiple images and/or multi-turn conversation.

Blaizzy avatar Feb 24 '25 11:02 Blaizzy

My LLM made this suggestion after digging around - it's helping me out and i can run multiple queries without reloading the model:

MLX-VLM Multi-Query Fix

Issue: MLX-VLM fails after the first generate() call with error: ValueError: input operand has more dimensions than allowed by the axis remapping

Root Cause: In mlx_vlm/models/llava_next/llava_next.py, the get_input_embeddings method modifies self.image_newline in-place:

Line 107-108: This modifies the array from 1D to 3D

self.image_newline = np.array(self.image_newline)[None, None, :] self.image_newline = np.broadcast_to(self.image_newline, image_features.shape)

On subsequent calls, self.image_newline is already 3D, causing dimension mismatch errors.

Fix: Save the original 1D state and restore it before each use:

def fixed_get_input_embeddings(self, input_ids, pixel_values): # Save original state on first call if hasattr(self, 'image_newline') and not hasattr(self, '_original_image_newline'): self._original_image_newline = np.array(self.image_newline).copy()

  # Restore original state before each use
  if hasattr(self, '_original_image_newline'):
      self.image_newline = self._original_image_newline.copy()

  # Continue with original method
  return original_get_input_embeddings(self, input_ids, pixel_values)

Result: Multiple generate() calls now work without model reloading, improving performance by ~10x for batch processing.

Suggested permanent fix: Instead of modifying self.image_newline in-place, create a local variable:

Better approach - don't modify self.image_newline

image_newline_expanded = np.array(self.image_newline)[None, None, :] image_newline_broadcasted = np.broadcast_to(image_newline_expanded, image_features.shape)

billbarnes avatar Jun 16 '25 04:06 billbarnes