pytorch
pytorch copied to clipboard
[nn] zero_grad() set_to_none default True
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/92731
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:hourglass_flowing_sand: No Failures, 1 Pending
As of commit b589af5b5f7a69e127a7d465731826fa3ddacecb: :green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Can you please add a bc-breaking note here?
@pytorchbot merge
Merge started
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here
ooooh exciting. this is a big change :)
This PR leads speedups widely for many models in torchbench. 22 models obtain over 1.03X speedup on A100! Thanks for your work! But I just found yolov3 has about a 17% slowdown for training on A100. The nightly regression test of torchbench runs on T4 for now. So the test may miss this case. Working on finding the root cause now. Somehow, I can't reproduce the same results in the profiling trace generated by pytorch profiler.
@FindHao Recently came back from PTO so sorry this response is delayed. Thanks for this callout! I'm curious about the yolov3 slowdown--have you been able to root cause it thus far? The simple workaround is to just directly pass set_to_none=False to regain perf, but I would like to help with figuring out the cause here.
Hi @janeyx99 , we found it is caused by torch.cuda.empty_cache(). It takes longer than the original version. Since torchbench only tests one iteration for training and we don't need to empty the cache, we removed this function as a workaround. But we still don't know why this function takes longer. Do you have any ideas?
@FindHao Ah, I spoke with @albanD and this is not surprising. When set_to_none was False, the same grad tensor was allocated once and kept alive throughout the iterations (it would be filled with 0s and then filled with values, and so forth).
Now, because we set grad to None, it would increase the number of allocations whenever we alternate between None -> real values -> None -> real values, and so forth. This is typically not a problem except for certain configurations (like given a particular batch_size on a particular GPU) where the stars align just right and allocations incur an actual communication to the GPU vs being able to service from existing allocated memory.
@janeyx99 Thanks for your explanation! It makes sense. I have another question. If I understand it correctly, setting it to none means marking the memory allocated to current tensors as going to be freed
, and it will be deallocated by the torch.cuda.empty_cache
from the memory pool, right? If so, does it mean the memory usage would increase until we call empty_cache?
No, set_to_none=True
decreases memory usage, as it frees grad memory when called, and doesn't allocate them again until they are computed (which will likely be after high memory watermark is reached). @robieta had plots showing how set_to_none
decreases memory usage.
Haha I will attempt to answer this question by setting down some terminology. PyTorch has a CUDACachingAllocator which reserves and manages memory for the duration of a PyTorch program. In a sense, you can imagine that it reserves a chunk of memory from the GPU and interfaces on top of the actual GPU so that every time the program releases/requests memory, we don't have to talk to the GPU. For example, if the PyTorch program releases memory, our CachingAllocator will hold onto it instead of immediately releasing it to the GPU so that later on when the program wants memory again, it can lend that memory out. This would save time as communication between the GPU would be avoided entirely.
Thus, we have the concept of memory reserved and memory allocated. The memory reserved is the total memory managed by the CUDACachingAllocator, and the memory allocated is the memory taken up by actual PyTorch tensors. Setting to None here will "free" the tensor so that the memory allocated goes down immediately BUT the memory reserved would remain the same. Calling torch.cuda.empty_cache() will empty the memory so that memory reserved approaches memory allocated.
@ngimel @janeyx99 Thanks for all your explanation!