score_inverse_problems icon indicating copy to clipboard operation
score_inverse_problems copied to clipboard

Host memory leak during training

Open cobalamin opened this issue 2 years ago • 2 comments

I'm experiencing issues with host memory usage continually increasing during training, until eventually my machine freezes up or the process is killed due to out-of-memory (I have 32GB available). Everything else about training seems to be working fine until it crashes (after around 5000 iterations), and GPU memory is also fine as usage of it is completely constant. I've tried several versions of jax/jaxlib/flax but there doesn't seem to be any change with this. I've attached the output of pip freeze in my virtualenv.

Any clues what could be causing this? I searched for JAX memory leaks on Google/StackOverflow, but didn't find anything that seemed useful/related.

pip-environment.txt .

cobalamin avatar Jul 07 '22 14:07 cobalamin

hi, I am a student of sysu, GPU cannot be used in this jax version, it is better to use tpu, and it is better to use video memory >=48G after testing, A100, jaxlib1.69-1.73 is better, the first time I tried jax framework The problem has been troubled for a long time, I hope it can help you

tianzhijiaoziA avatar Jul 24 '22 08:07 tianzhijiaoziA

Hi, thank you for your comment. I'm unable to use a TPU and only have GPUs available for training. GPU memory is not the issue for me in this problem, it's host memory.

I was able to sort of work around my issue by making my swapfile ridiculously large (32GB), as it seems the increase in memory usage eventually does stop. It still seems to me that there's a memory leak problem, perhaps memory increases up to the point where all examples have been seen once (are they all loaded into and kept in memory?)

cobalamin avatar Jul 28 '22 14:07 cobalamin