jax
jax copied to clipboard
Update CUDA custom call example code to use `ffi_call`
Following up on https://github.com/google/jax/pull/21925, we can update the example code in docs/cuda_custom_call to use ffi_call instead of manually registering core.Primitives. This removes quite a bit of boilerplate and doesn't require direct use of MLIR. This is meant as a demonstration of how ffi_call can be used for a common use case.