triton
triton copied to clipboard
type conversion before tl.dot fails compilation
Context: we are trying to upgrade trition pin in pytorch and start encountering test failure for python test/inductor/test_pattern_matcher.py -k test_mixed_mm
Doing a type conversion for an input tensor to tl.dot starts to fail. It works in triton commit: e6216047b8b0aef1fe8da6ca8667a3ad0a016411 , but fail on a recent commit 0410652666ddcad54aa4c403276b478a36379b90 .
Here is a standalone repro w.py: https://gist.github.com/shunting314/3a3b8ce1ccee7b51b8ee0d9a2d24dd3d
Running python w.py
will report the following error in the new commit:
loc("/tmp/torchinductor_shunting/bm/cbm7qsh5esh6xdkdddmv7l2ilel4kdbfwgy2luolzmme62njagrb.py":64:17): error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Aborted (core dumped)
Here is a same repro but with inductor dependencies removed: https://gist.github.com/shunting314/4eb4a6a7d8cbfc6396726518196e26d1
cc @ptillet @Jokeren any idea how we can unblock the triton version upgrade in pytorch?
A similar testing program https://gist.github.com/shunting314/2eb4dd1c12ec0ac90df42d0d2d6efd3a with slight different input types also start fail on the new commit with error:
unimplemented code path
UNREACHABLE executed at /home/shunting/ws/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:254!
Aborted (core dumped)
Here is the same repro without dependencies to inductor: https://gist.github.com/shunting314/d164f74f9830239fbe04b1913b690fa0
Can you provide a standalone reproducer without the inductor dependency? I'm getting errors like
TypeError: instance_descriptor.new() got an unexpected keyword argument 'ids_of_folded_args'
@Jokeren oh, that's because without my other PR ( https://github.com/pytorch/pytorch/pull/107722 ) , inductor uses a stale definition of instance_descriptor. This slighted updated script https://gist.github.com/shunting314/3a3b8ce1ccee7b51b8ee0d9a2d24dd3d should avoid the issue you mentioned. I can also spend more time to remove inductor dependencies if that helps
@Jokeren BTW, here is another repro with inductor dependencies removed: https://gist.github.com/shunting314/4eb4a6a7d8cbfc6396726518196e26d1
Thanks, I can reproduce the bug now
Looks like different combination of dtypes surface different issues. Here are a summary of them
- bfloat16 -> float32
- repro https://gist.github.com/shunting314/4eb4a6a7d8cbfc6396726518196e26d1
- status: fixed by #2162
- int8 -> float32
- repro https://gist.github.com/shunting314/d164f74f9830239fbe04b1913b690fa0
- status: fixed by #2184
- int8 -> bfloat16
- repro https://gist.github.com/shunting314/724d60b5a4f834234c6e922a7567679d
- status: *** NEED HELP ***
- check error mesage 1 below
- uint8 -> float16
- this is similar to the case above. Hopefully fixing one also fixes the other. But here is the repro: https://gist.github.com/shunting314/9f47293d653582808bc0edd1c4db4e04
Error message 1:
python: /home/shunting/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/ArrayRef.h:255: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = mlir::Type; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.```
@Jokeren could you also look like the 8bit -> 16bit conversion case? I tried to report as much as I can. But these are the error I see after the others are fixed. Thanks
@Jokeren could you also look like the 8bit -> 16bit conversion case? I tried to report as much as I can. But these are the error I see after the others are fixed. Thanks
Sure, I'll probably take look early next week
Sure, I'll probably take look early next week
Sounds good. Thank you for all your efforts unblocking us upgrading triton version in pytorch!
@shunting314 The last problem is actually more general. Any 8-bit width dot ops could break on very small shapes (16x16) running on A100, which we haven't covered in the existing test cases.
I haven't figured out the fix yet, still investigating
@shunting314 I actually left a comment here https://github.com/openai/triton/blob/1465b573e8d8e4c707d579092001bcff0f1523ed/python/triton/language/semantic.py#L1280
It's better to use a blocking size greater or equal to 16x32 for i8. It's tricky in the mixed precision case because we don't know the operands were converted from other precisions so the frontend didn't report an error.
And I think the check is overkill because 16x16 fma should work.
I'm checking with @ptillet to see if we need to support small matrix sizes as other precisions or improve the IR verification in the backend. Even if none of the proposed changes finally landed, you can add some constraints at the inductor end.
@Jokeren I'm a bit confused here about the block size restriction for int8
- with the old triton pin, small block size for int8 actually works. Is the restriction added recently?
- the 'int8 -> float32' case also create small block for int8 but has been fixed. Why that one works while the 'int8 -> bfloat16' case doesn't?
On the other hand, I'm trying to use larger block size to confirm
- with the old triton pin, small block size for int8 actually works. Is the restriction added recently?
hmm, I'm not sure about why that worked.
- the 'int8 -> float32' case also create small block for int8 but has been fixed. Why that one works while the 'int8 -> bfloat16' case doesn't?
The previous one was for 16x16 with blocked layout (non-tensor core path).
@Jokeren Just FYI, we work around the issue regarding int8 by make sure the BLOCK SIZE is at least 32. In case you figure out why it works in previous commit later, let us know :)
I am using commit id: 768fc1fcd98ecfc0892f8982b0bb009dd7bb11ea and type convert fp32->fp16 the reproduce code:
from triton.ops.matmul import matmul
def test_addmm(self):
batch,seq,hidden = 20,64,768
output = 1024
device = torch.cuda.current_device()
x = torch.randn(batch*seq,hidden,device=device,
# dtype=torch.float16,
requires_grad=True)
w = torch.randn(output, hidden, device=device,
# dtype=torch.float16,
requires_grad=True)
with torch.cuda.amp.autocast(dtype=torch.float16):
out = F.linear(x,w)
out_triton = matmul(x,w.T)
torch.testing.assert_close(out,out_triton)
I have test on T4,nvidia driver Driver Version: 450.80.02 CUDA Version: 11.0 and A100, Driver Version: 470.82.01 CUDA Version: 11.8 there are same issue, call stack:
#0 0x00007ffff71188af in raise () from /usr/lib64/libc.so.6
#1 0x00007ffff711a4aa in abort () from /usr/lib64/libc.so.6
#2 0x00007ffff7110d37 in __assert_fail_base () from /usr/lib64/libc.so.6
#3 0x00007ffff7110de2 in __assert_fail () from /usr/lib64/libc.so.6
#4 0x00007fff3d51c483 in llvm::ArrayRef<mlir::Value>::operator[] (Index=<optimized out>, this=<optimized out>) at /root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/ArrayRef.h:254
#5 0x00007fff3d52353f in llvm::ArrayRef<mlir::Value>::operator[] (Index=<optimized out>, this=<optimized out>) at /root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/ArrayRef.h:254
#6 operator() (__closure=<optimized out>, __closure=<optimized out>, idx=<optimized out>) at /root/python/torch-triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:299
#7 MMA16816SmemLoader::loadX4 (this=this@entry=0x7fffffff56e0, mat0=<optimized out>, mat1=<optimized out>, ptrs=..., matTy=..., shemPtrTy=...) at /root/python/torch-triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:300
#8 0x00007fff3d525d76 in operator() (__closure=0x6517f290, a=0, b=2) at /root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:291
#9 0x00007fff3d52271a in std::function<void (int, int)>::operator()(int, int) const (__args#1=2, __args#0=0, this=0x7fffffff5950) at /usr/include/c++/10/bits/std_function.h:617
#10 loadArg (rewriter=..., loc=..., tensor=..., encoding=..., smemObj=..., typeConverter=0x7fffffff76f0, thread=..., isA=true) at /root/python/torch-triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:616
#11 0x00007fff3d4864bf in ConvertLayoutOpConversion::lowerSharedToDotOperandMMA (this=this@entry=0x64eb9a50, op=..., adaptor=..., rewriter=..., mmaLayout=..., dotOperandLayout=..., isOuter=isOuter@entry=false)
@Jokeren
It might be T4 specific
I'm seeing similar issues here with uint8
→ float16
(or float32
). Using nightly with an A6000. The application is quantized matrix multiplication. I've found that basically only a block size of 16 works, while anything larger results in either "Invalid index!" for f16 or "unimplemented code path" and segfault for f32.
Just curious if we have an idea of what's causing this, or if there are any hints for choosing the right block size. Thanks!