flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

feat(attention): add Bi-Directional MLM attention model

Open TamirFriedman-RecoLabs opened this issue 1 year ago • 1 comments

I want to implement this kind of mask in xformers, for implementing bidirectional masked-language-model type of attention:

class BlockDiagoNULLMask(fmha.attn_bias.BlockDiagonalMask):
    """
    Modification of `BlockDiagonalMask` where blocks are inner-connected
    except for the diagonal elements, which are masked from themselves.
    """

    def _create_block_mask(
        self: Self,
        shape: Tuple[int, ...],
        dtype: torch.dtype,
        device: str | torch.device,
    ) -> torch.Tensor:
        # Create a matrix filled with `-inf` on the diagonal and `0` elsewhere
        return torch.zeros(shape, dtype=dtype, device=device).fill_diagonal_(-torch.inf)

TamirFriedman-RecoLabs avatar Dec 12 '23 22:12 TamirFriedman-RecoLabs

Hi @TamirFriedman-RecoLabs Are you working on encoder stack ? For example generate model for video, music and so on.

Are you still working on this branch ? Happy to hear from you soon!