optimum-habana
optimum-habana copied to clipboard
fix OOM when inference with llama-3.1-70b
What does this PR do?
background
when I running inference with command:
INPUT=32768
OUTPUT=32768
BATCH_SIZE=12
python gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--model_name_or_path Meta-Llama-3.1-70B-Instruct/ \
--max_input_tokens ${INPUT} \
--max_new_tokens ${OUTPUT} \
--bf16 \
--use_hpu_graphs \
--use_kv_cache \
--batch_size ${BATCH_SIZE} \
--attn_softmax_bf16 \
--limit_hpu_graphs \
--trim_logits \
--flash_attention_causal_mask \
--flash_attention_recompute \
--warmup 1 \
--n_iteration 1 \
--bucket_internal \
--bucket_size=512 \
--use_flash_attention
it will OOM, while not OOM if BATCH_SIZE=11
after I debugged by using memory analysis tool, I found that the first time of creating causal attention mask tensor need too much device memory, that lead to device memory exhaustion.
details of creating causal mask tensor
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions.
If attention_mask is causal, a causal mask will be added.
For the first time of creating this tensor, the shape is very big (for my case, it is [12, 1, 32768, 32768]).
During the creation of this tensor, it need a mask tensor. The mask tensor's dtype can be torch.bool
, but actual it is torch.int
, which caused four times the memory usage. (for shape [12, 1, 32768, 32768], it need 48G device memory, it will cause peak memory usage.)
Fixes
This PR's change is aim to explicitly make the computation of causal attention mask tensor use less device memory by using the torch.bool
type mask tensor.
For code changes, just overwrite the base class's to_4d
function.
Others
But why BATCH_SIZE=11
did not cause device memory exhaustion?
I think its a bug of LAZY GRAPH.
In lazy graph, it should optimize the computation of the big tensor with using less device memory.
So the best solution of fixing this bug is doing more optimization in LAZY GRAPH.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?