flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Feature Request: Fused Linear and Cross-Entropy Loss

Open imoneoi opened this issue 1 year ago • 1 comments
trafficstars

Since the latest models, such as Llama 3 and Gemma, adopt extremely large vocabularies (128-256K), the size of logits can become very large, consuming a large proportion of VRAM. For example, the following shows the VRAM size for Llama 3 8B with a batch size of 81920 tokens:

Logits size: 128256 (vocabulary size) * 81920 (batch size) * 2 bytes (bf16) = 19.57GiB
Hidden state size (checkpointed): 4096 (hidden size) * 81920 (batch size) * 32 (layers) * 2 bytes (bf16) = 20GiB

Therefore, a fused linear and cross-entropy loss operator that does not require materializing full logits may reduce VRAM consumption by half. It'd be a great addition to the FlashAttention model implementations.

imoneoi avatar Apr 19 '24 12:04 imoneoi

I personally don't have cycles for this but we welcome PRs

tridao avatar Apr 19 '24 17:04 tridao