numba-scipy icon indicating copy to clipboard operation
numba-scipy copied to clipboard

Implement array-valued signatures

Open adeak opened this issue 3 years ago • 12 comments

As of https://github.com/numba/numba-scipy/pull/54 the simplest scalar calls to jitted special functions should work.

However there's no support yet for array-valued inputs:

import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

@njit
def jitted_j0(x):
    res = special.j0(x[0])  # works after PR #54
    # res = special.j0(x)  # breaks
    return res

print(jitted_j0(x))

This is not obviously a shortcoming, since looping in jitted functions should be alright. So this is just a mild suggestion to consider adding support for array-valued signatures. (This should probably be preceded with some benchmarks to see whether this would help anything performance-wise.)

adeak avatar Apr 16 '21 15:04 adeak

This is not obviously a shortcoming, since looping in jitted functions should be alright.

It's definitely a shortcoming, because the corresponding scipy.special functions that are being overloaded are ufuncs and do not have this limitation.

I would say that it doesn't render the library useless, though.

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

brandonwillard avatar Apr 16 '21 16:04 brandonwillard

Anyway, I took a quick shot at using numba.vectorize on the functions produced by choose_kernel, but numba.extending.overload does not like the returned type of a numba.vectorize-wrapped function. Is that expected?

do you have an example, perchance?

esc avatar Apr 16 '21 16:04 esc

modified   numba_scipy/special/overloads.py
@@ -10,7 +10,12 @@ def choose_kernel(name, all_signatures):
         for signature in all_signatures:
             if args == signature:
                 f = signatures.name_and_types_to_pointer[(name, *signature)]
-                return lambda *args: f(*args)
+
+                @numba.vectorize
+                def _f(*args):
+                    return f(*args)
+
+                return _f
 
     return choice_function

results in the following error:

E   numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E   No implementation of function Function(<ufunc 'agm'>) found for signature:
E    
E    >>> agm(float64, float64)
E    
E   There are 2 candidate implementations:
E     - Of which 2 did not match due to:
E     Overload in function 'choose_kernel.<locals>.choice_function': File: ../code/python/numba-scipy/numba_scipy/special/overloads.py: Line 9.
E       With argument(s): '(float64, float64)':
E      Rejected as the implementation raised a specific error:
E        AssertionError: Implementator function returned by `@overload` has an unexpected type.  Got <numba._DUFunc '_f'>
E     raised from ~/envs/numba-scipy-env/lib/python3.7/site-packages/numba/core/typing/templates.py:742
E   
E   During: resolving callee type: Function(<ufunc 'agm'>)
E   During: typing of call at ~/code/python/numba-scipy/numba_scipy/tests/test_special.py (76)
E   
E   
E   File "numba_scipy/tests/test_special.py", line 76:
E       def numba_func(*args):
E           return scipy_func(*args)
E           ^

Is numba.extending.overload attempting to numba.jit the function returned by choose_kernel? The error looks similar to the one produced when numba.njit-ing a function wrapped with numba.vectorize.

brandonwillard avatar Apr 16 '21 17:04 brandonwillard

The varargs could also be a problem.

brandonwillard avatar Apr 16 '21 17:04 brandonwillard

I have a hack to get this working in my vectorized-overloads branch. It creates a fixed-arguments function on the fly to get past some apparent varargs issues with numba.vectorize.

If anyone knows how to get past this varargs issue without creating functions in this fashion—or any other fundamentally AST-based approach—please tell me, it would really help with the work we're doing in Aesara, as well.

brandonwillard avatar Apr 17 '21 00:04 brandonwillard

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))

stuartarchibald avatar Apr 20 '21 08:04 stuartarchibald

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

brandonwillard avatar Apr 20 '21 16:04 brandonwillard

The issue I ran into above is the signature for the @vectorized function: varargs wouldn't work, so I had to construct the function via compile/AST.

Ah, I see, I misinterpreted this as not being able to register an overload with vectorize, and whilst that's a problem, I can see why *args failing is also a problem if you want to do that automatic generation!

Opened https://github.com/numba/numba/issues/6954 to track.

stuartarchibald avatar Apr 20 '21 17:04 stuartarchibald

Opened numba/numba#6954 to track.

Thanks for that; it's a problem that shows up in at least a couple places where we're trying to use Numba as a backend (e.g. here).

brandonwillard avatar Apr 20 '21 18:04 brandonwillard

Hello, I have been able of using the workaround by @stuartarchibald . Is there any plan add this so there is no need to write the vectorized version of every function?

PabloRdrRbl avatar Jun 17 '21 15:06 PabloRdrRbl

@PabloRdrRbl I think a PR has already been opened: https://github.com/numba/numba-scipy/pull/58

esc avatar Jun 17 '21 15:06 esc

There's no public extension API in Numba for declaring this in a simple manner, this sort of thing could be a work around.

from numba import njit, vectorize, types
from numba.extending import overload
import numpy as np
from numba import njit
from scipy import special

x = np.linspace(-10, 10, 1000)

# this is just a dummy scalar function cf. those in numba-scipy's wrapper for
# scipy.special.*, now #54 is in the standard overload for scalar j0 should
# just work.
@njit
def pretend_j0_from_cython(x):
    return x + 12.34

@vectorize
def vectorize_j0(x):
    return pretend_j0_from_cython(x)

# This gets the vectorization mechanics but will end up "hiding" the NumPy ufunc
@overload(special.j0)
def ol_beta(x):
    if isinstance(x, (types.Array, types.Number)):
        def impl(x):
            return vectorize_j0(x)
        return impl

@njit
def jitted_j0(x):
    res1 = special.j0(x[0])
    res2 = special.j0(x)
    return res1, res2

print(jitted_j0(x))

Is it possible to extend it to a function like jv, which takes two arguments?

PabloRdrRbl avatar Mar 10 '22 11:03 PabloRdrRbl