Fixes the value_and_grad check in splitter on symbol scaled_dot_product_attention
…
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
What does this PR do?
In PR #2035 , the value_and_grad(thunder_symbol) is used in try_execute_thunder_symbol in splitter to determine if the node can be executed by Thunder.
when running HF bigcode/starcoder2-7b, there are 96 SplitReasons due to
(Pdb) len(subgraph_info.split_reasons)
96
(Pdb) subgraph_info.split_reasons[0]
SplitReason(reason_type=<SplitReasonType.EXCEPTION_META_THUNDER_OP: 4>, info='Failed while running meta for node with name: attn_output and target: <built-in function scaled_dot_product_attention>, see exception field', exception="Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
the reason is:
Traceback (most recent call last):
File "/wayan/lightning-thunder/thunder/dynamo/utils.py", line 308, in _run_with_cache_info
function_to_run(*proxy_args, **proxy_kwargs)
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2975, in _value_and_grad
return vjp(func)(args, cotangents, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2937, in _vjp
result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2919, in vjp_call
result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2704, in augmented_forward_pass
result, env = eval_trace(
^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/trace_interpreter.py", line 66, in interpret_trace
prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2629, in vjp_symbol_mapper
if _get_gradfn_and_executor(symbol)[0] is not None:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/core/transforms.py", line 1498, in _get_gradfn_and_executor
if ex.can_execute_or_fuse(bsym):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 87, in can_execute_or_fuse
return self.can_execute(bsym) or self.can_fuse(bsym)
^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 99, in can_execute
return impl.checker(*bsym.args, **bsym.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 403, in _cudnn_sdpa_checker
_make_cudnn_sdpa_forward_graph(
File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 164, in _make_cudnn_sdpa_forward_graph
O, softmax_stats = graph.scaled_dot_product_flash_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
This PR fixes it by converting the NumberProxy in the checker to values
It seems the error arises in _cudnn_sdpa_checker because a FloatProxy type is being passed instead of a float. My earlier assumption was that try_execute_thunder_symbol only invokes the meta function associated with the Torch symbol (without interacting with the executors). Could the reason for this behavior be related to #2035, where value_and_grad leads to transform_for_execution which then calls _cudnn_sdpa_checker? I don't believe we should be reaching the executor-level symbols in this case.
For example, in the provided code snippet, even when we configure the system to use only the always-on executor, the same error still occurs in the cuDNN checker function (likely because value_and_grad relies on the default executors).
import torch
from torch.utils import benchmark
import transformer_engine
import transformer_engine.pytorch as te
from transformers import AutoConfig, AutoModelForCausalLM
from huggingface_hub import login
dtype = torch.bfloat16
BS = 2
SEQ_LEN = 4096
HF_TOKEN = None
# These models may require accepting licenses.
models = ("bigcode/starcoder2-7b",)
from thunder.dynamo import thunderfx
import thunder
for model_name in models:
with torch.device("cuda"):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=False)
config.num_hidden_layers = 2
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=False)
# Works
model(torch.randint(0, 10, (BS, SEQ_LEN)))
print("EAGER WORKS")
# Fails
model = thunderfx(model, executors=()) # Only always-on executors
model(torch.randint(0, 10, (BS, SEQ_LEN)))
Ideally, we should:
- Prevent value_and_grad from accessing executor symbols (preferred approach), or
- Ensure value_and_grad uses the same set of executors consistently (that user passed). In this case, we would also need to validate that all checkers within the default executors correctly handle FloatProxy and IntegerProxy types.
cc: @IvanYashchuk for opinion
Hi @kshitij12345
My earlier assumption was that try_execute_thunder_symbol only invokes the meta function associated with the Torch symbol (without interacting with the executors). Could the reason for this behavior be related to https://github.com/Lightning-AI/lightning-thunder/pull/2035, where value_and_grad leads to transform_for_execution which then calls _cudnn_sdpa_checker? I don't believe we should be reaching the executor-level symbols in this case.
yes, the reason is that I added value_and_grad to check the backward in the splitter. Currently we have 2 ways to register backward: the augmented_forward_impls+backward_impls or _grad_fn_map, and in the value_and_grad it checks both (https://github.com/Lightning-AI/lightning-thunder/blob/f73bfa0eef857532beeaea44964b0bfd9926325a/thunder/core/transforms.py#L2624-L2627), and it appears that _get_gradfn_and_executor, which for _grad_fn_map, is intended to be bound to the executor.
https://github.com/Lightning-AI/lightning-thunder/blob/f73bfa0eef857532beeaea44964b0bfd9926325a/thunder/core/transforms.py#L1492-L1507
this is where brings interaction with the executors.
So maybe it's not easy to prevent value_and_grad from accessing executor symbols
It seems hard to move the ex_grad_transform: None | Callable = ex.get_grad_transform(bsym.sym) part into the later transform_for_execution
@IvanYashchuk maybe you have some idea about this
Hi @IvanYashchuk @kshitij12345 , could you take a look again? I added the repro in the issue #2120
Hi @kshitij12345 I created this issue as a follow-up discussion. I did a quick check on the checker functions in other executors, and it seems they work correctly with NumberProxy because NumberProxy can handle most operations involving regular numbers.
Hi @t-vi , @mruberry , it's ready to merge