FastChat
FastChat copied to clipboard
Setting attention mask based on past-key-value shape in inference.py
Hi, in the function def generate_stream in inference.py, the attention mask setup as:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
It seems that the attention mask shape should be past_key_values[0][0].shape[2] + 1, since the third dim of past_kv is the seq_len.
Or just setting attention_mask is None.
Update: The shape of past key values of different models are different, for example: in LLaMa:
past_key_value[0][0].size = [batch_size, num_head, seq_len, head_dim], a 4d tensor
In Bloom:
past_key_value[0][0].size = [batch_size * num_head, head_dim, seq_len], a 3d tensor
Hence, in the def generate_stream in inference.py, the attention mask size (i.e., past_key_values[0][0].shape[-2] + 1) would be wrong in Bloom model.
I think it would be more flexible not to pass the attention mask (attention_mask=None), to fit more different models
Could you send a pull request to fix it?
Of course, I will post a PR later.