llm-foundry
llm-foundry copied to clipboard
Loss explodes with Flash/Triton Attention
Hi!
using this docker image mosaicml/llm-foundry:2.0.1_cu118-latest
I'm training mpt-125m with your default parameters, my loss explodes after some number of steps
I have added warmup 2k steps as well
It looks really strange, cause it's works fine with torch
attention, can you help please?
It's happening on different datasets, moment of explode little bit different
same with torch 1.13 cuda 11.7
maybe some problems in mpt-125m
?
do you notice some bugs of flash attention at some num_heads
and other things?
It looks really strange, cause it's works fine with torch attention, can you help please?
Can you clarify what you mean by "it's works fine with torch
attention".
You don't actually show that in your figure; can that be added for clarification
i mean attn_config: attn_impl: torch
works fine, figure updated in header
also my inference (on Russian dataset) looks not promising (after convert this torch mpt checkpoint from composer to hf, and run hf_generate.py
):
*************'***********......
. ::: :::=:::: : : : : : : : :: ::-: : :. 28........ ... . }[[[[[
. | " ходы 0 а а а аки блиииииии и и и ааааааааааа??????
,,,,,,,,,,,,,.
,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,, ,,
,,,,,,,,,
I don't know, but maybe it's some problems in MPT graph?
I just want to verify that you used the exact same configuration for all 3 runs and the only diff was the attn_impl
exactly, only attn_impl is different in this 3 runs
have you trained your mpt-125m on big datasets (more than Chinchilla 2.5B tokens for this)? it was ok?
I guess mpt-125m is OK, but loss explode is exist. Checkpoint on 4k steps much better than 15k (after explode), so it would be nice to fix this
I'm thinking, maybe I feed so big batch size to 125m model? I'm using 2048 max_seq_len and global train bs 256 => 2k * 256 = 0.5 M It's relative with Brown's OpenAI Research
but here is same bs for 125M and 760M model, it looks little bit strange
attn_impl: torch | flash | triton
handle numerical precision differently.
attn_impl: torch
operates under the with torch.autocast(**kwargs)
context manager should be the most numerically stable; the other attn_impl are a bit more laxed about how numerics are handled.
LLMs get loss spikes for a lot of reasons and the larger the model, the more likely it is to get a spike. Given attn_impl: torch
is fine and the others are not, make it seem as though your training has entered into a regime where it is extremely sensitive to the numerical precision difference happening within the attn layer. There are potentially a lot of ways to get around loss spikes.
The OPT paper talks about how they deal with loss spikes; they also released an on-call logbook chronicling how they deal with all the issues encountered during training (including loss spikes).
There are also other works which talk about all the diff reasons why loss spikes happen and how to mitigating loss spikes.
MosaicML helps its customers get into a stable training regime so that they do not need to deal with this.
@vchiley due maybe some anomaly batches rolls here, do you provide technique to skip batches with big loss during training in straming
?
looks like it should be something like:
for i, batch in enumerate(dataloader):
if loss(dataloader[i]) > loss(dataloader[i-1]) * 2 # for example multiply of 2:
continue
else:
losses.append(loss)
Is any way to make this elegant?
Of course I can just go to src in streaming
, but I want only works with llm-foundry
I don't think we've implemented anything like that, but you'd probably implement this as a composer callback which is triggered <AFTER_TRAIN_BATCH> or <BEFORE_BACKWARD> depending on exactly what you want to implement and how you want to implement it.