zk-kit icon indicating copy to clipboard operation
zk-kit copied to clipboard

Build a Rust implementation of the IMTs

Open cedoor opened this issue 1 year ago • 4 comments

cedoor avatar Jan 26 '24 15:01 cedoor

@cedoor can you schedule this PR for review

Arch0125 avatar Jan 30 '24 10:01 Arch0125

@cedoor can you schedule this PR for review

Hey @Arch0125, sure. Thanks a lot for working on this!! I didn't expect someone to implement it so quickly :D

cedoor avatar Jan 30 '24 10:01 cedoor

Will be coming up with lean-imt implementation as well :)

Arch0125 avatar Jan 30 '24 10:01 Arch0125

use std::collections::BinaryHeap; use std::cmp::Ordering;

#[derive(PartialEq, Eq)] struct ReverseOrd(i32);

impl PartialOrd for ReverseOrd { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(other.0.cmp(&self.0)) } }

impl Ord for ReverseOrd { fn cmp(&self, other: &Self) -> Ordering { other.0.cmp(&self.0) } }

struct IncrementalMedianTracker { lower_half: BinaryHeap<ReverseOrd>, upper_half: BinaryHeap, }

impl IncrementalMedianTracker { fn new() -> Self { IncrementalMedianTracker { lower_half: BinaryHeap::new(), upper_half: BinaryHeap::new(), } }

fn add(&mut self, num: i32) {
    if self.lower_half.is_empty() || num <= self.median().unwrap() {
        self.lower_half.push(ReverseOrd(num));
    } else {
        self.upper_half.push(num);
    }

    if self.lower_half.len() > self.upper_half.len() + 1 {
        if let Some(ReverseOrd(max_lower)) = self.lower_half.pop() {
            self.upper_half.push(max_lower);
        }
    } else if self.upper_half.len() > self.lower_half.len() {
        if let Some(min_upper) = self.upper_half.pop() {
            self.lower_half.push(ReverseOrd(min_upper));
        }
    }
}

fn median(&self) -> Option<f64> {
    if self.lower_half.is_empty() && self.upper_half.is_empty() {
        return None;
    }

    if self.lower_half.len() > self.upper_half.len() {
        Some(self.lower_half.peek().unwrap().0 as f64)
    } else if self.lower_half.len() < self.upper_half.len() {
        Some(*self.upper_half.peek().unwrap() as f64)
    } else {
        Some((self.lower_half.peek().unwrap().0 as f64 + *self.upper_half.peek().unwrap() as f64) / 2.0)
    }
}

}

#[cfg(test)] mod tests { use super::*;

#[test]
fn test_imt() {
    let mut imt = IncrementalMedianTracker::new();
    imt.add(1);
    imt.add(3);
    imt.add(2);
    assert_eq!(imt.median().unwrap(), 2.0);
    
    imt.add(5);
    assert_eq!(imt.median().unwrap(), 2.5);
    
    imt.add(4);
    assert_eq!(imt.median().unwrap(), 3.0);
    
    imt.add(6);
    assert_eq!(imt.median().unwrap(), 3.5);
}

}

Nagu40 avatar Feb 08 '24 19:02 Nagu40