Retrieval_Head icon indicating copy to clipboard operation
Retrieval_Head copied to clipboard

unexpected keyword argument 'block_list'

Open yuhuanyuan opened this issue 1 year ago • 3 comments

python3 needle_in_haystack_with_mask.py TypeError: Qwen2ForCausalLM.forward() got an unexpected keyword argument 'block_list'

yuhuanyuan avatar Jun 14 '24 03:06 yuhuanyuan

@yuhuanyuan I met exactly same issue when following the default setting by using from source.modeling_qwen2 import Qwen2ForCausalModel, and got the same error TypeError: Qwen2ForCausalLM.forward() got an unexpected keyword argument 'block_list'

If you look at the code at faiss_attn/source/modeling_qwen2.py, you would notice that the class Qwen2ForCausalModel did not implement a mask operation indicator which is the block_list to indict the index of blocks to be masked. However, faiss_attn/source/modeling_llama.py do have the block list in its LlamaForCausalLM, so I guess maybe the author somehow forget to add this feature in Qwen2ForCausalModel.

To solve the problem

simply use Qwen2Model by from source.modeling_qwen2 import Qwen2Model can make it work for https://github.com/nightdessert/Retrieval_Head/blob/2d9fb9d72eaa685acf6a4d29d0115a682c3c05ab/needle_in_haystack_with_mask.py#L342 However, I dont know is it a workaround? or we can use Qwen2Model instead of Qwen2ForCausalLM for reimplement the result. Pls correct me if I was wrong. @nightdessert

zhouliang-yu avatar Jun 14 '24 08:06 zhouliang-yu

@yuhuanyuan I met exactly same issue when following the default setting by using from source.modeling_qwen2 import Qwen2ForCausalModel, and got the same error TypeError: Qwen2ForCausalLM.forward() got an unexpected keyword argument 'block_list'

If you look at the code at faiss_attn/source/modeling_qwen2.py, you would notice that the class Qwen2ForCausalModel did not implement a mask operation indicator which is the block_list to indict the index of blocks to be masked. However, faiss_attn/source/modeling_llama.py do have the block list in its LlamaForCausalLM, so I guess maybe the author somehow forget to add this feature in Qwen2ForCausalModel.

To solve the problem

simply use Qwen2Model by from source.modeling_qwen2 import Qwen2Model can make it work for

https://github.com/nightdessert/Retrieval_Head/blob/2d9fb9d72eaa685acf6a4d29d0115a682c3c05ab/needle_in_haystack_with_mask.py#L342

However, I dont know is it a workaround? or we can use Qwen2Model instead of Qwen2ForCausalLM for reimplement the result. Pls correct me if I was wrong. @nightdessert

Hey guys, sorry for the waiting. I have updated the file 'source/modeling_qwen2.py'. Hope that works well.

nightdessert avatar Jun 19 '24 06:06 nightdessert

@nightdessert The kwargs aren't flowing to forward_torch. We had to make few modifications to flow the block_list to forward_torch.

# Self Attention
        if (attn_mode == "flash"):
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
        else:
            hidden_states, inspect, self_attn_weights, present_key_value = self.self_attn.forward_torch(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

needs to be changed to

# Self Attention
        if (attn_mode == "flash"):
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                **kwargs
            )
        else:
            hidden_states, inspect, self_attn_weights, present_key_value = self.self_attn.forward_torch(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                **kwargs
            )

shaswatpatel123 avatar Apr 28 '25 18:04 shaswatpatel123