peft icon indicating copy to clipboard operation
peft copied to clipboard

FIX Prompt learning issue with 4d attention mask

Open BenjaminBossan opened this issue 9 months ago • 19 comments

Resolves #2452

Some causal language models in transformers have 4d attention masks at the input preparation stage. So far, we have assumed 2d attention masks, which results in an error in that case. This PR fixes the situation.

My first attempt was to transform the 2d prefix attention mask (from the virtual tokens) into a 4d attention mask before concatenating them. However, this was error prone and I was unsure if my approach would generalize to other model architectures than the one tested (gemma), as it involved using private transformers methods (model._prepare_4d_causal_attention_mask_with_cache_position). The simpler approach was thus to just create a 2d attention mask and let the model handle it.

The test suite has been extended to include a tiny gemma model. To prevent the test suite from ballooning, I removed another model. Specifically, this was GPT neox, which from HF download stats seems to be one of the least popular architectures from our test suite. I also extended the default parameters in constants.py for the different PEFT methods to support gemma.

Unfortunately, some tests are failing with gemma. When they were unrelated to changes in this PR, I chose to just skip those tests, as I consider them out of scope for this PR.

BenjaminBossan avatar Mar 27 '25 11:03 BenjaminBossan

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@krishnakanthnakkav2 This PR should hopefully fix your issue. If you have the opportunity, please check if this branch resolves your initial problem.

BenjaminBossan avatar Mar 27 '25 15:03 BenjaminBossan

Thank you @BenjaminBossan . I will try and update here at the earliest.

