jax
jax copied to clipboard
[WIP] Add custom call primitive interface
Not yet ready for review.
This is a sketch of an interface to abstract away some of the boiler plate required when defining primitives that get lowered to custom calls. There are a few design decisions to consider more carefully, and my goal is to balance supporting a wide range of use cases, without over-committing to a huge public API.
I note that much of this implementation comes from @superbobry.