composer
composer copied to clipboard
use flash attn fuse cross entropy loss to reduce metric memory usage
This PR uses fused cross entropy loss from flash attention in the metric LanguageCrossEntropy (also LanguagePerplexity).
The current torch.nn.CrossEntropyLoss call needs 6 * seq_len * vocab_size GPU memory, and can be the bottleneck memory usage when sequence length is long (where act ckpt is probably used). Using cross entropy loss from flash attn resolves this problem.
Example test model with long sequence and full act ckpt:
with torch loss fn:
with flash_attn loss fn