Using a wrapper `byte_array` to support both `uint8_t` and `int8_t`
Currently, libcuvs instantiates both uint8_t and int8_t which causes binary size bloat. Given the relative similarity of both types and the ease of normalization of values that are encompassed by them, I propose that we combine the two types using a wrapper cuvs::detail::byte_array.
The public API remains the same:
namespace cuvs {
void foo(raft::device_matrix_view<uint8_t, ...>(...);
void foo(raft::device_matrix_view<int8_t, ...>(...);
}
The definition of cuvs::detail::byte_array, which is a struct that would always normalize int8_t values to uint8_t when returning and de-normalize when assigning:
namespace cuvs::detail {
struct byte_array {
void* data = nullptr;
bool is_signed = false;
byte_array(void* ptr, bool signed_flag)
: data(ptr), is_signed(signed_flag) {}
// Proxy that references an element in the array
struct byte {
byte_array* parent = nullptr;
int64_t idx = -1;
uint8_t value = 0; // used for detached proxies
// Constructor for live proxy
byte(byte_array& p, int64_t i) : parent(&p), idx(i) {}
// Copy constructor: detached copy stores the current value
byte(const byte& other)
: parent(nullptr), idx(-1), value(static_cast<uint8_t>(other)) {}
// Copy assignment: detached copy stores value
byte& operator=(const byte& other) {
parent = nullptr;
idx = -1;
value = static_cast<uint8_t>(other);
return *this;
}
// Deleted move operations
byte(byte&& other) = delete;
byte& operator=(byte&& other) = delete;
// Conversion to uint8_t
operator uint8_t() const {
if (parent) {
if (parent->is_signed) {
int8_t val = reinterpret_cast<int8_t*>(parent->data)[idx];
return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
} else {
return reinterpret_cast<uint8_t*>(parent->data)[idx];
}
} else {
return value; // return local value if detached
}
}
// Assignment from uint8_t
byte& operator=(uint8_t normalized_value) {
if (parent) {
if (parent->is_signed) {
reinterpret_cast<int8_t*>(parent->data)[idx] =
static_cast<int8_t>(static_cast<int16_t>(normalized_value) - 128);
} else {
reinterpret_cast<uint8_t*>(parent->data)[idx] = normalized_value;
}
} else {
value = normalized_value; // store in local value if detached
}
return *this;
}
};
// Non-const index access: returns live proxy
byte operator[](int64_t idx) { return byte(*this, idx); }
// Const index access: returns immediate value
uint8_t operator[](int64_t idx) const {
if (is_signed) {
int8_t val = reinterpret_cast<int8_t*>(data)[idx];
return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
} else {
return reinterpret_cast<uint8_t*>(data)[idx];
}
}
// Dereference (like *ptr)
uint8_t operator*() const { return (*this)[0]; }
byte operator*() { return byte(*this, 0); }
// Pointer arithmetic
byte_array operator+(int64_t offset) const {
if (is_signed)
return byte_array(static_cast<int8_t*>(data) + offset, true);
else
return byte_array(static_cast<uint8_t*>(data) + offset, false);
}
bool operator==(const byte_array& other) const { return data == other.data; }
bool operator!=(const byte_array& other) const { return !(*this == other); }
};
} // namespace cuvs::detail
This would change the kernel definition to:
template <typename T, typename DataT = std::conditional_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, cuvs::detail::byte_array, T*>
__global__ void kernel(DataT data, ...) {
...
}
And kernel instantiation would now look like:
// T is still uint8_t or int8_t at this point
template <typename T>
void launch_kernel(raft::device_matrix_view<T, ...> data, ...) {
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {
cuvs::detail::byte_array data_b{data.data_handle(), std::is_same_v<T, int8_t>);
kernel<uint8_t><<<...>>>(data_b, ...)
}
else {
// normal instantiation, DataT is initialized to T*
kernel<<<...>>>(data.data_handle(), ...)
}
}
This approach will allow us to instantiate kernels and do all arithmetic in uint8_t without having to modify user data. At the time of search, we still do arithmetic in uint8_t and can write back data in int8_t if that is what the user requests.
This is a good proposal and I think it would solve the double instantiation that we are facing with uint8 and int8. I have a few remarks:
- Can we do without the byte struct tosimplify things for users? I think not but can you explicitly say what's the main issue.
- How can we preserve constness of the underlying data in case we get a const pointer:
raft::device_matrix_view<const int8_t, ....>
As an additional note for later, some algorithms in cuvs take advantage of vectorization where multiple values are loaded at the same time (raft/util/vectorized.cuh and cuvs/src/neighbors/detail/device_common.hpp). So we will have to make sure that this does not bypass the live proxy.
Thanks for your review @lowener!
Can we do without the byte struct tosimplify things for users? I think not but can you explicitly say what's the main issue.
The byte type is an internal proxy that is invisible to the user. It helps us deal with the dangling reference issue, because without it we cannot return direct references (meaning we can't support writes back to the data). As you may see in this example here for the int8_t case, val is a local variable:
uint8_t& operator[](int64_t idx) {
if (is_signed) {
int8_t val = reinterpret_cast<int8_t*>(data)[idx];
return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
} else {
return reinterpret_cast<uint8_t*>(data)[idx];
}
}
How can we preserve constness of the underlying data in case we get a const pointer: raft::device_matrix_view<const int8_t, ....>
This one is already supported! Just declare a const byte_array b{mds.data_handle(), ...}. There are const qualified operators in the struct that will return read-only references.
As an additional note for later, some algorithms in cuvs take advantage of vectorization where multiple values are loaded at the same time (raft/util/vectorized.cuh and cuvs/src/neighbors/detail/device_common.hpp). So we will have to make sure that this does not bypass the live proxy.
Thanks for pointing this out - I looked at the code and we should be able to support vectorized loads with some minor changes.
Thanks for the proposal.
- This looks good to me and it wont affect any of the kernels in my PR https://github.com/rapidsai/cuvs/pull/1099 because that just allows uint8 inputs for binary ivf flat.
- If my understanding is right, for unsigned types, this would be non-owning and would not create a copy of the dataset. However, for signed types, we'd be preprocessing the entire dataset beforehand, essentially doubling the memory consumption -- or is this still non owning and does this cast/conversion for every fetched byte?
Thanks for the review @tarang-jain ! There will be no pre-processing and this type is always non-owning. The signed to unsigned conversion will happen only in kernels, directly when you access the data. For most kernels that should be only once as we generally only access from global memory to store it in shared or register memory.
Ok so the kernels will take byte_array as inputs instead of uint8_t* and int8_t*?
Yep that's exactly right. I must note that I haven't tagged the host device tags but I will when I implement this.