triton icon indicating copy to clipboard operation
triton copied to clipboard

triton generates unnecessary shared memory stores/loads

Open isuruf opened this issue 1 year ago • 8 comments

For the following triton kernels generated by pytorch, triton generated shared memory stores and loads in the LLVM IR and PTX just before the atomic add operation.

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/5e/c5ehw64oxeoeqqjnqn6v3gfy6z5ukksktwihp7jgzg6sujz5umto.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 8750000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = 0.0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/qg/cqgmsmgdzivumf2gmksclwbmyrwpfpouuv3s5suqkeg4j4cdmpjr.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[67108864], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 35000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 1000) % 1000
    x0 = xindex % 1000
    x3 = xindex
    x2 = (xindex // 1000000)
    tmp22 = tl.load(in_ptr0 + (x3), xmask)
    tmp0 = x1
    tmp1 = tmp0.to(tl.float32)
    tmp2 = 0.5
    tmp3 = tmp1 + tmp2
    tmp4 = tmp3 * tmp2
    tmp5 = tmp4 - tmp2
    tmp6 = tmp5.to(tl.int32)
    tmp7 = x0
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp8 + tmp2
    tmp10 = tmp9 * tmp2
    tmp11 = tmp10 - tmp2
    tmp12 = tmp11.to(tl.int32)
    tmp13 = tmp6.to(tl.float32)
    tmp14 = tmp5 - tmp13
    tmp15 = 1.0
    tmp16 = tmp15 - tmp14
    tmp17 = tmp15 * tmp16
    tmp18 = tmp12.to(tl.float32)
    tmp19 = tmp11 - tmp18
    tmp20 = tmp15 - tmp19
    tmp21 = tmp17 * tmp20
    tmp23 = tmp21 * tmp22
    tmp24 = tl.full([1], 1, tl.int32)
    tmp25 = tmp12 + tmp24
    tmp26 = tl.full([1], 499, tl.int32)
    tmp27 = triton_helpers.minimum(tmp25, tmp26)
    tmp28 = tmp17 * tmp19
    tmp29 = tmp28 * tmp22
    tmp30 = tmp6 + tmp24
    tmp31 = triton_helpers.minimum(tmp30, tmp26)
    tmp32 = tmp15 * tmp14
    tmp33 = tmp32 * tmp20
    tmp34 = tmp33 * tmp22
    tmp35 = tmp32 * tmp19
    tmp36 = tmp35 * tmp22
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (250000*x2)), tmp23, xmask)
    tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp6) + (250000*x2)), tmp29, xmask)
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp31) + (250000*x2)), tmp34, xmask)
    tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp31) + (250000*x2)), tmp36, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    args_1, = args
    args.clear()
    assert_size_stride(args_1, (7, 5, 1000, 1000), (5000000, 1000000, 1000, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((7, 5, 500, 500), (1250000, 250000, 500, 1), torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(buf0, 8750000, grid=grid(8750000), stream=stream0)
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_1.run(args_1, buf0, 35000000, grid=grid(35000000), stream=stream0)
        del args_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    args_1 = rand_strided((7, 5, 1000, 1000), (5000000, 1000000, 1000, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([args_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Shared memory loads/stores are unnecessary in this case. cc @peterbell10

isuruf avatar Mar 28 '24 15:03 isuruf

Based on a suggestion from @peterbell10 I removed AtomicRMWOp at https://github.com/openai/triton/blob/0ba87e2ff35f703f84040400554702ee55476cdb/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#L192 which resulted in the PTX not having any shared memory loads/stores. This resulted in the triton generated kernel to match the pytorch eager backend code whereas it was 50% slower previously with the shared stores and loads.

isuruf avatar Mar 28 '24 19:03 isuruf

Is there a case where removing AtomicRMWOp as a layout anchor can result in incorrect code?

isuruf avatar Mar 28 '24 19:03 isuruf

I don't think it will result in incorrect code, but I may be wrong. It can affect performance, so will likely need to go through benchmark suites to verify performance impact. Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

manman-ren avatar Mar 29 '24 19:03 manman-ren

I'm using pytorch v2.3.0-rc6

isuruf avatar Mar 29 '24 21:03 isuruf

Which version of pytorch are you on? I tried to run your code, but failed. AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

Given that graphsafe_set_state doesn't appear in the generated code, you probably just need to rebuild pytorch.

peterbell10 avatar Mar 29 '24 21:03 peterbell10

You are right. I thought I built it after the source pull.

manman-ren avatar Mar 29 '24 23:03 manman-ren

I looked at this, but not sure what is the best solution :] Instead, I noticed a few things which I will try to figure out why. 1> It is not clear to me why the atomic op has a different layout sizePerThread = [1] (sizePerThread = [4] for the load op). 2> why the atomic op is an anchor for remove-layout 3> With sizePerThread = [1] and sizePerThread = [4], at ptx level, the atomic op uses the same instruction 8 times atom.global.gpu.acq_rel.add.f32. For the first case, there are two different predicates, but for the latter, it has one predicate. So it looks like sizePerThread=[4] is more efficent?

manman-ren avatar Apr 01 '24 18:04 manman-ren

cc @ThomasRaoux @Jokeren for visibility.

lezcano avatar Apr 03 '24 18:04 lezcano