fastmoe icon indicating copy to clipboard operation
fastmoe copied to clipboard

CUDA memory increases after each loss.backward()

Open sreetamasarkar opened this issue 11 months ago • 6 comments

I am trying to use the FMoE layer in a ViT-Base model for a simple classification task. However, there is a gradual increase in CUDA memory, which eventually leads to out-of-memory error. Digging deeper, I observe that there is a small increase in memory every time after the loss.backward() call. Here is what the memory growth looks like:

Training (50/10000 Steps) Data time=0.03(0.04) Batch time=1.00(0.97) Memory=8588.2(8440.4) Training (51/10000 Steps) Data time=0.04(0.04) Batch time=0.78(0.96) Memory=8590.5(8443.3) Training (52/10000 Steps) Data time=0.03(0.04) Batch time=1.13(0.97) Memory=8602.8(8446.3) Training (53/10000 Steps) Data time=0.04(0.04) Batch time=0.67(0.96) Memory=8602.8(8449.2) Training (54/10000 Steps) Data time=0.08(0.04) Batch time=1.13(0.96) Memory=8602.8(8452.0) Training (55/10000 Steps) Data time=0.02(0.04) Batch time=0.85(0.96) Memory=8602.8(8454.7) Training (56/10000 Steps) Data time=0.06(0.04) Batch time=0.96(0.96) Memory=8602.8(8457.3) Training (57/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8602.8(8459.9) Training (58/10000 Steps) Data time=0.06(0.04) Batch time=0.81(0.96) Memory=8623.7(8462.5) Training (59/10000 Steps) Data time=0.04(0.04) Batch time=1.08(0.96) Memory=8623.7(8465.2) Training (60/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.96) Memory=8623.7(8467.8) Training (61/10000 Steps) Data time=0.03(0.04) Batch time=1.11(0.96) Memory=8623.7(8470.4) Training (62/10000 Steps) Data time=0.02(0.04) Batch time=0.77(0.96) Memory=8623.7(8472.8) Training (63/10000 Steps) Data time=0.04(0.04) Batch time=1.10(0.96) Memory=8655.3(8475.4) Training (64/10000 Steps) Data time=0.02(0.04) Batch time=0.88(0.96) Memory=8655.3(8478.2) Training (65/10000 Steps) Data time=0.04(0.04) Batch time=0.92(0.96) Memory=8667.8(8481.0) Training (66/10000 Steps) Data time=0.03(0.04) Batch time=1.02(0.96) Memory=8667.8(8483.8) Training (67/10000 Steps) Data time=0.04(0.04) Batch time=0.72(0.95) Memory=8667.8(8486.6) Training (68/10000 Steps) Data time=0.03(0.04) Batch time=1.10(0.96) Memory=8667.8(8489.2) Training (69/10000 Steps) Data time=0.02(0.04) Batch time=0.70(0.95) Memory=8667.8(8491.8) Training (70/10000 Steps) Data time=0.06(0.04) Batch time=1.09(0.96) Memory=8667.8(8494.3) Training (71/10000 Steps) Data time=0.02(0.04) Batch time=0.89(0.95) Memory=8667.8(8496.7) Training (72/10000 Steps) Data time=0.05(0.04) Batch time=0.86(0.95) Memory=8725.6(8499.5) Training (73/10000 Steps) Data time=0.03(0.04) Batch time=1.01(0.95) Memory=8725.6(8502.5) Training (74/10000 Steps) Data time=0.04(0.04) Batch time=0.71(0.95) Memory=8725.6(8505.5)

Here is my training loop. If I replace the FMoE Mlp layer with a regular Mlp layer, the training works fine.

    for step, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        batch = tuple(t.to(args.device) for t in batch)

        x, y = batch

        pred = model(x, 0)
        loss_ = loss_fct(pred, y)
        
        loss_.backward()
        model.allreduce_params()
        optimizer.step()
        optimizer.zero_grad()
        MB = 1024 * 1024
        memory_meter.update(torch.cuda.max_memory_allocated() / MB)
        log.info("Training ({}/{} Steps)\tData time={:.2f}({:.2f})\tBatch time={:.2f}({:.2f})\tMemory={:.1f}({:.1f})".format(
            global_step, t_total, data_time.val, data_time.avg, batch_time.val, batch_time.avg, memory_meter.val, memory_meter.avg))

Can you please help me with what might be causing this? Thanks.

sreetamasarkar avatar Mar 22 '24 23:03 sreetamasarkar

The maximum memory may get larger because of a more imbalanced load during the computation. Can you check if torch.cuda.memory_allocated() also gets larger here?

laekov avatar Mar 26 '24 06:03 laekov

Yes, the memory values I reported are measured using torch.cuda.memory_allocated().

sreetamasarkar avatar Mar 27 '24 03:03 sreetamasarkar

I am not able to reproduce this memory footprint increase using FMoETransformerMLP. What is your FastMoE and PyTorch version? Do you use expert parallelism or only data parallelism? A minimum script that can reproduce the issue is highly appreciated.

laekov avatar Mar 28 '24 03:03 laekov

I was using a slightly modified version inspired from FMoETransformerMLP. I observed that when I use NaiveGate, I do not have the memory issue. I suspect the memory increase might have something to do with the gate implementation.

Thank you very much for your attention!

sreetamasarkar avatar Apr 03 '24 09:04 sreetamasarkar

Are you using a gate from FastMoE or a customized gate?

laekov avatar Apr 03 '24 09:04 laekov

I was having the memory issue with a customized gate.

sreetamasarkar avatar Apr 04 '24 07:04 sreetamasarkar