FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Setting attention mask based on past-key-value shape in inference.py

Open ShomyLiu opened this issue 2 years ago • 3 comments
trafficstars

Hi, in the function def generate_stream in inference.py, the attention mask setup as:

attention_mask = torch.ones(
                1, past_key_values[0][0].shape[-2] + 1, device=device)

It seems that the attention mask shape should be past_key_values[0][0].shape[2] + 1, since the third dim of past_kv is the seq_len.

Or just setting attention_mask is None.

ShomyLiu avatar Apr 11 '23 05:04 ShomyLiu

Update: The shape of past key values of different models are different, for example: in LLaMa:

past_key_value[0][0].size = [batch_size, num_head, seq_len, head_dim], a 4d tensor

In Bloom:

past_key_value[0][0].size = [batch_size * num_head, head_dim, seq_len], a 3d tensor

Hence, in the def generate_stream in inference.py, the attention mask size (i.e., past_key_values[0][0].shape[-2] + 1) would be wrong in Bloom model.

I think it would be more flexible not to pass the attention mask (attention_mask=None), to fit more different models

ShomyLiu avatar Apr 11 '23 09:04 ShomyLiu

Could you send a pull request to fix it?

merrymercy avatar Apr 11 '23 17:04 merrymercy

Of course, I will post a PR later.

ShomyLiu avatar Apr 12 '23 01:04 ShomyLiu