Python bindings don't support bfloat16 (yet)
The main issue here is that the Python buffer protocol understands all the other common scalar types we use (even float16)... but not bfloat16, so trying to do anything interesting with it will likely fail.
This isn't a huge blocker for anyone (AFAIK), since there is very little code outside of Tensorflow/TFLite that uses bfloat16, but closing the loop on this would be nice. Opening this as a tracking issue.
@steven-johnson
I'm implementing a Halide backend for PyTorch/torch.compile/TorchInductor. Early work-in-progress version here: https://github.com/pytorch/pytorch/pull/126417
For this backend I am using the Halide-Python bindings to define a hl.generator that generates the kernel (motivation is to lower compile times by skipping a C++ compile). I am then calling the generated kernels using the C++ API and the generated *.{h,a}.
I don't actually need to be able to call bfloat16 kernels from Python, just define them in an hl.generator. So I don't think the lack of buffer protocol support is an issue for me.
As far as I can tell there is no way to do this today. hl.BFloat (and variants) is not defined in Python. Would this be easy to add? Any pointers?
@steven-johnson
I'm implementing a Halide backend for PyTorch/torch.compile/TorchInductor. Early work-in-progress version here: pytorch/pytorch#126417
For this backend I am using the Halide-Python bindings to define a
hl.generatorthat generates the kernel (motivation is to lower compile times by skipping a C++ compile). I am then calling the generated kernels using the C++ API and the generated*.{h,a}.I don't actually need to be able to call bfloat16 kernels from Python, just define them in an
hl.generator. So I don't think the lack of buffer protocol support is an issue for me.As far as I can tell there is no way to do this today.
hl.BFloat(and variants) is not defined in Python. Would this be easy to add? Any pointers?
It's been a minute since I worked on this code, but IIRC the main issue when I opened this was that the Python buffer protocols (which we use for interop with e.g. NumPy and etc.) don't have an intrinsic type for BFloat16, and (at the time) had no plans to add one. If there is now support, then the rest would be pretty trivial to add. We'd welcome any contribution along these lines from anyone, as we'd love to support BFloat16 here, it just didn't seem (easily) doable at the time and we didn't know of any pressing demand for it.