numba-scipy
numba-scipy copied to clipboard
scipy.ndimage (.distance_transform_edt) support?
Hey, I'm using numba to speed up my project and the only slow function is the one with scipy.ndimage.distance_transform_edt
inside, Which numba 0.54.1 doesn't support... sadly I'm stuck with this version atm cause of dependency issues with the numba package and Python 3.9.7 I get from Anaconda (+ gdal depending packages). So I decided to search for a faster implementation but couldn't find any.
Now I'm thinking about making a faster implementation on my own which works for numba.
This numba-scipy package does only support scipy.special
as far as I can see?
The documentation of the normal numba package says scipy
speed up through np.linalg.*
support, but I guess the newest version 0.55.1 doesn't support scipy.ndimage either? (didn't find its mentioned in the numba documentation supported list)
If I manage to implement and create an overload I'd like to share it here. Or is this the wrong package and the overload should go into the normal numba package?
Thanks for your help and numba (support) :)
Ok, I didn't know there was an alternative already implemented in opencv called distanceTransform.
With the parameters cv2.distanceTransform(<array>.astype(np.uint8), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_PRECISE, dstType=cv2.CV_32F)
its the same calculation as scipy.ndimage.distance_transform_edt(<array>)
does.
That one is written in C++ and actually pretty fast so we might not need a numba variant... For comparision:
array = np.array([[0, 0, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 1, 1],
[1, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0],
[1, 1, 0, 0, 1, 0]])
dist_array_cv2 = cv2.distanceTransform(array.astype(np.uint8), distanceType=2, maskSize=0, dstType=5)
dist_array_scipy = ndimage.distance_transform_edt(array)
if np.allclose(dist_array_cv2, dist_array_scipy):
print("same content")
print(timeit.timeit("cv2.distanceTransform(array.astype(np.uint8), distanceType=2, maskSize=0, dstType=5)", globals=globals(), number=10000))
print(timeit.timeit("ndimage.distance_transform_edt(array)", globals=globals(), number=10000))
prints me:
same content 0.0344644999999999 0.4173283000000001
Not sure if numba could beat that...