torchrec
torchrec copied to clipboard
[Question] Is there gradient accumulation support for training?
I am tuning hyper-parameters on two different compute clusters. Since the number of GPUs on these clusters varies, I need to use gradient accumulation (GA) to ensure that the total batch size is equal. Does torchrec support GA?
Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.
Although this is a feature which I'm looking for as well, conisdering the embedding lookup backend is FBGEMM which combines optimizer update with backward at each single step, I would expect there is no GA supported.
Hi Jaco. According to your experience, how hard it is to add this GA functionality into the FGGEMM CPU/CUDA kernel?
Hi @gouchangjiang I'm not a fbgemm expert, but I think it's not a trivial workload. Though it's feasible it may violate the design principle of fbgemm.
The principle of FBGEMM is to eliminate wgrad write back and so users can not access the wgrad. You can of course allocate a buffer and pass it into the backward kernels and remove the update and optimizer state related code(the original fbgemm kernel codes are optimizer templated & partial-instantiated) . But you have to pay:
- Extra memory footprint and time. Typically the wgrad is a
sparsetensor (You may not want to have a dense tensor), and thus the shape is dynamic. Sparsetensor accumulation and exposure ofupdate. GA means that you have to explictly trigger anupdatemethod. If the wgrad is a sparse tensor, you have to implement your own accumulation operations and optimizer.- Adapter from fbgemm to torchrec EBC/EC. TorchRec has a deep calling stack, even you manage to expose the wgrad from fbgemm, you still need changes in torchrec codebase.
Thank you @JacoCheung . That's quite a lot of work.
You can use a dense optimizer if you like and do grad accumulation that way. However, I would advice against for performance reasons since the FBGEMM fused update is a huge performance improvement by not having to commit wgrad to memory. Depending on number of trainers you can change the local batch size to match global batch size for both runs.