AITemplate
AITemplate copied to clipboard
What is the best way to accept uint8 input
float16 is not CPU-friendly and float32 input is unnecessarily large (if we are to add data marshaling).
I usually pass the input as bytes (uint8), then convert to float16 inside the model (a GPU node e.g. in ONNX). Currently I am thinking of adding a ops.castfp16 but wanted to ask if there is already a better solution.
I think it is possible to add a specialized cutlass/ck kernel to accept uint8 as input and do cast in prologue. An unfused way is to add a cast function.