x-transformers
x-transformers copied to clipboard
Flash is not flash
I test Flash attention vs HF GPT2 with pytorch lightning warp. But it is slow than transformers.GPT2LMHeadModel with same config parameters. Not sure where I am going wrong?
Purple is x-transformers flash attn.
` class FlashAttentionLM(pl.LightningModule):
def __init__(self, config):
super().__init__()
model = TransformerWrapper(
num_tokens = config.vocab_size,
max_seq_len = config.seq_length,
attn_layers = Decoder(
dim = config.embd_size,
depth = config.n_layer + 1,
heads = 8,
attn_flash = True
)
)
self.model = AutoregressiveWrapper(model)
`