triton
triton copied to clipboard
triton generates unnecessary shared memory stores/loads
For the following triton kernels generated by pytorch, triton generated shared memory stores and loads in the LLVM IR and PTX just before the atomic add operation.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_isuruf/5e/c5ehw64oxeoeqqjnqn6v3gfy6z5ukksktwihp7jgzg6sujz5umto.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 8750000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_isuruf/qg/cqgmsmgdzivumf2gmksclwbmyrwpfpouuv3s5suqkeg4j4cdmpjr.py
# Source Nodes: [], Original ATen: []
triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[67108864],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 35000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 1000) % 1000
x0 = xindex % 1000
x3 = xindex
x2 = (xindex // 1000000)
tmp22 = tl.load(in_ptr0 + (x3), xmask)
tmp0 = x1
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.5
tmp3 = tmp1 + tmp2
tmp4 = tmp3 * tmp2
tmp5 = tmp4 - tmp2
tmp6 = tmp5.to(tl.int32)
tmp7 = x0
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp8 + tmp2
tmp10 = tmp9 * tmp2
tmp11 = tmp10 - tmp2
tmp12 = tmp11.to(tl.int32)
tmp13 = tmp6.to(tl.float32)
tmp14 = tmp5 - tmp13
tmp15 = 1.0
tmp16 = tmp15 - tmp14
tmp17 = tmp15 * tmp16
tmp18 = tmp12.to(tl.float32)
tmp19 = tmp11 - tmp18
tmp20 = tmp15 - tmp19
tmp21 = tmp17 * tmp20
tmp23 = tmp21 * tmp22
tmp24 = tl.full([1], 1, tl.int32)
tmp25 = tmp12 + tmp24
tmp26 = tl.full([1], 499, tl.int32)
tmp27 = triton_helpers.minimum(tmp25, tmp26)
tmp28 = tmp17 * tmp19
tmp29 = tmp28 * tmp22
tmp30 = tmp6 + tmp24
tmp31 = triton_helpers.minimum(tmp30, tmp26)
tmp32 = tmp15 * tmp14
tmp33 = tmp32 * tmp20
tmp34 = tmp33 * tmp22
tmp35 = tmp32 * tmp19
tmp36 = tmp35 * tmp22
tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (250000*x2)), tmp23, xmask)
tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp6) + (250000*x2)), tmp29, xmask)
tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp31) + (250000*x2)), tmp34, xmask)
tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp31) + (250000*x2)), tmp36, xmask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
args_1, = args
args.clear()
assert_size_stride(args_1, (7, 5, 1000, 1000), (5000000, 1000000, 1000, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((7, 5, 500, 500), (1250000, 250000, 500, 1), torch.float32)
# Source Nodes: [], Original ATen: []
stream0 = get_raw_stream(0)
triton_poi_fused_0.run(buf0, 8750000, grid=grid(8750000), stream=stream0)
# Source Nodes: [], Original ATen: []
triton_poi_fused_1.run(args_1, buf0, 35000000, grid=grid(35000000), stream=stream0)
del args_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
args_1 = rand_strided((7, 5, 1000, 1000), (5000000, 1000000, 1000, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([args_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Shared memory loads/stores are unnecessary in this case. cc @peterbell10
Based on a suggestion from @peterbell10 I removed AtomicRMWOp at https://github.com/openai/triton/blob/0ba87e2ff35f703f84040400554702ee55476cdb/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#L192 which resulted in the PTX not having any shared memory loads/stores. This resulted in the triton generated kernel to match the pytorch eager backend code whereas it was 50% slower previously with the shared stores and loads.
Is there a case where removing AtomicRMWOp as a layout anchor can result in incorrect code?
I don't think it will result in incorrect code, but I may be wrong. It can affect performance, so will likely need to go through benchmark suites to verify performance impact. Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'
I'm using pytorch v2.3.0-rc6
Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'
Given that graphsafe_set_state doesn't appear in the generated code, you probably just need to rebuild pytorch.
You are right. I thought I built it after the source pull.
I looked at this, but not sure what is the best solution :] Instead, I noticed a few things which I will try to figure out why. 1> It is not clear to me why the atomic op has a different layout sizePerThread = [1] (sizePerThread = [4] for the load op). 2> why the atomic op is an anchor for remove-layout 3> With sizePerThread = [1] and sizePerThread = [4], at ptx level, the atomic op uses the same instruction 8 times atom.global.gpu.acq_rel.add.f32. For the first case, there are two different predicates, but for the latter, it has one predicate. So it looks like sizePerThread=[4] is more efficent?
cc @ThomasRaoux @Jokeren for visibility.