lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

OOM with rematerialization when torch.compile works

Open mpatel31415 opened this issue 1 year ago • 2 comments

🐛 Bug

When benchmarking model: 'Mixtral-8x7B-v0.1' we get OOM errors even with --checkpoint_activations True The same configurations works for torch.compile. Might be related to https://github.com/Lightning-AI/lightning-thunder/issues/194.

The same issue occurs for falcon-180B.

To Reproduce

Please use: 8 node(s), each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl-20240930" Training script: python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mixtral-8x7B-v0.1
--distributed_mode fsdp
--shard_mode zero3
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode none
--micro_batch_size 1

Expected behavior

We should be able to run the script for given arguments.

Environment

system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.2.004 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.7 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.13+git2cee59d libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gitc4ae451 libraries.pip.torchmetrics 1.4.2 libraries.pip.torchvision 0.19.0a0+d23a6e1

mpatel31415 avatar Oct 01 '24 10:10 mpatel31415

@riccardofelluga can you take a look here?

tfogal avatar Oct 04 '24 18:10 tfogal

So I've looked into it a bit over the last couple of days and basically it boils down to the scheduling of operations in the trace. This happens both with and without saved for backward rematerialization enabled. My hypothesis here is that when a fusion region is created, the executor claims the operations without taking into account when the output of the computed tensors will be used. This leads to tensors that live in memory for long time, and for most of their lifetime they are unused. To fix this I'm looking into what are our options, an idea would be to reorder the computation such that the producers are closer to the consumers.

Same is valid for #246 and the other OOMs. The fact that it happens with saved for backward remat. is because that makes the problem worst as it adds a bunch of computation at the start of the bwd trace

riccardofelluga avatar Oct 07 '24 12:10 riccardofelluga

What's the latest status of this problem?

IvanYashchuk avatar Nov 19 '24 09:11 IvanYashchuk

Right now Mixtral-8x7B-v0.1' works with ThunderFX, so we might close this issue.However it still doesn't work with Thunder.

mpatel31415 avatar Nov 19 '24 11:11 mpatel31415

Thank you for checking! Closing. ThunderFX is a better entry point to get Thunder optimizations while moving some of the responsibility of capturing the program and the implementation of sharded data parallel to PyTorch. We won't reimplement the missing features to get the plan Thunder working for this model at this time.

IvanYashchuk avatar Nov 19 '24 13:11 IvanYashchuk