Assertion error from linear layouts
I am running into an assertion error in the codegen for local_load which is coming from the linear layouts code. Here is a minified reproducer
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 2056 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @test_fn() attributes {noinline = false} {
%0 = triton_gpu.local_alloc { allocation.offset = 0 : i32} : () -> !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable>
%1 = triton_gpu.local_load %0 : !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<4x128xf32, #blocked>
tt.return
}
}
When lowering to llvm ir it fails with the following error
$ triton-opt --convert-triton-gpu-to-llvm repro.ttgir
triton-opt: /root/code/triton/lib/Tools/LinearLayout.cpp:512: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: /root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt --convert-triton-gpu-to-llvm repro.ttgir
#0 0x00005621e7032ff7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c94ff7)
#1 0x00005621e7030b1e llvm::sys::RunSignalHandlers() (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c92b1e)
#2 0x00005621e70336af SignalHandler(int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c956af)
#3 0x00007f20a484c420 __restore_rt (/usr/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
#4 0x00007f20a431900b raise /build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:51:1
#5 0x00007f20a42f8859 abort /build/glibc-LcI20x/glibc-2.31/stdlib/abort.c:81:7
#6 0x00007f20a42f8729 get_sysdep_segment_value /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:509:8
#7 0x00007f20a42f8729 _nl_load_domain /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:970:34
#8 0x00007f20a4309fd6 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
#9 0x00005621e4aac52a mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const /root/code/triton/lib/Tools/LinearLayout.cpp:520:37
#10 0x00005621e46b26dd mlir::emitTransferBetweenRegistersAndShared(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, std::optional<int>, mlir::Value, llvm::ArrayRef<mlir::Value>, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&, std::function<void (mlir::VectorType, mlir::Value)>) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:307:61
#11 0x00005621e46b31f3 mlir::loadSharedToDistributed(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, mlir::LLVM::SharedMemoryObject, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:386:55
#12 0x00005621e47a4185 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:172:69
#13 0x00005621e47a3d05 (anonymous namespace)::LocalLoadOpConversion::matchAndRewrite(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:124:47
#14 0x00005621e47ac85d mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::LocalLoadOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/mlir/Conversion/LLVMCommon/Pattern.h:166:77
#15 0x00005621e6b3bd10 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279dd10)
#16 0x00005621e6b7a65b mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27dc65b)
#17 0x00005621e6b771df mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27d91df)
#18 0x00005621e6b3cca1 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279eca1)
#19 0x00005621e6b3bdb4 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279ddb4)
#20 0x00005621e6b3d1bf mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279f1bf)
#21 0x00005621e6b438fb mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27a58fb)
#22 0x00005621e4d6e312 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /root/code/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp:178:15
#23 0x00005621e6081996 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce3996)
#24 0x00005621e6082140 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce4140)
#25 0x00005621e60845f5 mlir::PassManager::run(mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce65f5)
#26 0x00005621e607dccf performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdfccf)
#27 0x00005621e607d8fd llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdf8fd)
#28 0x00005621e6fb2656 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c14656)
#29 0x00005621e6078721 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda721)
#30 0x00005621e60789d3 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda9d3)
#31 0x00005621e6078da6 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdada6)
#32 0x00005621e4e72ad0 main /root/code/triton/bin/triton-opt.cpp:9:0
#33 0x00007f20a42fa083 __libc_start_main /build/glibc-LcI20x/glibc-2.31/csu/../csu/libc-start.c:342:3
#34 0x00005621e468707e _start (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2e907e)
cc @Jokeren @jlebar
Just to confirm, the TritonGPU IR is generated from valid Triton python code?
It's came from the lowering from a new operator I'm adding, but I'll see if I can reproduce with an existing operator.
This produces the same error on the current master branch
import triton.language as tl
import triton
import torch
@triton.jit
def test_fn(out_ptr, a_ptr, workspace, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
desc_ptr = workspace
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=desc_ptr, global_address=a_ptr, load_size=[4, N_BLOCK], global_size=[M, N], element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_ptr)
gather = tl._experimental_descriptor_load(desc_ptr, [0, 0], [4, N_BLOCK], a_ptr.dtype.element_ty)
tl.store(out_ptr + tl.arange(0, 4)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :], gather)
out = torch.empty((4, 128), dtype=torch.float32, device="cuda")
inp = torch.arange(4 * 128, dtype=torch.float32, device="cuda").reshape(4, 128)
workspace = torch.empty(128, dtype=torch.uint8, device="cuda")
test_fn[(1,)](out, inp, workspace, 4, 128, 4, 128)
I'll take a look today
Reopening this as it seems the TMA hardware does support swizzling with only 4 rows of data.
I get this result if it's helpful:
unswizzled:
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.,
24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35.,
36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.,
48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,
60., 61., 62., 63.],
[128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139.,
140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151.,
152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163.,
164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
188., 189., 190., 191.],
[256., 257., 258., 259., 260., 261., 262., 263., 264., 265., 266., 267.,
268., 269., 270., 271., 272., 273., 274., 275., 276., 277., 278., 279.,
280., 281., 282., 283., 284., 285., 286., 287., 288., 289., 290., 291.,
292., 293., 294., 295., 296., 297., 298., 299., 300., 301., 302., 303.,
304., 305., 306., 307., 308., 309., 310., 311., 312., 313., 314., 315.,
316., 317., 318., 319.],
[384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,
396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407.,
408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419.,
420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431.,
432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443.,
444., 445., 446., 447.]], device='cuda:0', dtype=torch.float16)
swizzled:
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.,
24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35.,
36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.,
48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59.,
60., 61., 62., 63.],
[136., 137., 138., 139., 140., 141., 142., 143., 128., 129., 130., 131.,
132., 133., 134., 135., 152., 153., 154., 155., 156., 157., 158., 159.,
144., 145., 146., 147., 148., 149., 150., 151., 168., 169., 170., 171.,
172., 173., 174., 175., 160., 161., 162., 163., 164., 165., 166., 167.,
184., 185., 186., 187., 188., 189., 190., 191., 176., 177., 178., 179.,
180., 181., 182., 183.],
[272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283.,
284., 285., 286., 287., 256., 257., 258., 259., 260., 261., 262., 263.,
264., 265., 266., 267., 268., 269., 270., 271., 304., 305., 306., 307.,
308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319.,
288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299.,
300., 301., 302., 303.],
[408., 409., 410., 411., 412., 413., 414., 415., 400., 401., 402., 403.,
404., 405., 406., 407., 392., 393., 394., 395., 396., 397., 398., 399.,
384., 385., 386., 387., 388., 389., 390., 391., 440., 441., 442., 443.,
444., 445., 446., 447., 432., 433., 434., 435., 436., 437., 438., 439.,
424., 425., 426., 427., 428., 429., 430., 431., 416., 417., 418., 419.,
420., 421., 422., 423.]], dtype=torch.float16)
I think the problem is on this line int tileRows = 8;
I'll try to address it tomorrow