[QST] Why _CUTLASS_TYPE_TO_TORCH_TYPE doesn't support torch.bfloat16?
What is your question?
Inpython/cutlass/emit/pytorch.py, bfloat16 is not supported?
_CUTLASS_TYPE_TO_TORCH_TYPE = {
DataType.f16: "torch::kF16",
DataType.f32: "torch::kF32",
DataType.f64: "torch::kF64",
DataType.s8: "torch::I8",
DataType.s32: "torch::I32",
}
We simply haven't implemented it. We welcome contributions in this space.
Hi @jackkosaian could I take care of it? Could you assign me?
Yes, please feel free to submit a PR supporting this.
Hi, I have added PR https://github.com/NVIDIA/cutlass/pull/1843 I added it because:
- torch::kBFloat16 is a best match for bf16 type
- torch::kBFloat16 is used in pytorch benchmarks (pytorch/benchmarks/static_runtime/test_static_runtime.cc) and unit tests (pythorch/test/cpp/api/functional.cpp TEST_F(FunctionalTest, ELU))
- bf16 is already mapped to kBFloat16 in other place of cutlass (cutlass/python/cutlass/utils/datatypes.py is_torch_available()) in library_to_torch_dict. However it has different value types than _CUTLASS_TYPE_TO_TORCH_TYPE
However I have some doubts:
- as I mentioned in PR description in pytorch/test/cpp/api/support.h void assert_tensor_equal() kBFloat16 type is cast to kFloat32 because some tensor operations are not available for kBFloat16
- at::kBFloat16 is not exposed explicitly in torch namespace in pytorch/torch/csrc/api/include/torch/types.h.
I was not able to run 02_pytorch_extension_grouped_gemm.ipynb with my changes so I have no living proof that it works at least in this example. Any advice how to import pytorch from local repo will be welcome.