OOM with rematerialization when torch.compile works
🐛 Bug
When benchmarking model: 'Mixtral-8x7B-v0.1' we get OOM errors even with --checkpoint_activations True The same configurations works for torch.compile. Might be related to https://github.com/Lightning-AI/lightning-thunder/issues/194.
The same issue occurs for falcon-180B.
To Reproduce
Please use:
8 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl-20240930"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mixtral-8x7B-v0.1
--distributed_mode fsdp
--shard_mode zero3
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode none
--micro_batch_size 1
Expected behavior
We should be able to run the script for given arguments.
Environment
system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.2.004 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.7 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.13+git2cee59d libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gitc4ae451 libraries.pip.torchmetrics 1.4.2 libraries.pip.torchvision 0.19.0a0+d23a6e1
@riccardofelluga can you take a look here?
So I've looked into it a bit over the last couple of days and basically it boils down to the scheduling of operations in the trace. This happens both with and without saved for backward rematerialization enabled. My hypothesis here is that when a fusion region is created, the executor claims the operations without taking into account when the output of the computed tensors will be used. This leads to tensors that live in memory for long time, and for most of their lifetime they are unused. To fix this I'm looking into what are our options, an idea would be to reorder the computation such that the producers are closer to the consumers.
Same is valid for #246 and the other OOMs. The fact that it happens with saved for backward remat. is because that makes the problem worst as it adds a bunch of computation at the start of the bwd trace
What's the latest status of this problem?
Right now Mixtral-8x7B-v0.1' works with ThunderFX, so we might close this issue.However it still doesn't work with Thunder.
Thank you for checking! Closing. ThunderFX is a better entry point to get Thunder optimizations while moving some of the responsibility of capturing the program and the implementation of sharded data parallel to PyTorch. We won't reimplement the missing features to get the plan Thunder working for this model at this time.