Optimize euclidean distance in host refine phase
Issue
Original code (below) generated serial assembly and used strictly-ordered fadda instruction on ARM with gcc & clang. That resulted in suboptimal performance.
for (size_t k = 0; k < dim; k++) {
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
Proposed solution
This PR provides euclidean distance optimized with partial vector sum (below), that helps vectorization but loses strcictly-ordered compliance.
template <typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) {
size_t constexpr max_vreg_len = 512 / (8 * sizeof(DistanceT));
// max_vreg_len is a power of two
size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1));
DistanceT distance[max_vreg_len] = {0};
for (size_t i = 0; i < n_rounded; i += max_vreg_len) {
for (size_t j = 0; j < max_vreg_len; ++j) {
distance[j] += DC::template eval<DistanceT>(a[i + j], b[i + j]);
}
}
for (size_t i = n_rounded; i < n; ++i) {
distance[i] += DC::template eval<DistanceT>(a[i], b[i]);
}
for (size_t i = 1; i < max_vreg_len; ++i) {
distance[0] += distance[i];
}
return distance[0];
}
In addition, it has an implementation with NEON intrinsics which provides further speedup on certain test cases (can be removed if arch-specific code is undesired).
Results
This pull request requires additional validation before any workflows can run on NVIDIA's runners.
Pull request vetters can view their responsibilities here.
Contributors can view more details about this message here.
/ok to test
/ok to test
UPD:
@cjnolet, seems like CI is triggered only by repository members, could you please do it one more time?
I changed formatting with clang-format.
/ok to test
You have changed the distance computation for the large batch size case, but did not change for the small batch case (which is handled in a separate branch here). Is this because your benchmarks have shown no improvement for the small batch case?
Correct, on small batch size I saw a minor performance degradation, so I decided to apply optimization only to a large batch.
/ok to test
/ok to test
/ok to test 2de39ef
/ok to test 18fe20f
/merge