hashbrown icon indicating copy to clipboard operation
hashbrown copied to clipboard

get_many_mut with variable number of keys

Open Ten0 opened this issue 2 years ago • 18 comments

get_many_mut on slices (https://github.com/rust-lang/rust/pull/83608) does not need to handle a variable number of keys (although it might be nice) because, worse case scenario, you can sort your input array of indexes and iterate over it with successive split_at_mut.

In this case however, that's not an option, as HashMaps don't have the equivalent of split_at_mut. So it may be useful to be able to lookup a variable (somewhat large) number of distinct keys at the same time in a HashMap, and get mutable references back.

That is in particular needed when we have a storage of many entries, and we want to run a round of updates on distinct subsets of keys, and want to run these in parallel.

Having get_many_mut follow a similar model to what is described at https://github.com/rust-lang/rust/pull/83608#pullrequestreview-966536956 and providing an interface that allows for this would consequently be useful - in fact, that is something I currently would use.

Ten0 avatar May 09 '22 17:05 Ten0

I'd also really like to have this, I'm in a situation right now where I need to have an arbitrary number of mutable references and the way I'm doing it for now isn't great

Update: I've made some code that allows you to do this, I can share it if anyone needs it

What42Pizza avatar Feb 15 '23 23:02 What42Pizza

The problem with returning a variable number of mutable references is that it requires returning a Vec<&mut T> which requires a dynamic allocation. It might end up being faster to just keep the keys around and doing a hash table lookup every time the value is needed.

Amanieu avatar Feb 22 '23 14:02 Amanieu

It might end up being faster to just keep the keys around and doing a hash table lookup every time the value is needed.

It seems that this would be for the end user to decide, no? In particular if the follow-up is e.g. rayon parallel processing or that kind of stuff it quickly becomes worth it.

Ten0 avatar Feb 22 '23 18:02 Ten0

Keep in mind that get_many_mut has a complexity of O(n^2) where n is the number of keys to look up. This is because each returned value must be checked against every other returned value to ensure that none of them point to the same entry, which is necessary to satisfy Rust's safety guarantees (mutable references are unique). This can quickly add up when n becomes large.

Amanieu avatar Feb 23 '23 11:02 Amanieu

each returned value must be checked against every other returned value

That seems reasonable on the const size version because those are unlikely to ever be very large, but here indeed that would seem a bit too expensive, so probably flagging entry as currently being borrowed would fix this?

By the way, on the const size version, how do you deal with Eq implementations that are inconsistent? ("It's ok I'm equal to nobody - oh wait in fact I'm the same key" -> yields several mutable borrows to the same entry) It looks like the flag at the entry level would be the solution for this as well, no?

Ten0 avatar Feb 23 '23 12:02 Ten0

We compare the pointers to the entries that are actually returned by the lookups.

Amanieu avatar Feb 23 '23 12:02 Amanieu

Ah of course! ^^ Then I guess if n is too large just sorting the pointers and checking for no consecutive dupes works in O(n* log(n)).

The problems becomes similar to this one: https://github.com/rust-lang/rust/pull/83608#pullrequestreview-966536956

Ten0 avatar Feb 23 '23 13:02 Ten0

Ah of course! ^^ Then I guess if n is too large just sorting the pointers and checking for no consecutive dupes works in O(n* log(n)).

Unlikely, since you still need to take into account the sort itself

JustForFun88 avatar Feb 23 '23 13:02 JustForFun88

Currently the implementation is optimized for small n (the most common case is expected to be n = 2).

Amanieu avatar Feb 23 '23 13:02 Amanieu

Unlikely, since you still need to take into account the sort itself

The sort itself is O(n* log(n)) where n is the number of queried entries.

Currently the implementation is optimized for small n

For the const-size version it's pretty clear that we want to ensure that we are as optimized as possible for small slice sizes. It's also very practical to be able to pattern-match directly on the result (so we don't want to remove a const size version), and I would imagine as well that the most common case is 2.

However it looks like we could provide different APIs through the same function via a GetManyMut trait, as was considered for the std's version on slices.

Then we could have different versions based on the slice size:

let ptrs: Vec<*mut ??> = /* ... */;
if n < 10 { // 10 is the limit for insertion sort in std
    // do n^2
} else {
    let mut ptrs_sorted = ptrs.clone();
    ptrs_sorted.sort_unstable(); // n * log(n)
    assert!(ptrs_sorted.windows(2).all(|w| w[0] != w[1]); // n
}

having optimal perf for both cases. (That was actually also suggested for the std const version here)

With pointers, it looks like the problem becomes essentially exactly the same as they have in std with indexes. https://github.com/rust-lang/rust/pull/83608#pullrequestreview-966536956 https://github.com/rust-lang/rust/pull/83608#issuecomment-1207205415

Ten0 avatar Feb 23 '23 13:02 Ten0

Unlikely, since you still need to take into account the sort itself

The sort itself is O(n* log(n)) where n is the number of queried entries.

Well, as you can see from your own example, you first need to sort, and then go through the array again, looking for the equal pointers. In addition, you do not take into account that sorting will violate the order of the values. That is, the corresponding key will not point to the corresponding value.

JustForFun88 avatar Feb 23 '23 14:02 JustForFun88

Well, as you can see from your own example, you first need to sort, and then go through the array again, looking for the equal pointers

I'm not sure why that is an issue.

In addition, you do not take into account that sorting will violate the order of the values. That is, the corresponding key will not point to the corresponding value.

I do, that's what the ptrs.clone() is for. I'm returning ptrs at the end. If we want to have a single alloc we could also with_capacity(n*2) and duplicate the contents at the end of the vec, sort the subslice for the check, then truncate it.

Ten0 avatar Feb 23 '23 14:02 Ten0

I do, that's what the ptrs.clone() is for. I'm returning ptrs at the end. If we want to have a single alloc we could also with_capacity(n*2) and duplicate the contents at the end of the vec, sort the subslice for the check, then truncate it.

Hmm... it might work. It looks like the method you suggested will improve the existing implementation as well.

As for variable length. Why not consider using something like https://doc.rust-lang.org/std/array/struct.IntoIter.html. For example, create a separate search array and a separate output array. Then push only unique pointers into the output array and return it as IntoIter?

JustForFun88 avatar Feb 23 '23 19:02 JustForFun88

Why not consider...

I'm afraid I don't understand what you have in mind. Perhaps that would be easier with a code sample?

Ten0 avatar Feb 23 '23 19:02 Ten0

Why not consider...

I'm afraid I don't understand what you have in mind. Perhaps that would be easier with a code sample?

So I came up with the function below. The only thing is that maybe it can be used for some other version of get_many_key_value_mut, since it is not guaranteed that the length of the iterator is equal to the length of the array of keys and that all the keys are in the map.

use core::ops::Range;

impl<T, A: Allocator + Clone> RawTable<T, A> {
    pub fn get_many_mut_two<const N: usize>(
        &mut self,
        hashes: [u64; N],
        mut eq: impl FnMut(usize, &T) -> bool,
    ) -> IntoIter<&mut T, N> {
        let mut guard = Guard::<*mut T, N>::new();

        for (i, &hash) in hashes.iter().enumerate() {
            if let Some(cur) = self.find(hash, |k| eq(i, k)) {
                // SAFETY: hashes length is equal to `array` length
                unsafe { guard.push_unchecked(cur.as_ptr()) }
            }
        }

        let into_iter = guard.into_iter();
        let slice = into_iter.as_slice();
        let mut guard = Guard::<&mut T, N>::new();

        for (i, &cur) in slice.iter().enumerate() {
            if slice[i + 1..].iter().all(|&next| !ptr::eq::<T>(next, cur)) {
                // SAFETY:
                // 1. The `slice` length is equal or less than `out_array` length
                // 2. We have just verified that this is a unique pointer that does not repeat within the array.
                // 3. We got all the pointers from the `find` function, so they are valid.
                unsafe { guard.push_unchecked(&mut *cur) };
            }
        }

        guard.into_iter()
    }
}

struct Guard<T, const N: usize> {
    /// The array to be initialized.
    array_mut: [MaybeUninit<T>; N],
    /// The number of items that have been initialized so far.
    initialized: usize,
}

impl<T, const N: usize> Guard<T, N> {
    #[inline]
    fn new() -> Self {
        // SAFETY: The `assume_init` is safe because the type we are claiming to have
        // initialized here is a bunch of `MaybeUninit`s, which do not require initialization.
        Guard {
            array_mut: unsafe { MaybeUninit::uninit().assume_init() },
            initialized: 0,
        }
    }

    /// Adds an item to the array and updates the initialized item counter.
    ///
    /// # Safety
    ///
    /// No more than N elements must be initialized.
    #[inline(always)]
    unsafe fn push_unchecked(&mut self, item: T) {
        // SAFETY: If `initialized` was correct before and the caller does not
        // invoke this method more than N times then writes will be in-bounds
        // and slots will not be initialized more than once.
        self.array_mut
            .get_unchecked_mut(self.initialized)
            .write(item);
        self.initialized += 1;
    }

    #[inline]
    fn into_iter(self) -> IntoIter<T, N> {
        let initialized = 0..self.initialized;
        // SAFETY: We provide the number of elements that are guaranteed to be initialized
        unsafe { IntoIter::new_unchecked(self.array_mut, initialized) }
    }
}

/// A by-value [array] iterator.
pub struct IntoIter<T, const N: usize> {
    data: [MaybeUninit<T>; N],
    alive: Range<usize>,
}

impl<T, const N: usize> IntoIter<T, N> {
    /// Creates an iterator over the elements in a partially-initialized buffer.
    ///
    /// # Safety
    ///
    /// - The `buffer[initialized]` elements must all be initialized.
    /// - The range must be canonical, with `initialized.start <= initialized.end`.
    /// - The range must be in-bounds for the buffer, with `initialized.end <= N`.
    ///   (Like how indexing `[0][100..100]` fails despite the range being empty.)
    ///
    /// It's sound to have more elements initialized than mentioned, though that
    /// will most likely result in them being leaked.
    #[inline]
    const unsafe fn new_unchecked(buffer: [MaybeUninit<T>; N], initialized: Range<usize>) -> Self {
        // SAFETY: one of our safety conditions is that the range is canonical.
        Self {
            data: buffer,
            alive: initialized,
        }
    }

    #[inline]
    pub fn as_slice(&self) -> &[T] {
        unsafe {
            // SAFETY: We know that all elements within `alive` are properly initialized.
            let slice = self.data.get_unchecked(self.alive.clone());
            // SAFETY: casting `slice` to a `*const [T]` is safe since the `slice` is initialized,
            // and `MaybeUninit` is guaranteed to have the same layout as `T`.
            // The pointer obtained is valid since it refers to memory owned by `slice` which is a
            // reference and thus guaranteed to be valid for reads.
            &*(slice as *const [MaybeUninit<T>] as *const [T])
        }
    }
}

impl<T, const N: usize> Iterator for IntoIter<T, N> {
    type Item = T;

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        // Get the next index from the front.
        //
        // Increasing `alive.start` by 1 maintains the invariant regarding
        // `alive`. However, due to this change, for a short time, the alive
        // zone is not `data[alive]` anymore, but `data[idx..alive.end]`.
        self.alive.next().map(|idx| {
            // Read the element from the array.
            // SAFETY: `idx` is an index into the former "alive" region of the
            // array. Reading this element means that `data[idx]` is regarded as
            // dead now (i.e. do not touch). As `idx` was the start of the
            // alive-zone, the alive zone is now `data[alive]` again, restoring
            // all invariants.
            unsafe { self.data.get_unchecked(idx).assume_init_read() }
        })
    }

    #[inline]
    fn size_hint(&self) -> (usize, Option<usize>) {
        let len = self.len();
        (len, Some(len))
    }

    #[inline]
    fn fold<Acc, Fold>(mut self, init: Acc, mut fold: Fold) -> Acc
    where
        Fold: FnMut(Acc, Self::Item) -> Acc,
    {
        let data = &mut self.data;
        self.alive.by_ref().fold(init, |acc, idx| {
            // SAFETY: idx is obtained by folding over the `alive` range, which implies the
            // value is currently considered alive but as the range is being consumed each value
            // we read here will only be read once and then considered dead.
            fold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
        })
    }

    #[inline]
    fn count(self) -> usize {
        self.len()
    }

    #[inline]
    fn last(mut self) -> Option<Self::Item> {
        self.next_back()
    }
}

impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
    #[inline]
    fn next_back(&mut self) -> Option<Self::Item> {
        // Get the next index from the back.
        //
        // Decreasing `alive.end` by 1 maintains the invariant regarding
        // `alive`. However, due to this change, for a short time, the alive
        // zone is not `data[alive]` anymore, but `data[alive.start..=idx]`.
        self.alive.next_back().map(|idx| {
            // Read the element from the array.
            // SAFETY: `idx` is an index into the former "alive" region of the
            // array. Reading this element means that `data[idx]` is regarded as
            // dead now (i.e. do not touch). As `idx` was the end of the
            // alive-zone, the alive zone is now `data[alive]` again, restoring
            // all invariants.
            unsafe { self.data.get_unchecked(idx).assume_init_read() }
        })
    }

    #[inline]
    fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc
    where
        Fold: FnMut(Acc, Self::Item) -> Acc,
    {
        let data = &mut self.data;
        self.alive.by_ref().rfold(init, |acc, idx| {
            // SAFETY: idx is obtained by folding over the `alive` range, which implies the
            // value is currently considered alive but as the range is being consumed each value
            // we read here will only be read once and then considered dead.
            rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
        })
    }
}

impl<T, const N: usize> ExactSizeIterator for IntoIter<T, N> {
    #[inline]
    fn len(&self) -> usize {
        self.alive.len()
    }
}

impl<T, const N: usize> FusedIterator for IntoIter<T, N> {}

JustForFun88 avatar Feb 24 '23 09:02 JustForFun88

get_many_mut on slices (rust-lang/rust#83608) does not need to handle a variable number of keys (although it might be nice) because, worse case scenario, you can sort your input array of indexes and iterate over it with successive split_at_mut.

Actually I figure out how it can be done. I try to implement it in next two days.

JustForFun88 avatar Feb 24 '23 11:02 JustForFun88

@Amanieu, @Ten0 What if we provide the following API (draft pull request #408):

pub fn try_get_many_key_value_mut<'a, Q, I, const N: usize>(
    &mut self,
    iter: &mut I,
) -> ArrayIter<(&K, &mut V), N>
where
    I: Iterator<Item = &'a Q>,
    Q: ?Sized + Hash + Equivalent<K> + 'a,
{
    /* implementation */
}

pub unsafe fn try_get_many_key_value_unchecked_mut<'a, Q, I, const N: usize>(
    &mut self,
    iter: &mut I,
) -> ArrayIter<(&K, &mut V), N>
where
    I: Iterator<Item = &'a Q>,
    Q: ?Sized + Hash + Equivalent<K> + 'a,
{
    /* implementation */
}

Where ArrayIter practically repeats the implementation of IntoIter from https://doc.rust-lang.org/std/array/struct.IntoIter.html. Unfortunately, using IntoIter directly does not work due to the instability of new_unchecked. The ArrayIter itself is not only an iterator, it has a few more methods:

impl<T, const N: usize> ArrayIter<T, N> {
    /// Returns an immutable slice of all elements that have not been yielded yet.
    #[inline]
    pub fn as_slice(&self) -> &[T] {
         /* implementation */
    }

    /// Returns a mutable slice of all elements that have not been yielded yet.
    #[inline]
    pub fn as_mut_slice(&mut self) -> &mut [T] {
         /* implementation */
    }

    /// Returns an [`ArrayIter`] of the same size as `self`, with function `f`
    /// applied to each element in order.
    pub fn convert<F, U>(self, mut f: F) -> ArrayIter<U, N>
    where
        F: FnMut(T) -> U,
    {
         /* implementation */
    }
}

JustForFun88 avatar Feb 25 '23 09:02 JustForFun88

While your implementation is very clever, I don't think this is an API I actually want in hashbrown. It just ends up being very hard to use correctly (due to the interactions between the N constant and the iterator size, for example).

As I've said before, you really should just do separate lookups when you actually need the values instead of keeping an arbitrary number of mutable references around.

Amanieu avatar Mar 01 '23 01:03 Amanieu