mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Prompt Lookup Decoding - merged under Speculative example

Open LeonEricsson opened this issue 1 year ago • 10 comments

Continuation of #202. Decided to merge the Prompt Lookup Decoding under the Speculative Decoding example.

This PR implements a example for the "Prompt Lookup Decoding" technique:

https://github.com/apoorvumang/prompt-lookup-decoding

  • This approach works best on input-grounded tasks such as summarization, document QA, code editing etc where there is a high overlap in input and output.
  • It replaces the draft model in speculative decoding with a simple ngram search of the input tokens. Similar to speculative decoding it has no impact on output.

TODO

  • [x] Debug: Model output doesn't have any spaces
  • [x] Add --color flag to SpeculativeDecoder Ended up being quite a messy implementation; need to deal with the fact that output can be truncated and hence only come from draft model
  • [x] Update README
  • [x] Code cleaning

LeonEricsson avatar Jan 05 '24 21:01 LeonEricsson

@awni perhaps we can leave this as T5 and then make an attempt at swapping to Llama in a new PR? I was thinking we could adopt the model format / conversion from hf_llm as well while we're still at it? I could take a first crack at it.

LeonEricsson avatar Jan 08 '24 09:01 LeonEricsson

@awni perhaps we can leave this as T5 and then make an attempt at swapping to Llama in a new PR?

Yea that sounds like a great plan to me! Sorry for the delay in the review here, I will get to it shortly!

awni avatar Jan 08 '24 14:01 awni

Looks really nice! I think we can get this in soon. I didn't look yet at the core of the prompt decoder but left a few comments.

Thanks a ton for refactoring them together, I think it makes a lot of sense this way.

Thanks! Addressed all your comments

LeonEricsson avatar Jan 18 '24 06:01 LeonEricsson

I've been poking around your code @LeonEricsson because I have some long summarization tasks that I'd like to speed up, but noticed a significant bottleneck from the loop. This is probably still not perfect, but I've had a go at speeding it up. This implementation is about 500x faster:

def find_draft(self, input_ids):
            # Convert MLX array to NumPy for vectorized operations
            input_ids_np = np.array(input_ids)
            
            # Create a sliding window of the last ngram_max tokens
            ngram = input_ids_np[-self.ngram_max:]
            
            # Vectorized comparison of ngram with all possible sub-arrays of input_ids
            matches = np.lib.stride_tricks.sliding_window_view(input_ids_np, self.ngram_max) == ngram
            
            # Check if all elements in ngram match for each sub-array
            match_indices = np.all(matches, axis=1).nonzero()[0]
            
            # Filter out matches that are too short or overlap with the ngram itself
            match_indices = match_indices[(match_indices + self.ngram_max <= input_ids_np.size - self.ngram_max) & (match_indices >= self.ngram_min)]
            
            # Find the largest match
            if match_indices.size > 0:
                largest_match_idx = match_indices[-1]  # Assuming the last match is the largest
                start_idx = largest_match_idx + self.ngram_max
                end_idx = min(start_idx + self.n_draft, input_ids_np.size)
                candidate = input_ids_np[start_idx:end_idx]
                
                # Convert the candidate back to MLX array
                return mx.array(candidate, dtype=mx.uint32)
            
            return mx.array([], dtype=mx.uint32)

cmcmaster1 avatar Jan 22 '24 10:01 cmcmaster1

I've been poking around your code @LeonEricsson because I have some long summarization tasks that I'd like to speed up, but noticed a significant bottleneck from the loop. This is probably still not perfect, but I've had a go at speeding it up. This implementation is about 500x faster:

nice :rocket: the original implementation employed numpy's sliding windows, but I chose to maintain a purely mlx approach. However, as these are user examples, we should prioritize what is most beneficial for the user. A performance bottleneck like this is indeed a significant issue, and I concur that it warrants a change.

sidenote: perhaps we can attain comparable speed improvements using mlx.core.vmap?

LeonEricsson avatar Jan 22 '24 13:01 LeonEricsson

That makes sense. And I'm guessing you didn't notice a huge performance gap, because you didn't try it on long texts? I'm shaving ~30 seconds off inference time. I was thinking about trying a vmap version next.

Edit: I should clarify, without vectorization prompt lookup is slower than generate for anything but the most trivial task (e.g. repetition). So I think this change is necessary to really justify its existence as a useful example for the community.

cmcmaster1 avatar Jan 22 '24 15:01 cmcmaster1

@cmcmaster1 finally implemented a pure MLX version that should be comparable in performance to the numpy one. Would be great if you could confirm this on your end. However, before you do so note that your current implementation does not consider ngram matches other than of size self.ngram_max, which is not aligned with how Prompt Lookup was originally proposed. The idea of prompt lookup is to iteratively check for smaller ngram keys until you get to self.ngram_min; note the for loop here. I spent a lot of time trying to do away with this for loop and letting mlx.core do the work but couldn't find a way I was happy with. The further distance between self.ngram_max and self.ngram_min the more of a python overhead you're going to have, you could set self.ngram_max = self.ngram_min if you don't want this behaviour.

@awni imo this is ready to be merged, sorry for the delay.

LeonEricsson avatar Jan 27 '24 20:01 LeonEricsson

@LeonEricsson oops, you're right. I somehow missed that and just tested on examples where it made no difference! Still much faster than the original and definitely comparable to the (flawed) numpy implementation.

cmcmaster1 avatar Jan 30 '24 23:01 cmcmaster1

ping @awni

LeonEricsson avatar Feb 23 '24 12:02 LeonEricsson

Sorry for the delay!! I will review and get this in early next week

awni avatar Feb 23 '24 15:02 awni