llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Loss explodes with Flash/Triton Attention

Open germanjke opened this issue 1 year ago • 10 comments

Hi!

using this docker image mosaicml/llm-foundry:2.0.1_cu118-latest

I'm training mpt-125m with your default parameters, my loss explodes after some number of steps

I have added warmup 2k steps as well

It looks really strange, cause it's works fine with torch attention, can you help please?

It's happening on different datasets, moment of explode little bit different

image

germanjke avatar Jun 05 '23 19:06 germanjke

same with torch 1.13 cuda 11.7

germanjke avatar Jun 05 '23 19:06 germanjke

maybe some problems in mpt-125m? do you notice some bugs of flash attention at some num_heads and other things?

germanjke avatar Jun 05 '23 19:06 germanjke

It looks really strange, cause it's works fine with torch attention, can you help please?

Can you clarify what you mean by "it's works fine with torch attention". You don't actually show that in your figure; can that be added for clarification

vchiley avatar Jun 06 '23 11:06 vchiley

i mean attn_config: attn_impl: torch works fine, figure updated in header

germanjke avatar Jun 06 '23 11:06 germanjke

also my inference (on Russian dataset) looks not promising (after convert this torch mpt checkpoint from composer to hf, and run hf_generate.py):

*************'***********......
. ::: :::=:::: : : : : : : : :: ::-: : :. 28........ ... . }[[[[[
.  | " ходы 0 а а а аки блиииииии и и и ааааааааааа??????

,,,,,,,,,,,,,.
,,,,,,,,,,,,,,
,,,,,,,,,,,,,,,, ,,
,,,,,,,,,

I don't know, but maybe it's some problems in MPT graph?

germanjke avatar Jun 06 '23 12:06 germanjke

I just want to verify that you used the exact same configuration for all 3 runs and the only diff was the attn_impl

vchiley avatar Jun 06 '23 16:06 vchiley

exactly, only attn_impl is different in this 3 runs

germanjke avatar Jun 06 '23 17:06 germanjke

have you trained your mpt-125m on big datasets (more than Chinchilla 2.5B tokens for this)? it was ok?

germanjke avatar Jun 06 '23 17:06 germanjke

I guess mpt-125m is OK, but loss explode is exist. Checkpoint on 4k steps much better than 15k (after explode), so it would be nice to fix this

germanjke avatar Jun 06 '23 21:06 germanjke

I'm thinking, maybe I feed so big batch size to 125m model? I'm using 2048 max_seq_len and global train bs 256 => 2k * 256 = 0.5 M It's relative with Brown's OpenAI Research

image

but here is same bs for 125M and 760M model, it looks little bit strange

germanjke avatar Jun 07 '23 11:06 germanjke

attn_impl: torch | flash | triton handle numerical precision differently. attn_impl: torch operates under the with torch.autocast(**kwargs) context manager should be the most numerically stable; the other attn_impl are a bit more laxed about how numerics are handled.

LLMs get loss spikes for a lot of reasons and the larger the model, the more likely it is to get a spike. Given attn_impl: torch is fine and the others are not, make it seem as though your training has entered into a regime where it is extremely sensitive to the numerical precision difference happening within the attn layer. There are potentially a lot of ways to get around loss spikes. The OPT paper talks about how they deal with loss spikes; they also released an on-call logbook chronicling how they deal with all the issues encountered during training (including loss spikes). There are also other works which talk about all the diff reasons why loss spikes happen and how to mitigating loss spikes.

MosaicML helps its customers get into a stable training regime so that they do not need to deal with this.

vchiley avatar Jun 08 '23 23:06 vchiley

@vchiley due maybe some anomaly batches rolls here, do you provide technique to skip batches with big loss during training in straming?

looks like it should be something like:

for i, batch in enumerate(dataloader):
    if loss(dataloader[i]) > loss(dataloader[i-1]) * 2 # for example multiply of 2:
        continue
    else:
         losses.append(loss)

Is any way to make this elegant? Of course I can just go to src in streaming, but I want only works with llm-foundry

germanjke avatar Jun 14 '23 07:06 germanjke

I don't think we've implemented anything like that, but you'd probably implement this as a composer callback which is triggered <AFTER_TRAIN_BATCH> or <BEFORE_BACKWARD> depending on exactly what you want to implement and how you want to implement it.

vchiley avatar Jun 14 '23 13:06 vchiley