lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Fixes the value_and_grad check in splitter on symbol scaled_dot_product_attention

Open kiya00 opened this issue 7 months ago • 2 comments

Before submitting
  • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure to update the docs?
  • [ ] Did you write any new necessary tests?

What does this PR do?

In PR #2035 , the value_and_grad(thunder_symbol) is used in try_execute_thunder_symbol in splitter to determine if the node can be executed by Thunder. when running HF bigcode/starcoder2-7b, there are 96 SplitReasons due to

(Pdb) len(subgraph_info.split_reasons)
96
(Pdb) subgraph_info.split_reasons[0]
SplitReason(reason_type=<SplitReasonType.EXCEPTION_META_THUNDER_OP: 4>, info='Failed while running meta for node with name: attn_output and target: <built-in function scaled_dot_product_attention>, see exception field', exception="Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")

the reason is:

Traceback (most recent call last):
  File "/wayan/lightning-thunder/thunder/dynamo/utils.py", line 308, in _run_with_cache_info
    function_to_run(*proxy_args, **proxy_kwargs)
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2975, in _value_and_grad
    return vjp(func)(args, cotangents, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2937, in _vjp
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2919, in vjp_call
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2704, in augmented_forward_pass
    result, env = eval_trace(
                  ^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/trace_interpreter.py", line 66, in interpret_trace
    prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
                ^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 2629, in vjp_symbol_mapper
    if _get_gradfn_and_executor(symbol)[0] is not None:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/core/transforms.py", line 1498, in _get_gradfn_and_executor
    if ex.can_execute_or_fuse(bsym):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 87, in can_execute_or_fuse
    return self.can_execute(bsym) or self.can_fuse(bsym)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/extend/__init__.py", line 99, in can_execute
    return impl.checker(*bsym.args, **bsym.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 403, in _cudnn_sdpa_checker
    _make_cudnn_sdpa_forward_graph(
  File "/wayan/lightning-thunder/thunder/executors/cudnnex.py", line 164, in _make_cudnn_sdpa_forward_graph
    O, softmax_stats = graph.scaled_dot_product_flash_attention(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unable to cast Python instance of type <class 'thunder.core.proxies.FloatProxy'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

This PR fixes it by converting the NumberProxy in the checker to values

kiya00 avatar May 15 '25 10:05 kiya00

It seems the error arises in _cudnn_sdpa_checker because a FloatProxy type is being passed instead of a float. My earlier assumption was that try_execute_thunder_symbol only invokes the meta function associated with the Torch symbol (without interacting with the executors). Could the reason for this behavior be related to #2035, where value_and_grad leads to transform_for_execution which then calls _cudnn_sdpa_checker? I don't believe we should be reaching the executor-level symbols in this case.

For example, in the provided code snippet, even when we configure the system to use only the always-on executor, the same error still occurs in the cuDNN checker function (likely because value_and_grad relies on the default executors).

import torch
from torch.utils import benchmark
import transformer_engine
import transformer_engine.pytorch as te

from transformers import AutoConfig, AutoModelForCausalLM
from huggingface_hub import login

dtype = torch.bfloat16
BS = 2
SEQ_LEN = 4096


HF_TOKEN = None
# These models may require accepting licenses.
models = ("bigcode/starcoder2-7b",)

from thunder.dynamo import thunderfx
import thunder

for model_name in models:

    with torch.device("cuda"):
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=False)
        config.num_hidden_layers = 2
        model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=False)
        # Works
        model(torch.randint(0, 10, (BS, SEQ_LEN)))
        print("EAGER WORKS")

        # Fails
        model = thunderfx(model, executors=()) # Only always-on executors
        model(torch.randint(0, 10, (BS, SEQ_LEN)))

Ideally, we should:

  • Prevent value_and_grad from accessing executor symbols (preferred approach), or
  • Ensure value_and_grad uses the same set of executors consistently (that user passed). In this case, we would also need to validate that all checkers within the default executors correctly handle FloatProxy and IntegerProxy types.

cc: @IvanYashchuk for opinion

kshitij12345 avatar May 16 '25 08:05 kshitij12345

Hi @kshitij12345

My earlier assumption was that try_execute_thunder_symbol only invokes the meta function associated with the Torch symbol (without interacting with the executors). Could the reason for this behavior be related to https://github.com/Lightning-AI/lightning-thunder/pull/2035, where value_and_grad leads to transform_for_execution which then calls _cudnn_sdpa_checker? I don't believe we should be reaching the executor-level symbols in this case.

yes, the reason is that I added value_and_grad to check the backward in the splitter. Currently we have 2 ways to register backward: the augmented_forward_impls+backward_impls or _grad_fn_map, and in the value_and_grad it checks both (https://github.com/Lightning-AI/lightning-thunder/blob/f73bfa0eef857532beeaea44964b0bfd9926325a/thunder/core/transforms.py#L2624-L2627), and it appears that _get_gradfn_and_executor, which for _grad_fn_map, is intended to be bound to the executor. https://github.com/Lightning-AI/lightning-thunder/blob/f73bfa0eef857532beeaea44964b0bfd9926325a/thunder/core/transforms.py#L1492-L1507 this is where brings interaction with the executors. So maybe it's not easy to prevent value_and_grad from accessing executor symbols It seems hard to move the ex_grad_transform: None | Callable = ex.get_grad_transform(bsym.sym) part into the later transform_for_execution

@IvanYashchuk maybe you have some idea about this

kiya00 avatar May 20 '25 09:05 kiya00

Hi @IvanYashchuk @kshitij12345 , could you take a look again? I added the repro in the issue #2120

kiya00 avatar May 28 '25 10:05 kiya00

Hi @kshitij12345 I created this issue as a follow-up discussion. I did a quick check on the checker functions in other executors, and it seems they work correctly with NumberProxy because NumberProxy can handle most operations involving regular numbers.

kiya00 avatar May 28 '25 13:05 kiya00

Hi @t-vi , @mruberry , it's ready to merge

kiya00 avatar Jun 06 '25 13:06 kiya00