[BUG] DeepCompile: MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers
Describe the bug @tohtana Thanks for your great work about DeepCompile. I am trying to enable DeepCompile for a 2-nodes training job, while hit below exception:
MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers
Have you met similar issue before? Appreciate for any clue!
Launching compile passes: global_steps=0 passes=[<function add_z1_reduce at 0x7fe9857581f0>]
MemoryProfiling error /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7280: SymIntArrayRef expected to contain only concrete integers
While executing %iota : [num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%primals_2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
Original traceback:
File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/bgm/modeling/load_train.py", line 361, in causal_lm_forward
outputs = model.__original_forward__(input_ids, attention_mask=attention_mask)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
outputs = self.model(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 858, in forward
cache_position = torch.arange(
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 350, in <module>
[rank0]: run_app()
[rank0]: File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 334, in run_app
[rank0]: train(
[rank0]: File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 164, in train
[rank0]: outputs, _ = _forward(model, batch, conf)
[rank0]: File "/scratch/azureml/cr/j/89c766bec23f4cbb85212c424847a0fd/exe/wd/train.py", line 112, in _forward
[rank0]: outputs = model(**inputs, use_cache=False)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2054, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1737, in _wrapped_call_impl
[rank0]: return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
[rank0]: return self._torchdynamo_orig_callable(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
[rank0]: result = self._inner_convert(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
[rank0]: return _compile(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
[rank0]: return _compile_inner(code, one_graph, hooks, transform)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]: return function(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
[rank0]: out_code = transform_code_object(code, transform)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
[rank0]: transformations(instructions, code_options)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
[rank0]: tracer.run()
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
[rank0]: super().run()
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 657, in wrapper
[rank0]: return handle_graph_break(self, inst, speculation.reason)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 698, in handle_graph_break
[rank0]: self.output.compile_subgraph(self, reason=reason)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1136, in compile_subgraph
[rank0]: self.compile_and_call_fx_graph(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
[rank0]: compiled_fn = self.call_user_compiler(gm)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
[rank0]: return self._call_user_compiler(gm)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
[rank0]: raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
[rank0]: compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
[rank0]: compiled_gm = compiler_fn(gm, example_inputs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/__init__.py", line 2385, in __call__
[rank0]: return self.compiler_fn(model_, inputs_, **self.kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 275, in backend_fn
[rank0]: return torch._inductor.compile(gm, real_inputs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_inductor/__init__.py", line 48, in compile
[rank0]: return compile_fx(gm, example_inputs, config_patches=options)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
[rank0]: return aot_autograd(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 83, in __call__
[rank0]: cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
[rank0]: compiled_fn = dispatch_and_compile()
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
[rank0]: compiled_fn, _ = create_aot_dispatcher_function(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
[rank0]: return _create_aot_dispatcher_function(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
[rank0]: compiled_fn, fw_metadata = compiler_fn(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 678, in aot_dispatch_autograd
[rank0]: compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/inductor.py", line 27, in wrapped_compiler
[rank0]: mod_graph = dc_compiler(gm, fake_inputs)
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 178, in make_fw_graph
[rank0]: run_opt_passes(
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/backend.py", line 129, in run_opt_passes
[rank0]: mem_prof.run(*create_inputs_fn())
[rank0]: File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/compile/profilers/graph_profile.py", line 258, in run
[rank0]: return return_val
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='backend_fn' raised:
[rank0]: UnboundLocalError: local variable 'return_val' referenced before assignment
To Reproduce Steps to reproduce the behavior:
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
Expected behavior A clear and concise description of what you expected to happen.
ds_report output
Please run ds_report to give us details about your setup.
Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: [e.g. Ubuntu 18.04]
- GPU count and types: two machines with x8 A100s each]
- Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
- Python version
- Any other relevant info about your setup
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Docker context Are you using a specific docker image that you can share?
Additional context Add any other context about the problem here.
Hi @unavailableun
Can you share any code to repro?
Sure, need some time to prepare a clean repro script.
@therealnaveenkamal Here is the repro script: https://gist.github.com/unavailableun/5efab330f82ec6af051717adee4c3455
Running command on my local device (1 node with 2 A6000 GPUs): torchrun --nproc_per_node=2 train.py
Environment:
- cuda 12.4
- python 3.10
- torch 2.6.0+cu124
- deepspeed 0.16.9
- transformers 4.47.1
Above MemoryProfiling error was throwed from below memory profiling part in compile/backend.py, seems no big impact to running, so I commented this part, while still get an aot TypeError.
Launching compile passes: global_steps=0 passes=[<function add_z1_reduce at 0x7f41cb7443a0>]
/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
[rank1]:W0529 09:48:51.006000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard s1 - 1 < 2 == False, this could result in accuracy problems
[rank1]:W0529 09:48:57.388000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 49*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank1]:W0529 09:48:57.641000 43431 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 7*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank0]:W0529 09:49:00.081000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard s1 - 1 < 2 == False, this could result in accuracy problems
[rank0]:W0529 09:49:06.436000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 49*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
[rank0]:W0529 09:49:06.661000 43430 site-packages/torch/fx/experimental/symbolic_shapes.py:6184] [0/0] Ignored guard 7*(((s0*s1 + 13)//14)) < 8388608 == True, this could result in accuracy problems
/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
warnings.warn(
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/user_name/deeprank/projects/bgm/scripts/deep_compile/train.py", line 270, in <module>
[rank1]: train()
[rank1]: File "/home/user_name/deeprank/projects/bgm/scripts/deep_compile/train.py", line 246, in train
[rank1]: model_engine.backward(loss)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2216, in backward
[rank1]: self._do_optimizer_backward(loss, retain_graph)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward
[rank1]: self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2082, in backward
[rank1]: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]: scaled_loss.backward(retain_graph=retain_graph)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_tensor.py", line 626, in backward
[rank1]: torch.autograd.backward(
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]: _engine_run_backward(
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]: return user_fn(self, *args)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1710, in backward
[rank1]: return impl_fn()
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1700, in impl_fn
[rank1]: out = CompiledFunction._backward_impl(ctx, all_args)
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2065, in _backward_impl
[rank1]: out = call_func_at_runtime_with_args(
[rank1]: File "/home/user_name/miniforge3/envs/lightning/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 135, in call_func_at_runtime_with_args
[rank1]: out = normalize_as_list(f(*args))
[rank1]: TypeError: DisableContext.__call__() takes 2 positional arguments but 670 were given
Hi @unavailableun,
Thank you for reporting and offering a repro! I'm currently working on a symint issue in #7243. As this PR currently throws an error with CI tests, let me try your repro after I fix the issue.
Hi @tohtana Have you reproduced the error after your fix?
Hi @unavailableun,
It took time, but your repro worked after some fixes. Until #7386 is merged, please use tohtana/dc_improve_z3_coverage.
As we need to have exact the same graph across all ranks, we also need to pad sequences to the same lengths. You can check this modified one. I added these two options as well as padding.
--deepcompile: Enable DeepCompile--dynamic: Enable dynamic shape
After several recompilations (they take a long time), these are the iteration times:
- Baseline (no option): Epoch 3, Step 750, deepcompile False, Average Loss: 11.8906, Avg Iteration Time: 0.0630s
- DeepCompile (
--deepcompile): Epoch 3, Step 750, deepcompile True, Average Loss: 11.9013, Avg Iteration Time: 0.0463s - DeepCompile with dynamic shape (
--deepcompile --dynamic): Epoch 3, Step 750, deepcompile True, Average Loss: 11.8886, Avg Iteration Time: 0.0373s