flashinfer
flashinfer copied to clipboard
add multi-item scoring
Co-authored with Qingquan Song (@qingquansong) and Ziang Li (@zianglih )
Multi-item scoring
- concatenate multiple candidates of a same member with all ranking candidates with delimiter separation.
<member prefix (profile & history)> +
+ + + item 2 + ... + item N - Extract the logits of the hidden states of the tokens before each delimiter token and extract the log prob of given label tokens. For each single prompt, output returned will be a 2D list with shape N * K where N is the number of candidate it contains and K is the number of choices we provided to the server engine (e.g., 2 for ["Yes", "No"])) (mainly done in the logit processor)
The PR optimized the multi-item scoring attention by passing four new args and use it to check the masking condition. The provided args are:
prefix_len_ptr :Optional[torch.Tensor]
prefix length. A uint32 1D tensor indicating the prefix length of each prompt. The tensor size is equal to the batch size.
token_pos_in_items_ptr : Optional[float]
A uint16 1D tensor (it will be converted to uint16 in flashinfer) indicating the token position of each item and started from 0 (delimiter)
for each item. E.g., if we have 3 items of length 3, 2, 4 respectively for this member. This vector will be looking like
`[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]` with 4 delimiters indexed as 0. For batch size > 1,
we will concat them as 1D with zero paddings to make sure each has the same length, the padding length is defined by
`token_pos_in_items_len` - length of the raw `token_pos_in_items_ptr` for each prompt.
token_pos_in_items_len : Optional[int]
zero padding length for `token_pos_in_items_ptr` to better handle the bsz > 1 case. Still using the above 3,2,4 example.
If we set `token_pos_in_items_len` to be 20, it will be `[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0]`
with 7 padded zeros. (note there're 8 zeros in the end where the first one is the delimiter token 0 in the end of the prompt)
max_item_len_ptr : Optional[float]
a uint16 vector contains the max token length of all items for each prompt
Optimizations
- Implement efficient multi-item scoring mask for FA2 and FA3.
- Enhance FA3 to support batch-idx for the multi-item scoring mask.
- Implement skip tiles for FA2 and FA3 multi-item scoring
- Optimize mask by preloading to L1 cache for thread register.
Hey @yzh119 as discussed, here's the PR for multi-item scoring masked attention. Please feel free to leave comments and provide suggestions if there could be better ways to help upstream the change. Thank you in advance!
Hey @yzh119 , @arde171 has resolved the comments, could you help take another look? Thank you!