functionalTensor is not supported by ipex custom kernel when using Torch.compile after ipex.llm.optimize
Describe the issue
I was attempting to use the torch.compile after doing the ipex.llm.optimize on language model on a Max 1100 GPU. My goal is to improve the torch.compile by recognizing the fx graph pattern and directly using the custom kernel such as torch_ipex.xetla_sdp_dropout. However, as I was testing torch.compile on the NewIPEXBertSelfAttention as shown in the following code,
import torch
import torch.nn as nn
torch.set_default_dtype(torch.float16)
from transformers.models.bert.modeling_bert import BertConfig, BertSelfAttention
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.bert import NewIPEXBertSelfAttention
config = BertConfig()
attention_layer = BertSelfAttention(config)
new_attention = NewIPEXBertSelfAttention(attention_layer, config).to('xpu')
batch_size = 2
seq_length = 10
hidden_size = config.hidden_size
hidden_states = torch.rand(batch_size, seq_length, hidden_size).to('xpu') # Random input tensor
attention_mask = torch.ones(batch_size, 1, 1, seq_length).to('xpu') # No masking for simplicity
# Forward pass
outputs = new_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
)
# Print the outputs
print("Output shape:", outputs[0].shape) # Should be (batch_size, seq_length, hidden_size)
print(outputs)
if len(outputs) > 1:
print("Past key value shape:", [pkv.shape for pkv in outputs[1]])
# Compile the new_attention module using torch.compile
compiled_attention = torch.compile(new_attention)
# Forward pass with compiled module
compiled_outputs = compiled_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
)
# Print the outputs from the compiled module
print("Compiled output shape:", compiled_outputs[0].shape) # Should be (batch_size, seq_length, hidden_size)
print(compiled_outputs)
if len(compiled_outputs) > 1:
print("Compiled past key value shape:", [pkv.shape for pkv in compiled_outputs[1]])
I got error as the following:
Traceback (most recent call last):
File "/workspace/newipexatten.py", line 57, in <module>
compiled_outputs = compiled_attention(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
return self._torchdynamo_orig_callable(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
result = self._inner_convert(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
out_code = transform_code_object(code, transform)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
transformations(instructions, code_options)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
tracer.run()
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
super().run()
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
return super().call_function(tx, args, kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
tracer.run()
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
while self.step():
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
return inner_fn(self, inst)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 897, in call_function
tensor_variable = wrap_fx_proxy(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2124, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2082, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
ret_val = wrap_fake_exception(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
return fn()
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2018, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
return node.target(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_ipex.xetla_sdp_dropout(*(FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 1, 1, 10)), 0.1, False, None), **{}):
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
from user code:
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/bert.py", line 103, in forward
context_layer = torch.xpu.IpexSDP_dropout(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/intel_extension_for_pytorch/xpu/intrinsic/__init__.py", line 163, in IpexSDP_dropout
return torch.ops.torch_ipex.xetla_sdp_dropout(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I followed Adding torch.compile support for an operator to added FakeTensor kernels for torch.ops.torch_ipex.xetla_sdp_dropout and use torch.library.opcheck for detail failed reason as shown below.
import torch
from torch import Tensor
import intel_extension_for_pytorch as ipex
@torch.library.register_fake("torch_ipex::xetla_sdp_dropout")
def _(query, key, value, attn_mask, dropout_p, is_causal, scale):
# Add Python-side checks for inputs
torch._check(query.shape == key.shape == value.shape, "All three must have the same embedding dimension")
torch._check(query.dtype == key.shape == value.shape == torch.float16, "Query must be float16")
torch._check(query.device == key.device == value.device == attn_mask.device, "All inputs must be on the same device")
# Provide a fake output shape or mock implementation for testing
return torch.empty_like(query)
sample_inputs = [
(torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
torch.ones(2, 1, 1, 10, dtype=torch.float16, device="xpu"),
0.0, False, None)
]
for args in sample_inputs:
torch.library.opcheck(torch.ops.torch_ipex.xetla_sdp_dropout, args, test_utils='test_aot_dispatch_static')
The code will fail on test_aot_dispatch_static and test_aot_dispatch_dynamic. I had a close look at the source code and found the problem is at OpOverload. It will return None and cause the failure. It seems like the problem is because the functionalTensor is not supported even I have already declared the fakeTensor kernel. The error is as follows:
Traceback (most recent call last):
File "/root/miniforge3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniforge3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
cli.main()
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
run()
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
runpy.run_path(target, run_name="__main__")
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
_run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
exec(code, run_globals)
File "/workspace/opcheck.py", line 48, in <module>
torch.library.opcheck(torch.ops.torch_ipex.xetla_sdp_dropout, args, test_utils='test_aot_dispatch_static')
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/library.py", line 1322, in opcheck
return optests.opcheck(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 657, in opcheck
tester(op, args, kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 114, in safe_aot_autograd_check
return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/aot_autograd.py", line 75, in aot_autograd_check
compiled_out = wrapper_set_seed(compiled_f, args)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_utils.py", line 18, in wrapper_set_seed
output = op(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 868, in returned_function
compiled_fn, _ = create_aot_dispatcher_function(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/aot_autograd.py", line 64, in func_no_tensors
return func(*c_args, **c_kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 110, in func
return op(*args, **kwargs)
File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
return self._op(*args, **kwargs)
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Any tips on how to solve this problem will be really helpful!
My system configuration is as follows:
PyTorch version: 2.5.1+cxx11.abi
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.10.16 | packaged by conda-forge | (main, Dec 5 2024, 14:16:10) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-118-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 35
On-line CPU(s) list: 0-34
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480+
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 35
Socket(s): 1
Stepping: 8
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.1 MiB (35 instances)
L1i cache: 1.1 MiB (35 instances)
L2 cache: 140 MiB (35 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-34
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] intel_extension_for_pytorch==2.5.10+xpu
[pip3] numpy==1.26.4
[pip3] pytorch-triton-xpu==3.1.0+91b14bf559
[pip3] torch==2.5.1+cxx11.abi
[pip3] triton==3.2.0+git6ee08cd2
[conda] intel-extension-for-pytorch 2.5.10+xpu pypi_0 pypi
[conda] mkl 2025.0.1 pypi_0 pypi
[conda] mkl-dpcpp 2025.0.1 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] onemkl-sycl-blas 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-datafitting 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-dft 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-lapack 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-rng 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-sparse 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-stats 2025.0.1 pypi_0 pypi
[conda] onemkl-sycl-vm 2025.0.1 pypi_0 pypi
[conda] pytorch-triton-xpu 3.1.0+91b14bf559 pypi_0 pypi
[conda] torch 2.5.1+cxx11.abi pypi_0 pypi
[conda] triton 3.2.0+git6ee08cd2 pypi_0 pypi
IPEX version: 2.5.10+xpu
IPEX commit: Unknown
@ZhaoqiongZ , any update?
Hi @EikanWang , I pass the issue to Su, Tong, since it is too detail with torch.compile feature.
Hi @lostkingdom4 , thanks for the detailed reproducer! I did a check, the problem is because the registered meta function's schema does not find the exact custom op. Actually, if you change your registration like the below:
@torch.library.register_fake("torch_ipex::xetla_sdp_dropout")
def xetla_sdp_dropout(query, key, value, attn_mask, dropout_p, is_causal, scale):
print("run into fake")
assert False
You will find that this still does not throw the assert False. This means that the fake tensor does not actually run into this op.
After further investigation, I found that when registering the custom op, it has the dispatch key of c10::DispatchKey::AutogradXPU.
https://github.com/intel/intel-extension-for-pytorch/blob/033af6f63745ac748cccdadee5c6140c7971edf6/csrc/gpu/aten/operators/transformers/attention.cpp#L2771
This makes the schema can't be correctly found.
A quick and temporary solution is to change c10::DispatchKey::AutogradXPU to c10::DispatchKey::XPU and rebuild IPEX. Then your fake tensor registration would work.
This solution is not perfect and just a temporary solution, but normally it won't affect much performance/accuracy. We will try to fix that later. Thanks again for your patience!
Hi @Stonepia Thanks for the feedback. It works for me!
Hi @Stonepia, I was trying to do the same thing for torch.ops.torch_ipex.mm_qkv_out(input, self.weight, self.bias, q, k, v). With a fake register as an example:
@torch.library.register_fake("torch_ipex::mm_qkv_out.xpu")
def _(query, key, value, attn_mask, dropout_p, is_causal, scale):
print("run into fake")
assert False
I got
RuntimeError: register_fake(...): the operator torch_ipex::mm_qkv_out.xpu already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an fake impl; instead, the operator will decompose into its constituents and those can have fake impls defined on them.
I have already modified the dispatch key but I think this might not be the problem. Is there a quick fix for this type of operator? I went through the C++ source code but I'm still not entirely sure where it is registered as CompositeImplicitAutograd.
I try to solve the problem by modifying code in csrc
IPEX_OP_REGISTER("mm_qkv_out.xpu", at::AtenIpexTypeXPU::mm_qkv_out);
IPEX_OP_REGISTER_DISPATCH(
"mm_qkv_out.xpu",
at::AtenIpexTypeXPU::mm_qkv_out_autocast,
c10::DispatchKey::AutocastXPU);
to
// IPEX_OP_REGISTER("mm_qkv_out", at::AtenIpexTypeXPU::mm_qkv_out);
IPEX_OP_REGISTER_DISPATCH(
"mm_qkv_out",
at::AtenIpexTypeXPU::mm_qkv_out,
c10::DispatchKey::XPU);
The problem with the fake tensor seems to be solved. However, when running with torch.compile, the dynamo will cause a graph break on this operator. Could you double-check that this will work for operators with a pre-existing registration to DispatchKey::CompositeImplicitAutograd?
Meanwhile, what is the difference between IPEX_OP_REGISTER and IPEX_OP_REGISTER_DISPATCH? Why do we want to put .xpu at the end of the mm_qkv_out? I know it is for overloading to XPU. But is this necessary?
Thanks
@Stonepia Is there any update on this?
Thanks
Hi @lostkingdom4 , I suspect this should be a bug by PyTorch, not your implementation issue. We haven't found it because no one has tried this path (custom ops registered in another dispatch key). So the next step is to write a reproducer and submit it to the PyTorch issue.
Apologize for that I didn't take the chance to get some bandwidth on this. I will update you once I have some new findings.
@Stonepia
Thanks for replying, and thanks for your effort on this matter.
Essentially, I want the torch.compile to understand the operator with a simple fake tensor registration. So, if the operator is significantly faster than what is generated by Triton, we can use it instead.
I also tried to understand the problem by going through the source codes of IPEX_OP_REGISTER, IPEX_OP_REGISTER_DISPATCH, and TORCH_LIBRARY_IMPL. I think the problem is that once the operator registered using IPEX_OP_REGISTER. The dispatch key will be registered as CompositeImplicitAutograd. I've tried to use IPEX_OP_REGISTER_DISPATCH to register it to a certain dispatch key and rebuild it. However the build is always unsuccessful.
@Stonepia Hi, I accidentally marked this thread as closed. Can you reopen it?
Hi, @lostkingdom4 Yes, that's why I said that the problem is not your implementation, it should be PyTorch side issue (not even IPEX side). So I would assume to write a reproducer from pure PyTorch side code for reproducing, but didn't get a chance😿 If you are interested, there are two manuals that would help:
- torch.compile : https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?tab=t.0
- custom operators: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0
@Stonepia Thanks for replying, and thanks for your effort on this matter. Essentially, I want the torch.compile to understand the operator with a simple fake tensor registration. So, if the operator is significantly faster than what is generated by Triton, we can use it instead.
I also tried to understand the problem by going through the source codes of IPEX_OP_REGISTER, IPEX_OP_REGISTER_DISPATCH, and TORCH_LIBRARY_IMPL. I think the problem is that once the operator registered using IPEX_OP_REGISTER. The dispatch key will be registered as CompositeImplicitAutograd. I've tried to use IPEX_OP_REGISTER_DISPATCH to register it to a certain dispatch key and rebuild it. However, the build is always unsuccessful.
As I basically need to rebuild the entire pytorch and ipex every time I tried with new registration, it would be really helpful if you could give me some tips on just building the operator without rebuilding the entire system. So I can try to solve this problem more conveniently. I might have a way to solve the problem, but building the ipex from scratch is killing me.
I don't think you need to rebuild everytime, you could try first with the Python Op registration, it should be the same with C++ side. You don't need to build PyTorch as well.
I suggest starting from a simpler custom op with Python registration, to see if everything goes well. Then move to the harder one (that fused everything on IPEX).
I see what you mean. I've already tried the custom operator registration. It works. The only thing I'm not so sure about is these headers. For example: https://github.com/intel/intel-extension-for-pytorch/blob/5b268a58047430005f23023a0c8dbf55882c50c8/csrc/gpu/aten/operators/XeGemm.cpp#L1-L14
But I think I will try to build it within the repository so I don't need to worry about the path. I will get back to you after I have some findings.
Again, thanks for the information.