[CPU][Flex attn] Add a readable error message for the backward path
Fixes #169224. The flex attention does not support backward path on CPU. This PR adds a readable and meaningful error message for the case.
Before:
Traceback (most recent call last):
File "/workspace/test_flex_attn.py", line 24, in <module>
output = flex_attention(query, key, value, block_mask=block_mask)
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 940, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1019, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1003, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1757, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1452, in codegen_and_compile
graph.run(*example_inputs)
File "/workspace/pytorch/torch/_inductor/graph.py", line 987, in run
return super().run(*args)
File "/workspace/pytorch/torch/fx/interpreter.py", line 200, in run
self.env[node] = self.run_node(node)
File "/workspace/pytorch/torch/_inductor/graph.py", line 1726, in run_node
result = super().run_node(n)
File "/workspace/pytorch/torch/fx/interpreter.py", line 295, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/workspace/pytorch/torch/_inductor/graph.py", line 1257, in call_function
return super().call_function(target, args, kwargs)
File "/workspace/pytorch/torch/fx/interpreter.py", line 375, in call_function
return target(*args, **kwargs)
torch._inductor.exc.InductorError: IndexError: tuple index out of range
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
After:
Traceback (most recent call last):
File "/workspace/test_flex_attn.py", line 24, in <module>
output = flex_attention(query, key, value, block_mask=block_mask)
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 926, in compile_wrapper
return fn(*args, **kwargs)
File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1481, in flex_attention
_validate_device(query, key, value)
File "/workspace/pytorch/torch/nn/attention/flex_attention.py", line 1332, in _validate_device
raise NotImplementedError(
NotImplementedError: FlexAttention does not support backward on CPU. Please set the input requires_grad to False or use another device.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/169646
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: You can merge normally! (1 Unrelated Failure)
As of commit 2b1ac972043eb6de2189b9498d36255ee3ef92b7 with merge base dd1f0f8f66671424347b323f59beae2964254491 ():
FLAKY - The following job failed but was likely due to flakiness present on trunk:
- inductor / unit-test / inductor-pallas-cpu-test / test (inductor-pallas-cpu, 1, 1, linux.12xlarge) (gh) (similar failure)
Process completed with exit code 1.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@pytorchbot merge
Merge started
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here