numba-dpex icon indicating copy to clipboard operation
numba-dpex copied to clipboard

SPIRV flags are not set for the second call of atomics `fetch_add` call

Open ZzEeKkAa opened this issue 1 year ago • 2 comments

When two different kernels that call the same overloaded function, e.g., fetch_add in the reproducer, are compiled, the extra compilation flags needed for llvm-spirv translation are only applied to the kernel that is compiled first.

For the first kernel when the fetch_add overload is not available in the compiled cache, it is compiled and during the compilation of the fetch_add function the intrinsic function adds the extra flags to the target context's compilation options. The next time since, a compiled version of fetch_add is available in the dispatcher's overload cache the intrinsic is not invoked. Thus, the extra compilation flags are never updated and an internal compiler error is raised during llvm-spriv translation.

A minimal reproducer:

@dpex_exp.kernel
def atomic_ref_0(a):
    i = dpex.get_global_id(0)
    v = AtomicRef(a, index=0)
    v.fetch_add(a[i + 2])


@dpex_exp.kernel
def atomic_ref_1(a):
    i = dpex.get_global_id(0)
    v = AtomicRef(a, index=1)
    v.fetch_add(a[i + 2])


def test_spirv_flags():
    N = 10
    a = dpnp.ones(N, dtype=dpnp.float32)

    dpex_exp.call_kernel(atomic_ref_0, dpex.Range(N - 2), a) # flags set here because intrinsic is not cached
    # SPIRV flags then being removed in `spirv_generator.Module.finalize`.
    dpex_exp.call_kernel(atomic_ref_1, dpex.Range(N - 2), a) # intrinsic is already compiled and cached, so flags are not set

    assert a[0] == N - 1
    assert a[1] == N - 1

I did some investigation and for old style it works. My guess is that old style uses lower instead of intrinsic.

In order to reproduce update llvm_spirv_args to empty list at spirv_genrator.py. Search for the reference to this issue.

ZzEeKkAa avatar Dec 28 '23 19:12 ZzEeKkAa

@ZzEeKkAa thank you for the reproducer. The issue with the overload PR is now clear to me.

The issue happens because overloads are not compiled to SPIR-V. We do SPIR-V compilation only for kernel functions after all overloads compiled to LLVM are linked to the kernel function at the level of LLVM bitcode.

As a solution, when we compile an overload, e.g., fetch_add, any extra compilation flag should be stored as part of the CompileResult for that overload. When a kernel code library is finalized by the dispatcher, the process should gather all extra compilation flags for every library that is linked into the final code library for the kernel. Then the llvm-spriv should be invoked correctly with all needed flags.

diptorupd avatar Dec 28 '23 19:12 diptorupd

@diptorupd I like the idea. I know it will depend on realization but we need to keep those flag populated for any other overload wrapper. I'm saying if any overload uses overload with compilation flags we need to store those compilation flags also on caller level overload.

ZzEeKkAa avatar Jan 05 '24 22:01 ZzEeKkAa