TensorRT-LLM
TensorRT-LLM copied to clipboard
Question: Return log probabilites
Trying out T5 with python backend. https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/enc_dec/run.py#L484
I see SamplingConfig has output_log_probs https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/generation.py#L355.
But in the return dict does not have the log probabilities https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/generation.py#L2515.
Is there any other way to get the log probabilities?
Hi @sindhuvahinis , this is a general missing return in generation.py that is not T5 or enc-dec only. In summary, the return log probs and return cumulative log probs are both supported, as you saw in the SamplingConfig. But the get_outputs_dict()
doesn't added that.
I recommend you manually add output['log_probs'] = self.log_probs
(and similarly if you want self.cum_log_probs) in get_outputs_dict()
. And don't forget to set return_dict=True when calling generate() inside enc_dec/run.py
Please let me know this works
Hey @symphonylyh Thanks for the reply Yes I understand it's a generic setting. And return_dict or not does not matter. I just want to get the log_probs of output tokens generated. Yes I tried directly accessing the log_probs directly. But the log_probs is always zero. Here is my sample code.
encoder_input_ids = encoder_input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
# encoder run
encoder_input_ids, encoder_input_lengths, encoder_max_input_length = self.process_input(
encoder_input_ids, self.encoder_model_config.remove_input_padding,
pad_token_id)
encoder_output = self.encoder_run(encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
debug_mode=debug_mode)
# decoder run
decoder_input_ids, decoder_input_lengths, decoder_max_input_length = self.process_input(
decoder_input_ids, self.decoder_model_config.remove_input_padding,
pad_token_id)
# generation config
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
**kwargs)
sampling_config.output_log_probs = True
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
max_attention_window_size=None,
encoder_max_input_length=encoder_max_input_length,
)
torch.cuda.synchronize()
output_ids = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
)
torch.cuda.synchronize()
if self.runtime_rank == 0:
# [max_new_tokens, batch_size, num_beams] -> [batch_size, max_new_tokens, num_beams]
log_probs = self.decoder_session.log_probs.cpu().transpose(0, 1).numpy()
logging.info(f"shape {log_probs.shape}")
logging.info(f"log probs {log_probs[0]}")
return output_ids
Output I got.
shape (1, 256, 1)
log probs [[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
...
[0.]]
Hello @symphonylyh I printed out self.log_probs after each each handler_per_step call. For each token generation, it seems to be zero tensors and the does not seem like log_probs is updated anywhere.
Also printed out log_probs after all generation is over. It is still zero tensors.
Hi @sindhuvahinis, I reproduced your observation, and confirmed it's a bug.
I have a full fix internally, but it would be too many small changes to communicate over here. As a temporary workaround, can you try the following changes:
- move this line to the end of
setup()
call - for beam_width > 1, change this line to
dynamic_decode_outputs.output_log_probs = outputs.output_log_probs ? outputs.output_log_probs : outputs.output_log_probs_tiled;
- for beam_width = 1, add the following code before this line
if (outputs.output_log_probs)
{
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < max_seq_len);
Tensor& output_log_probs = outputs.output_log_probs.value();
size_t step_offset = mCyclicStep * batch_size * beam_width;
decode_outputs.output_log_probs
= output_log_probs.slice({1, local_batch_size * beam_width}, step_offset + local_batch_offset);
}
Thank you @symphonylyh . Will check this out.
@symphonylyh Are you sure v0.8.0 fixed this issue? I tried with 0.8.0. I still dont see log_probs. They are all set to zero
+1
Also can't see log_probs returning non-zero in v0.8.0
@symphonylyh I see your suggested code is in v0.8.0, but does it work for you? Could you confirm?