composer icon indicating copy to clipboard operation
composer copied to clipboard

use flash attn fuse cross entropy loss to reduce metric memory usage

Open cli99 opened this issue 1 year ago • 2 comments

This PR uses fused cross entropy loss from flash attention in the metric LanguageCrossEntropy (also LanguagePerplexity). The current torch.nn.CrossEntropyLoss call needs 6 * seq_len * vocab_size GPU memory, and can be the bottleneck memory usage when sequence length is long (where act ckpt is probably used). Using cross entropy loss from flash attn resolves this problem.

Example test model with long sequence and full act ckpt: with torch loss fn: image

with flash_attn loss fn image

cli99 avatar Feb 09 '24 05:02 cli99