tvm
tvm copied to clipboard
[Target] Support CUDA device function calls
This commit adds support for CUDA device function calls by:
- Modifying the calling convention handling in CUDA codegen to support both device kernel launches and device function calls
- Updating the function signature printing to emit appropriate CUDA attributes (global vs device) based on calling convention
- Adding a test case demonstrating device function calls
- Fixing target handling in split_host_device_mods to properly handle device function dictionaries
- Adding a safety check for global symbol extraction
The changes enable proper compilation and execution of CUDA device functions that can be called from CUDA kernels.
Example:
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])
super helpful enhancement, Thanks!
after reading the comments so far on host/device function info split and the compiler phases:
- S0: In the beginning(before SplitHostDevice), we don't distinguish host/device function, a function can contain kernels
- S1: The host/device function split becomes clear after the SplitHostDevice pass. currently in the case of single device launch:
- global kernel are annotated as DeviceKernelLaunch calling conv
- host ones are annotated as others
After we enable the compiler to handle device function, one thing we first need to ensure is what is the behavior after S1. Would be useful to clarify in the PR with comments.
Summarizing the logic so far:
- Before S0 seems the decision is to not distinguish between host/device function and implicit
- Such distinction should become clear after S1, by checking the target annotation of each function that marks the default convention.
Here is an example of such case:
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.float32, b: T.float32) -> T.float32:
return a + b
@T.prim_func
def main(
A: T.Buffer((1024, 1024), "float32"),
B: T.Buffer((1024, 1024), "float32"),
C: T.Buffer((1024, 1024), "float32"),
):
# bound temp var in host side
temp_var = T.float32()
with T.LetStmt(
Module.add(T.float32(1), T.float32(2))
var=temp_var,
):
for bx in T.thread_binding(1024, "blockIdx.x"):
for tx in T.thread_binding(1024, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) + temp_var
Because of the implicitness, we may need to cross check the current bebehavior of SplitHostDevice, for rare cases where say both host and device calls the same function: in such cases we may either - S0a: place constraint and report an error - S0b: have SplitHostDevice pass manually duplicate such function and mark the target
In both cases, would be good to enhance splithostdevice testcases to ensure target field is clear after S1
Thanks @tqchen and @Kathryn-cat 's valuable comments, will refactor the PR to enhance SplitHostDevice systematically.
@tqchen @Kathryn-cat I've updated a version to detect if a function is from hthe ost side and device side at an early stage (in BindTarget pass), here is an example of a mixture call to the same function
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
Please review again when you have time
@Hzfengsy LGTM! Just added a small comment and I think we're good to go.