mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Poor Speculative Decoding Performance on M2 Ultra

Open mattjcly opened this issue 10 months ago • 4 comments

Speculative decoding does not seem to improve generation speed as expected on M2 Ultra Mac Studio, 128GB.

Main model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit Draft model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit or https://huggingface.co/mlx-community/Qwen2.5-0.5B-Instruct-4bit

Prompt: "Write a quicksort algorithm" Without spec decoding: 29.803 tokens-per-sec With spec decoding: 29.051 tokens-per-sec Qwen2.5-Coder-0.5B-Instruct-MLX-4Bit alone: 284.647 tokens-per-sec

In the same situation on an M3 Pro, 32GB of ram, we see tremendous speedup (~7tok/sec -> ~16tok/sec)

Full logs:

Click to expand

(venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-4bit -m 1000 --temp 0

========== Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]  # Choose the middle element as the pivot
        left = [x for x in arr if x < pivot]  # Elements less than the pivot
        middle = [x for x in arr if x == pivot]  # Elements equal to the pivot
        right = [x for x in arr if x > pivot]  # Elements greater than the pivot
        return quicksort(left) + middle + quicksort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)

Explanation:

  1. Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
  2. Pivot Selection: We choose the middle element of the array as the pivot.
  3. Partitioning: We create three lists:
    • left for elements less than the pivot.
    • middle for elements equal to the pivot.
    • right for elements greater than the pivot.
  4. Recursive Sorting: We recursively apply the quicksort function to the left and right lists and concatenate the results with the middle list.

This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:

def quicksort_inplace(arr, low, high):
    if low < high:
        pi = partition(arr, low, high)  # Partitioning index
        quicksort_inplace(arr, low, pi - 1)  # Sort left part
        quicksort_inplace(arr, pi + 1, high)  # Sort right part

def partition(arr, low, high):
    pivot = arr[high]  # Choose the last element as the pivot
    i = low - 1  # Index of smaller element
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]  # Swap
    arr[i + 1], arr[high] = arr[high], arr[i + 1]  # Swap pivot element
    return i + 1

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)

In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.

========== Prompt: 34 tokens, 71.386 tokens-per-sec Generation: 709 tokens, 29.051 tokens-per-sec Peak memory: 18.932 GB (venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

========== Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:

def quicksort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]  # Choose the middle element as the pivot
        left = [x for x in arr if x < pivot]  # Elements less than the pivot
        middle = [x for x in arr if x == pivot]  # Elements equal to the pivot
        right = [x for x in arr if x > pivot]  # Elements greater than the pivot
        return quicksort(left) + middle + quicksort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)

Explanation:

  1. Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
  2. Pivot Selection: We choose the middle element of the array as the pivot.
  3. Partitioning: We create three lists:
    • left for elements less than the pivot.
    • middle for elements equal to the pivot.
    • right for elements greater than the pivot.
  4. Recursive Sorting: We recursively apply the quicksort function to the left and right lists and concatenate the results with the middle list.

This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:

def quicksort_inplace(arr, low, high):
    if low < high:
        pi = partition(arr, low, high)  # Partitioning index
        quicksort_inplace(arr, low, pi - 1)  # Sort left part
        quicksort_inplace(arr, pi + 1, high)  # Sort right part

def partition(arr, low, high):
    pivot = arr[high]  # Choose the last element as the pivot
    i = low - 1  # Index of smaller element
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]  # Swap
    arr[i + 1], arr[high] = arr[high], arr[i + 1]  # Swap pivot element
    return i + 1

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)

In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.

