Aman Karmani

Results 460 comments of Aman Karmani

Aha I just saw this: > partial(CrossEntropyLoss, inplace_backward=True)

I'm still not able to measure any difference. I'm using the HF trainer and model with this change: ```py import transformers from functools import partial from flash_attn.losses.cross_entropy import CrossEntropyLoss transformers.models.llama.modeling_llama.CrossEntropyLoss...

I discovered my patching code wasn't running for some silly reason. I'm interested in quantifying the differences between these implementations, especially when it comes to VRAM usage. I used the...

Yea, that was with batchsize=1. I made some more measurements @ ctx=4096: | cfg | mem | | --- | --- | | bs=1 xentropy=false | 3.59699 GB | |...

Now I applied the rmsnorm kernel on top, as follows: ```py from flash_attn.ops.rms_norm import RMSNorm class LlamaRMSNorm(RMSNorm): """Patched LLamaRMSNorm""" def __init__(self, hidden_size, eps=1e-6): super().__init__(hidden_size, eps=eps) transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm ``` and...

I'm working on the rotary kernel next, but am not quite sure if I'm handling the cos/sin correctly: ```diff --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from torch.nn import...

Thanks for the pointer. I know I can convert the weights and use the trainer here, but I'm interested in features that transformers offers out of the box, such as...

Okay I see I should probably be using the `flash_attn.layers.rotary.RotaryEmbedding` module instead of trying to call `apply_rotary_emb` directly. On the transformers side there are [several variations](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L92-L170): ```python class LlamaRotaryEmbedding(torch.nn.Module): class...

You can use this subclass with the HF trainer: https://github.com/Dao-AILab/flash-attention/pull/486

Yes, but not with fused MLP because there's no place for peft to hook into the linear layers.