fastmoe
fastmoe copied to clipboard
CUDA memory increases after each loss.backward()
I am trying to use the FMoE layer in a ViT-Base model for a simple classification task. However, there is a gradual increase in CUDA memory, which eventually leads to out-of-memory error. Digging deeper, I observe that there is a small increase in memory every time after the loss.backward() call. Here is what the memory growth looks like:
Training (50/10000 Steps) Data time=0.03(0.04) Batch time=1.00(0.97) Memory=8588.2(8440.4) Training (51/10000 Steps) Data time=0.04(0.04) Batch time=0.78(0.96) Memory=8590.5(8443.3) Training (52/10000 Steps) Data time=0.03(0.04) Batch time=1.13(0.97) Memory=8602.8(8446.3) Training (53/10000 Steps) Data time=0.04(0.04) Batch time=0.67(0.96) Memory=8602.8(8449.2) Training (54/10000 Steps) Data time=0.08(0.04) Batch time=1.13(0.96) Memory=8602.8(8452.0) Training (55/10000 Steps) Data time=0.02(0.04) Batch time=0.85(0.96) Memory=8602.8(8454.7) Training (56/10000 Steps) Data time=0.06(0.04) Batch time=0.96(0.96) Memory=8602.8(8457.3) Training (57/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8602.8(8459.9) Training (58/10000 Steps) Data time=0.06(0.04) Batch time=0.81(0.96) Memory=8623.7(8462.5) Training (59/10000 Steps) Data time=0.04(0.04) Batch time=1.08(0.96) Memory=8623.7(8465.2) Training (60/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.96) Memory=8623.7(8467.8) Training (61/10000 Steps) Data time=0.03(0.04) Batch time=1.11(0.96) Memory=8623.7(8470.4) Training (62/10000 Steps) Data time=0.02(0.04) Batch time=0.77(0.96) Memory=8623.7(8472.8) Training (63/10000 Steps) Data time=0.04(0.04) Batch time=1.10(0.96) Memory=8655.3(8475.4) Training (64/10000 Steps) Data time=0.02(0.04) Batch time=0.88(0.96) Memory=8655.3(8478.2) Training (65/10000 Steps) Data time=0.04(0.04) Batch time=0.92(0.96) Memory=8667.8(8481.0) Training (66/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8667.8(8483.8) Training (67/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.95) Memory=8667.8(8486.6) Training (68/10000 Steps) Data time=0.03(0.04) Batch time=1.10(0.96) Memory=8667.8(8489.2) Training (69/10000 Steps) Data time=0.02(0.04) Batch time=0.70(0.95) Memory=8667.8(8491.8) Training (70/10000 Steps) Data time=0.06(0.04) Batch time=1.09(0.96) Memory=8667.8(8494.3) Training (71/10000 Steps) Data time=0.02(0.04) Batch time=0.89(0.95) Memory=8667.8(8496.7) Training (72/10000 Steps) Data time=0.05(0.04) Batch time=0.86(0.95) Memory=8725.6(8499.5) Training (73/10000 Steps) Data time=0.03(0.04) Batch time=1.01(0.95) Memory=8725.6(8502.5) Training (74/10000 Steps) Data time=0.04(0.04) Batch time=0.71(0.95) Memory=8725.6(8505.5)
Here is my training loop. If I replace the FMoE Mlp layer with a regular Mlp layer, the training works fine.
for step, batch in enumerate(train_loader):
data_time.update(time.time() - end)
batch = tuple(t.to(args.device) for t in batch)
x, y = batch
pred = model(x, 0)
loss_ = loss_fct(pred, y)
loss_.backward()
model.allreduce_params()
optimizer.step()
optimizer.zero_grad()
MB = 1024 * 1024
memory_meter.update(torch.cuda.max_memory_allocated() / MB)
log.info("Training ({}/{} Steps)\tData time={:.2f}({:.2f})\tBatch time={:.2f}({:.2f})\tMemory={:.1f}({:.1f})".format(
global_step, t_total, data_time.val, data_time.avg, batch_time.val, batch_time.avg, memory_meter.val, memory_meter.avg))
Can you please help me with what might be causing this? Thanks.
The maximum memory may get larger because of a more imbalanced load during the computation. Can you check if torch.cuda.memory_allocated()
also gets larger here?
Yes, the memory values I reported are measured using torch.cuda.memory_allocated().
I am not able to reproduce this memory footprint increase using FMoETransformerMLP
. What is your FastMoE and PyTorch version? Do you use expert parallelism or only data parallelism? A minimum script that can reproduce the issue is highly appreciated.
I was using a slightly modified version inspired from FMoETransformerMLP. I observed that when I use NaiveGate, I do not have the memory issue. I suspect the memory increase might have something to do with the gate implementation.
Thank you very much for your attention!
Are you using a gate from FastMoE or a customized gate?
I was having the memory issue with a customized gate.