[FRONTEND][BACKEND] Add `tl.atomic_load` and `tl.atomic_store`
Currently triton only has read-modify-write style atomic operations, but for many communication strategies you don't need the full power of RMW; plain load and store may be enough if they have strong memory consistency guarantees.
Here I've split this out from normal weak load and store both in the frontend and backend because I think they require very different thinking both from the compiler and the programmer. I've also intentionally limited this to scalars at the moment since the memory consistency is much simpler to reason about and I suspect scalars are the main use case anyway.
It seems making sense to me since C++ atomic also has specialized load and store implementations
Currently triton only has read-modify-write style atomic operations, but for many communication strategies you don't need the full power of RMW; plain load and store may be enough if they have strong memory consistency guarantees.
- could we have
semparam to load/store similar to e.g., atomic_add to enforce ordering? - couldn't we do atomic loads with atomic_add(Val, 0) ?
could we have sem param to load/store similar to e.g., atomic_add to enforce ordering?
That is one option, but I prefer this form for a number of reasons:
- The
atomic_name makes a clear signal to the reader that this code requires deep scrutiny. Wheresem="relaxed"is not especially clear and could easily blend into other load/store ops with various kwargs. - It means there is one section in the documentation for all atomic operations
Also having it be it's own operation in the IR is good because many of the optimizations on LoadOp and StoreOp aren't valid for atomics, so you would have to add if (sem != WEAK) { return mlir::failure(); } absolutely everywhere.
couldn't we do atomic loads with
atomic_add(Val, 0)?
Yes I've been using this as a workaround. Also atomic_xchg works as a poor man's store. It would be nicer to have first-class support in the language though.
- It's cryptic in already hard to reason about atomic code
- It supports fewer dtypes
- It unnecessarily reads back the old value from global memory, which may effect performance
BTW I did confirm that with these workarounds the generated sass does use ATOMG instructions instead of the equivalent LDG/STG instruction, even though the read value isn't used.
Could we have atomic_load and atomic_store at the Triton language level, lower this to atomic_add and atomic_xchg, and then pattern-match in the codegen to generate something more optimized that owuldn't use atomg?
I am very wary of growing Triton-IR
For some additional context, I was hoping to extend this to support vectorized loads and stores. So I could write a pair of values in a single atomic operation. For this atomic read-modify-write operations are only atomic per element in the vector, according to the ptx documentation, whereas loads and stores can handle the entire vector in one atomic operation.
This is necessary for example in cub's single kernel device scan implementation where each thread communicates via a "status word" which is twice the size of the data type being scanned. For 32-bit types this is fine and I get great performance by packing the values into a uint64, however for 64-bit datatypes it's not possible. This is a real bummer because torch.cumsum promotes all integer types to int64, so we wouldn't be able to use this for any indexing calculations which is the main place where cumsum comes up at the moment. The equivalent code using acquire-release semantics instead of relaxed atomics is also much much slower so not a viable solution.
I see. That's an interesting (and good) point. Wouldn't atomic_load on tensor end up leaking abstractions though? i.e. things would only be atomic at the granularity the compiler decides to vectorize, not the whole Triton tensor?
I thought about it and I think I am supportive of this work. We just need to make sure that atomic_load works at the level of a tensor rather than the granularity of the underlying ptx instructions
So from the above, I'm taking that we should support tensors of size 1, 2 or 4 elements where all the elements will be loaded/stored from a single thread. I'd like to propose an alternative, where you pass a group of 1, 2, or 4 scalar tensors to tl.atomic_store and they will be stored in one vectorized chunk. Correspondingly, you would get 1, 2, or 4 return values from tl.atomic_load.
e.g. something like
block_sum = tl.sum(data)
flag = tl.full(block_sum.shape, 1, tl.int32).to(tl.float32, bitcast=True)
tl.atomic_store(workspace + xid, (block_sum, flag))
sum, flag = tl.atomic_load(workspace + xid - 1, vec_size=2)
flag = flag.to(tl.int32, bitcast=True)
This has the major advantage that we can do processing on each packed element separately which is required for the use case I outlined above. As a side note, this would also be useful for plain load/store as well e.g. for implementing complex numbers. You load/store the real, imag pair in one operation but get them in separate tensors at the program level.
A further generalization of the concept would be to support a tl.struct type,
pack_ty = tl.struct(tl.float32, tl.int32)
pack = tl.pack(pack_ty, block_sum, flag)
tl.atomic_store(workspace + xid, pack)
but that would probably be a much larger change.
I do like the first proposal of explicit vectorization in the argument to atomic_load :)
Hey, I stumbled on this PR as I also have a need to atomic_load and atomic_store in Triton! Just like @ptillet suggested, we've been using atomic_cas to implement loads and atomic_xchg to implement store but, apart from readability concerns, these will stop working for us because we need to move to uint8 dtypes. (Not sure if this is a limitation of Triton or of CUDA?)
We're resorting to re-implementing atomic_load and atomic_store using inline_asm_elementwise but it isn't pretty.
I haven't fully followed the discussion about tensor-wide atomics and vectorization, but in any case we only need support for scalars. Is there a chance that this PR will land shortly? Thanks!
I'd be happy to rebase if @ptillet is okay with adding tl.atomic_{load,store} as scalar-only for now.