Poor Speculative Decoding Performance on M2 Ultra
Speculative decoding does not seem to improve generation speed as expected on M2 Ultra Mac Studio, 128GB.
Main model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit Draft model: https://huggingface.co/lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit or https://huggingface.co/mlx-community/Qwen2.5-0.5B-Instruct-4bit
Prompt: "Write a quicksort algorithm" Without spec decoding: 29.803 tokens-per-sec With spec decoding: 29.051 tokens-per-sec Qwen2.5-Coder-0.5B-Instruct-MLX-4Bit alone: 284.647 tokens-per-sec
In the same situation on an M3 Pro, 32GB of ram, we see tremendous speedup (~7tok/sec -> ~16tok/sec)
Full logs:
Click to expand
(venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-4bit -m 1000 --temp 0
========== Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:
def quicksort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[len(arr) // 2] # Choose the middle element as the pivot
left = [x for x in arr if x < pivot] # Elements less than the pivot
middle = [x for x in arr if x == pivot] # Elements equal to the pivot
right = [x for x in arr if x > pivot] # Elements greater than the pivot
return quicksort(left) + middle + quicksort(right)
# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)
Explanation:
- Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
- Pivot Selection: We choose the middle element of the array as the pivot.
- Partitioning: We create three lists:
leftfor elements less than the pivot.middlefor elements equal to the pivot.rightfor elements greater than the pivot.
- Recursive Sorting: We recursively apply the
quicksortfunction to theleftandrightlists and concatenate the results with themiddlelist.
This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:
def quicksort_inplace(arr, low, high):
if low < high:
pi = partition(arr, low, high) # Partitioning index
quicksort_inplace(arr, low, pi - 1) # Sort left part
quicksort_inplace(arr, pi + 1, high) # Sort right part
def partition(arr, low, high):
pivot = arr[high] # Choose the last element as the pivot
i = low - 1 # Index of smaller element
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i] # Swap
arr[i + 1], arr[high] = arr[high], arr[i + 1] # Swap pivot element
return i + 1
# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)
In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.
========== Prompt: 34 tokens, 71.386 tokens-per-sec Generation: 709 tokens, 29.051 tokens-per-sec Peak memory: 18.932 GB (venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-32B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0
========== Certainly! Quicksort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Below is a simple implementation of the Quicksort algorithm in Python:
def quicksort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[len(arr) // 2] # Choose the middle element as the pivot
left = [x for x in arr if x < pivot] # Elements less than the pivot
middle = [x for x in arr if x == pivot] # Elements equal to the pivot
right = [x for x in arr if x > pivot] # Elements greater than the pivot
return quicksort(left) + middle + quicksort(right)
# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quicksort(arr)
print("Sorted array:", sorted_arr)
Explanation:
- Base Case: If the array has 0 or 1 element, it is already sorted, so we return it as is.
- Pivot Selection: We choose the middle element of the array as the pivot.
- Partitioning: We create three lists:
leftfor elements less than the pivot.middlefor elements equal to the pivot.rightfor elements greater than the pivot.
- Recursive Sorting: We recursively apply the
quicksortfunction to theleftandrightlists and concatenate the results with themiddlelist.
This implementation is simple and easy to understand, but it may not be the most efficient in terms of space complexity due to the use of additional lists. For an in-place version, you can modify the algorithm to swap elements within the original array. Here's an in-place version:
def quicksort_inplace(arr, low, high):
if low < high:
pi = partition(arr, low, high) # Partitioning index
quicksort_inplace(arr, low, pi - 1) # Sort left part
quicksort_inplace(arr, pi + 1, high) # Sort right part
def partition(arr, low, high):
pivot = arr[high] # Choose the last element as the pivot
i = low - 1 # Index of smaller element
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i] # Swap
arr[i + 1], arr[high] = arr[high], arr[i + 1] # Swap pivot element
return i + 1
# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
quicksort_inplace(arr, 0, len(arr) - 1)
print("Sorted array:", arr)
In this in-place version, the partition function rearranges the elements in the array such that elements less than the pivot are on the left, elements greater than the pivot are on the right, and the pivot is in its correct position. The quicksort_inplace function then recursively sorts the subarrays.
========== Prompt: 34 tokens, 75.790 tokens-per-sec Generation: 709 tokens, 29.803 tokens-per-sec Peak memory: 18.643 GB (venv) ➜ test mlx_lm.generate --model lmstudio-community/Qwen2.5-Coder-0.5B-Instruct-MLX-4bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0
========== Sure, here's a simple implementation of the quicksort algorithm in Python:
def quicksort(arr):
# Base case: if the array is empty or has one element, it's already sorted
if len(arr) <= 1:
return arr
# Choose a pivot element
pivot = arr[len(arr) // 2]
# Partition the array into two sub-arrays: elements less than or equal to the pivot and elements greater than or equal to the pivot
less_than_pivot = [x for x in arr if x <= pivot]
greater_than_pivot = [x for x in arr if x > pivot]
# Recursively sort the two sub-arrays
quicksort(less_than_pivot)
quicksort(greater_than_pivot)
# Merge the sorted sub-arrays
return less_than_pivot + [pivot] + greater_than_pivot
This function takes an array as input and returns a new array sorted in ascending order. It uses a simple partitioning strategy: it selects a pivot element and partitions the array into two sub-arrays: all elements less than or equal to the pivot and all elements greater than or equal to the pivot. The function then recursively sorts the two sub-arrays and merges them to form the sorted array.
========== Prompt: 34 tokens, 683.032 tokens-per-sec Generation: 276 tokens, 284.647 tokens-per-sec Peak memory: 0.299 GB
I ran a couple benchmarks on M3 max and M2 Ultra. As expected we get much better scaling of the big model w.r.t. sequence length on M3 max than M2 Ultra. This probably explains why we are seeing little to no performance improvement on M2 Ultra.
In the figure below you see time as you increase sequence length. You want the line to be as flat as possible for the best possible speedup with speculative generation.
On the optimistic side, from conversations @angeloskath and @barronalex there is likely room to improve small batch qmm which should help this use case considerably.
Does this seem right about 49% speed up for M4 Max 16/40 and 22% for M3 Ultra 28/60
M4 Max
mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0
Prompt: 34 tokens, 39.482 tokens-per-sec Generation: 642 tokens, 7.919 tokens-per-sec Peak memory: 59.349 GB
mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-8bit -m 1000 --temp 0
Prompt: 34 tokens, 43.196 tokens-per-sec Generation: 642 tokens, 11.783 tokens-per-sec Peak memory: 59.887 GB
M3 Ultra
mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" -m 1000 --temp 0
Prompt: 34 tokens, 45.241 tokens-per-sec Generation: 642 tokens, 11.346 tokens-per-sec Peak memory: 59.349 GB
mlx_lm.generate --model mlx-community/Qwen2.5-72B-Instruct-6bit --prompt "Write a quicksort algorithm" --draft-model mlx-community/Qwen2.5-0.5B-Instruct-8bit -m 1000 --temp 0
Prompt: 34 tokens, 48.431 tokens-per-sec Generation: 642 tokens, 13.815 tokens-per-sec Peak memory: 59.887 GB
Same issue here as well on the M2 Max where rather than seeing a speedup, we see a degradation in throughput:
QwQ-32B-6bit
> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/QwQ-32B-6bit --prompt "How many r's are in the word 'Strawberry'?"
==========
Okay, so I need to figure out how many times the letter 'r' appears in the word 'Strawberry'. Let me start by writing down the word and looking at each letter one by one.
First, I'll spell out 'Strawberry' letter by letter to make sure I don't miss any. Let's see:
S - T - R - A - W - B - E - R - R - Y.
Wait, is that right? Let me check again
==========
Prompt: 23 tokens, 35.913 tokens-per-sec
Generation: 100 tokens, 9.643 tokens-per-sec
Peak memory: 26.710 GB
Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
> mlx_lm.generate --model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits --prompt "How many r's are in the word 'Strawberry'?"
==========
The word 'Strawberry' contains 3 r's.
==========
Prompt: 42 tokens, 996.588 tokens-per-sec
Generation: 14 tokens, 127.092 tokens-per-sec
Peak memory: 0.427 GB
QwQ-32B-6bit + Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/QwQ-32B-6bit --prompt "How many r's are in the word 'Strawberry'?" --draft-model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
==========
Okay, so I need to figure out how many times the letter 'r' appears in the word 'Strawberry'. Let me start by writing down the word and looking at each letter one by one.
First, I'll spell out 'Strawberry' letter by letter. Let's see: S, T, R, A, W, B, E, R, R, Y. Wait, is that right? Hmm, maybe I should double-check the spelling. Sometimes people
==========
Prompt: 23 tokens, 52.834 tokens-per-sec
Generation: 100 tokens, 9.138 tokens-per-sec
Peak memory: 27.121 GB
Similar results show for Qwen 2.5 32b Coder 6 bit:
> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/Qwen2.5-Coder-32B-Instruct-6bit --prompt "How many r's are in the word 'Strawberry'?"
==========
The word "strawberry" contains three 'r's.
==========
Prompt: 42 tokens, 62.223 tokens-per-sec
Generation: 14 tokens, 13.853 tokens-per-sec
Peak memory: 26.727 GB
Qwen 2.5 32b Coder 6 bit + Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
> mlx_lm.generate --model /Users/sub01/.lmstudio/models/mlx-community/Qwen2.5-Coder-32B-Instruct-6bit --prompt "How many r's are in the word 'Strawberry'?" --draft-model /Users/sub01/.lmstudio/models/moot20/Qwen2.5-Coder-0.5B-Instruct-MLX-6bits
==========
The word "strawberry" contains three 'r's.
==========
Prompt: 42 tokens, 61.163 tokens-per-sec
Generation: 14 tokens, 12.656 tokens-per-sec
Peak memory: 27.138 GB