audiolm-pytorch
audiolm-pytorch copied to clipboard
Question about 'attention bias not supported for flash attention'
https://github.com/lucidrains/audiolm-pytorch/blob/879f3bd2668b4fc61126b8405bbfdea8fa2c8778/audiolm_pytorch/attend.py#L112
Why not to add the bias and the mask and create an attn_mask of type float and supply it to the scaled_dot_product_attention as attn_mask? is that not the same as we do where not using flash attention?
actually I see that in the SoundStrom repo you started to do something like that:
https://github.com/lucidrains/soundstorm-pytorch/blob/22d257d6b5241583e84619b7af6a634158aba426/soundstorm_pytorch/attend.py#L96-L99
but yo left the assert there and also i didn't understand why you divide the value by 2?
@amitaie it is not possible to get bias gradients from flash attention