How to reduce memory usage of backpropagation?
I implemented the tiny NeRF example using candle here: https://github.com/laptou/nerfy/blob/fc50dbd61c4012d1f12f556a72474b59a8b3c158/examples/tiny_nerf.rs
The example, which is written using TensorFlow, runs fine on my laptop. My candle implementation consumes all available memory on my laptop, which crashes my desktop session if I use CPU and errors out with a CUDA memory allocation error if I use the GPU. I'm running on a laptop with 32 GB of RAM, 32 GB of swap, and an RTX A3000 w/ 12 GB of VRAM.
I'm barely able to run it on CPU if I decrease the hidden layer size from 256 to 64.
I tracked the memory allocations using heaptrack, and it seems like most of them are related to keeping track of the operations for backpropagation.
Can you spot any obvious issues in my implementation that are causing it to consume so much memory? Is there a way that I can disable or reduce this behavior in some parts of the code to reduce the amount of memory that it uses?
Here's a screenshot from heaptrack that I am using to blame backpropagation for my memory woes.
The program errors out with Error: Cuda(Cuda(DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory"))) almost immediately if I try to use CUDA, so I think that there are additional issues at play here besides the backpropagation impl possibly being space-inefficient.
It seems like the reference count of a tensor (and therefore its Storage) never reaches zero until it can no longer be used for backpropagation at all. Seems related to this comment.
I reimplemented Tiny NeRF using tch. Neither of them work on CUDA using f64. The tch version works much better using f32, but the candle version is not made much better (still doesn't work on CUDA). On CPU, memory usage is significantly lower for both f32 and f64 tch implementations than their corresponding candle implementations.
Interesting, the memory usage of candle during backprop is certainly not optimal.
Could you maybe measure the memory footprint after doing the forward pass but before doing the backward pass? If these are on par between tch and candle it would mean that the tracking during the backprop is the issue - this one should be easy to improve. If these are already very different, I'm not really sure what would be going on - maybe somehow tch notices that some part of the graph are not necessary for backprop but I'm not sure which part so trimming down the example might be a good thing (it's certainly a bit long at the moment).
Also obviously all this should only be an issue if you care about backprop, if you cared about running inference it's easy to disable the gradient tracking and so the memory footprint should be reduced a lot.
Also could you give it a try with #1243 ? This should only improve memory consumption on the backprop part avoiding that all the intermediary values get stored until the end of the backprop.
When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without
grad.detach()?
Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.
When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without
grad.detach()?Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.
When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without
grad.detach()?Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.
Thanks a lot. That really helped me.
I've merged #1243 as it seems to fix things for you. @laptou would be great to have your experience with this too!
I wouldn't expect the new behavior to be different, and the additional compute cost should be very small (tensors are detached lazily, i.e. nothing happens if they were already detached which will often be the case). However if you notice anything weird you can get the old behavior back by setting the following environment variable: CANDLE_GRAD_DO_NOT_DETACH=1.
I'll test it out soon!
Just tested it, here are my observations:
torch cuda f32 8x256: mem usage = 6398 MiB, train time = 4m30s
torch cuda f64 8x256: crash due to gpu oom
torch cpu f32 8x256: mem usage = 6304 MiB, train time = 74m17s
before merge:
candle cuda f32 8x256: crash due to gpu oom during forward pass
candle cuda f32 8x64: crash due to gpu oom at backpropagation step
after merge:
candle cuda f32 8x256: crash due to gpu oom during forward pass (running the network forward uses way too much memory)
candle cuda f32 8x64: peak mem usage = 9426 MiB, train time = 4m6s
running the network forward on a single batch makes gpu mem increase by 1966 MiB to store a 65536x4 f32 array 1886 -> 3852, probably b/c the intermediate tensors aren't deleted. OOMs on 8x256 b/c there are too many batches held in memory at once and each batch consumes a ton of memory
Thanks for trying this out, that's very interesting.
Do you have a sense of why the pytorch model wouldn't have to retain the intermediary values? It could be that we have ops on the candle side that gets split into multiple sub-ops and w rtain every of these whereas in PyTorch these only result in a single op that gets retained. It could also be a difference in the model but I imagine you've made them pretty similar - also you should ensure that the set of variables in the VarMap only contains the variable that you want the optimizer to work on so as not to trigger unrequired dependencies.
Could you measure the PyTorch memory usage for 8x64? Also maybe measure them in the forward pass only if that's easy?
Sure.
-
candle cuda f32 8x64GPU mem usage at forward pass: 4702 MiB -
candle cuda f32 8x64GPU mem usage after backprop: 5438 MiB -
torch cuda f32 8x64GPU mem usage at forward pass: 1912 MiB -
torch cuda f32 8x64GPU mem usage after backprop: 1982 MiB
I attached a debugger and placed a breakpoint on the line that calls Optimizer::backward_step, then measured the memory usage using nvidia-smi before and after calling it.
The implementation of the two models is as 1:1 as I could get them. Here's the diff.
Thanks, hard to tell by just looking at the code as the model seems pretty involved. A way to optimize this a bit would probably be to use heaptrack as you already did while running on cpu just for the forward pass and see if there are any ops that could be optimized to reduce allocation. The most likely thing here is that some part of the architecture gets broken down into lots of ops on the candle side and we retain values for all these intermediary steps whereas on the PyTorch side these parts may be able to use a far smaller number of ops.
I ran into the same high memory usage trying to train my implementation of FCN, using pretty much twice the memory of what pytorch uses:
Before forward: 1250MiB/7930MiB
After forward: 3138MiB/7930MiB
Before backward: 3102MiB/7930MiB
After backward: 7841MiB/7930MiB
Vram usage captured with this code, I then proceeded with annotating the main backward method to see where the memory jumps:
vram_through_backpropagation.txt
It appears to creep up with each operation: majority of increase happens in Op::Conv2D backward propagation for my network, one instance for example shows:
[ 3449MiB/7930MiB ] Remove node: TensorId(12152) Tensor[dims 3, 4096, 9, 9; f32, cuda:0]
[ 3449MiB/7930MiB ] Detached grad:: TensorId(12373)
[ 3449MiB/7930MiB ] Start of Op::Conv2D
[ 4633MiB/7930MiB ] End of Op::Conv2D
[ 4633MiB/7930MiB ] After node iter: 14
This shows a large jump in backprop.rs when the Op::Conv2D is handled. But we also see it increase by 32mb in the addition just following the detached 12282 tensor.
After the backpropagation finishes, the total size of the gradients is 556408028 bytes, so 550mb, which is an expected amount, but the overall vram grew by more than a factor of two to over 7 GB.
I proceeded to write a singleton to track tensor creation and deletion, but unfortunately those backtraces do not point at a smoking gun, because tensors appear to be dropped with the strong count to the storage still being more than one. Likely because storage can be shared between tensors? The backtraces do point at the math operations that are throughout the back propagation calculation itself...
I did take a stab at sprinkling detach() throughout the backprop function in e357f2bd58d3690a7bf52adae38e577becac792b , which did reduce memory usage by 2 GB:
Before forward: 1264MiB/7930MiB
After forward: 3152MiB/7930MiB
Before backward: 3120MiB/7930MiB
After backward: 5816MiB/7930MiB
but then thought that there must be a better / more correct way to do this, perhaps with a thread local bool or something that allows us to disable back propagation information while some RAII object is held?
I've already spent quite a bit of time trying to understand the issue and attempt to solve it, and I'm not sure how much more time I want to commit, so thought I'd share my findings thus far.
I've found that one thing that makes backprop use significantly more memory (and compute time) than necessary is that it computes the gradients for all tensors that appear anywhere inside a BackpropOp, even the ones that are not on the "path" to the variables. For example, if you do backprop on a tensor computed from T = T1 + (T2 + (T3 + V1)), we store the gradients for T1, T2, and T3 in the gradstore even though we have already found that we don't need to track their gradients (the boolean stored in the already_seen hashmap).
For large models it makes a particularly big difference, especially because these extra grads don't get dropped at any point during the sorted_nodes loop, since we exclude these tensors we don't want to track from sorted_nodes, so they don't even hit the grads.remove line that would normally drop the temporary intermediate gradients.
The way the code is currently structured, it's a bit tricky to prevent computing the gradients of these unnecessary "branch" tensors. I have a prototype that uses the booleans stored in the already_seen map to prevent computing these unnecessary gradients at all. This restructuring of the code also makes it easy to care of the additional detaches that @iwanders found are still necessary (https://github.com/huggingface/candle/pull/1243 only detaches gradients upon removing them from the queue, but they can still accrue BackpropOps and OOM the GPU on their way into the queue).
Anecdotally I've found that excluding these unnecessary gradients from the gradstore (in addition to @iwanders' additional detaches) makes a substantial difference in memory usage -- in my case (LoRA training on Llama 3) the difference between OOMing a 48GB GPU 5% of the way through backprop, versus making it all the way through backprop and being able to train continously. @LaurentMazare would it be helpful if I cleaned up that prototype into a standalone pull request with some empirical measurements to quantify the impact?