triton
triton copied to clipboard
[backend]Ensure coalescing and saturation for every memop in a slice.
While current memory coalescing heuristics works quite well, a couple opportunities may be missing:
- It uses the narrowest data type for the layouts of all memops in a slice which may break memory coalescing of wider memops. Actually, we have seen this issue (https://github.com/pytorch/pytorch/issues/120667) where a full vectorization of a narrower memop breaks memory coalescing of a wider memop. This occurs in a way that the wider memop will end up fulfilled by two consecutive memop ops, and each is not consecutive across threads. It seems that in general full memory coalescing should always be ensured (within or across threads). And on top of that, maximizing vectorization for each memop is a plus.
- All memops could end up non-vectorizable if any memop in the slice is non-vectorizable.
This patch tries to fix the issues by trading off between the performance of individual memop and overall performance. It takes bandwidth saturation and correct alignment for every memop into considering and tries to compute a unified layout among the memops while favoring vectorization of each memop as much as possible.
An overview of what the new heuristics does
- Compute a slice of same-shaped memory ops that have a dataflow connection.
- For each mem op in the slice, compute the
i. The maximum consecutive elements each thread (named
maxPerThread) should own that does not break cross-thread memory coalescing. ii. The minimum consecutive elements each thread (namedminPerThread) should own to saturate the memory bandwidth and avoid alignment issue. - The unified
perThreadshould be no bigger than the minimalmaxPerThreadand no smaller than the maximalminPerThread.
The new heuristics could also takes into account of the execution frequency of each memop to be more accurate, which is not included in this work.
cc @shunting314
Looks like a test is failing?
Looks like a test is failing?
Yes, I'm trying to repro it locally. Looks like an assert failure about the number of add instructions.
BTW, do you know which version of Pytorch is used in the lab?
BTW, do you know which version of Pytorch is used in the lab?
I don't know offhand, but it should all be (effectively) specified by the pip install and similar commands inside the .github folder. There's no "magic" or pre-installed stuff afaik.
This is changing the current memory coalescing heuristics which uses the narrowest data type for the layouts of all memops in a slice. This patch is changing it to use the widest data type instead.
If so, shouldn't we consider the number of bytes instead of the number of elements?
This is changing the current memory coalescing heuristics which uses the narrowest data type for the layouts of all memops in a slice. This patch is changing it to use the widest data type instead.
If so, shouldn't we consider the number of bytes instead of the number of elements?
I think that's what it does. Number of bytes for the data type is used to compute the number of elements that each thread should own consecutively (getNumElementsPerThread).
This is changing the current memory coalescing heuristics which uses the narrowest data type for the layouts of all memops in a slice. This patch is changing it to use the widest data type instead.
If so, shouldn't we consider the number of bytes instead of the number of elements?
I think that's what it does. Number of bytes for the data type is used to compute the number of elements that each thread should own consecutively (
getNumElementsPerThread).
Oh, I see it. getNumElementsPerThread actually refers to getNumBytesPerThread.
We'll have to to try this heuristic on our internal workloads. I wonder if for some sub-byte types this is going to be a problem, also I wonder if some cases where the we have 16bytes load in a loop and 32 bytes store outside this is going to be suboptimal. Were you able to try out such scenarios?
That's a good point. I haven't specifically checked for those cases. Also the number of memops in different dtypes may also matter.
@ThomasRaoux It looks like the test failure is due to an unhandled case in wgmma to LLVM lowering, where the accumulator was zero. The zero accumulator was introduced by a layout optimization triggered by this patch:
dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
This optimization seems good. I tried to fix the compiler error but then I hit the assert in the test which I need some help to understand:
https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/python/test/unit/language/test_core.py#L3090
BTW, have you got a chance to try this heuristric with your workloads?
@ThomasRaoux It looks like the test failure is due to an unhandled case in wgmma to LLVM lowering, where the accumulator was zero. The zero accumulator was introduced by a layout optimization triggered by this patch:
dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
I'm surprised we don't support zero accumulator, I'm pretty sure we have cases like that already. Where is the problem?
This optimization seems good. I tried to fix the compiler error but then I hit the assert in the test which I need some help to understand:
https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/python/test/unit/language/test_core.py#L3090
This is checking that we do the accumulation outside of the tensorcores. Not sure what would cause this.
BTW, have you got a chance to try this heuristric with your workloads?
Not yet, I can try it soon, however it feels like the heuristic would be suboptimal in some simple cases?
I'll try to run this tomorrow
@ThomasRaoux It looks like the test failure is due to an unhandled case in wgmma to LLVM lowering, where the accumulator was zero. The zero accumulator was introduced by a layout optimization triggered by this patch:
dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))I'm surprised we don't support zero accumulator, I'm pretty sure we have cases like that already. Where is the problem?
Basically I hit an AV when d is null because zeroAcc==true:
https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#L451
This optimization seems good. I tried to fix the compiler error but then I hit the assert in the test which I need some help to understand: https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/python/test/unit/language/test_core.py#L3090
This is checking that we do the accumulation outside of the tensorcores. Not sure what would cause this.
With my workaround, I was seeing h.asm["ptx"].count("add.f32") == 128 while the other values was 256.
I'll try to run this tomorrow
Thanks!
@ThomasRaoux It looks like the test failure is due to an unhandled case in wgmma to LLVM lowering, where the accumulator was zero. The zero accumulator was introduced by a layout optimization triggered by this patch:
dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))I'm surprised we don't support zero accumulator, I'm pretty sure we have cases like that already. Where is the problem?
Basically I hit an AV when
dis null becausezeroAcc==true:https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#L451
This optimization seems good. I tried to fix the compiler error but then I hit the assert in the test which I need some help to understand: https://github.com/openai/triton/blob/38cc733efd1262dc6c81a1862247c09e9d982350/python/test/unit/language/test_core.py#L3090
This is checking that we do the accumulation outside of the tensorcores. Not sure what would cause this.
With my workaround, I was seeing
h.asm["ptx"].count("add.f32") == 128while the other values was 256.
This should be fixed in the latest version.
Sorry for the delay, I did run this on our internal workloads non of the kernel improved and multiple kernels had large regressions. Since we cannot share our kernels I'll try to isolate few cases and describe the high level reasons why it makes things worse.
So one example should be easy to reproduce, for sub-bytes type (for example f16), if there are multiple loads/stores and one of them cannot be vectorized then all of them are not vectorized. And reading less than 4 bytes per thread would tank the perf. You should be able to reproduce that with a simple kernel. I assume the heuristic will have to consider sub-bytes types
Thanks a lot for measuring perf for this change. At this point it appears that the heuristics needs more tweaks.
So one example should be easy to reproduce, for sub-bytes type (for example f16), if there are multiple loads/stores and one of them cannot be vectorized then all of them are not vectorized. And reading less than 4 bytes per thread would tank the perf.
What is the rationale behind this? Is that because less-than-4-byte load per thread would result in a waste of a 128-byte memory transaction per warp?
What is the rationale behind this? Is that because less-than-4-byte load per thread would result in a waste of a 128-byte memory transaction per warp?
yes that's one way to see it. In general threads accessing unaligned addresses (not aligned on 4B) or accessing less than 4B of data per thread won't be able to saturate the bandwidth.
What is the rationale behind this? Is that because less-than-4-byte load per thread would result in a waste of a 128-byte memory transaction per warp?
yes that's one way to see it. In general threads accessing unaligned addresses (not aligned on 4B) or accessing less than 4B of data per thread won't be able to saturate the bandwidth.
Just send out a new iteration. I'm experimenting a new heuristics based on equivalent class. I'm making sure two things:
-
Ensure cross-thread coalescing for every memop. -
Ensure bandwidth saturation and correct alignment for every memop.
I'd like to get your thoughts on it. The code isn't polished much though.
What is the rationale behind this? Is that because less-than-4-byte load per thread would result in a waste of a 128-byte memory transaction per warp?
yes that's one way to see it. In general threads accessing unaligned addresses (not aligned on 4B) or accessing less than 4B of data per thread won't be able to saturate the bandwidth.
Just send out a new iteration. I'm experimenting a new heuristics based on equivalent class. I'm making sure two things:
Ensure cross-thread coalescing for every memop.Ensure bandwidth saturation and correct alignment for every memop.I'd like to get your thoughts on it. The code isn't polished much though.
I scan quickly through the code, could you add a high level comment explaining the heuristic as I'm not sure I fully understand from the code.
Ensure cross-thread coalescing for every memop.
what does that mean when there is a chain of memop? Do we just want to pay for convert everytime?
I scan quickly through the code, could you add a high level comment explaining the heuristic as I'm not sure I fully understand from the code.
Thanks for taking a lot and sorry about missing a high-level description. I'll add it.
what does that mean when there is a chain of memop? Do we just want to pay for convert everytime?
We try to compute a unified perThread without losing cross-thread coalescing for every memop. The vectorization of memop may not be guaranteed though.
Summary updated to reflect the new heuristic.
closing this PR due to no activity. Feel free to reopen.
@htyu @ThomasRaoux is the optimization still on the menu? I am just learning how to auto coalesce global access to SMEM (to make sure data load store continuously).
Is there an overview of the Coalesce transform cpp ? Could we illustrate the algorithm with the the classic example tranposeOp ?
@htyu @ThomasRaoux is the optimization still on the menu? I am just learning how to auto coalesce global access to SMEM (to make sure data load store continuously).
The specific optimization implemented by this PR doesn't seem to work well across the board, i.e, perf regression seen in some cases, so it's put on hold for now.
Is there an overview of the Coalesce transform cpp ? Could we illustrate the algorithm with the the classic example tranposeOp ?
I'm not sure about an overview anywhere, but if you have specific case you like to study, feel free to share. The transposeOP is a good example to start with.