xla
xla copied to clipboard
[Pallas] Introduce make_kernel_from_pallas
Summary: This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.
Test Plan: python test/test_pallas.py
Do you need this pr in 2.3?
Yea, will also need a couple for the TODOs.
Can I get any reviews?
I still think we should refactor
convert_torch_dtype_to_jax
and invesgate bf16(which I assume most people will use), approve to unblock.
Yea, for sure. Let me follow up with that.