torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Missing dtype support for OnnxToTorch

Open jinchen62 opened this issue 1 year ago • 3 comments

Current support: https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchOnnxToTorch/Utils.cpp#L64

According to onnx dtype and torch dtype, missing support:

  • 4: onnx.TensorProto.UINT16
  • 8: onnx.TensorProto.STRING
  • 12: onnx.TensorProto.UINT32
  • 13: onnx.TensorProto.UINT64
  • 17: onnx.TensorProto.FLOAT8E4M3FN
  • 18: onnx.TensorProto.FLOAT8E4M3FNUZ
  • 19: onnx.TensorProto.FLOAT8E5M2
  • 20: onnx.TensorProto.FLOAT8E5M2FNUZ
  • 21: onnx.TensorProto.UINT4
  • 22: onnx.TensorProto.INT4

jinchen62 avatar Apr 29 '24 01:04 jinchen62

https://github.com/pytorch/pytorch/blob/main/c10/core/ScalarType.h#L55

jinchen62 avatar May 09 '24 09:05 jinchen62

https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto#L487

jinchen62 avatar May 09 '24 09:05 jinchen62

checkpoint https://github.com/jinchen62/torch-mlir/tree/dtype_support https://github.com/jinchen62/llvm-project/tree/dtype_support

added float8 types to the op def, it could generates %78 = "llvm.fptrunc"(%76) : (vector<4xf32>) -> vector<4xf8E5M2>, but it fails on https://github.com/llvm/llvm-project/blob/main/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp#L79, and seems there is no support for float8 here https://github.com/llvm/llvm-project/blob/main/llvm/lib/IR/Type.cpp#L234

jinchen62 avatar May 14 '24 04:05 jinchen62