pytorch-lightning
pytorch-lightning copied to clipboard
CPU-Memory keeps accumulating during `trainer.predict`
Bug description
This is very similar to closed issue #15656
I am working on predicting using PL Trainer on 3D images and these are huge, my process keeps getting killed when a large number of samples are to be predicted. I found #15656 and expected that to be the solution but setting return_predictions=False
does not fix the memory accumulation.
What seems to work instead is adding a gc.collect()
in the predict_loop
. This keeps CPU memory usage constant as would be expected.
It seems like setting return_predictions=False
should stop the memory accumulation but I'm confused as to why the gc.collect()
is needed.
This is where the gc.collect()
is applied: https://github.com/project-lighter/lighter/blob/07018bb2c66c0c8848bab748299e2c2d21c7d185/lighter/callbacks/writer/base.py#L120
I've also attached a memory log using scalene
of the return predictions and the gc collect comparison. As you can see, there is no memory growth for gc collect.
Would you be able to provide any intuition on this? It would be much appreciated!
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
gc_collect.pdf return_predictions_false.pdf
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
cc @borda
@surajpaib In addition to your gc.collect()
call, I see you do
trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]
but based on your description (return_predictions=False
), this should already be an empty list. Can you confirm? In any case, I can't tell why it is necessary, but if you want we can add the gc.collect()
call in the loop. If it doesn't impact the iteration speed / throughput (it might be expensive in certain situations).
return_predictions=False
wasn't working without gc.collect()
. Since we needed to call gc.collect()
anyway, we figured let's just clean the predictions manually right there too via trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]
and not deal with return_predictions
until it's fixed.
return_predictions=False
wasn't working withoutgc.collect()
. Since we needed to callgc.collect()
anyway, we figured let's just clean the predictions manually right there too viatrainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]
and not deal withreturn_predictions
until it's fixed.
To add to this, there is a minor difference in memory usage over time with and without clearing the trainer.predict_loop._predictions
when gc.collect()
is added.
Given that our batch inferences take long (3D images), the gc.collect()
in comparison doesn't seem to have much of an influence on iteration speed. But this would need additional testing for the general case.
What I still don't get is how the memory accumulates when return_predictions
are set to False
. I assume this should not collect any predictions and therefore have no memory growth. Which doesn't seem to be the case.