taichi icon indicating copy to clipboard operation
taichi copied to clipboard

Fusing Taichi with JAX

Open chaoming0625 opened this issue 2 years ago • 9 comments

We have already seen some examples which can use Taichi as a part of the PyTorch program. For example,

  • https://github.com/ailzhang/blog_code/blob/master/rwkv/benchmark.py
  • https://github.com/ifsheldon/stannum

However, is it possible to integrate Taichi into JAX?

Taichi is able to generate highly optimized operators, and it is very suitable to implement operators involving sparse computations. If Taichi kernels can be used in a JAX program, it will be interesting for broad programmers.

I think the key to the integration is the address of the compiled kernel in Taichi. There are examples that launch a GPU kernel compiled by Triton in JAX. Maybe it is straightforward for Taichi too.

chaoming0625 avatar Oct 18 '22 11:10 chaoming0625

@chaoming0625 I'm by no means a JAX expert so my guess could be wrong. IIUC JAX device arrays don't give a raw ptr to storage in memory as PyTorch does, which making a torch-like integration (zero-copy) with Taichi for JAX kinda hard. Then if you have to copy the device array from JAX, copying it to numpy arrays or torch tensors so that Taichi can operate on those pretty efficiently, this could be a possible way to workaround?

Note taichi's sparse computation requires a specific datalayout (depending on your snode structure) in a root buffer managed by Taichi, dense numpy arrays/torch tensors are still the recommended way to interact with other librarys for those sparse fields.

ailzhang avatar Oct 28 '22 05:10 ailzhang

Dear @ailzhang , one way to interoperate JAX data with Taichi is using dlpack:

import jax.dlpack
import torch

def j2t(x_jax):
  x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
  return x_torch

def t2j(x_torch):
  x_torch = x_torch.contiguous()
  x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
  return x_jax

This could make a zero-copy from JAX data to PyTorch tensor. This PyTorch tensor can then be used in Taichi kenerls. Finally, the tensors returned from the Taichi kernel can also be zero-copied to JAX.

I think this may be one possible solution.

chaoming0625 avatar Oct 29 '22 00:10 chaoming0625

We are just wondering where can get the address of the Taichi compiled kernels. Thanks.

chaoming0625 avatar Oct 30 '22 01:10 chaoming0625

@chaoming0625 Sounds good! Taichi ndarrays are just contiguous memory so it should be pretty straightforward to support dl_pack format (although it doesn't yet). Taichi compiled kernels are https://github.com/taichi-dev/taichi/blob/master/python/taichi/lang/kernel_impl.py#L574.

ailzhang avatar Oct 31 '22 01:10 ailzhang

Dear @ailzhang , that's wonderful. Thanks very much!

chaoming0625 avatar Nov 01 '22 01:11 chaoming0625

Hi @ailzhang

I just wanted to ask if there is any update on this issue or an alternative solution to @chaoming0625's. Do you plan to implement support for jax arrays via taichi.ndarray as it was done for pytorch?

salykova avatar May 09 '23 08:05 salykova

Also curious about this, since I'd like to use some packages written for Jax (numpyro specifically) and try out the taichi ad system.

maedoc avatar Oct 11 '23 16:10 maedoc

As further motivation, I would love to be able to tap into these JAX projects with Taichi:

  • https://github.com/RobertTLange/evosax
  • https://github.com/google/evojax

jarmitage avatar Dec 01 '23 19:12 jarmitage

See examples in https://github.com/brainpy/BrainPy/pull/553

chaoming0625 avatar Dec 12 '23 09:12 chaoming0625