numba
numba copied to clipboard
CUDA JIT errors when multiple signatures provided
With Numba's CPU JIT, one is able to supply multiple signatures (e.g. single and double precision floating values) for a function. However when trying the same with Numba's CUDA JIT, this results in an error. Limiting the CUDA JIT to a single signature fixes the issue. Code snippet demonstrating this below.
Example:
In [1]: from numba import jit
In [2]: from numba.cuda import jit as cuda_jit
In [3]: @jit(["f4(f4, f4)", "f8(f8, f8)"])
...: def f(a, b):
...: return a + b
...:
...:
In [4]: @cuda_jit(["f4(f4, f4)", "f8(f8, f8)"])
...: def g(a, b):
...: return a + b
...:
...:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-5a86fdfdf6d9> in <module>()
----> 1 @cuda_jit(["f4(f4, f4)", "f8(f8, f8)"])
2 def g(a, b):
3 return a + b
~/.conda/envs/numba/lib/python3.6/site-packages/numba/cuda/compiler.py in __call__(self, *args)
700 Specialize and invoke this kernel with *args*.
701 '''
--> 702 kernel = self.specialize(*args)
703 cfg = kernel[self.griddim, self.blockdim, self.stream, self.sharedmem]
704 cfg(*args)
~/.conda/envs/numba/lib/python3.6/site-packages/numba/cuda/compiler.py in specialize(self, *args)
710 '''
711 argtypes = tuple(
--> 712 [self.typingctx.resolve_argument_type(a) for a in args])
713 kernel = self.compile(argtypes)
714 return kernel
~/.conda/envs/numba/lib/python3.6/site-packages/numba/cuda/compiler.py in <listcomp>(.0)
710 '''
711 argtypes = tuple(
--> 712 [self.typingctx.resolve_argument_type(a) for a in args])
713 kernel = self.compile(argtypes)
714 return kernel
~/.conda/envs/numba/lib/python3.6/site-packages/numba/typing/context.py in resolve_argument_type(self, val)
288 ValueError is raised for unsupported types.
289 """
--> 290 return typeof(val, Purpose.argument)
291
292 def resolve_value_type(self, val):
~/.conda/envs/numba/lib/python3.6/site-packages/numba/typing/typeof.py in typeof(val, purpose)
31 msg = errors.termcolor.errmsg(
32 "cannot determine Numba type of %r") % (type(val),)
---> 33 raise ValueError(msg)
34 return ty
35
ValueError: cannot determine Numba type of <class 'function'>
Thanks for the report. It seems that this is the case, lists of signatures are not handled, we should either raise with a more specific error message or just fix the problem!
It's also worth noting from your example that @cuda.jit
functions are GPU kernel functions and should be void
by definition.
Also, might be worth looking at @vectorize
and @guvectorize
, both of which support the kwarg target='cuda'
(and multiple signatures). Here's a quick example of both ways:
from numba.cuda import jit as cuda_jit
from numba import cuda, vectorize
import numpy as np
@cuda_jit("void(f8[:], f8)")
def g(a, b):
tx = cuda.threadIdx.x
ty = cuda.blockIdx.x
bw = cuda.blockDim.x
pos = tx + ty * bw
if pos < a.size:
a[pos] += b
@vectorize(['f4(f4, f4)', 'f8(f8, f8)'], target='cuda')
def vec_g(a, b):
return a + b
x = np.arange(64.)
y = 17
g[1, 64](x, y)
print(x)
x = np.arange(64.)
print(vec_g(x, y))
It appears that there may be cases when kernels with signatures specified don't catch type errors as they should do, e.g.: https://github.com/numba/numba/pull/5388#issuecomment-603226418 - when looking into adding support for multiple signatures, type checking should also be checked / considered.
It would be really cool if this behavior of cuda.jit
was (explicitly) documented ;) I have just burned a couple of hours on this, thinking that I made an utterly stupid mistake, until I ended up digging into the sources and issues ... In the meantime, some more useful exceptions could also help. When passing a list of strings, cuda.jit
currently complains: [...] is not a callable object
... which kind of makes sense but is inconsistent with numba.jit
and does not tell the actual story: Only one signature allowed.
Just out of interest: What are the odds (or technical feasibility) for cuda.jit
to support multiple signatures just like numba.jit
?
Just out of interest: What are the odds (or technical feasibility) for
cuda.jit
to support multiple signatures just likenumba.jit
?
Hopefully much better than they were when this issue was first reported, now that the CUDA Dispatcher uses a lot more of the infrastructure that the CPU target does for dispatch. I'll look into this and see what the gaps are.
Looks like this is pretty doable, so I'm adding it to the 0.57 milestone. With the changes in my branch so far: https://github.com/gmarkall/numba/tree/issue-3226 (I just have one commit, https://github.com/gmarkall/numba/commit/b2f5a097a7701ac087379abc3f9670c52d876752, in it at the moment) the following:
from numba import cuda
import numpy as np
# Define a kernel with multiple signatures
sigs = [
'void(int32[::1], int32[::1])',
'void(float32[::1], float32[::1])'
]
@cuda.jit(sigs)
def f(r, x):
r[0] = x[0] * 2 + 1
# Create arguments for supported and unsupported types
x_int32 = np.ones(1, dtype=np.int32)
x_int64 = np.ones(1, dtype=np.int64)
x_float32 = np.ones(1, dtype=np.float32)
x_float64 = np.ones(1, dtype=np.float64)
r_int32 = np.zeros_like(x_int32)
r_int64 = np.zeros_like(x_int64)
r_float32 = np.zeros_like(x_float32)
r_float64 = np.zeros_like(x_float64)
# Execute with supported types and print results
f[1, 1](r_int32, x_int32)
f[1, 1](r_float32, x_float32)
print(r_int32)
print(r_float32)
# Demonstrate that unsupported types are rejected
def unsupported_call(r, x):
try:
f[1, 1](r, x)
except TypeError as te:
print(f'Got TypeError: {te.args[0]}')
unsupported_call(r_int64, x_int64)
unsupported_call(r_float64, x_float64)
produces:
[3]
[3.]
Got TypeError: No matching definition for argument type(s) array(int64, 1d, C), array(int64, 1d, C)
Got TypeError: No matching definition for argument type(s) array(float64, 1d, C), array(float64, 1d, C)
which is what I would hope to get.
This work still needs:
- [ ] ~~The example above formalizing into a test case~~ (decided to just use the relevant tests from the CPU dispatcher)
- [X] Other relevant test cases from the CPU dispatcher tests porting across
- [ ] Documentation updating
- [ ] Checking that device functions also behave correctly with multiple signatures
(edited to make the to-do list a checklist to track progress)