strum icon indicating copy to clipboard operation
strum copied to clipboard

Implement strum(flatten) for EnumIter

Open juliancoffee opened this issue 9 months ago • 3 comments

Fixes #424

As I said, it's possible if slightly complex. I'm not an expert in writing iterators, though, so maybe it's possible to cut some rough edges; I just tried to make it correct. I tried to produce a slim diff, but DoubleEndedIterator implementation went into pieces.

Also, you can see in tests, Color::simple_iter() gives much simpler implementation, but maybe a bit slower to run and/or compile? I didn't bench it.

UPD: I think I know how to simplify this a little (without going through implementation I pointed out above), so if you're interested I'll try to refactor it a bit

juliancoffee avatar Mar 25 '25 22:03 juliancoffee

This is the manual implementation of what these macros generate. It has some dbg!() here and there, which I used while developing the algorithm. Of course none are present of them in MR code :P

#[derive(Debug, Eq, PartialEq)]
enum Vibe {
    Weak,
    Average,
    Strong,
}

impl Vibe {
    fn iter() -> <Self as IntoIterator>::IntoIter {
        let vibe = Vibe::Weak;
        vibe.into_iter()
    }
}

impl IntoIterator for Vibe {
    type Item = Vibe;
    type IntoIter = std::vec::IntoIter<Vibe>;
    fn into_iter(self) -> Self::IntoIter {
        vec![Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
    }
}

const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
    Light,
    Med1(Vibe),
    Med2(Vibe),
    Med3(Vibe),
    Dark,
}

impl Shade {
    fn iter() -> ShadeIter {
        ShadeIter {
            idx: 0,
            med1_iter: Some(Vibe::iter()),
            med2_iter: Some(Vibe::iter()),
            med3_iter: Some(Vibe::iter()),
            back_idx: 0,
        }
    }
}

impl Shade {
    fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
        vec![Shade::Light]
            .into_iter()
            .chain(Vibe::iter().map(Shade::Med1))
            .chain(Vibe::iter().map(Shade::Med2))
            .chain(Vibe::iter().map(Shade::Med3))
            .chain(vec![Shade::Dark])
    }
}

struct ShadeIter {
    idx: usize,
    med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    back_idx: usize,
}

#[derive(Debug)]
enum Res {
    Done(Shade),
    DoneStep(Shade),
    EndStep,
    End,
}

impl ShadeIter {
    fn nested_get(
        nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
        wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
        forward: bool,
    ) -> Res {
        let next_inner = if forward {
            nested_iter.as_mut().and_then(|t| t.next())
        } else {
            nested_iter.as_mut().and_then(|t| t.next_back())
        };
        if let Some(it) = next_inner {
            Res::DoneStep(wrap(it))
        } else {
            nested_iter.take();
            Res::EndStep
        }
    }

    fn get(&mut self, idx: usize, forward: bool) -> Res {
        let res = match dbg!(idx) {
            0 => Res::Done(Shade::Light),
            1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
            2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
            3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
            4 => Res::Done(Shade::Dark),
            _ => Res::End,
        };
        dbg!(res)
    }
}

impl Iterator for ShadeIter {
    type Item = Shade;

    fn next(&mut self) -> Option<Self::Item> {
        self.nth(0)
    }

    fn nth(&mut self, n: usize) -> Option<Self::Item> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }
        match ShadeIter::get(self, dbg!(self.idx) + dbg!(n), true) {
            Res::Done(x) => {
                // move to requested, and past it
                self.idx += n + 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                self.idx += n;
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, move past it and request again
                self.idx += 1;
                let res = self.nth(0);
                res
            }
            Res::End => None,
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        /*
        let min = if self.idx + self.back_idx >= SHADE_NUM {
            0
        } else {
            SHADE_NUM - self.idx - self.back_idx
        };
        */

        let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let t = SHADE_NUM
            + dbg!(med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
            + dbg!(med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
            + dbg!(med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
            - dbg!(self.idx)
            - dbg!(self.back_idx);

        (t, Some(t))
    }
}

impl ShadeIter {
    fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }

        let res = match ShadeIter::get(
            self,
            SHADE_NUM - dbg!(self.back_idx) - back_n - 1,
            false,
        ) {
            Res::Done(x) => {
                // move to requested, and past it
                self.back_idx += 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, try the next one
                self.back_idx += 1;
                self.nth_back(0)
            }
            Res::End => None,
        };
        res
    }
}

impl DoubleEndedIterator for ShadeIter {
    fn next_back(&mut self) -> Option<Self::Item> {
        self.nth_back(0)
    }
}

impl ExactSizeIterator for ShadeIter {
    fn len(&self) -> usize {
        self.size_hint().0
    }
}

fn main() {
    println!("Hello, world!");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn flatten() {
        let result = Shade::iter().collect::<Vec<_>>();
        let expected = vec![
            Shade::Light,
            Shade::Med1(Vibe::Weak),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Strong),
            Shade::Med2(Vibe::Weak),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Strong),
            Shade::Med3(Vibe::Weak),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Strong),
            Shade::Dark,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn flatten_back() {
        let result = Shade::iter().rev().collect::<Vec<_>>();
        let expected = vec![
            Shade::Dark,
            Shade::Med3(Vibe::Strong),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Weak),
            Shade::Med2(Vibe::Strong),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Weak),
            Shade::Med1(Vibe::Strong),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Weak),
            Shade::Light,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn iter_mixed_next_and_next_back() {
        let mut iter = Shade::iter();

        assert_eq!(iter.next(), Some(Shade::Light));
        assert_eq!(iter.next_back(), Some(Shade::Dark));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
        assert_eq!(iter.next_back(), None);
    }

    #[test]
    fn iter_quickheck() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            let mut results = vec![];
            let mut expected = vec![];
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    results.push(iter.next());
                    expected.push(simple_iter.next());
                } else {
                    results.push(iter.next_back());
                    expected.push(simple_iter.next_back());
                }
            }
            assert_eq!(results, expected);
        }
    }

    #[test]
    fn iter_quickheck_sizehint() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                    _ = simple_iter.next();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                    _ = simple_iter.next_back();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                }
            }
        }
    }

    #[test]
    fn iter_quickheck_len() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            const MAX: usize = 11;

            assert_eq!(dbg!(iter.len()), MAX);
            for i in 1..=MAX {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                }
                assert_eq!(dbg!(iter.len()), MAX - i);
            }
        }
    }
}

Open to your comments 🙌

juliancoffee avatar Mar 25 '25 22:03 juliancoffee

@vic1707 Oh, don't worry about that. Both Vibe::iter() and Shade::simple_iter() are functions I added for prototyping.

Vibe::iter() was added because I needed a nested iterator, and yeah, I didn't care much about its implementation, because it wouldn't be present in "real" code. In a real-world use case, Vibe would get the same implementation as Shade, if you place derive(EnumIter) on both.

Shade::simple_iter() is there so that I have something to compare results to without writing too many tests, so it wouldn't be present in generated code as well.

Thanks for noting that, though. I guess the drawback of simple_iter() approach is that it would be harder to get working with #[no_std], yet as you mentioned, it could work if you just replace vec with an array.

juliancoffee avatar Mar 31 '25 12:03 juliancoffee

Sorry I for that misunderstanding on my part, good job, can't wait to see it land if the devs are ok :+1:

vic1707 avatar Mar 31 '25 12:03 vic1707