pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

[CPU][Flex attn] Add a readable error message for the backward path

Open Valentine233 opened this issue 2 weeks ago • 1 comments

Fixes #169224. The flex attention does not support backward path on CPU. This PR adds a readable and meaningful error message for the case.

Before:

Traceback (most recent call last):
  File "/workspace/test_flex_attn.py", line 24, in <module>
    output = flex_attention(query, key, value, block_mask=block_mask)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 940, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1019, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1003, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1757, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1452, in codegen_and_compile
    graph.run(*example_inputs)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 987, in run
    return super().run(*args)
  File "/workspace/pytorch/torch/fx/interpreter.py", line 200, in run
    self.env[node] = self.run_node(node)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 1726, in run_node
    result = super().run_node(n)
  File "/workspace/pytorch/torch/fx/interpreter.py", line 295, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/workspace/pytorch/torch/_inductor/graph.py", line 1257, in call_function
    return super().call_function(target, args, kwargs)
  File "/workspace/pytorch/torch/fx/interpreter.py", line 375, in call_function
    return target(*args, **kwargs)
torch._inductor.exc.InductorError: IndexError: tuple index out of range

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

After:

Traceback (most recent call last):
  File "/workspace/test_flex_attn.py", line 24, in <module>
    output = flex_attention(query, key, value, block_mask=block_mask)
  File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 926, in compile_wrapper
    return fn(*args, **kwargs)
  File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1481, in flex_attention
    _validate_device(query, key, value)
  File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1332, in _validate_device
    raise NotImplementedError(
NotImplementedError: FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela

Valentine233 avatar Dec 05 '25 07:12 Valentine233

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/169646

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: You can merge normally! (1 Unrelated Failure)

As of commit 2b1ac972043eb6de2189b9498d36255ee3ef92b7 with merge base dd1f0f8f66671424347b323f59beae2964254491 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Dec 05 '25 07:12 pytorch-bot[bot]

@pytorchbot merge

Valentine233 avatar Dec 15 '25 05:12 Valentine233

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Dec 15 '25 05:12 pytorchmergebot