lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

OOM for ThunderFX and Thunder with DDP for Mistral-7B-v0.1

Open mpatel31415 opened this issue 1 year ago • 4 comments

🐛 Bug

When running Mistral-7B-v0.1 we get OOM error. The same configuration passes for torch.compile.

To Reproduce

Steps to reproduce the behavior:

Please use: 1 node(s), each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl_20241112" Training script:

python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \
    --model_name Mistral-7B-v0.1 \
    --distributed_mode ddp \
    --shard_mode None \
    --compile dynamo_thunder \
    --checkpoint_activations False \
    --low_precision_mode none  \
    --micro_batch_size 1

Expected behavior

If we can run training with torch.compile we should be able to run it with Thunder as well.

Environment

system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.98.001 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.8 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.22+gitba4f7d4 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gita9b4989 libraries.pip.torchao 0.6.1 libraries.pip.torchmetrics 1.5.1 libraries.pip.torchvision 0.19.0a0+d23a6e1

mpatel31415 avatar Nov 12 '24 11:11 mpatel31415

By reducing the n_layers we observe the peak allocated memory(GB):

n_layers torch compile thunderfx
16 41.68 42.72
2 8.18 8.18
4 12.88 13.12
To keep the memory snapshot small I analyzed the 4-layer case, the memory snapshot of torch.compile and thunder are very similar and hard to analyze: image

we observe in thunderfx the model is segmented in 7 graphs, so I measured the memory usage on each graph:

graph torch compile(MB) thunder jit(MB)
0 876.578125 908.578125
1 930.5234375 962.5234375
2 752.0234375 784.0234375
3 882.546875 946.546875
4 1026.546875 1090.546875
5 832.0234375 832.0234375
6 564 564
7 5864.242188 6088.242188
I further analysed the smallest graph g2, here are some results:

With the same script(a piece of the graph2), Thunder(with nv_enable_linear=True) passes all the operators to nvFusion, but the memory usage is a bit higher than torch.compile, 224MB vs 192MB:

image

torch.compile: According to the line number listed in the snapshot, the 3 pieces of memory block allocated from the line 138, 140, 144 in the following function image

Thunder: According to the line number listed in the snapshot, the 4 pieces of memory block allocated from the line 12 in the trace which is the nvFusion0

# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_):
  # y_5: "cuda:0 bf16[1, 4096, 4096]"
  # l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_: "cuda:0 bf16[4096, 4096]"
  # x_5: "cuda:0 bf16[1, 4096, 4096]"
  # l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_: "cuda:0 bf16[4096]"
  [add, to] = nvFusion0(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_)
    # linear = prims.linear(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, None)  # linear: "cuda:0 bf16[1, 4096, 4096]"
    # t1 = prims.convert_element_type(linear, dtypes.float32)  # t1: "cuda:0 f32[1, 4096, 4096]"
    # t2 = prims.convert_element_type(x_5, dtypes.float32)  # t2: "cuda:0 f32[1, 4096, 4096]"
    # t3 = prims.add(t1, t2)  # t3: "cuda:0 f32[1, 4096, 4096]"
    # add = prims.convert_element_type(t3, dtypes.bfloat16)  # add: "cuda:0 bf16[1, 4096, 4096]"
    # mul = prims.mul(t3, t3)  # mul: "cuda:0 f32[1, 4096, 4096]"
    # t7 = prims.sum(mul, (2,))  # t7: "cuda:0 f32[1, 4096]"
    # t8 = prims.broadcast_in_dim(t7, [1, 4096, 1], [0, 1])  # t8: "cuda:0 f32[1, 4096, 1]"
    # mean = prims.div(t8, 4096.0)  # mean: "cuda:0 f32[1, 4096, 1]"
    # add_1 = prims.add(mean, 1e-05)  # add_1: "cuda:0 f32[1, 4096, 1]"
    # rsqrt = prims.rsqrt(add_1)  # rsqrt: "cuda:0 f32[1, 4096, 1]"
    # t12 = prims.broadcast_in_dim(rsqrt, (1, 4096, 4096), (0, 1, 2))  # t12: "cuda:0 f32[1, 4096, 4096]"
    # mul_1 = prims.mul(t3, t12)  # mul_1: "cuda:0 f32[1, 4096, 4096]"
    # float_2 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_, dtypes.float32)  # float_2: "cuda:0 f32[4096]"
    # t15 = prims.broadcast_in_dim(float_2, (1, 4096, 4096), (2,))  # t15: "cuda:0 f32[1, 4096, 4096]"
    # mul_2 = prims.mul(mul_1, t15)  # mul_2: "cuda:0 f32[1, 4096, 4096]"
    # to = prims.convert_element_type(mul_2, dtypes.bfloat16)  # to: "cuda:0 bf16[1, 4096, 4096]"
  return (to, add)

The script of the graph:

import torch
import thunder

def test_graph2_thunder_0():
    class DynamoModule(torch.nn.Module):
      def forward(self, y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_ : torch.nn.parameter.Parameter, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_ : torch.nn.parameter.Parameter, ):
          linear = torch._C._nn.linear(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, None);  y_5 = l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_ = None
          add = linear + x_5;  linear = x_5 = None
          float_1 = add.float()
          mul = float_1 * float_1
          mean = torch.mean(mul, dim = -1, keepdim = True);  mul = None
          add_1 = mean + 1e-05;  mean = None
          rsqrt = torch.rsqrt(add_1);  add_1 = None
          mul_1 = float_1 * rsqrt;  float_1 = rsqrt = None
          float_2 = l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_.float();  l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_ = None
          mul_2 = mul_1 * float_2;  mul_1 = float_2 = None
          to = mul_2.to(dtype = torch.bfloat16);  mul_2 = None

    inputs = [
        torch.testing.make_tensor((1, 4096, 4096), dtype=torch.bfloat16,  device='cuda:0', requires_grad=False, low=None, high=None,).as_strided((1, 4096, 4096), (16777216, 4096, 1)),
        torch.testing.make_tensor((4096, 4096), dtype=torch.bfloat16,  device='cuda:0', requires_grad=False, low=-0.015625, high=0.015625,).as_strided((4096, 4096), (4096, 1)),
        torch.testing.make_tensor((1, 4096, 4096), dtype=torch.bfloat16,  device='cuda:0', requires_grad=False, low=None, high=None,).as_strided((1, 4096, 4096), (16777216, 4096, 1)),
        torch.full((4096,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=False, layout=torch.strided).as_strided((4096,), (1,)),

    mod = DynamoModule()
    compiled = thunder.jit(mod, nv_enable_linear=True)
    # compiled = torch.compile(mod)
    torch.cuda.memory._record_memory_history()
    compiled(*inputs)
    torch.cuda.memory._dump_snapshot("piece_snapshot_thunder_nvlinear_rgfalse.pickle")
    print(thunder.last_traces(compiled)[-1])

test_graph2_thunder_0()

kiya00 avatar Dec 12 '24 16:12 kiya00

Very cool analysis, @kiya00! I look forward to developing tools that automate more of that process.

fyi @kevinstephano -- what should we do next?

mruberry avatar Dec 12 '24 16:12 mruberry

We have new OOM errors for Thunder for: Mistral-7B-v0.2, longchat-13b-16k, vicuna-7b-v1.5-16k for Thunder. The three models (shown in red) fail for the configurations as in the image below.

Image

Happy to share exact container on Slack.

wprazuch avatar Feb 04 '25 10:02 wprazuch

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Apr 16 '25 05:04 stale[bot]