numba-dpex
numba-dpex copied to clipboard
SPIRV flags are not set for the second call of atomics `fetch_add` call
When two different kernels that call the same overloaded function, e.g., fetch_add
in the reproducer, are compiled, the extra compilation flags needed for llvm-spirv translation are only applied to the kernel that is compiled first.
For the first kernel when the fetch_add
overload is not available in the compiled cache, it is compiled and during the compilation of the fetch_add
function the intrinsic
function adds the extra flags to the target context's compilation options. The next time since, a compiled version of fetch_add
is available in the dispatcher's overload
cache the intrinsic
is not invoked. Thus, the extra compilation flags are never updated and an internal compiler error is raised during llvm-spriv translation.
A minimal reproducer:
@dpex_exp.kernel
def atomic_ref_0(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=0)
v.fetch_add(a[i + 2])
@dpex_exp.kernel
def atomic_ref_1(a):
i = dpex.get_global_id(0)
v = AtomicRef(a, index=1)
v.fetch_add(a[i + 2])
def test_spirv_flags():
N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
dpex_exp.call_kernel(atomic_ref_0, dpex.Range(N - 2), a) # flags set here because intrinsic is not cached
# SPIRV flags then being removed in `spirv_generator.Module.finalize`.
dpex_exp.call_kernel(atomic_ref_1, dpex.Range(N - 2), a) # intrinsic is already compiled and cached, so flags are not set
assert a[0] == N - 1
assert a[1] == N - 1
I did some investigation and for old style it works. My guess is that old style uses lower
instead of intrinsic
.
In order to reproduce update llvm_spirv_args
to empty list at spirv_genrator.py
. Search for the reference to this issue.
@ZzEeKkAa thank you for the reproducer. The issue with the overload PR is now clear to me.
The issue happens because overloads are not compiled to SPIR-V. We do SPIR-V compilation only for kernel functions after all overloads compiled to LLVM are linked to the kernel function at the level of LLVM bitcode.
As a solution, when we compile an overload
, e.g., fetch_add
, any extra compilation flag should be stored as part of the CompileResult
for that overload. When a kernel code library is finalized
by the dispatcher, the process should gather all extra compilation flags for every library that is linked into the final code library for the kernel. Then the llvm-spriv should be invoked correctly with all needed flags.
@diptorupd I like the idea. I know it will depend on realization but we need to keep those flag populated for any other overload wrapper. I'm saying if any overload uses overload with compilation flags we need to store those compilation flags also on caller level overload.