OOM for ThunderFX and Thunder with DDP for Mistral-7B-v0.1
🐛 Bug
When running Mistral-7B-v0.1 we get OOM error. The same configuration passes for torch.compile.
To Reproduce
Steps to reproduce the behavior:
Please use: 1 node(s), each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl_20241112" Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py \
--model_name Mistral-7B-v0.1 \
--distributed_mode ddp \
--shard_mode None \
--compile dynamo_thunder \
--checkpoint_activations False \
--low_precision_mode none \
--micro_batch_size 1
Expected behavior
If we can run training with torch.compile we should be able to run it with Thunder as well.
Environment
system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.98.001 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.8 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.22+gitba4f7d4 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gita9b4989 libraries.pip.torchao 0.6.1 libraries.pip.torchmetrics 1.5.1 libraries.pip.torchvision 0.19.0a0+d23a6e1
By reducing the n_layers we observe the peak allocated memory(GB):
| n_layers | torch compile | thunderfx |
|---|---|---|
| 16 | 41.68 | 42.72 |
| 2 | 8.18 | 8.18 |
| 4 | 12.88 | 13.12 |
we observe in thunderfx the model is segmented in 7 graphs, so I measured the memory usage on each graph:
| graph | torch compile(MB) | thunder jit(MB) |
|---|---|---|
| 0 | 876.578125 | 908.578125 |
| 1 | 930.5234375 | 962.5234375 |
| 2 | 752.0234375 | 784.0234375 |
| 3 | 882.546875 | 946.546875 |
| 4 | 1026.546875 | 1090.546875 |
| 5 | 832.0234375 | 832.0234375 |
| 6 | 564 | 564 |
| 7 | 5864.242188 | 6088.242188 |
With the same script(a piece of the graph2), Thunder(with nv_enable_linear=True) passes all the operators to nvFusion, but the memory usage is a bit higher than torch.compile, 224MB vs 192MB:
torch.compile:
According to the line number listed in the snapshot, the 3 pieces of memory block allocated from the line 138, 140, 144 in the following function
Thunder:
According to the line number listed in the snapshot, the 4 pieces of memory block allocated from the line 12 in the trace which is the nvFusion0
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_):
# y_5: "cuda:0 bf16[1, 4096, 4096]"
# l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_: "cuda:0 bf16[4096, 4096]"
# x_5: "cuda:0 bf16[1, 4096, 4096]"
# l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_: "cuda:0 bf16[4096]"
[add, to] = nvFusion0(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_)
# linear = prims.linear(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, None) # linear: "cuda:0 bf16[1, 4096, 4096]"
# t1 = prims.convert_element_type(linear, dtypes.float32) # t1: "cuda:0 f32[1, 4096, 4096]"
# t2 = prims.convert_element_type(x_5, dtypes.float32) # t2: "cuda:0 f32[1, 4096, 4096]"
# t3 = prims.add(t1, t2) # t3: "cuda:0 f32[1, 4096, 4096]"
# add = prims.convert_element_type(t3, dtypes.bfloat16) # add: "cuda:0 bf16[1, 4096, 4096]"
# mul = prims.mul(t3, t3) # mul: "cuda:0 f32[1, 4096, 4096]"
# t7 = prims.sum(mul, (2,)) # t7: "cuda:0 f32[1, 4096]"
# t8 = prims.broadcast_in_dim(t7, [1, 4096, 1], [0, 1]) # t8: "cuda:0 f32[1, 4096, 1]"
# mean = prims.div(t8, 4096.0) # mean: "cuda:0 f32[1, 4096, 1]"
# add_1 = prims.add(mean, 1e-05) # add_1: "cuda:0 f32[1, 4096, 1]"
# rsqrt = prims.rsqrt(add_1) # rsqrt: "cuda:0 f32[1, 4096, 1]"
# t12 = prims.broadcast_in_dim(rsqrt, (1, 4096, 4096), (0, 1, 2)) # t12: "cuda:0 f32[1, 4096, 4096]"
# mul_1 = prims.mul(t3, t12) # mul_1: "cuda:0 f32[1, 4096, 4096]"
# float_2 = prims.convert_element_type(l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_, dtypes.float32) # float_2: "cuda:0 f32[4096]"
# t15 = prims.broadcast_in_dim(float_2, (1, 4096, 4096), (2,)) # t15: "cuda:0 f32[1, 4096, 4096]"
# mul_2 = prims.mul(mul_1, t15) # mul_2: "cuda:0 f32[1, 4096, 4096]"
# to = prims.convert_element_type(mul_2, dtypes.bfloat16) # to: "cuda:0 bf16[1, 4096, 4096]"
return (to, add)
The script of the graph:
import torch
import thunder
def test_graph2_thunder_0():
class DynamoModule(torch.nn.Module):
def forward(self, y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_ : torch.nn.parameter.Parameter, x_5, l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_ : torch.nn.parameter.Parameter, ):
linear = torch._C._nn.linear(y_5, l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_, None); y_5 = l_self_modules_transformer_modules_h_modules_1_modules_attn_modules_proj_parameters_weight_ = None
add = linear + x_5; linear = x_5 = None
float_1 = add.float()
mul = float_1 * float_1
mean = torch.mean(mul, dim = -1, keepdim = True); mul = None
add_1 = mean + 1e-05; mean = None
rsqrt = torch.rsqrt(add_1); add_1 = None
mul_1 = float_1 * rsqrt; float_1 = rsqrt = None
float_2 = l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_.float(); l_self_modules_transformer_modules_h_modules_1_modules_norm_2_parameters_weight_ = None
mul_2 = mul_1 * float_2; mul_1 = float_2 = None
to = mul_2.to(dtype = torch.bfloat16); mul_2 = None
inputs = [
torch.testing.make_tensor((1, 4096, 4096), dtype=torch.bfloat16, device='cuda:0', requires_grad=False, low=None, high=None,).as_strided((1, 4096, 4096), (16777216, 4096, 1)),
torch.testing.make_tensor((4096, 4096), dtype=torch.bfloat16, device='cuda:0', requires_grad=False, low=-0.015625, high=0.015625,).as_strided((4096, 4096), (4096, 1)),
torch.testing.make_tensor((1, 4096, 4096), dtype=torch.bfloat16, device='cuda:0', requires_grad=False, low=None, high=None,).as_strided((1, 4096, 4096), (16777216, 4096, 1)),
torch.full((4096,), 1.0, dtype=torch.bfloat16, device='cuda:0', requires_grad=False, layout=torch.strided).as_strided((4096,), (1,)),
mod = DynamoModule()
compiled = thunder.jit(mod, nv_enable_linear=True)
# compiled = torch.compile(mod)
torch.cuda.memory._record_memory_history()
compiled(*inputs)
torch.cuda.memory._dump_snapshot("piece_snapshot_thunder_nvlinear_rgfalse.pickle")
print(thunder.last_traces(compiled)[-1])
test_graph2_thunder_0()
Very cool analysis, @kiya00! I look forward to developing tools that automate more of that process.
fyi @kevinstephano -- what should we do next?
We have new OOM errors for Thunder for: Mistral-7B-v0.2, longchat-13b-16k, vicuna-7b-v1.5-16k for Thunder. The three models (shown in red) fail for the configurations as in the image below.
Happy to share exact container on Slack.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.