lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Likely memory fragmentation for larger models

Open parthmannan opened this issue 1 year ago • 6 comments

🐛 Bug

Running LLaMa2 13B with FSDP ZeRO2 on 8xH100

torchrun --nproc_per_node=8 --nnodes=1 benchmark_litgpt.py --model_name Llama-2-13b-hf --compile thunder_cudnn --distributed_mode fsdp --shard_mode zero2 --bucketing_mode none --micro_batch_size 1 --global_batch_size 8

Average iter time: 867.59 ms The performance looks worse than expected and on inspecting the timeline, there is a large portion where GPU is idle as there are many cudaMalloc and cudaFree memory operations happening.

Screenshot 2024-05-15 at 12 56 45 PM

Using expandable_segments in PyTorch, the performance significantly improves and there is no gap in the timeline.

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=8 --nnodes=1 benchmark_litgpt.py --model_name=Llama-2-13b-hf --compile thunder_cudnn --distributed_mode=fsdp --shard_mode zero2 --bucketing_mode none --micro_batch_size 1 --global_batch_size 8

Average iter time: 729.44 ms

Screenshot 2024-05-15 at 12 56 49 PM

We should find out if this memory fragmentation is due to Thunder's interaction with the PyTorch memory allocator or something else.

parthmannan avatar May 15 '24 20:05 parthmannan

@kshitij12345 - From our discussion, I remember you were looking into this. I believe this is what is causing the memory operations but needs further investigation.

parthmannan avatar May 15 '24 20:05 parthmannan

cc @eqy re: fragmentation lunch discussion

tfogal avatar May 15 '24 22:05 tfogal

Does TORCH_NCCL_AVOID_RECORD_STREAMS=1 help?

eqy avatar May 15 '24 23:05 eqy

@eqy Yes, it does. Either of the two env variables give the same performance benefit. Is this fair to call this a memory fragmentation issue or is this something else you think?

parthmannan avatar May 16 '24 00:05 parthmannan

As per offline discussion with @ptrblck , we should enable TORCH_NCCL_AVOID_RECORD_STREAMS=1 by default in thunder.

kshitij12345 avatar May 16 '24 15:05 kshitij12345

cc - @IvanYashchuk @mruberry Can we enable this env var by default in Thunder or should we rely on nvidia containers do enable this?

parthmannan avatar May 20 '24 07:05 parthmannan

Ping @IvanYashchuk @ptrblck @eqy Any reason why we shouldn't enable this by default?

csarofeen avatar Jun 03 '24 00:06 csarofeen

Let's set this variable on by default for Thunder-generated functions.

Quoting Carilli from the PR description that added this env variable:

Because we're juggling razor blades here and it's hard to test, recordStream avoidance is off by default, and existing default (aka recordStream-based) behavior is unchanged.

Since it's hard to test, let's enable it in Thunder by default.

Linking here relevant forum post for curious people: https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/

IvanYashchuk avatar Jun 03 '24 12:06 IvanYashchuk