warp icon indicating copy to clipboard operation
warp copied to clipboard

[REQ] Allow generic kernel overload resolution with foreign arrays

Open nvlukasz opened this issue 9 months ago • 0 comments

Description

Passing foreign arrays to Warp kernels works by parsing the cuda_array_interface. However, calling generic kernels with foreign array arguments fails to resolve the concrete overload.

Repro:

from typing import Any
import cupy as cp
import warp as wp

wp.init()

TILE_THREADS = 64
n_segments = 32 * 4096
segment_size = 512

TILE_SIZE = wp.constant(segment_size)
n_elements = n_segments * segment_size

@wp.kernel
def seqmented_reduce_kernel(in_arr: wp.array2d(dtype=Any), out_arr: wp.array1d(dtype=Any)):
    # obtain our block index
    i = wp.tid()

    # load a row from global memory
    t = wp.tile_load(in_arr[i], TILE_SIZE)

    # cooperatively compute the sum of the tile elements; s is a single element tile
    s = wp.tile_sum(t)

    # store s in global memory
    wp.tile_store(out_arr, s, i)


arr = cp.arange(1, n_elements+1, dtype=cp.float32).reshape(n_segments, segment_size)
out = cp.empty(n_segments, dtype=arr.dtype)
wp.launch_tiled(kernel=seqmented_reduce_kernel,
                dim=(n_segments,),
                inputs=(arr, out),      # <--- fail
                block_dim=TILE_THREADS,
                device="cuda")

assert cp.allclose(out.reshape((n_segments,)), arr.sum(axis=-1)), f"{out}, {arr.sum(axis=-1)}"

Context

nvlukasz avatar Mar 09 '25 16:03 nvlukasz