numba-dpex icon indicating copy to clipboard operation
numba-dpex copied to clipboard

numba.literally doesn't work in kernels

Open AlexanderKalistratov opened this issue 1 year ago • 2 comments

Reproducer:

import dpnp as np

import numba_dpex as ndpx
from numba import literally


@ndpx.kernel
def lit(a, n):
    literally(n)
    local_a_0 = ndpx.private.array(n, dtype=a.dtype)
    lid = ndpx.get_local_id(0)

    if lid == 0:
        result = a.dtype.type(0)
        for i in range(n):
            local_a_0[i] = a[i]

        for i in range(n):
            a[0] += local_a_0[i]


a = np.ones(10, dtype='float32')

lit[a.shape[0], ](a, a.shape[0])

print("Result: ", a[0])

Result:

/home/akalistr/repo/literally.py:28: DeprecationWarning: The current syntax for specification of kernel launch parameters is deprecated. Users should set the kernel parameters through Range/NdRange classes.
Example:
    from numba_dpex import Range,NdRange

    # for global range only
    <function>[Range(X,Y)](<parameters>)
    # or,
    # for both global and local ranges
    <function>[NdRange((X,Y), (P,Q))](<parameters>)
  lit[a.shape[0], ](a, a.shape[0])
Traceback (most recent call last):
  File "/home/akalistr/repo/literally.py", line 28, in <module>
    lit[a.shape[0], ](a, a.shape[0])
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/dispatcher.py", line 455, in __call__
    ) = self._compile_and_cache(
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/dispatcher.py", line 141, in _compile_and_cache
    kernel.compile(
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/spirv_kernel.py", line 123, in compile
    cres = compile_with_dpex(
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/compiler.py", line 64, in compile_with_dpex
    cres = compiler.compile_extra(
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 742, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 460, in compile_extra
    return self._compile_bytecode()
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 528, in _compile_bytecode
    return self._compile_core()
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 507, in _compile_core
    raise e
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 494, in _compile_core
    pm.run(self.state)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 368, in run
    raise patched_exception
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/untyped_passes.py", line 426, in run_pass
    find_literally_calls(state.func_ir, state.args)
  File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/analysis.py", line 696, in find_literally_calls
    raise errors.ForceLiteralArg(marked_args, loc=loc)
numba.core.errors.ForceLiteralArg: Failed in dpex_kernel_nopython mode pipeline (step: find literally calls)
Pseudo-exception to force literal arguments in the dispatcher

File "literally.py", line 13:
def lit(a, n):
    literally(n)
    ^

Expected result:

/home/akalistr/repo/literally.py:30: DeprecationWarning: The current syntax for specification of kernel launch parameters is deprecated. Users should set the kernel parameters through Range/NdRange classes.
Example:
    from numba_dpex import Range,NdRange

    # for global range only
    <function>[Range(X,Y)](<parameters>)
    # or,
    # for both global and local ranges
    <function>[NdRange((X,Y), (P,Q))](<parameters>)
  lit[a.shape[0], ](a)
Result:  11.

AlexanderKalistratov avatar Jun 14 '23 15:06 AlexanderKalistratov

@ZzEeKkAa Can you please investigate the issue? I tried it with the new experimental.kernel instead of the kernel and see a potential issue with the experimental.launcher module:

Traceback (most recent call last):
  File "/home/diptorupd/Desktop/devel/numba-dpex/driver.py", line 26, in <module>
    ndpx_exp.call_kernel(lit, ndpx.Range(a.shape[0]), a, a.shape[0])
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 487, in _compile_for_args
    raise e
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 965, in compile
    cres = self._compiler.compile(args, return_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 125, in compile
    status, retval = self._compile_cached(args, return_type)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 139, in _compile_cached
    retval = self._compile_core(args, return_type)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 152, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 762, in compile_extra
    return pipeline.compile_extra(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 460, in compile_extra
    return self._compile_bytecode()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 528, in _compile_bytecode
    return self._compile_core()
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 503, in _compile_core
    raise e
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 494, in _compile_core
    pm.run(self.state)
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 364, in run
    raise e
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typed_passes.py", line 110, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
                                            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typed_passes.py", line 88, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 1078, in propagate
    errors = self.constraints.propagate(self)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 177, in propagate
    raise e
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 155, in propagate
    constraint(typeinfer)
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 578, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 607, in resolve
    folded = e.fold_arguments(folding_args, self.kws)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/Desktop/devel/numba-dpex/numba_dpex/experimental/kernel_dispatcher.py", line 345, in folded
    return self._compiler.fold_argument_types(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 118, in fold_argument_types
    args = fold_arguments(self.pysig, args, kws,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typing/templates.py", line 221, in fold_arguments
    ba = pysig.bind(*bind_args, **bind_kws)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/inspect.py", line 3212, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/inspect.py", line 3133, in _bind
    raise TypeError('too many positional arguments') from None
TypeError: too many positional arguments

It may be indicative of other potential issues in the launcher that need fixing.

diptorupd avatar Dec 10 '23 17:12 diptorupd

Supporting the numba.literally feature requires changes to the numba_dpex dispatcher and the numba_dpex kernel launch mechanisms.

The feature is implemented in Numba as follows:

  • If a numba.literally function call is encountered in a function and the argument to the numba.literally call is not a literal value, then an exception is raised.
  • Catch the exception and recompile the function that has the numba.literally call and try to infer the type of the argument eventually passed to the numba.literally call as a literal value.

The numba_dpex.experimental.kernel submission involves a two-step compilation process:

a) Compilation of an instance of the call_kernel kernel launcher function that involves type inference of all the kernel arguments passed to call_kernel. b) Compilation of the kernel function that was passed in to a call_kernel call. The argument signature for the kernel is generated based on the types inferred for the arguments passed to call_kernel.

The exception from the numba.literally call is encountered during the kernel compilation, but the type inference of the argument happens during the compilation of the call_kernel function. For this reason, supporting the feature will require significant design change in numba-dpex compilation pipeline.

diptorupd avatar Jan 22 '24 20:01 diptorupd