DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] DeepCompile: MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers

Open unavailableun opened this issue 7 months ago • 4 comments

Describe the bug @tohtana Thanks for your great work about DeepCompile. I am trying to enable DeepCompile for a 2-nodes training job, while hit below exception:

MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers

Have you met similar issue before? Appreciate for any clue!

Launching compile passes: global_steps=0 passes=[<function add_z1_reduce at 0x7fe9857581f0>]
MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers

While executing %iota : [num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%primals_2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
Original traceback:
  File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/bgm/modeling/load_train.py", line 361, in causal_lm_forward
    outputs = model.__original_forward__(input_ids, attention_mask=attention_mask)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
    outputs = self.model(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 858, in forward
    cache_position = torch.arange(

[rank0]: Traceback (most recent call last):
[rank0]:   File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 350, in <module>
[rank0]:     run_app()
[rank0]:   File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 334, in run_app
[rank0]:     train(
[rank0]:   File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 164, in train
[rank0]:     outputs, _ = _forward(model, batch, conf)
[rank0]:   File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 112, in _forward
[rank0]:     outputs = model(**inputs, use_cache=False)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2054, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1737, in _wrapped_call_impl
[rank0]:     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
[rank0]:     result = self._inner_convert(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
[rank0]:     return _compile(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
[rank0]:     super().run()
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 657, in wrapper
[rank0]:     return handle_graph_break(self, inst, speculation.reason)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 698, in handle_graph_break
[rank0]:     self.output.compile_subgraph(self, reason=reason)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1136, in compile_subgraph
[rank0]:     self.compile_and_call_fx_graph(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
[rank0]:     compiled_fn = self.call_user_compiler(gm)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
[rank0]:     return self._call_user_compiler(gm)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
[rank0]:     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
[rank0]:     compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
[rank0]:     compiled_gm = compiler_fn(gm, example_inputs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/__init__.py", line 2385, in __call__
[rank0]:     return self.compiler_fn(model_, inputs_, **self.kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 275, in backend_fn
[rank0]:     return torch._inductor.compile(gm, real_inputs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_inductor/__init__.py", line 48, in compile
[rank0]:     return compile_fx(gm, example_inputs, config_patches=options)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
[rank0]:     return aot_autograd(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 83, in __call__
[rank0]:     cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
[rank0]:     compiled_fn = dispatch_and_compile()
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
[rank0]:     compiled_fn, _ = create_aot_dispatcher_function(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
[rank0]:     return _create_aot_dispatcher_function(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
[rank0]:     compiled_fn, fw_metadata = compiler_fn(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 678, in aot_dispatch_autograd
[rank0]:     compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/inductor.py", line 27, in wrapped_compiler
[rank0]:     mod_graph = dc_compiler(gm, fake_inputs)
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 178, in make_fw_graph
[rank0]:     run_opt_passes(
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 129, in run_opt_passes
[rank0]:     mem_prof.run(*create_inputs_fn())
[rank0]:   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/profilers/graph_profile.py", line 258, in run
[rank0]:     return return_val
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='backend_fn' raised:
[rank0]: UnboundLocalError: local variable 'return_val' referenced before assignment

To Reproduce Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior A clear and concise description of what you expected to happen.

ds_report output Please run ds_report to give us details about your setup.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types: two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

unavailableun avatar May 26 '25 08:05 unavailableun

Hi @unavailableun

Can you share any code to repro?

therealnaveenkamal avatar May 27 '25 14:05 therealnaveenkamal

Sure, need some time to prepare a clean repro script.

unavailableun avatar May 28 '25 08:05 unavailableun

@therealnaveenkamal Here is the repro script: https://gist.github.com/unavailableun/5efab330f82ec6af051717adee4c3455

Running command on my local device (1 node with 2 A6000 GPUs): torchrun --nproc_per_node=2 train.py

Environment:

  • cuda 12.4
  • python 3.10
  • torch 2.6.0+cu124
  • deepspeed 0.16.9
  • transformers 4.47.1

unavailableun avatar May 29 '25 08:05 unavailableun

Above MemoryProfiling error was throwed from below memory profiling part in compile/backend.py, seems no big impact to running, so I commented this part, while still get an aot TypeError.

Image

Launching compile passes: global_steps=0 passes=[<function add_z1_reduce at 0x7f41cb7443a0>]
/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
[rank1]:W0529 09:48:51.006000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard s1 - 1 < 2 == False, this could result in accuracy problems
[rank1]:W0529 09:48:57.388000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 49*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank1]:W0529 09:48:57.641000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 7*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank0]:W0529 09:49:00.081000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard s1 - 1 < 2 == False, this could result in accuracy problems
[rank0]:W0529 09:49:06.436000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 49*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank0]:W0529 09:49:06.661000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 7*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/user_name/deeprank/projects/bgm/scripts/deep_compile/train.py", line 270, in <module>
[rank1]:     train()
[rank1]:   File "/home/user_name/deeprank/projects/bgm/scripts/deep_compile/train.py", line 246, in train
[rank1]:     model_engine.backward(loss)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2216, in backward
[rank1]:     self._do_optimizer_backward(loss, retain_graph)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2082, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_tensor.py", line 626, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1710, in backward
[rank1]:     return impl_fn()
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1700, in impl_fn
[rank1]:     out = CompiledFunction._backward_impl(ctx, all_args)
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2065, in _backward_impl
[rank1]:     out = call_func_at_runtime_with_args(
[rank1]:   File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 135, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(*args))
[rank1]: TypeError: DisableContext.__call__() takes 2 positional arguments but 670 were given

unavailableun avatar May 29 '25 09:05 unavailableun

Hi @unavailableun,

Thank you for reporting and offering a repro! I'm currently working on a symint issue in #7243. As this PR currently throws an error with CI tests, let me try your repro after I fix the issue.

tohtana avatar Jun 02 '25 07:06 tohtana

Hi @tohtana Have you reproduced the error after your fix?

unavailableun avatar Jun 11 '25 05:06 unavailableun

Hi @unavailableun,

It took time, but your repro worked after some fixes. Until #7386 is merged, please use tohtana/dc_improve_z3_coverage.

As we need to have exact the same graph across all ranks, we also need to pad sequences to the same lengths. You can check this modified one. I added these two options as well as padding.

  • --deepcompile: Enable DeepCompile
  • --dynamic: Enable dynamic shape

After several recompilations (they take a long time), these are the iteration times:

  • Baseline (no option): Epoch 3, Step 750, deepcompile False, Average Loss: 11.8906, Avg Iteration Time: 0.0630s
  • DeepCompile (--deepcompile): Epoch 3, Step 750, deepcompile True, Average Loss: 11.9013, Avg Iteration Time: 0.0463s
  • DeepCompile with dynamic shape (--deepcompile --dynamic): Epoch 3, Step 750, deepcompile True, Average Loss: 11.8886, Avg Iteration Time: 0.0373s

tohtana avatar Jun 27 '25 07:06 tohtana