Lookback-Lens icon indicating copy to clipboard operation
Lookback-Lens copied to clipboard

No logits_warper applied in guided_decoding method

Open yw-ucsb opened this issue 1 year ago • 0 comments

Hi,

Thanks for sharing this impressive work! This repo has been very helpful in my recent research. However, when I check the guided decoding method, I found that no logits warper is passed to this method. My understanding of the paper is that the guided decoding process is to call multiple rounds of the "sample" method to generate candidate chunks, so that the sampling process should ideally stay consistent with the default sample method. I checked that the default sampling process from huggingface is as below:

......
elif generation_mode == GenerationMode.SAMPLE:
            # 11. prepare logits warper
            logits_warper = self._get_logits_warper(generation_config)

            # 12. expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 13. run sample
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )
......

The guided decoding method does not perform step 11 above: (thought I found an input argument in its implementation)

......
if guiding_classifier is not None:
            # 11. classifier guided decoding
            return self.classifier_guided_decoding(
                input_ids,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                extra_prompt_length=extra_prompt_length,
                guiding_classifier=guiding_classifier,
                chunk_size=chunk_size,
                num_candidates=num_candidates,
                conversion_matrix=conversion_matrix,
                feat_layer=feat_layer,
                **model_kwargs,
            )
......

I'm wondering if there is something I miss and I'd appreciate if you can help me with this question. Thanks again!

yw-ucsb avatar Sep 21 '24 04:09 yw-ucsb