Feature Request: Add Min-P sampling layer
It would be very nice if the library supported using Min-P sampling as an alternative to Top-P/Top-K. This became popular for local LLMs in the past few months because it provides significantly more useful results, or at least feels like it does. More info here: https://www.reddit.com/r/LocalLLaMA/comments/17vonjo/your_settings_are_probably_hurting_your_model_why/
Most other libraries already support it, examples: https://github.com/turboderp/exllamav2/commit/0d436d7ac35c402698c37bb73e534f8e5aad839a https://github.com/ggerganov/llama.cpp/pull/3841
This only requires a single parameter - consider all tokens whose probability is greater than than the probability of the first one scaled down by some number.
Forgot to mention: this sampling method should be applied before temperature
@ncomly-nvidia seconded on adding min p - makes a noticeable impact on production, doesn't seem too bad to implement compared to some others.
@byshiue any chance of this being added soon? 👀
Most other engines have it now, it's in vLLM and HF transformers is also adding it.
^^ has a huge impact on production
If Nvidia doesn't want to do it (why? much superior inference results...), maybe we can add it ourselves? It looks like sampling layers are part of the code that is open source.
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/samplingTopPKernels.h https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/layers/topPSamplingLayer.h https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/layers/topPSamplingLayer.cu
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/layers/samplingLayer.cpp#L52
@aikitoria I found that it was far easier and performant to implement in decodingCommon.cu since the same math used for logprobs can be used for calculating the relative threshold used in min-p sampling See my PR in #1536
I'll need to review that things are done in the correct order: I'm still grappling with the codebase but I assumed it should be doing min-p before sampling
Also mentioned in this issue https://github.com/NVIDIA/TensorRT-LLM/issues/1683.
I hope this feature being added any time soon!
You can implement this as a logit processor as far as I can tell:
def _get_min_p_fn(min_p: float):
def _fn(
logits: torch.Tensor,
) -> torch.Tensor:
probs = torch.softmax(logits)
top_prob = probs.max()
scaled_min_p = min_p * top_prob
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
return logits
return _fn
Is that not much slower than it would be if properly implemented it in the CUDA sampling layers?
Just saw the PR above also. That's an interesting way. Highly doubt nvidia would accept it given it seems more like a hack, but it gives us something to experiment with...
@aikitoria Yeah, since it's computed per request and not in a batch (see also https://github.com/NVIDIA/TensorRT-LLM/issues/1681). But if you are already using other logit processors, it might not have a big of an effect.
FWIW the new executor API does not allow parametrizing logit processors per-request anymore -- they are fixed at startup -- so one can't implement MinP that way. You have to go lower-level to GptManager in C++, so bumping this thread @ncomly-nvidia @AdamzNV