krishnakanthnakkav2 avatar Mar 27 '25 15:03 krishnakanthnakkav2

 warnings.warn(
Traceback (most recent call last):
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
   return node.target(*args, **kwargs)
          ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
   return fn(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
   return self.dispatch(func, types, args, kwargs)
          ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
   return self._cached_dispatch_impl(func, types, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
   output = self._dispatch_impl(func, types, args, kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
   op_impl_out = op_impl(self, func, *args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
   return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
   raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
   ret_val = wrap_fake_exception(
       lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
   return fn()
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
   lambda: run_node(tx.output, node, args, kwargs, nnmodule)
           ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
   raise RuntimeError(make_error_message(e)).with_traceback(
       e.__traceback__
   ) from e
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
   return node.target(*args, **kwargs)
          ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
   return fn(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
   return self.dispatch(func, types, args, kwargs)
          ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
   return self._cached_dispatch_impl(func, types, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
   output = self._dispatch_impl(func, types, args, kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
   op_impl_out = op_impl(self, func, *args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
   return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
   raise DataDependentOutputException(func)
RuntimeError: Failed running call_function <built-in method arange of type object at 0x71841af9af60>(*(FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)), **{'device': device(type='cuda', index=0)}):
aten._local_scalar_dense.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
 File "/home/krishna/PII/cs-llm-temp/misc.py", line 97, in <module>
   generated_output = model.generate(
       input_ids=inputs["input_ids"],
       attention_mask=inputs["attention_mask"],
       max_length=60
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/peft/peft_model.py", line 1919, in generate
   outputs = self.base_model.generate(**kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
   return func(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 2223, in generate
   result = self._sample(
       input_ids,
   ...<5 lines>...
       **model_kwargs,
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 3214, in _sample
   outputs = model_forward(**model_inputs, return_dict=True)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
   return fn(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
   return forward_call(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
   return self._torchdynamo_orig_callable(
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
       frame, cache_entry, self.hooks, frame_state, skip=1
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   )
   ^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
   return _compile(
       frame.f_code,
   ...<14 lines>...
       skip=skip + 1,
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
   guarded_code = compile_inner(code, one_graph, hooks, transform)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
   return _compile_inner(code, one_graph, hooks, transform)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
   return function(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
   out_code = transform_code_object(code, transform)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
   transformations(instructions, code_options)
   ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
   return fn(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
   tracer.run()
   ~~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
   super().run()
   ~~~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
   while self.step():
         ~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
   self.dispatch_table[inst.opcode](self, inst)
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
   return inner_fn(self, inst)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
   self.call_function(fn, argsvars.items, kwargsvars)
   ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
   self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
   return super().call_function(tx, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
   return super().call_function(tx, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
   return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
   return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
   return cls.inline_call_(parent, func, args, kwargs)
          ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
   tracer.run()
   ~~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
   while self.step():
         ~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
   self.dispatch_table[inst.opcode](self, inst)
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
   return inner_fn(self, inst)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
   self.call_function(fn, argsvars.items, kwargsvars)
   ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
   self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
   return super().call_function(tx, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
   return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
   return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
   return cls.inline_call_(parent, func, args, kwargs)
          ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
   tracer.run()
   ~~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
   while self.step():
         ~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
   self.dispatch_table[inst.opcode](self, inst)
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
   return inner_fn(self, inst)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
   self.call_function(fn, argsvars.items, kwargsvars)
   ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
   self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
   return getattr(self.realize(), name)(*args, **kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/nn_module.py", line 914, in call_function
   return variables.UserFunctionVariable(fn, source=source).call_function(
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
       tx, [self] + list(args), kwargs
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   )
   ^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
   return super().call_function(tx, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
   return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
   return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
   return cls.inline_call_(parent, func, args, kwargs)
          ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
   tracer.run()
   ~~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
   while self.step():
         ~~~~~~~~~^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
   self.dispatch_table[inst.opcode](self, inst)
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
   return inner_fn(self, inst)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2482, in CALL_KW
   self._call(inst, call_kw=True)
   ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
   self.call_function(fn, args, kwargs)
   ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
   self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
   tensor_variable = wrap_fx_proxy(
       tx=tx,
   ...<4 lines>...
       ),
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
   return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
   return _wrap_fx_proxy(
       target_cls, tx, proxy, example_value, subclass_type, **options
   )
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
   example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2484, in get_fake_value
   unimplemented(
   ~~~~~~~~~~~~~^
       f"data dependent operator: {cause.func}; "
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       "to enable, set torch._dynamo.config.capture_scalar_outputs = True"
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   )
   ^
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
   raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

from user code:
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
   output = module._old_forward(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
   return func(*args, **kwargs)
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
   outputs = self.model(
 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
   cache_position = torch.arange(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
   import torch._dynamo
   torch._dynamo.config.suppress_errors = True

Can you please see this. And, can you also share the library versions that you tested for this PR?

krishnakanthnakkav2 avatar Mar 27 '25 20:03 krishnakanthnakkav2

@krishnakanthnakkav2 To install the PEFT version from this PR, run:

python -m pip install -U git+https://github.com/BenjaminBossan/peft.git@fix-prompt-tuning-4d-attention-mask

(after you finished testing, please install the normal PEFT version again)

The error you show above appears to be related to usage of torch.compile please try again without compilation and check if it works.

BenjaminBossan avatar Mar 28 '25 10:03 BenjaminBossan

gentle ping @krishnakanthnakkav2

BenjaminBossan avatar Apr 01 '25 10:04 BenjaminBossan

Hi,

I tried again with the latest code.. I still face errors.. My transformers version is 4.49.0

 File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_decomp/decompositions.py", line 4822, in arange_start
    return aten.arange.start_step(
           ~~~~~~~~~~~~~~~~~~~~~~^
        start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_refs/__init__.py", line 5085, in arange
    return prims.iota(
           ~~~~~~~~~~^
        length,
        ^^^^^^^
    ...<4 lines>...
        requires_grad=requires_grad,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2312, in _dispatch_impl
    func.prim_meta_impl(*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_prims/__init__.py", line 2342, in _iota_meta
    return torch.empty(
           ~~~~~~~~~~~^
        length,
        ^^^^^^^
    ...<2 lines>...
        requires_grad=requires_grad,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1378, in _cached_dispatch_impl
    entry = self._make_cache_entry(state, key, func, args, kwargs, output)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1648, in _make_cache_entry
    output_info = self._get_output_info_for_cache_entry(
        state, key, func, args, kwargs, output
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1591, in _get_output_info_for_cache_entry
    synth_output = self._output_from_cache_entry(
        state, entry_for_synth_output, key, func, args
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1746, in _output_from_cache_entry
    return self._get_output_tensor_from_cache_entry(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        state, entry.output_infos[0], key, func, args
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1697, in _get_output_tensor_from_cache_entry
    empty = torch.empty_strided(
        shape,
    ...<4 lines>...
        requires_grad=metadata.requires_grad,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
        self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
           ~~~~~~~~~~~~~~~~~~~^
        orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
    ...<3 lines>...
    )
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0 - u1, 0) (unhinted: Eq(u0 - u1, 0)).  (Size-like symbols: none)

Caused by: cache_position = torch.arange(  # transformers/models/gemma2/modeling_gemma2.py:612 in forward (utils/_stats.py:21 in wrapper)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u1,u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
        lambda: run_node(tx.output, node, args, kwargs, nnmodule)
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
        e.__traceback__
    ) from e
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_decomp/decompositions.py", line 4822, in arange_start
    return aten.arange.start_step(
           ~~~~~~~~~~~~~~~~~~~~~~^
        start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
    result = fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_refs/__init__.py", line 5085, in arange
    return prims.iota(
           ~~~~~~~~~~^
        length,
        ^^^^^^^
    ...<4 lines>...
        requires_grad=requires_grad,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2312, in _dispatch_impl
    func.prim_meta_impl(*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_prims/__init__.py", line 2342, in _iota_meta
    return torch.empty(
           ~~~~~~~~~~~^
        length,
        ^^^^^^^
    ...<2 lines>...
        requires_grad=requires_grad,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1378, in _cached_dispatch_impl
    entry = self._make_cache_entry(state, key, func, args, kwargs, output)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1648, in _make_cache_entry
    output_info = self._get_output_info_for_cache_entry(
        state, key, func, args, kwargs, output
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1591, in _get_output_info_for_cache_entry
    synth_output = self._output_from_cache_entry(
        state, entry_for_synth_output, key, func, args
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1746, in _output_from_cache_entry
    return self._get_output_tensor_from_cache_entry(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        state, entry.output_infos[0], key, func, args
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1697, in _get_output_tensor_from_cache_entry
    empty = torch.empty_strided(
        shape,
    ...<4 lines>...
        requires_grad=metadata.requires_grad,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/sym_node.py", line 564, in guard_size_oblivious
    r = self.shape_env.evaluate_expr(
        self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6303, in evaluate_expr
    return self._evaluate_expr(
           ~~~~~~~~~~~~~~~~~~~^
        orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6493, in _evaluate_expr
    raise self._make_data_dependent_error(
    ...<3 lines>...
    )
RuntimeError: Failed running call_function <built-in method arange of type object at 0x79838e99af60>(*(FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)), **{'device': device(type='cuda', index=0)}):
Could not guard on data-dependent expression Eq(u0 - u1, 0) (unhinted: Eq(u0 - u1, 0)).  (Size-like symbols: none)

Caused by: cache_position = torch.arange(  # transformers/models/gemma2/modeling_gemma2.py:612 in forward (utils/_stats.py:21 in wrapper)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u1,u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/krishna/PII/cs-llm-temp/misc.py", line 104, in <module>
    generated_output = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        attention_mask=inputs["attention_mask"].cuda(),
        max_length=60
    )
  File "/home/krishna/PII/cs-llm-temp/libs/peftgemma/peft/src/peft/peft_model.py", line 1935, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 3214, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        frame, cache_entry, self.hooks, frame_state, skip=1
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
        frame.f_code,
    ...<14 lines>...
        skip=skip + 1,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/nn_module.py", line 914, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        tx, [self] + list(args), kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2482, in CALL_KW
    self._call(inst, call_kw=True)
    ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
        tx=tx,
    ...<4 lines>...
        ),
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
        target_cls, tx, proxy, example_value, subclass_type, **options
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2526, in get_fake_value
    raise UserError(  # noqa: B904
    ...<3 lines>...
    )
torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0 - u1, 0) (unhinted: Eq(u0 - u1, 0)).  (Size-like symbols: none)

Caused by: cache_position = torch.arange(  # transformers/models/gemma2/modeling_gemma2.py:612 in forward (utils/_stats.py:21 in wrapper)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u1,u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example

from user code:
   File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

krishnakanthnakkav2 avatar Apr 13 '25 16:04 krishnakanthnakkav2

@krishnakanthnakkav2 It looks like you're using torch.compile, could you please check if you get an error without compilation?

BenjaminBossan avatar Apr 14 '25 09:04 BenjaminBossan

@krishnakanthnakkav2 It looks like you're using torch.compile, could you please check if you get an error without compilation?

Can you please what exactly do you mean about compilation.. I am using following code and use the command CUDA_VISIBLE_DEVICES=0 python check.py

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig
from peft import get_peft_model, TaskType, PromptTuningConfig, PromptTuningInit


# import torch._dynamo  # For managing Dynamo configuration

# # Option 1: Enable scalar output capture
# #torch._dynamo.config.capture_scalar_outputs = True

# # Option 2: Or, use eager execution (fallback)
# torch._dynamo.config.suppress_errors = True

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
# torch._dynamo.config.capture_scalar_outputs = True


model_name = "google/gemma-2-2b"
cache_dir = "/assets/hub"
attention = "eager"
tokenizer_model_name = model_name

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    attn_implementation=attention,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer.pad_token = tokenizer.eos_token

config = PromptTuningConfig(
    peft_type="PROMPT_TUNING",
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    prompt_tuning_init_text="email",  # "phone" # "address"
    num_virtual_tokens=20,
    tokenizer_name_or_path=model_name)

model = get_peft_model(model, config)

input_texts = [
    "In the world of artificial intelligence,",
]

inputs = tokenizer(input_texts, return_tensors="pt",
                   padding=True, truncation=True
                   )

print(f"Input attention mask shape: {inputs['attention_mask'].shape}")

generated_output = model.generate(
    input_ids=inputs["input_ids"].cuda(),
    attention_mask=inputs["attention_mask"].cuda(),
    max_length=60
)

generated_texts = tokenizer.batch_decode(
    generated_output, skip_special_tokens=True
)

for idx, generated_text in enumerate(generated_texts):
    print(f"Generated text for input {idx + 1}: {generated_text}")

Error is

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
        lambda: run_node(tx.output, node, args, kwargs, nnmodule)
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
        e.__traceback__
    ) from e
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
RuntimeError: Failed running call_function <built-in method arange of type object at 0x70dc31b9af60>(*(FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)), **{'device': device(type='cuda', index=0)}):
aten._local_scalar_dense.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/krishna/PII/cs-llm-temp/misc.py", line 62, in <module>
    generated_output = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        attention_mask=inputs["attention_mask"].cuda(),
        max_length=60
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/peft/peft_model.py", line 1919, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 3214, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        frame, cache_entry, self.hooks, frame_state, skip=1
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
        frame.f_code,
    ...<14 lines>...
        skip=skip + 1,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/nn_module.py", line 914, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        tx, [self] + list(args), kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2482, in CALL_KW
    self._call(inst, call_kw=True)
    ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
        tx=tx,
    ...<4 lines>...
        ),
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
        target_cls, tx, proxy, example_value, subclass_type, **options
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2484, in get_fake_value
    unimplemented(
    ~~~~~~~~~~~~~^
        f"data dependent operator: {cause.func}; "
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        "to enable, set torch._dynamo.config.capture_scalar_outputs = True"
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

from user code:
   File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I then tried adding dynamo related lines like torch._dynamo.config.inline_inbuilt_nn_modules=True, torch._dynamo.config.capture_scalar_outputs = True , torch._dynamo.config.suppress_errors = True but the error remains a similar issues with dynamo.

Can you please guide me on this.. How to proceed from here about compilation..

krishnakanthnakkav2 avatar Apr 14 '25 09:04 krishnakanthnakkav2

@krishnakanthnakkav2 Could you please try again with the latest changes on this branch?

BenjaminBossan avatar Apr 14 '25 11:04 BenjaminBossan

I just tried again..

It is similar error trace..

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
        lambda: run_node(tx.output, node, args, kwargs, nnmodule)
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
        e.__traceback__
    ) from e
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
RuntimeError: Failed running call_function <built-in method arange of type object at 0x7bf40539af60>(*(FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)), **{'device': device(type='cuda', index=0)}):
aten._local_scalar_dense.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/krishna/PII/cs-llm-temp/misc.py", line 63, in <module>
    generated_output = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        attention_mask=inputs["attention_mask"].cuda(),
        max_length=60
    )
  File "/home/krishna/PII/cs-llm-temp/libs/peftgemmav2/peft/src/peft/peft_model.py", line 1935, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/generation/utils.py", line 3214, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        frame, cache_entry, self.hooks, frame_state, skip=1
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
        frame.f_code,
    ...<14 lines>...
        skip=skip + 1,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/nn_module.py", line 443, in call_function
    return tx.inline_user_function_return(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        variables.UserFunctionVariable(fn, source=fn_source),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        args,
        ^^^^^
        kwargs,
        ^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2482, in CALL_KW
    self._call(inst, call_kw=True)
    ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
        tx=tx,
    ...<4 lines>...
        ),
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
        target_cls, tx, proxy, example_value, subclass_type, **options
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2484, in get_fake_value
    unimplemented(
    ~~~~~~~~~~~~~^
        f"data dependent operator: {cause.func}; "
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        "to enable, set torch._dynamo.config.capture_scalar_outputs = True"
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

from user code:
   File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v4/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Would you mind giving your exact versions of important libraries like python,torch, transformers,accelerate,.., I will reproduce and see in different environment?

krishnakanthnakkav2 avatar Apr 14 '25 13:04 krishnakanthnakkav2

FYI, my conda environment is


# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.5.2                    pypi_0    pypi
aiohappyeyeballs          2.6.1                    pypi_0    pypi
aiohttp                   3.11.14                  pypi_0    pypi
aiosignal                 1.3.2                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
asttokens                 3.0.0              pyhd8ed1ab_1    conda-forge
attrs                     25.3.0                   pypi_0    pypi
av                        14.3.0                   pypi_0    pypi
blobfile                  3.0.0                    pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2025.2.25            h06a4308_0
certifi                   2025.1.31                pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
click                     8.1.8                    pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_1    conda-forge
configspace               1.2.1                    pypi_0    pypi
contourpy                 1.3.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  3.4.1                    pypi_0    pypi
debugpy                   1.8.13          py313h46c70d0_0    conda-forge
decorator                 5.2.1              pyhd8ed1ab_0    conda-forge
dill                      0.3.8                    pypi_0    pypi
einops                    0.8.1                    pypi_0    pypi
exceptiongroup            1.2.2              pyhd8ed1ab_1    conda-forge
executing                 2.2.0                    pypi_0    pypi
expat                     2.6.4                h6a678d5_0
filelock                  3.18.0                   pypi_0    pypi
filetype                  1.2.0                    pypi_0    pypi
flash-attn                2.7.4.post1              pypi_0    pypi
fonttools                 4.56.0                   pypi_0    pypi
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.12.0                pypi_0    pypi
fuzzywuzzy                0.18.0                   pypi_0    pypi
grpcio                    1.71.0                   pypi_0    pypi
huggingface-hub           0.29.3                   pypi_0    pypi
hurry                     1.1                      pypi_0    pypi
hurry-filesize            0.9                      pypi_0    pypi
icecream                  2.1.4                    pypi_0    pypi
idna                      3.10                     pypi_0    pypi
importlib-metadata        8.6.1              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh3099207_0    conda-forge
ipython                   9.0.2              pyhfb0248b_0    conda-forge
ipython_pygments_lexers   1.1.1              pyhd8ed1ab_0    conda-forge
jedi                      0.19.2             pyhd8ed1ab_1    conda-forge
jinja2                    3.1.6                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jsonschema                4.23.0                   pypi_0    pypi
jsonschema-specifications 2024.10.1                pypi_0    pypi
jupyter_client            8.6.3              pyhd8ed1ab_1    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kiwisolver                1.4.8                    pypi_0    pypi
krb5                      1.21.3               h143b758_0
ld_impl_linux-64          2.40                 h12ee557_0
libedit                   3.1.20230828         h5eee18b_0
libffi                    3.4.4                h6a678d5_1
libgcc                    14.2.0               h767d61c_2    conda-forge
libgcc-ng                 14.2.0               h69a702a_2    conda-forge
libgomp                   14.2.0               h767d61c_2    conda-forge
libmpdec                  4.0.0                h5eee18b_0
libsodium                 1.0.20               h4ab18f5_0    conda-forge
libstdcxx                 14.2.0               h8f9b012_2    conda-forge
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
lxml                      5.3.1                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
matplotlib                3.10.1                   pypi_0    pypi
matplotlib-inline         0.1.7              pyhd8ed1ab_1    conda-forge
more-itertools            10.6.0                   pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.2.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
nest-asyncio              1.6.0              pyhd8ed1ab_1    conda-forge
networkx                  3.4.2                    pypi_0    pypi
numpy                     2.2.4                    pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-htop               1.2.0                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
opencv-python-headless    4.11.0.86                pypi_0    pypi
openssl                   3.4.1                h7b32b05_0    conda-forge
packaging                 24.2               pyhd8ed1ab_2    conda-forge
pandas                    2.2.3                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_1    conda-forge
peft                      0.14.0                   pypi_0    pypi
pexpect                   4.9.0              pyhd8ed1ab_1    conda-forge
pickleshare               0.7.5           pyhd8ed1ab_1004    conda-forge
pillow                    10.4.0                   pypi_0    pypi
pip                       25.0            py313h06a4308_0
platformdirs              4.3.7              pyh29332c3_0    conda-forge
prompt-toolkit            3.0.50             pyha770c72_0    conda-forge
propcache                 0.3.1                    pypi_0    pypi
protobuf                  6.30.1                   pypi_0    pypi
psutil                    7.0.0           py313h536fd9c_0    conda-forge
ptyprocess                0.7.0              pyhd8ed1ab_1    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_1    conda-forge
pyarrow                   19.0.1                   pypi_0    pypi
pycryptodomex             3.22.0                   pypi_0    pypi
pydantic                  2.11.3                   pypi_0    pypi
pydantic-core             2.33.1                   pypi_0    pypi
pydantic-settings         2.8.1                    pypi_0    pypi
pygments                  2.19.1             pyhd8ed1ab_0    conda-forge
pympler                   1.1                      pypi_0    pypi
pyparsing                 3.2.3                    pypi_0    pypi
pypdfium2                 4.30.0                   pypi_0    pypi
python                    3.13.2          hf623796_100_cp313
python-dateutil           2.9.0.post0        pyhff2d567_1    conda-forge
python-dotenv             1.1.0                    pypi_0    pypi
python_abi                3.13                    0_cp313
pytz                      2025.2                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
pyzmq                     26.3.0          py313h8e95178_0    conda-forge
qwen-vl-utils             0.0.10                   pypi_0    pypi
ray                       3.0.0.dev0               pypi_0    pypi
readline                  8.2                  h5eee18b_0
referencing               0.36.2                   pypi_0    pypi
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rpds-py                   0.24.0                   pypi_0    pypi
safetensors               0.5.3                    pypi_0    pypi
scikit-learn              1.6.1                    pypi_0    pypi
scipy                     1.15.2                   pypi_0    pypi
setuptools                75.8.0          py313h06a4308_0
six                       1.17.0             pyhd8ed1ab_0    conda-forge
sqlite                    3.45.3               h5eee18b_0
stack_data                0.6.3              pyhd8ed1ab_1    conda-forge
surya-ocr                 0.13.1                   pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tensorboardx              2.6.2.2                  pypi_0    pypi
termcolor                 2.5.0                    pypi_0    pypi
threadpoolctl             3.6.0                    pypi_0    pypi
tiktoken                  0.8.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.21.1                   pypi_0    pypi
torch                     2.6.0                    pypi_0    pypi
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.21.0                   pypi_0    pypi
tornado                   6.4.2           py313h536fd9c_0    conda-forge
tqdm                      4.67.1                   pypi_0    pypi
traitlets                 5.14.3             pyhd8ed1ab_1    conda-forge
transformers              4.49.0                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-inspection         0.4.0                    pypi_0    pypi
typing_extensions         4.13.0             pyh29332c3_1    conda-forge
tzdata                    2025.2                   pypi_0    pypi
urllib3                   2.3.0                    pypi_0    pypi
wcwidth                   0.2.13             pyhd8ed1ab_1    conda-forge
wheel                     0.45.1          py313h06a4308_0
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.6.4                h5eee18b_1
yarl                      1.18.3                   pypi_0    pypi
zeromq                    4.3.5                h3b0a872_7    conda-forge
zipp                      3.21.0             pyhd8ed1ab_1    conda-forge
zlib                      1.2.13               h5eee18b_1


krishnakanthnakkav2 avatar Apr 14 '25 13:04 krishnakanthnakkav2

@krishnakanthnakkav2 Your code snippet worked from (I only changed the cache_dir). What I noticed with your env is that it shows PEFT v0.14.0. This should not be the case if you're installing from this branch. Could you please try:

python -m pip install -U git+https://github.com/BenjaminBossan/peft.git@fix-prompt-tuning-4d-attention-mask

If this succeeds, the PEFT version should be v0.15.2.dev0.

BenjaminBossan avatar Apr 14 '25 13:04 BenjaminBossan

True but I made sure I was using this branch with PYTHONPATH. v0.14.0 was installed inside the conda environment and that exists in /minconda3/<envname>/xx/site-packages/peft. I did want to override that since it was working great for other models. What i did was i cloned this branch and put locally next to python script and prepended it to the PYTHONPATH.

Anyway, I installed now directly into the environment by running the python -m pip install -U git+https://github.com/BenjaminBossan/peft.git@fix-prompt-tuning-4d-attention-mask and conda env is

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.5.2                    pypi_0    pypi
aiohappyeyeballs          2.6.1                    pypi_0    pypi
aiohttp                   3.11.14                  pypi_0    pypi
aiosignal                 1.3.2                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
asttokens                 3.0.0              pyhd8ed1ab_1    conda-forge
attrs                     25.3.0                   pypi_0    pypi
av                        14.3.0                   pypi_0    pypi
blobfile                  3.0.0                    pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2025.2.25            h06a4308_0
certifi                   2025.1.31                pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
click                     8.1.8                    pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
comm                      0.2.2              pyhd8ed1ab_1    conda-forge
configspace               1.2.1                    pypi_0    pypi
contourpy                 1.3.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
datasets                  3.4.1                    pypi_0    pypi
debugpy                   1.8.13          py313h46c70d0_0    conda-forge
decorator                 5.2.1              pyhd8ed1ab_0    conda-forge
dill                      0.3.8                    pypi_0    pypi
einops                    0.8.1                    pypi_0    pypi
exceptiongroup            1.2.2              pyhd8ed1ab_1    conda-forge
executing                 2.2.0                    pypi_0    pypi
expat                     2.6.4                h6a678d5_0
filelock                  3.18.0                   pypi_0    pypi
filetype                  1.2.0                    pypi_0    pypi
flash-attn                2.7.4.post1              pypi_0    pypi
fonttools                 4.56.0                   pypi_0    pypi
frozenlist                1.5.0                    pypi_0    pypi
fsspec                    2024.12.0                pypi_0    pypi
fuzzywuzzy                0.18.0                   pypi_0    pypi
grpcio                    1.71.0                   pypi_0    pypi
huggingface-hub           0.29.3                   pypi_0    pypi
hurry                     1.1                      pypi_0    pypi
hurry-filesize            0.9                      pypi_0    pypi
icecream                  2.1.4                    pypi_0    pypi
idna                      3.10                     pypi_0    pypi
importlib-metadata        8.6.1              pyha770c72_0    conda-forge
ipykernel                 6.29.5             pyh3099207_0    conda-forge
ipython                   9.0.2              pyhfb0248b_0    conda-forge
ipython_pygments_lexers   1.1.1              pyhd8ed1ab_0    conda-forge
jedi                      0.19.2             pyhd8ed1ab_1    conda-forge
jinja2                    3.1.6                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
jsonschema                4.23.0                   pypi_0    pypi
jsonschema-specifications 2024.10.1                pypi_0    pypi
jupyter_client            8.6.3              pyhd8ed1ab_1    conda-forge
jupyter_core              5.7.2              pyh31011fe_1    conda-forge
kiwisolver                1.4.8                    pypi_0    pypi
krb5                      1.21.3               h143b758_0
ld_impl_linux-64          2.40                 h12ee557_0
libedit                   3.1.20230828         h5eee18b_0
libffi                    3.4.4                h6a678d5_1
libgcc                    14.2.0               h767d61c_2    conda-forge
libgcc-ng                 14.2.0               h69a702a_2    conda-forge
libgomp                   14.2.0               h767d61c_2    conda-forge
libmpdec                  4.0.0                h5eee18b_0
libsodium                 1.0.20               h4ab18f5_0    conda-forge
libstdcxx                 14.2.0               h8f9b012_2    conda-forge
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
lxml                      5.3.1                    pypi_0    pypi
markupsafe                3.0.2                    pypi_0    pypi
matplotlib                3.10.1                   pypi_0    pypi
matplotlib-inline         0.1.7              pyhd8ed1ab_1    conda-forge
more-itertools            10.6.0                   pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.1.0                    pypi_0    pypi
multidict                 6.2.0                    pypi_0    pypi
multiprocess              0.70.16                  pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
nest-asyncio              1.6.0              pyhd8ed1ab_1    conda-forge
networkx                  3.4.2                    pypi_0    pypi
numpy                     2.2.4                    pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-htop               1.2.0                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
opencv-python-headless    4.11.0.86                pypi_0    pypi
openssl                   3.4.1                h7b32b05_0    conda-forge
packaging                 24.2               pyhd8ed1ab_2    conda-forge
pandas                    2.2.3                    pypi_0    pypi
parso                     0.8.4              pyhd8ed1ab_1    conda-forge
peft                      0.15.2.dev0              pypi_0    pypi
pexpect                   4.9.0              pyhd8ed1ab_1    conda-forge
pickleshare               0.7.5           pyhd8ed1ab_1004    conda-forge
pillow                    10.4.0                   pypi_0    pypi
pip                       25.0            py313h06a4308_0
platformdirs              4.3.7              pyh29332c3_0    conda-forge
prompt-toolkit            3.0.50             pyha770c72_0    conda-forge
propcache                 0.3.1                    pypi_0    pypi
protobuf                  6.30.1                   pypi_0    pypi
psutil                    7.0.0           py313h536fd9c_0    conda-forge
ptyprocess                0.7.0              pyhd8ed1ab_1    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_1    conda-forge
pyarrow                   19.0.1                   pypi_0    pypi
pycryptodomex             3.22.0                   pypi_0    pypi
pydantic                  2.11.3                   pypi_0    pypi
pydantic-core             2.33.1                   pypi_0    pypi
pydantic-settings         2.8.1                    pypi_0    pypi
pygments                  2.19.1             pyhd8ed1ab_0    conda-forge
pympler                   1.1                      pypi_0    pypi
pyparsing                 3.2.3                    pypi_0    pypi
pypdfium2                 4.30.0                   pypi_0    pypi
python                    3.13.2          hf623796_100_cp313
python-dateutil           2.9.0.post0        pyhff2d567_1    conda-forge
python-dotenv             1.1.0                    pypi_0    pypi
python_abi                3.13                    0_cp313
pytz                      2025.2                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
pyzmq                     26.3.0          py313h8e95178_0    conda-forge
qwen-vl-utils             0.0.10                   pypi_0    pypi
ray                       3.0.0.dev0               pypi_0    pypi
readline                  8.2                  h5eee18b_0
referencing               0.36.2                   pypi_0    pypi
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rpds-py                   0.24.0                   pypi_0    pypi
safetensors               0.5.3                    pypi_0    pypi
scikit-learn              1.6.1                    pypi_0    pypi
scipy                     1.15.2                   pypi_0    pypi
setuptools                75.8.0          py313h06a4308_0
six                       1.17.0             pyhd8ed1ab_0    conda-forge
sqlite                    3.45.3               h5eee18b_0
stack_data                0.6.3              pyhd8ed1ab_1    conda-forge
surya-ocr                 0.13.1                   pypi_0    pypi
sympy                     1.13.1                   pypi_0    pypi
tensorboardx              2.6.2.2                  pypi_0    pypi
termcolor                 2.5.0                    pypi_0    pypi
threadpoolctl             3.6.0                    pypi_0    pypi
tiktoken                  0.8.0                    pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.21.1                   pypi_0    pypi
torch                     2.6.0                    pypi_0    pypi
torchsummary              1.5.1                    pypi_0    pypi
torchvision               0.21.0                   pypi_0    pypi
tornado                   6.4.2           py313h536fd9c_0    conda-forge
tqdm                      4.67.1                   pypi_0    pypi
traitlets                 5.14.3             pyhd8ed1ab_1    conda-forge
transformers              4.49.0                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-inspection         0.4.0                    pypi_0    pypi
typing_extensions         4.13.0             pyh29332c3_1    conda-forge
tzdata                    2025.2                   pypi_0    pypi
urllib3                   2.3.0                    pypi_0    pypi
wcwidth                   0.2.13             pyhd8ed1ab_1    conda-forge
wheel                     0.45.1          py313h06a4308_0
xxhash                    3.5.0                    pypi_0    pypi
xz                        5.6.4                h5eee18b_1
yarl                      1.18.3                   pypi_0    pypi
zeromq                    4.3.5                h3b0a872_7    conda-forge
zipp                      3.21.0             pyhd8ed1ab_1    conda-forge
zlib                      1.2.13               h5eee18b_1

and it throws same error

  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
        lambda: run_node(tx.output, node, args, kwargs, nnmodule)
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
        e.__traceback__
    ) from e
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_tensor.py", line 2354, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_subclasses/fake_impls.py", line 399, in local_scalar_dense
    raise DataDependentOutputException(func)
RuntimeError: Failed running call_function <built-in method arange of type object at 0x7f88faf9af60>(*(FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64)), **{'device': device(type='cuda', index=0)}):
aten._local_scalar_dense.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/krishna/PII/cs-llm-temp/misc.py", line 63, in <module>
    generated_output = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        attention_mask=inputs["attention_mask"].cuda(),
        max_length=60
    )
  File "/home/krishna/PII/cs-llm-temp/libs/peftgemmav2/peft/src/peft/peft_model.py", line 1935, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/generation/utils.py", line 2223, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/generation/utils.py", line 3214, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        frame, cache_entry, self.hooks, frame_state, skip=1
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
        frame.f_code,
    ...<14 lines>...
        skip=skip + 1,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/lazy.py", line 170, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/nn_module.py", line 443, in call_function
    return tx.inline_user_function_return(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        variables.UserFunctionVariable(fn, source=fn_source),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        args,
        ^^^^^
        kwargs,
        ^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2482, in CALL_KW
    self._call(inst, call_kw=True)
    ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
        tx=tx,
    ...<4 lines>...
        ),
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
        target_cls, tx, proxy, example_value, subclass_type, **options
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/utils.py", line 2484, in get_fake_value
    unimplemented(
    ~~~~~~~~~~~~~^
        f"data dependent operator: {cause.func}; "
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        "to enable, set torch._dynamo.config.capture_scalar_outputs = True"
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True

from user code:
   File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 887, in forward
    outputs = self.model(
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 612, in forward
    cache_position = torch.arange(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

krishnakanthnakkav2 avatar Apr 14 '25 13:04 krishnakanthnakkav2

That's strange. I created a completely fresh environment and the script passes for me:

conda create -n test6 python=3.12
conda activate test6
python -m pip install -U git+https://github.com/BenjaminBossan/peft.git@fix-prompt-tuning-4d-attention-mask
python issue-2458.py
conda env export
name: test6
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2025.2.25=h06a4308_0
  - expat=2.6.4=h6a678d5_0
  - ld_impl_linux-64=2.40=h12ee557_0
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.16=h5eee18b_0
  - pip=25.0=py312h06a4308_0
  - python=3.12.9=h5148396_0
  - readline=8.2=h5eee18b_0
  - setuptools=75.8.0=py312h06a4308_0
  - sqlite=3.45.3=h5eee18b_0
  - tk=8.6.14=h39e8969_0
  - tzdata=2025a=h04d1e81_0
  - wheel=0.45.1=py312h06a4308_0
  - xz=5.6.4=h5eee18b_1
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==1.6.0
      - certifi==2025.1.31
      - charset-normalizer==3.4.1
      - filelock==3.18.0
      - fsspec==2025.3.2
      - huggingface-hub==0.30.2
      - idna==3.10
      - jinja2==3.1.6
      - markupsafe==3.0.2
      - mpmath==1.3.0
      - networkx==3.4.2
      - numpy==2.2.4
      - nvidia-cublas-cu12==12.4.5.8
      - nvidia-cuda-cupti-cu12==12.4.127
      - nvidia-cuda-nvrtc-cu12==12.4.127
      - nvidia-cuda-runtime-cu12==12.4.127
      - nvidia-cudnn-cu12==9.1.0.70
      - nvidia-cufft-cu12==11.2.1.3
      - nvidia-curand-cu12==10.3.5.147
      - nvidia-cusolver-cu12==11.6.1.9
      - nvidia-cusparse-cu12==12.3.1.170
      - nvidia-cusparselt-cu12==0.6.2
      - nvidia-nccl-cu12==2.21.5
      - nvidia-nvjitlink-cu12==12.4.127
      - nvidia-nvtx-cu12==12.4.127
      - packaging==24.2
      - peft==0.15.2.dev0
      - psutil==7.0.0
      - pyyaml==6.0.2
      - regex==2024.11.6
      - requests==2.32.3
      - safetensors==0.5.3
      - sympy==1.13.1
      - tokenizers==0.21.1
      - torch==2.6.0
      - tqdm==4.67.1
      - transformers==4.51.3
      - triton==3.2.0
      - typing-extensions==4.13.2
      - urllib3==2.4.0

BenjaminBossan avatar Apr 14 '25 14:04 BenjaminBossan

It works now in the test6 environment.

I can see the output Generated text for input 1: In the world of artificial intelligence, and the 1990s, the 1990s, and the 1990s, and the 19999999999999999999999

krishnakanthnakkav2 avatar Apr 14 '25 14:04 krishnakanthnakkav2

Thanks for confirming. Then I think we can proceed with the PR.

@githubnemo the PR is ready for review.

BenjaminBossan avatar Apr 14 '25 14:04 BenjaminBossan

Hello @BenjaminBossan ,

I found some issues when the number of new tokens generate (eg. 25) is less than number of virtual tokens (eg., 50). See the code and error below:


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig
from peft import get_peft_model, TaskType, PromptTuningConfig, PromptTuningInit


# import torch._dynamo  # For managing Dynamo configuration

# # Option 1: Enable scalar output capture
# #torch._dynamo.config.capture_scalar_outputs = True

# # Option 2: Or, use eager execution (fallback)
# torch._dynamo.config.suppress_errors = True

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True
# torch._dynamo.config.capture_scalar_outputs = True


model_name = "google/gemma-2-2b"
cache_dir = "/assets/hub"
attention = "eager"
tokenizer_model_name = model_name

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    attn_implementation=attention,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer.pad_token = tokenizer.eos_token

config = PromptTuningConfig(
    peft_type="PROMPT_TUNING",
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.RANDOM,
    prompt_tuning_init_text="email",  # "phone" # "address"
    num_virtual_tokens=50,
    tokenizer_name_or_path=model_name)

model = get_peft_model(model, config)

input_texts = [
    "In the world of artificial intelligence,",
]

inputs = tokenizer(input_texts, return_tensors="pt",
                   padding=True, truncation=True
                   )

print(f"Input attention mask shape: {inputs['attention_mask'].shape}")

generated_output = model.generate(
    input_ids=inputs["input_ids"].cuda(),
    attention_mask=inputs["attention_mask"].cuda(),
    max_new_tokens=25  # UPDATED
)

generated_texts = tokenizer.batch_decode(
    generated_output, skip_special_tokens=True
)

for idx, generated_text in enumerate(generated_texts):
    print(f"Generated text for input {idx + 1}: {generated_text}")
Traceback (most recent call last):
  File "/home/krishna/PII/cs-llm-temp/misc.py", line 175, in <module>
    generated_output = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        attention_mask=inputs["attention_mask"].cuda(),
        max_length=25
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/peft/peft_model.py", line 1935, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/generation/utils.py", line 2465, in generate
    result = self._sample(
        input_ids,
    ...<5 lines>...
        **model_kwargs,
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/generation/utils.py", line 3431, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 851, in forward
    outputs: BaseModelOutputWithPast = self.model(
                                       ~~~~~~~~~~^
        input_ids=input_ids,
        ^^^^^^^^^^^^^^^^^^^^
    ...<8 lines>...
        **loss_kwargs,
        ^^^^^^^^^^^^^^
    )
    ^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/utils/generic.py", line 965, in wrapper
    output = func(self, *args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 596, in forward
    causal_mask = self._update_causal_mask(
        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 688, in _update_causal_mask
    causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask,
    ...<5 lines>...
        batch_size=input_tensor.shape[0],
    )
  File "/home/krishna/miniconda3/envs/fs-llm-v5/lib/python3.13/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 747, in _prepare_4d_causal_attention_mask_with_cache_position
    padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        causal_mask.device
        ~~~~~~~~~~~~~~~~~~
    )
    ~
RuntimeError: The size of tensor a (24) must match the size of tensor b (58) at non-singleton dimension 3

can you please look into this?

krishnakanthnakkav2 avatar Apr 16 '25 14:04 krishnakanthnakkav2

@krishnakanthnakkav2 I can confirm that this causes issues but can't tell yet where this is coming from. I'd suggest to treat this as a separate issue instead of trying to fix it in this PR.

BenjaminBossan avatar Apr 16 '25 14:04 BenjaminBossan

@BenjaminBossan @krishnakanthnakkav2 thanks for investigating and fixing this! 🚀

I can confirm that after installing the version from this PR, prompt tuning is working for me (using gemma-3-1b-it with the soft prompt size of 100). Hopefully this will get merged soon 👍

Wicwik avatar Apr 28 '25 11:04 Wicwik

Thanks for checking this and sharing your results with us @Wicwik. Yes, hopefully we can merge the PR very soon.

BenjaminBossan avatar Apr 28 '25 11:04 BenjaminBossan