fastertransformer_backend
fastertransformer_backend copied to clipboard
How can I get the logits of all tokens in vocab at each step?
Hey, thanks for providing such a great tool! I noticed that gpt_guide.md mentions a parameter: output_log_probs
. It records the log probability of logits at each step for sampling. output_log_probs
's shape is: [batch_size, beam_width, request_output_seq_len]
Is it possible to provide a parameter that records the logits of all output positions, so that I can see the topk output of any positions.
This param's shape may be : [batch_size, beam_width, request_output_seq_len, vocab_size]
The buffer is too large and we don't have plan to do it now.
You can try the ParallelGptDecoderOp
, which only contains the transformer blocks of GPT.
If you use beam search, the output_log_probs has contains the logits.
For sampling, we don't support such feature at the moment.
Thanks for your reply!
If I don't need all the logits for each token, but just the topk logits for a certain token and the token corresponding to those logits, how do I get the return.
My requirements are as follows:
- I use a gpt model
- I input a prompt
- I want to know what are the topk candidates for the next predicted token, and their logits.
Thanks again for your reply!
If you are interesting to modify the source code, you can modify the beam search / sampling kernels to save the logits of topk tokens. The relative simpler way is using the decoder op, which generates the results of transformer blocks. And then you can handle the logits in pytorch side.
Thank you for your reply, if I want to modify the source code to achieve my goal, what files do I need to modify, can you give me some advice? thank you very much.
For sampling, you need to modify the https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/DynamicDecodeLayer.cc, some files in https://github.com/NVIDIA/FasterTransformer/tree/main/src/fastertransformer/layers/sampling_layers and the related kernels. You also need to modify the API of custom OP for the framework you use.