x-transformers icon indicating copy to clipboard operation
x-transformers copied to clipboard

Flash is not flash

Open liujuncn opened this issue 1 year ago • 1 comments

I test Flash attention vs HF GPT2 with pytorch lightning warp. But it is slow than transformers.GPT2LMHeadModel with same config parameters. Not sure where I am going wrong?

image

Purple is x-transformers flash attn.

` class FlashAttentionLM(pl.LightningModule):

def __init__(self, config):
    super().__init__()
    model = TransformerWrapper(
        num_tokens = config.vocab_size,
        max_seq_len = config.seq_length,
        attn_layers = Decoder(
            dim = config.embd_size,
            depth = config.n_layer + 1,
            heads = 8,
            attn_flash = True
        )
    )
    self.model = AutoregressiveWrapper(model)

`

liujuncn avatar May 17 '23 05:05 liujuncn