warp
warp copied to clipboard
[REQ] Allow generic kernel overload resolution with foreign arrays
Description
Passing foreign arrays to Warp kernels works by parsing the cuda_array_interface. However, calling generic kernels with foreign array arguments fails to resolve the concrete overload.
Repro:
from typing import Any
import cupy as cp
import warp as wp
wp.init()
TILE_THREADS = 64
n_segments = 32 * 4096
segment_size = 512
TILE_SIZE = wp.constant(segment_size)
n_elements = n_segments * segment_size
@wp.kernel
def seqmented_reduce_kernel(in_arr: wp.array2d(dtype=Any), out_arr: wp.array1d(dtype=Any)):
# obtain our block index
i = wp.tid()
# load a row from global memory
t = wp.tile_load(in_arr[i], TILE_SIZE)
# cooperatively compute the sum of the tile elements; s is a single element tile
s = wp.tile_sum(t)
# store s in global memory
wp.tile_store(out_arr, s, i)
arr = cp.arange(1, n_elements+1, dtype=cp.float32).reshape(n_segments, segment_size)
out = cp.empty(n_segments, dtype=arr.dtype)
wp.launch_tiled(kernel=seqmented_reduce_kernel,
dim=(n_segments,),
inputs=(arr, out), # <--- fail
block_dim=TILE_THREADS,
device="cuda")
assert cp.allclose(out.reshape((n_segments,)), arr.sum(axis=-1)), f"{out}, {arr.sum(axis=-1)}"