xla icon indicating copy to clipboard operation
xla copied to clipboard

[Pallas] Introduce make_kernel_from_pallas

Open alanwaketan opened this issue 11 months ago • 1 comments

Summary: This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan: python test/test_pallas.py

alanwaketan avatar Mar 11 '24 19:03 alanwaketan

Do you need this pr in 2.3?

Yea, will also need a couple for the TODOs.

alanwaketan avatar Mar 11 '24 19:03 alanwaketan

Can I get any reviews?

alanwaketan avatar Mar 12 '24 23:03 alanwaketan

I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.

Yea, for sure. Let me follow up with that.

alanwaketan avatar Mar 13 '24 00:03 alanwaketan