Neil Girdhar

Results 329 comments of Neil Girdhar

> By the way, scan has a useful unroll option which will do this for you, so it should be easy to experiment: I know that `scan` has an unroll,...

Here are the screenshots you requested. Again, thanks for all your help! ![Trace Screenshot 1](https://user-images.githubusercontent.com/730137/122846180-4d323300-d2d3-11eb-9c5a-1cdfa9b40462.png) ![Trace Screenshot 2](https://user-images.githubusercontent.com/730137/122846181-4dcac980-d2d3-11eb-9026-9a2c2e23136b.png)

> adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation) That would definitely explain some...

So is there anything I can do finally to make my program run faster?

I hope you don't mind that I'm looking at this again. From what I undersatnd https://github.com/openai/triton tries to produce single GPU kernels. Is there any hope of JAX doing something...

> Did you saw this jax/triton project https://github.com/jax-ml/jax-triton? No, I had not! Thank you for sharing that. I was aware of Triton, so this is very exciting! > Also, what...

> The quickest path would be some manual kernel via custom ops or maybe via Triton. Do you agree with that? You know better than I do :smile: If you...

Thanks. I will look into that! Would I have access to Jax's automatic differentiation? Or would I need to do the differentiation myself and then implement that in CUDA?

First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that. However, I have...

> You do not need to do any memory allocation. You still use JAX. You write the custom kernel that XLA doesn't generate fast code for. I understand. What I'm...