numba-dpex
numba-dpex copied to clipboard
Failure in pairwise distance numpy implementation
The pairwise distance numba implementation with numpy calls fails since numba-dpex does not currently support dpnp.sum calls with non-default axis. See failing code snippet below.
@dpjit
def pairwise_distance(X1, X2, D):
x1 = np.sum(np.square(X1), axis=1)
x2 = np.sum(np.square(X2), axis=1)
np.dot(X1, X2.T, D)
D *= -2
x3 = x1.reshape(x1.size, 1)
np.add(D, x3, D)
np.add(D, x2, D)
np.sqrt(D, D)
How to reproduce:
Follow instructions to setup dpbench. Run pairwise distance - python -c "import dpbench; dpbench.run_benchmark("pairwise_distance")"
Duplicate of #784
I see that there are more issues than just sum
. Let us reopen