🐛 [gpt2][torch.compile] Encountered guard failures with GPT2 compilation
Bug Description
Configuration : llm_examples_main branch, current torch version : 2.4, transformers==4.41.2
Traceback (most recent call last):
File "/home/dperi/Downloads/TensorRT/examples/dynamo/torch_compile_gpt2.py", line 44, in <module>
trt_outputs = model(input_ids)
^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 230, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 726, in compile_inner
check_fn = CheckFunctionManager(
^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 2130, in __init__
guard.create(builder)
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_guards.py", line 260, in create
return self.create_fn(builder, self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 1717, in SHAPE_ENV
guards = output_graph.shape_env.produce_guards(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dperi/.pyenv/versions/3.11.7/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4163, in produce_guards
raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['input_ids'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of L['input_ids'].size()[1] = L['input_ids'].size()[1] in the specified range L['input_ids'].size()[1] <= 1023 satisfy the generated guard 7 <= L['input_ids'].size()[1] and L['input_ids'].size()[1] <= 1023
To Reproduce
python examples/dynamo/torch_compile_gpt2.py
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip,libtorch, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
torch._dynamo.mark_dynamic(input_ids, 1, min=7, max=1023) fixes this issue though. Previously, I was using min=2
Pytorch issue : https://github.com/pytorch/pytorch/issues/125604
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['input_ids'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of L['input_ids'].size()[1] = L['input_ids'].size()[1] in the specified range L['input_ids'].size()[1] <= 1023 satisfy the generated guard 7 <= L['input_ids'].size()[1] and L['input_ids'].size()[1] <= 1023
-
Here the input prompt length used for compile is 7 and hence it introduced guards for that length. So if I use
min=1and prompt length=7, it fails with the above error. -
torch._dynamo.mark_dynamic(input_ids, 1, min=1, max=1023) with a prompt length of 1 errors out with constraint violation errors.
-
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) works fine.