New OOM errors for ThunderFX and FSDP
🐛 Bug
Recently we got OOM errors causing failures of Gemma-2-2b (in canary runs) and distributed training of stablecode-completion-alpha-3b.
To Reproduce
Please use: 1 node(s), each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl-20241107" Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \
--model_name Gemma-2-2b \
--distributed_mode fsdp \
--shard_mode zero2 \
--compile dynamo_thunder \
--checkpoint_activations False \
--low_precision_mode none \
--micro_batch_size 1
Environment
system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.98.001 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.8 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.22+gitba4f7d4 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gita9b4989 libraries.pip.torchao 0.6.1 libraries.pip.torchmetrics 1.5.1 libraries.pip.torchvision 0.19.0a0+d23a6e1
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.