Implementation and Spec mismatch in tosa.matmul
The lowering from torch-mlir to tosa for the matmul seems not to follow the spec (https://www.mlplatform.org/tosa/tosa_spec.html#_matmul) when looking at the output type defined for each input type. For example, i8, i8 requires i32 , bf16, bf16 uses f32.
@eric-k256 or anyone knows if this is intentional? Is the implementation (https://github.com/llvm/torch-mlir/blob/4cc62aeb24e28b3ff60df6ff4a0fd97cc045efc1/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L1458) still correct or it should be following strictly the spec (with the introduction of tosa.cast at the end) ?
Thank you. Tiago
Sorry for the delay. I'd recommend we modify to follow the specification. What is in there now works for common cases (fp32 and fp16 with fp16 accumulate) but doesn't work for bf16 or fp16 with fp32 accumulate, and would not work for any of the fp8 implementations. I don't know how much anyone has looked carefully at torch bf16 implementations with TOSA, which is probably why it has been missed so far.