warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] Non-contiguous tile with Any typing in wp.func

Open etaoxing opened this issue 7 months ago • 0 comments

Bug Description

Seem to run into Warp NVRTC compilation error when passing non-contiguous tiles (result from wp.tile_broadcast) into a wp.func with tile: Any typing.

tile_fn_error.log

from typing import Any
import warp as wp
import torch

DIM_BATCH = 2
DIM_IN = 10
SHIFTED = True

TILE_M = wp.constant(1)
TILE_K = wp.constant(DIM_IN)

@wp.func
def func2(tile: Any):
  return tile

@wp.kernel
def kernel(x: wp.array2d(dtype=wp.float32), y: wp.array2d(dtype=wp.float32)):
  i = wp.tid()
  a = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i * TILE_M, 0))
  if wp.static(SHIFTED):  # error occurs with conditional and when moved outside conditional
    a_max = wp.tile_broadcast(wp.tile_max(a), shape=(TILE_M, TILE_K))
    a_max2 = func2(a_max)
    a -= a_max2  # or a2 = a - a_max2

x, y = torch.rand((DIM_BATCH, DIM_IN)), torch.rand((DIM_BATCH, DIM_IN))
x, y = wp.from_torch(x), wp.from_torch(y)

kernel_dims = (DIM_BATCH,)
inputs, outputs = [x], [y]

wp.launch_tiled(
    kernel,
    dim=kernel_dims,
    inputs=inputs,
    outputs=outputs,
    device=x.device,
    block_dim=64,
)

System Information

On warp nightly 1.8.0.dev20250523, https://github.com/NVIDIA/warp/commit/df2df7de42f05fdde77d16261da1c61d34cdaa1f

etaoxing avatar May 24 '25 10:05 etaoxing