optimum-habana icon indicating copy to clipboard operation
optimum-habana copied to clipboard

fix OOM when inference with llama-3.1-70b

Open harborn opened this issue 5 months ago • 6 comments

What does this PR do?

background

when I running inference with command:

INPUT=32768
OUTPUT=32768
BATCH_SIZE=12

python gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
    --model_name_or_path Meta-Llama-3.1-70B-Instruct/ \
    --max_input_tokens ${INPUT} \
    --max_new_tokens ${OUTPUT} \
    --bf16 \
    --use_hpu_graphs \
    --use_kv_cache \
    --batch_size ${BATCH_SIZE} \
    --attn_softmax_bf16 \
    --limit_hpu_graphs \
    --trim_logits \
    --flash_attention_causal_mask \
    --flash_attention_recompute \
    --warmup 1 \
    --n_iteration 1 \
    --bucket_internal \
    --bucket_size=512 \
    --use_flash_attention

it will OOM, while not OOM if BATCH_SIZE=11

after I debugged by using memory analysis tool, I found that the first time of creating causal attention mask tensor need too much device memory, that lead to device memory exhaustion.

details of creating causal mask tensor

Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, key_value_length) shape and by adding a large negative bias to not-attended positions.

If attention_mask is causal, a causal mask will be added.

For the first time of creating this tensor, the shape is very big (for my case, it is [12, 1, 32768, 32768]). During the creation of this tensor, it need a mask tensor. The mask tensor's dtype can be torch.bool, but actual it is torch.int, which caused four times the memory usage. (for shape [12, 1, 32768, 32768], it need 48G device memory, it will cause peak memory usage.)

Fixes

This PR's change is aim to explicitly make the computation of causal attention mask tensor use less device memory by using the torch.bool type mask tensor.

For code changes, just overwrite the base class's to_4d function.

Others

But why BATCH_SIZE=11 did not cause device memory exhaustion? I think its a bug of LAZY GRAPH. In lazy graph, it should optimize the computation of the big tensor with using less device memory. So the best solution of fixing this bug is doing more optimization in LAZY GRAPH.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you make sure to update the documentation with your changes?
  • [ ] Did you write any new necessary tests?

harborn avatar Aug 30 '24 07:08 harborn