xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Rotary Embeddings + Triton

Open david-macleod opened this issue 2 years ago • 12 comments

🚀 Feature

There is a suggestion here that the next step for rotary embeddings are to use Triton. I would be keen to have a go at implementing this, or is this already work in progress?

Thanks!

david-macleod avatar Apr 09 '22 14:04 david-macleod

hi @david-macleod, seeing that a bit late, sorry about it and thanks for reaching out ! Nothing started that I know of, it would be great indeed, and good match with current PRs like #263 and #248 (a bit stuck by the testing system right now :( working on that but short on time).

I you get started on that, I would recommend:

  • writing a small benchmark + parity test as a very first step, gives a north star and you cannot really improve what you cannot measure. Bonus if you can express the benchmark in terms of bandwidth (this is probably bandwidth limited), this can tell us early whether this is a worthwile use of your time or not (if it's close to your GPU's bandwidth there's not much that triton will do, of course)

  • using triton2 (aka the version currently on the triton repo, there are pip packages available, that's the version we're moving towards

blefaudeux avatar Apr 13 '22 02:04 blefaudeux

Hi @blefaudeux, I will share some timings soon but it initially looks promising, primarily because torch.jit appears to be able to fuse apply_rotary_pos_emb into a single kernel for the non-autograd scenario (forward only), but fails to fuse the cat op in the autograd case. Additionally it fails to create a single fused kernel for NvFuser in half precision e.g. torch.jit.fuser("fuser2") even in the non-autograd case so there definitely seems to be some potential benefits from triton.

I have managed to get output parity for a batch size 1 setup, and was starting to look at wrapping the kernel based partially on sum_strided.py, but even after reading the tutorials I have a few questions as a Triton/CUDA newcomer which I hoped you could help with !

  • What are the heuristic here based on? As I understand it is related to the amount of shared memory for a single SM (or the size of the L2 cache) to avoid swapping shared memory during the lifetime of a single programme but am not entirely clearly sure how the ranges are computed, and if it depends on GPU arch
  • Related to the above, what is the relationship between Triton BLOCK_SIZE and the thread block size in CUDA. The latter seems to define the number of threads available to that block (degree of parallelism which has access to the same shared memory?) and is limited to 1024 on A100, and I thought the Triton block size was somewhat analogous but there are many examples of BLOCK_SIZE > 1024 in the Triton kernels, and it appears to not relate to the number of threads accessible by the Triton program, which is controlled instead by num_warps? Is this correct?

Thanks!

david-macleod avatar May 08 '22 20:05 david-macleod

Hi @blefaudeux, I will share some timings soon but it initially looks promising, primarily because torch.jit appears to be able to fuse apply_rotary_pos_emb into a single kernel for the non-autograd scenario (forward only), but fails to fuse the cat op in the autograd case. Additionally it fails to create a single fused kernel for NvFuser in half precision e.g. torch.jit.fuser("fuser2") even in the non-autograd case so there definitely seems to be some potential benefits from triton.

Hey David, sorry for the delay, sounds great ! Yes for torchscript, I'm not too surprised, but I thought that nvfuser would do a better job here, from a distance. Good to know

I have managed to get output parity for a batch size 1 setup, and was starting to look at wrapping the kernel based partially on sum_strided.py, but even after reading the tutorials I have a few questions as a Triton/CUDA newcomer which I hoped you could help with !

Oh great if you already have something at parity for a batch 1, scaling that should not be a super big issue ! Tentatively replying below

* What are the heuristic [here](https://github.com/facebookresearch/xformers/blob/main/xformers/triton/sum_strided.py#L38-L43) based on? As I understand it is related to the amount of shared memory for a single SM (or the size of the L2 cache) to avoid swapping shared memory during the lifetime of a single programme but am not entirely clearly sure how the ranges are computed, and if it depends on GPU arch

top of mind it came from a dropout or layernorm tutorial some time ago, but these have changed. In general this sum_strided kernel is most probably not optimal and I would stick to triton autotune nowadays. Else yes, related to the amount of cuda threads you keep in flight, the cost to swap them when waiting for IO and the parallelism you can extract from having several of them in flight (for instance by coalescing memory accesses on a strided dimension), in general for kernels this is multi factored, @ptillet could explain a lot more, maybe @ngimel also ?

As far as I'm concerned I can often find some logic on a given kernel, but cannot really master a general rule for these. Best settings will be dependent on the GPU arch and even GPU within an arch (empirically I find that 3080 laptop, 3080 desktop and A100 don't expose the same bottlenecks for the same kernels), that's where using autotune and triton Jit shines I think.

* Related to the above, what is the relationship between Triton BLOCK_SIZE and the thread block size in CUDA. The latter seems to define the number of threads available to that block (degree of parallelism which has access to the same shared memory?) and is limited to 1024 on A100, and I thought the Triton block size was somewhat analogous but there are many examples of BLOCK_SIZE > 1024 in the Triton kernels, and it appears to not relate to the number of threads accessible by the Triton program, which is controlled instead by num_warps? Is this correct?

(I've rewritten this a few times, this is my best tentative explanation :)) It's a super good question I think, thanks for asking that here. Note that I'm not super familiar with the Cuda concepts anymore (I used to code in CUDA but that was long enough ago that I forgot a lot of the specifics)

  • triton BLOCK_SIZE and thread block size in CUDA: not really related language wise, I think that it's better to think of the triton block_sizes in general as the size of the memory block the kernel(s) works on, triton does the actual dispatch in terms of threads. I'm guessing that Triton takes into account some of this when actually scheduling the work though, but it's not something semantically related AFAIK. Something which works for me is to think of the BLOCK_SIZE as a way for you to define the size of the work unit, in terms of read/intermediate/write, and defer a lot of the scheduling logic to triton through a couple of knobs (num_warps mostly, and else change the size of these compute "tiles" depending on the problem size and parallelism which can be found)

  • thread block size (as per this definition): that would be closer to the number of warps used indeed I think, with the relationship being that 1 warp == 32 threads, but (I may be stating something wrong here and I hope that @ptillet will catch me) -- I don't think that it's an exact match, since the kernels written in triton don't have to communicate with each other via shared memory. It's sometimes better if they can I suppose, but it's not mandatory -- maybe that triton adds some logic around that when kicking the kernels / grid size (use thread blocks which are not == warps * 32) -- it's not an explicit concept in Triton since you don't specify where to allocate memory anyway, to me it's part of the complexity that triton hides in an intermediate layer

blefaudeux avatar May 09 '22 14:05 blefaudeux

Yes, @blefaudeux is correct, so I don't have much to add and will just repeat differently :p

BLOCK_SIZE isn't a Triton keyword, it is just a constexpr for the tile sizes in the tutorial kernel. Having a BLOCK_SIZE=1024 in the case of the vecadd/norm tutorials just means that each program id will process 1024 elements.

Orthogonally, each Triton kernel is launch with a thread block size of 32*num_warps. It means that it you compile an aforementioned kernel with num_warps=4 -- the default value -- it will be launched with 128 threads, and each thread will be responsible for loading+adding+storing 8 elements.

ptillet avatar May 09 '22 15:05 ptillet

Thanks for the replies, that makes things a lot clearer! @ptillet why is that in the softmax tutorial the BLOCK_SIZE is set to be the next power of 2 greater than the required number of data to be loaded? Is there a reason we should not just set it to be a multiple of the number of thread per warp?

david-macleod avatar May 12 '22 20:05 david-macleod

Ha, that's a good question. A lot of kernels can still work when the block size is a multiple of 32*num_warps. But (1) we want num_warps to be purely an optimization, meaning that it shouldn't affect the specs of the language; (2) there are a few places where the compiler assumes that shapes are a power-of-two; (3) some future hardware architectures may not even have a concept of warps, and we want Triton to be as portable as possible.

Right now, we are aware that a lot of perf is lost to padding. This is an issue that I've mitigated lately by using smaller block sizes in conjunction with L1 residency control (see new 05-layer-norm.py tutorial on the triton repo).

Hope this answers your question.

ptillet avatar May 13 '22 04:05 ptillet

Thanks @ptillet! I was also wondering if currently Triton has the potential to slice (or otherwise chunk) tensors after they have been loaded into SRAM? Rotary embeds include an op which splits the tensor in half and applies an elementwise combination of the two parts.

My current workaround is to load each half of the tensor as separate inputs, and reuse where appropriate, and this seems fine for rotary embeds, but there is a similar story with activation functions like GEGLU which would be harder to inject into the existing matmul + optional activation Triton kernel due to the fact that we don't have the option to split the input trivially as the activation is simply glued onto the end of the existing kernel

david-macleod avatar Jun 01 '22 20:06 david-macleod

@david-macleod Unfortunately it doesn't right now :( It's been planned for a while but other things always ended up taking precedence. There is indeed a similar story with geglu, and it can lead to some troubles if you end up having things like

store(x + 2*off + 0, x_slice_0)
store(x + 2*off + 1, x_slice_1)

I'm hoping that, in this case, the compiler should soon at least be able to merge the two stores into a single, coalesced store operation.

Note that for load it's usually not a big deal as the L1 cache will take care of uncoalesced accesses.

ptillet avatar Jun 01 '22 21:06 ptillet

Thanks @ptillet, just getting round to looking at this again. In the example above what happens if the x_slice_1 is not divisible by the block size and we end up having to use a mask. Would the second store not potentially overwrite some of the output from the first call to store?

david-macleod avatar Aug 22 '22 07:08 david-macleod

Hmm, I don't understand why it would? x + 2*off + 0 and x + 2*off + 1 do not overlap?

ptillet avatar Aug 23 '22 00:08 ptillet

Ah sorry, I misinterpreted your proposal. Ok this makes sense, so we interleave the target memory addresses for each split of the tensor. Other than timing, is there any way to introspect whether the compiler has resolved that into a single coalesced write?

david-macleod avatar Aug 23 '22 06:08 david-macleod

Or thinking more about this

A lot of kernels can still work when the block size is a multiple of 32*num_warps.

As an interim solution until slicing is supported, if I could guarantee that the size of the vector I wanted to process was a multiple of the above, the implication would be that BLOCK_SIZE == num_elements in vector and so could avoid using a mask, which could mean that I could split the two parts of the tensor into two contiguous parts (not interleaved) and guarantee coalesced load/store without having to worry about overwriting, or is that not correct?

As a side note, does the triton project have a rough/vague roadmap for things like slicing, or AOT compilation for torchscript export into C++?

david-macleod avatar Aug 23 '22 06:08 david-macleod