Frédéric Bastien
Frédéric Bastien
You don't have a normal setup. There isn't doc for it. Some of the packages above are provided by the cuda sdk. But cudnn and nccl isn't. Can you try...
You probably installed too many packages. I don't think this is an issue. Should we close this issue?
Instead of using cudaMalloc/cudaFree, did you look at CUDA stream ordered allocation: (cudaMallocAsync/cudaMallocFree) The free isn't blocking. A blog post about this: https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/ Also, allocation and free are much faster.
Hi, There isn't any easy way to do this right now. To be able to use it in JAX *without* jit, you could make a JAX primitive in Python. But...
I don't think it is trivial. But it could be mechanical. That part of the code isn't simple. You can base yourself on the jax-triton code. This jax custom operation...