Should we add bfloat16 support for HNSW?
Description
One of the biggest pain points of HNSW is that the graph and vectors must be in memory.
Since the vectors are stored off heap and read in via byte streams, it seems like we could reduce the memory requirements by half for a typical use case if we stored the vector dimensions in 2 bytes instead of 4.
When the bytes are read into heap for comparison, we would still be required to use float (until the JVM supports a native float16 type).
I guess the open questions are:
- Do we think this is worth it? It seems like users could get 2x memory savings with almost no configuration change.
- How big a hit on performance is this. I am assuming decoding a
bfloat16bytes will take longer than decoding afloatasfloatdecoding can use intrinsics.
I recognize that in the future Lucene will likely support a separate vector codec that uses less memory and is more disk-friendly, but I would argue even such a structure could benefit from only storing bfloat16 instead of float32.
afaik 16-bit fp support is in newer versions of java (21?) and being worked on for vector api there too. not sure of its current state.
in java 20+ there are at least functions for simple scalar conversions: https://docs.oracle.com/en/java/javase/20/docs/api/java.base/java/lang/Float.html#float16ToFloat(short) https://docs.oracle.com/en/java/javase/20/docs/api/java.base/java/lang/Float.html#floatToFloat16(float)
So it could be used to store data as short and then "expand" to float32 for calculations.
But doing this would just cost cpu for conversion, save a little space, but give no real advantage of using float16, which is much faster vector processing.
Here is link to their in-progress half-float vectorization: https://github.com/openjdk/panama-vector/blob/vectorIntrinsics%2Bfp16/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/HalffloatVector.java
afaik its not even in 21 yet so not yet usable.
looking at that branch too, the hardware support currently only exists for x86:
add: https://github.com/openjdk/panama-vector/blob/vectorIntrinsics%2Bfp16/src/hotspot/cpu/x86/x86.ad#L5527-L5536 mul: https://github.com/openjdk/panama-vector/blob/vectorIntrinsics%2Bfp16/src/hotspot/cpu/x86/x86.ad#L6036-L6045 reduce: https://github.com/openjdk/panama-vector/blob/vectorIntrinsics%2Bfp16/src/hotspot/cpu/x86/x86.ad#L5002-L5011 sub: https://github.com/openjdk/panama-vector/blob/vectorIntrinsics%2Bfp16/src/hotspot/cpu/x86/x86.ad#L5746-L5755
This is all we use across the floating point vector functions.
There is a issue with pull request for the ARM support: https://bugs.openjdk.org/browse/JDK-8305563
I recommend waiting.
I wonder if now that main requires jdk 21, if its worth it now? I would have to dig around to see if there are fast intrinsic decoding/encoding now for storing short floats. But we could inflate to float32 for comparisons (not optimal), but folks might be willing to sacrifice the f16 performance in order to gain a 2x reduction in disk & memory space.
vector api still doesnt support it yet in openjdk main
I think JDK now support Float16 objects. The HotSpot C2 compiler can auto-vectorize addition, subtraction, division, multiplication, square root, and fused multiply/add on supporting CPUs (From the JEP-529).
But VectorAPI still doesn't support Float16 vectors operations in the main. From the doc:
We may broaden the auto-vectorization of Float16 operations, eventually covering all relevant operations on supporting hardware. We may also enhance the Vector API and implementation to cover vectors of Float16 values; for exploratory work, see the [vectorIntrinsics+fp16 branch](https://github.com/openjdk/panama-vector/tree/vectorIntrinsics+fp16) of Project Panama's development repository. We will migrate Float16 to become a value class when [Project Valhalla](https://openjdk.org/projects/valhalla/) becomes available.
Does this mean, we should wait for Project Valhalla before pursuing this idea of supporting fp16 vectors or should we just add support using Float16 arrays and implementing vector score implementation in DefaultVectorUtilSupport class?
Also note that bfloat16 (the subject of this issue) is different from IEEE fp16, which is what I think Float16 is representing