Missing dtype support for OnnxToTorch
Current support: https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchOnnxToTorch/Utils.cpp#L64
According to onnx dtype and torch dtype, missing support:
- 4: onnx.TensorProto.UINT16
- 8: onnx.TensorProto.STRING
- 12: onnx.TensorProto.UINT32
- 13: onnx.TensorProto.UINT64
- 17: onnx.TensorProto.FLOAT8E4M3FN
- 18: onnx.TensorProto.FLOAT8E4M3FNUZ
- 19: onnx.TensorProto.FLOAT8E5M2
- 20: onnx.TensorProto.FLOAT8E5M2FNUZ
- 21: onnx.TensorProto.UINT4
- 22: onnx.TensorProto.INT4
https://github.com/pytorch/pytorch/blob/main/c10/core/ScalarType.h#L55
https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto#L487
checkpoint https://github.com/jinchen62/torch-mlir/tree/dtype_support https://github.com/jinchen62/llvm-project/tree/dtype_support
added float8 types to the op def, it could generates %78 = "llvm.fptrunc"(%76) : (vector<4xf32>) -> vector<4xf8E5M2>, but it fails on https://github.com/llvm/llvm-project/blob/main/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp#L79, and seems there is no support for float8 here https://github.com/llvm/llvm-project/blob/main/llvm/lib/IR/Type.cpp#L234