jax icon indicating copy to clipboard operation
jax copied to clipboard

Update CUDA custom call example code to use `ffi_call`

Open dfm opened this issue 1 year ago • 0 comments

Following up on https://github.com/google/jax/pull/21925, we can update the example code in docs/cuda_custom_call to use ffi_call instead of manually registering core.Primitives. This removes quite a bit of boilerplate and doesn't require direct use of MLIR. This is meant as a demonstration of how ffi_call can be used for a common use case.

dfm avatar Jun 27 '24 13:06 dfm