Lookback-Lens
Lookback-Lens copied to clipboard
No logits_warper applied in guided_decoding method
Hi,
Thanks for sharing this impressive work! This repo has been very helpful in my recent research. However, when I check the guided decoding method, I found that no logits warper is passed to this method. My understanding of the paper is that the guided decoding process is to call multiple rounds of the "sample" method to generate candidate chunks, so that the sampling process should ideally stay consistent with the default sample method. I checked that the default sampling process from huggingface is as below:
......
elif generation_mode == GenerationMode.SAMPLE:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
......
The guided decoding method does not perform step 11 above: (thought I found an input argument in its implementation)
......
if guiding_classifier is not None:
# 11. classifier guided decoding
return self.classifier_guided_decoding(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
extra_prompt_length=extra_prompt_length,
guiding_classifier=guiding_classifier,
chunk_size=chunk_size,
num_candidates=num_candidates,
conversion_matrix=conversion_matrix,
feat_layer=feat_layer,
**model_kwargs,
)
......
I'm wondering if there is something I miss and I'd appreciate if you can help me with this question. Thanks again!