========== Prompt: 34 tokens, 75.790 tokens-per-sec Generation: 709 tokens, 29.803 tokens-per-sec Peak memory: 18.643 GB (venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

========== Sure, here's a simple implementation of the quicksort algorithm in Python:

def quicksort(arr):
    # Base case: if the array is empty or has one element, it's already sorted
    if len(arr) <= 1:
        return arr
    
    # Choose a pivot element
    pivot = arr[len(arr) // 2]
    
    # Partition the array into two sub-arrays: elements less than or equal to the pivot and elements greater than or equal to the pivot
    less_than_pivot = [x for x in arr if x <= pivot]
    greater_than_pivot = [x for x in arr if x > pivot]
    
    # Recursively sort the two sub-arrays
    quicksort(less_than_pivot)
    quicksort(greater_than_pivot)
    
    # Merge the sorted sub-arrays
    return less_than_pivot + [pivot] + greater_than_pivot

This function takes an array as input and returns a new array sorted in ascending order. It uses a simple partitioning strategy: it selects a pivot element and partitions the array into two sub-arrays: all elements less than or equal to the pivot and all elements greater than or equal to the pivot. The function then recursively sorts the two sub-arrays and merges them to form the sorted array.

========== Prompt: 34 tokens, 683.032 tokens-per-sec Generation: 276 tokens, 284.647 tokens-per-sec Peak memory: 0.299 GB

mattjcly avatar Feb 12 '25 16:02 mattjcly

I ran a couple benchmarks on M3 max and M2 Ultra. As expected we get much better scaling of the big model w.r.t. sequence length on M3 max than M2 Ultra. This probably explains why we are seeing little to no performance improvement on M2 Ultra.

In the figure below you see time as you increase sequence length. You want the line to be as flat as possible for the best possible speedup with speculative generation.

Image

awni avatar Feb 12 '25 17:02 awni

On the optimistic side, from conversations @angeloskath and @barronalex there is likely room to improve small batch qmm which should help this use case considerably.

awni avatar Feb 12 '25 17:02 awni

Does this seem right about 49% speed up for M4 Max 16/40 and 22% for M3 Ultra 28/60

M4 Max

mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

Prompt: 34 tokens, 39.482 tokens-per-sec Generation: 642 tokens, 7.919 tokens-per-sec Peak memory: 59.349 GB

mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-8bit -m 1000 --temp 0

Prompt: 34 tokens, 43.196 tokens-per-sec Generation: 642 tokens, 11.783 tokens-per-sec Peak memory: 59.887 GB

M3 Ultra

mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0

Prompt: 34 tokens, 45.241 tokens-per-sec Generation: 642 tokens, 11.346 tokens-per-sec Peak memory: 59.349 GB

mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-8bit -m 1000 --temp 0

Prompt: 34 tokens, 48.431 tokens-per-sec Generation: 642 tokens, 13.815 tokens-per-sec Peak memory: 59.887 GB

dave-fl avatar Apr 04 '25 18:04 dave-fl

Same issue here as well on the M2 Max where rather than seeing a speedup, we see a degradation in throughput:

QwQ-32B-6bit

> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/QwQ-32B-6bit --prompt "How many r's are in the word 'Strawberry'?"
==========
Okay, so I need to figure out how many times the letter 'r' appears in the word 'Strawberry'. Let me start by writing down the word and looking at each letter one by one. 

First, I'll spell out 'Strawberry' letter by letter to make sure I don't miss any. Let's see:

S - T - R - A - W - B - E - R - R - Y.

Wait, is that right? Let me check again
==========
Prompt: 23 tokens, 35.913 tokens-per-sec
Generation: 100 tokens, 9.643 tokens-per-sec
Peak memory: 26.710 GB

Qwen2.5-Coder-0.5B-Instruct-MLX-6bits

> mlx_lm.generate --model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits --prompt "How many r's are in the word 'Strawberry'?"
==========
The word 'Strawberry' contains 3 r's.
==========
Prompt: 42 tokens, 996.588 tokens-per-sec
Generation: 14 tokens, 127.092 tokens-per-sec
Peak memory: 0.427 GB

QwQ-32B-6bit + Qwen2.5-Coder-0.5B-Instruct-MLX-6bits

> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/QwQ-32B-6bit --prompt "How many r's are in the word 'Strawberry'?" --draft-model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
==========
Okay, so I need to figure out how many times the letter 'r' appears in the word 'Strawberry'. Let me start by writing down the word and looking at each letter one by one. 

First, I'll spell out 'Strawberry' letter by letter. Let's see: S, T, R, A, W, B, E, R, R, Y. Wait, is that right? Hmm, maybe I should double-check the spelling. Sometimes people
==========
Prompt: 23 tokens, 52.834 tokens-per-sec
Generation: 100 tokens, 9.138 tokens-per-sec
Peak memory: 27.121 GB

Similar results show for Qwen 2.5 32b Coder 6 bit:

> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/Qwen2.5-Coder-32B-Instruct-6bit --prompt "How many r's are in the word 'Strawberry'?"
==========
The word "strawberry" contains three 'r's.
==========
Prompt: 42 tokens, 62.223 tokens-per-sec
Generation: 14 tokens, 13.853 tokens-per-sec
Peak memory: 26.727 GB

Qwen 2.5 32b Coder 6 bit + Qwen2.5-Coder-0.5B-Instruct-MLX-6bits

> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/Qwen2.5-Coder-32B-Instruct-6bit --prompt "How many r's are in the word 'Strawberry'?" --draft-model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
==========
The word "strawberry" contains three 'r's.
==========
Prompt: 42 tokens, 61.163 tokens-per-sec
Generation: 14 tokens, 12.656 tokens-per-sec
Peak memory: 27.138 GB

Sub0X avatar Apr 07 '25 16:04 Sub0X