Tri Dao

Results 639 comments of Tri Dao

Sorry idk much about the ROCm version, you can ask on their repo.

There's a persistent scheduler that's not yet enabled for causal, we'll update it soon.

Sure, would love to see some PR fixing this

Can you send a short script to reproduce the speed regression? e.g. with this input, 2.5.9.post1 gets XXX seconds and 2.6.1 gets YYY seconds

Please compare (flashattn in bf16 - reference attn in fp32) vs (reference attn in bf16 - reference attn in fp32)

We have [code](https://github.com/Dao-AILab/flash-attention/blob/6711b3bc40073e7ced2a4c7d8266feec7e6e137f/flash_attn/models/llama.py#L107) to convert weights from Meta and HF to be compatible with the implementation in this repo. Test is [here](https://github.com/Dao-AILab/flash-attention/blob/6711b3bc40073e7ced2a4c7d8266feec7e6e137f/tests/models/test_llama.py#L65) to verify the the models implemented in this...

My guess is that it's because our `GPTLMHeadModel` doesn't return a loss, it returns the output which is of size (batch, seqlen, vocab_size). You'd need to have a separate loss...

> I'm still not able to measure any difference. I'm using the HF trainer and model with this change: > > ```python > import transformers > from functools import partial...

Is this with batch size 1? My back-of-the-envelop calculation: the logits has size (batch, seqlen, vocab_size), taking 2 bytes each (e.g. training with bf16). Our xentropy kernel avoids storing an...