flash-attention
flash-attention copied to clipboard
feat(attention): add Bi-Directional MLM attention model
I want to implement this kind of mask in xformers, for implementing bidirectional masked-language-model type of attention:
class BlockDiagoNULLMask(fmha.attn_bias.BlockDiagonalMask):
"""
Modification of `BlockDiagonalMask` where blocks are inner-connected
except for the diagonal elements, which are masked from themselves.
"""
def _create_block_mask(
self: Self,
shape: Tuple[int, ...],
dtype: torch.dtype,
device: str | torch.device,
) -> torch.Tensor:
# Create a matrix filled with `-inf` on the diagonal and `0` elsewhere
return torch.zeros(shape, dtype=dtype, device=device).fill_diagonal_(-torch.inf)
Hi @TamirFriedman-RecoLabs Are you working on encoder stack ? For example generate model for video, music and so on.
Are you still working on this branch ? Happy to hear from you soon!