ColossalAI
ColossalAI copied to clipboard
[BUG]: OOM during llama2 pretraining with flashattention and PP
🐛 Describe the bug
I understand that this error came out of flash attention software stack, but it seems there is no related issue except for #https://github.com/Dao-AILab/flash-attention/issues/590, therefore I anyway open an issue here. This problem happens as well with flash-attn 2.0.5
.
Using pp in HybidParallelPlugin (No-ZeRO) and flash attention together for Llama2 results in OOM
When I try to run examples/language/llama2/pretrain.py
, adding padding back to inputs returns OOM. Without flashattention it works fine.
plugin = HybridParallelPlugin(tp_size=2, pp_size=2, # all the other args are the same as in the example)
Note that if you set pp_size=1
you will get cache only has 0 layers exception
(#5410) even before facing OOM :) So there is another bug in llama2 forward with attention parallelism. Just a sidenote
PYTHONPATH=/path/to/colossalai/examples/language/llama2 torchrun --standalone --nproc-per-node 4 pretrain.py -p hybrid_parallel -a -g -x bf16 -o /tmp/llama_checkpoint
File "/data/insujang/colossalai/examples/language/llama2/attn.py", line 174, in attention_forward
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 119, in unpad_input
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/flash_attn-2.5.6-py3.10-linux-x86_64.egg/flash_attn/bert_padding.py", line 17, in forward
return torch.gather(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 59.67 GiB. GPU 1 has a total capacity of 44.35 GiB of which 34.03 GiB is free. Process 1325526 has 10.31 GiB memory in use. Of the allocated memory 9.76 GiB is allocated by PyTorch, and 40.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
https://github.com/hpcaitech/ColossalAI/blob/7e0ec5a85c73fcc5666b9d218e43865141587dde/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py#L174
I think this might be related to the size of attention_mask, but not sure
# from flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask):
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# attention_mask.shape=torch.Size([1, 1, 4096, 4096]
# indices.shape=torch.Size([15642705])
...
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), # Error here
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# index_first_axis calls IndexFirstAxis.forward()
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
...
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) # Error here
).reshape(-1, *other_shape)
where attention_mask
is created here:
https://github.com/hpcaitech/ColossalAI/blob/fd4444058f9ebd5f99cfc60e2e5bf69a7dd38d73/colossalai/shardformer/modeling/llama.py#L101-L103
I would appreicate it if you could try if this is reproducible and the reason.
Environment
4 48GB A40s Pytorch 2.2.1 | CUDA 12.1 ColossalAI branch: feature/update-transformers transformers 4.36.0 flash-attn 2.5.6
@wangbluo Could you please help me solve this issue? Thanks
@wangbluo Could you please help me solve this issue? Thanks
Hi, could you please offer the model size you use?
I used 7b configuration.