cuvs icon indicating copy to clipboard operation
cuvs copied to clipboard

Using a wrapper `byte_array` to support both `uint8_t` and `int8_t`

Open divyegala opened this issue 2 months ago • 6 comments

Currently, libcuvs instantiates both uint8_t and int8_t which causes binary size bloat. Given the relative similarity of both types and the ease of normalization of values that are encompassed by them, I propose that we combine the two types using a wrapper cuvs::detail::byte_array.

The public API remains the same:

namespace cuvs {

void foo(raft::device_matrix_view<uint8_t, ...>(...);

void foo(raft::device_matrix_view<int8_t, ...>(...);

}

The definition of cuvs::detail::byte_array, which is a struct that would always normalize int8_t values to uint8_t when returning and de-normalize when assigning:

namespace cuvs::detail {

struct byte_array {
  void* data = nullptr;
  bool is_signed = false;

  byte_array(void* ptr, bool signed_flag)
      : data(ptr), is_signed(signed_flag) {}

  // Proxy that references an element in the array
  struct byte {
    byte_array* parent = nullptr;
    int64_t idx = -1;
    uint8_t value = 0;   // used for detached proxies

    // Constructor for live proxy
    byte(byte_array& p, int64_t i) : parent(&p), idx(i) {}

    // Copy constructor: detached copy stores the current value
    byte(const byte& other)
        : parent(nullptr), idx(-1), value(static_cast<uint8_t>(other)) {}

    // Copy assignment: detached copy stores value
    byte& operator=(const byte& other) {
        parent = nullptr;
        idx = -1;
        value = static_cast<uint8_t>(other);
        return *this;
    }

    // Deleted move operations
    byte(byte&& other) = delete;
    byte& operator=(byte&& other) = delete;

    // Conversion to uint8_t
    operator uint8_t() const {
        if (parent) {
            if (parent->is_signed) {
                int8_t val = reinterpret_cast<int8_t*>(parent->data)[idx];
                return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
            } else {
                return reinterpret_cast<uint8_t*>(parent->data)[idx];
            }
        } else {
            return value;  // return local value if detached
        }
    }

    // Assignment from uint8_t
    byte& operator=(uint8_t normalized_value) {
        if (parent) {
            if (parent->is_signed) {
                reinterpret_cast<int8_t*>(parent->data)[idx] =
                    static_cast<int8_t>(static_cast<int16_t>(normalized_value) - 128);
            } else {
                reinterpret_cast<uint8_t*>(parent->data)[idx] = normalized_value;
            }
        } else {
            value = normalized_value;  // store in local value if detached
        }
        return *this;
    }
  };

  // Non-const index access: returns live proxy
  byte operator[](int64_t idx) { return byte(*this, idx); }

  // Const index access: returns immediate value
  uint8_t operator[](int64_t idx) const {
    if (is_signed) {
      int8_t val = reinterpret_cast<int8_t*>(data)[idx];
      return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
    } else {
      return reinterpret_cast<uint8_t*>(data)[idx];
    }
  }

  // Dereference (like *ptr)
  uint8_t operator*() const { return (*this)[0]; }
  byte operator*() { return byte(*this, 0); }

  // Pointer arithmetic
  byte_array operator+(int64_t offset) const {
    if (is_signed)
      return byte_array(static_cast<int8_t*>(data) + offset, true);
    else
      return byte_array(static_cast<uint8_t*>(data) + offset, false);
  }

  bool operator==(const byte_array& other) const { return data == other.data; }
  bool operator!=(const byte_array& other) const { return !(*this == other); }
};

} // namespace cuvs::detail

This would change the kernel definition to:

template <typename T, typename DataT = std::conditional_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, cuvs::detail::byte_array, T*>
__global__ void kernel(DataT data, ...) {
  ...
} 

And kernel instantiation would now look like:

// T is still uint8_t or int8_t at this point
template <typename T>
void launch_kernel(raft::device_matrix_view<T, ...> data, ...) {
   if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {
       cuvs::detail::byte_array data_b{data.data_handle(), std::is_same_v<T, int8_t>);
       kernel<uint8_t><<<...>>>(data_b, ...)
   }
   else {
        // normal instantiation, DataT is initialized to T*
        kernel<<<...>>>(data.data_handle(), ...)
    }
}

This approach will allow us to instantiate kernels and do all arithmetic in uint8_t without having to modify user data. At the time of search, we still do arithmetic in uint8_t and can write back data in int8_t if that is what the user requests.

divyegala avatar Oct 09 '25 19:10 divyegala

This is a good proposal and I think it would solve the double instantiation that we are facing with uint8 and int8. I have a few remarks:

  • Can we do without the byte struct tosimplify things for users? I think not but can you explicitly say what's the main issue.
  • How can we preserve constness of the underlying data in case we get a const pointer: raft::device_matrix_view<const int8_t, ....>

As an additional note for later, some algorithms in cuvs take advantage of vectorization where multiple values are loaded at the same time (raft/util/vectorized.cuh and cuvs/src/neighbors/detail/device_common.hpp). So we will have to make sure that this does not bypass the live proxy.

lowener avatar Oct 16 '25 11:10 lowener

Thanks for your review @lowener!

Can we do without the byte struct tosimplify things for users? I think not but can you explicitly say what's the main issue.

The byte type is an internal proxy that is invisible to the user. It helps us deal with the dangling reference issue, because without it we cannot return direct references (meaning we can't support writes back to the data). As you may see in this example here for the int8_t case, val is a local variable:

  uint8_t& operator[](int64_t idx) {
    if (is_signed) {
      int8_t val = reinterpret_cast<int8_t*>(data)[idx];
      return static_cast<uint8_t>(static_cast<int16_t>(val) + 128);
    } else {
      return reinterpret_cast<uint8_t*>(data)[idx];
    }
  }

How can we preserve constness of the underlying data in case we get a const pointer: raft::device_matrix_view<const int8_t, ....>

This one is already supported! Just declare a const byte_array b{mds.data_handle(), ...}. There are const qualified operators in the struct that will return read-only references.

As an additional note for later, some algorithms in cuvs take advantage of vectorization where multiple values are loaded at the same time (raft/util/vectorized.cuh and cuvs/src/neighbors/detail/device_common.hpp). So we will have to make sure that this does not bypass the live proxy.

Thanks for pointing this out - I looked at the code and we should be able to support vectorized loads with some minor changes.

divyegala avatar Oct 16 '25 16:10 divyegala

Thanks for the proposal.

  1. This looks good to me and it wont affect any of the kernels in my PR https://github.com/rapidsai/cuvs/pull/1099 because that just allows uint8 inputs for binary ivf flat.
  2. If my understanding is right, for unsigned types, this would be non-owning and would not create a copy of the dataset. However, for signed types, we'd be preprocessing the entire dataset beforehand, essentially doubling the memory consumption -- or is this still non owning and does this cast/conversion for every fetched byte?

tarang-jain avatar Oct 17 '25 20:10 tarang-jain

Thanks for the review @tarang-jain ! There will be no pre-processing and this type is always non-owning. The signed to unsigned conversion will happen only in kernels, directly when you access the data. For most kernels that should be only once as we generally only access from global memory to store it in shared or register memory.

divyegala avatar Oct 17 '25 20:10 divyegala

Ok so the kernels will take byte_array as inputs instead of uint8_t* and int8_t*?

tarang-jain avatar Oct 17 '25 20:10 tarang-jain

Yep that's exactly right. I must note that I haven't tagged the host device tags but I will when I implement this.

divyegala avatar Oct 17 '25 20:10 divyegala