numba-dpex
numba-dpex copied to clipboard
numba.literally doesn't work in kernels
Reproducer:
import dpnp as np
import numba_dpex as ndpx
from numba import literally
@ndpx.kernel
def lit(a, n):
literally(n)
local_a_0 = ndpx.private.array(n, dtype=a.dtype)
lid = ndpx.get_local_id(0)
if lid == 0:
result = a.dtype.type(0)
for i in range(n):
local_a_0[i] = a[i]
for i in range(n):
a[0] += local_a_0[i]
a = np.ones(10, dtype='float32')
lit[a.shape[0], ](a, a.shape[0])
print("Result: ", a[0])
Result:
/home/akalistr/repo/literally.py:28: DeprecationWarning: The current syntax for specification of kernel launch parameters is deprecated. Users should set the kernel parameters through Range/NdRange classes.
Example:
from numba_dpex import Range,NdRange
# for global range only
<function>[Range(X,Y)](<parameters>)
# or,
# for both global and local ranges
<function>[NdRange((X,Y), (P,Q))](<parameters>)
lit[a.shape[0], ](a, a.shape[0])
Traceback (most recent call last):
File "/home/akalistr/repo/literally.py", line 28, in <module>
lit[a.shape[0], ](a, a.shape[0])
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/dispatcher.py", line 455, in __call__
) = self._compile_and_cache(
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/dispatcher.py", line 141, in _compile_and_cache
kernel.compile(
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/kernel_interface/spirv_kernel.py", line 123, in compile
cres = compile_with_dpex(
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba_dpex/core/compiler.py", line 64, in compile_with_dpex
cres = compiler.compile_extra(
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 742, in compile_extra
return pipeline.compile_extra(func)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 460, in compile_extra
return self._compile_bytecode()
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 528, in _compile_bytecode
return self._compile_core()
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 507, in _compile_core
raise e
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler.py", line 494, in _compile_core
pm.run(self.state)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 368, in run
raise patched_exception
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 356, in run
self._runPass(idx, pass_inst, state)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
mutated |= check(pss.run_pass, internal_state)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/compiler_machinery.py", line 273, in check
mangled = func(compiler_state)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/untyped_passes.py", line 426, in run_pass
find_literally_calls(state.func_ir, state.args)
File "/home/akalistr/miniconda3/envs/dpbench-dev/lib/python3.10/site-packages/numba/core/analysis.py", line 696, in find_literally_calls
raise errors.ForceLiteralArg(marked_args, loc=loc)
numba.core.errors.ForceLiteralArg: Failed in dpex_kernel_nopython mode pipeline (step: find literally calls)
Pseudo-exception to force literal arguments in the dispatcher
File "literally.py", line 13:
def lit(a, n):
literally(n)
^
Expected result:
/home/akalistr/repo/literally.py:30: DeprecationWarning: The current syntax for specification of kernel launch parameters is deprecated. Users should set the kernel parameters through Range/NdRange classes.
Example:
from numba_dpex import Range,NdRange
# for global range only
<function>[Range(X,Y)](<parameters>)
# or,
# for both global and local ranges
<function>[NdRange((X,Y), (P,Q))](<parameters>)
lit[a.shape[0], ](a)
Result: 11.
@ZzEeKkAa Can you please investigate the issue? I tried it with the new experimental.kernel
instead of the kernel
and see a potential issue with the experimental.launcher
module:
Traceback (most recent call last):
File "/home/diptorupd/Desktop/devel/numba-dpex/driver.py", line 26, in <module>
ndpx_exp.call_kernel(lit, ndpx.Range(a.shape[0]), a, a.shape[0])
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 487, in _compile_for_args
raise e
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
return_val = self.compile(tuple(argtypes))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 965, in compile
cres = self._compiler.compile(args, return_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 125, in compile
status, retval = self._compile_cached(args, return_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 139, in _compile_cached
retval = self._compile_core(args, return_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 152, in _compile_core
cres = compiler.compile_extra(self.targetdescr.typing_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 762, in compile_extra
return pipeline.compile_extra(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 460, in compile_extra
return self._compile_bytecode()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 528, in _compile_bytecode
return self._compile_core()
^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 503, in _compile_core
raise e
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler.py", line 494, in _compile_core
pm.run(self.state)
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 364, in run
raise e
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 356, in run
self._runPass(idx, pass_inst, state)
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
mutated |= check(pss.run_pass, internal_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/compiler_machinery.py", line 273, in check
mangled = func(compiler_state)
^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typed_passes.py", line 110, in run_pass
typemap, return_type, calltypes, errs = type_inference_stage(
^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typed_passes.py", line 88, in type_inference_stage
errs = infer.propagate(raise_errors=raise_errors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 1078, in propagate
errors = self.constraints.propagate(self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 177, in propagate
raise e
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 155, in propagate
constraint(typeinfer)
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 578, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typeinfer.py", line 607, in resolve
folded = e.fold_arguments(folding_args, self.kws)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/Desktop/devel/numba-dpex/numba_dpex/experimental/kernel_dispatcher.py", line 345, in folded
return self._compiler.fold_argument_types(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/dispatcher.py", line 118, in fold_argument_types
args = fold_arguments(self.pysig, args, kws,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/site-packages/numba/core/typing/templates.py", line 221, in fold_arguments
ba = pysig.bind(*bind_args, **bind_kws)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/inspect.py", line 3212, in bind
return self._bind(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/diptorupd/miniconda3/envs/dpex-devel/lib/python3.11/inspect.py", line 3133, in _bind
raise TypeError('too many positional arguments') from None
TypeError: too many positional arguments
It may be indicative of other potential issues in the launcher
that need fixing.
Supporting the numba.literally
feature requires changes to the numba_dpex dispatcher and the numba_dpex kernel launch mechanisms.
The feature is implemented in Numba as follows:
- If a
numba.literally
function call is encountered in a function and the argument to thenumba.literally
call is not a literal value, then an exception is raised. - Catch the exception and recompile the function that has the
numba.literally
call and try to infer the type of the argument eventually passed to thenumba.literally
call as a literal value.
The numba_dpex.experimental.kernel submission involves a two-step compilation process:
a) Compilation of an instance of the call_kernel
kernel launcher function that involves type inference of all the kernel arguments passed to call_kernel
.
b) Compilation of the kernel function that was passed in to a call_kernel
call. The argument signature for the kernel is generated based on the types inferred for the arguments passed to call_kernel
.
The exception from the numba.literally
call is encountered during the kernel compilation, but the type inference of the argument happens during the compilation of the call_kernel
function. For this reason, supporting the feature will require significant design change in numba-dpex compilation pipeline.