zk-kit
zk-kit copied to clipboard
Build a Rust implementation of the IMTs
@cedoor can you schedule this PR for review
@cedoor can you schedule this PR for review
Hey @Arch0125, sure. Thanks a lot for working on this!! I didn't expect someone to implement it so quickly :D
Will be coming up with lean-imt implementation as well :)
use std::collections::BinaryHeap; use std::cmp::Ordering;
#[derive(PartialEq, Eq)] struct ReverseOrd(i32);
impl PartialOrd for ReverseOrd { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(other.0.cmp(&self.0)) } }
impl Ord for ReverseOrd { fn cmp(&self, other: &Self) -> Ordering { other.0.cmp(&self.0) } }
struct IncrementalMedianTracker {
lower_half: BinaryHeap<ReverseOrd>,
upper_half: BinaryHeap
impl IncrementalMedianTracker { fn new() -> Self { IncrementalMedianTracker { lower_half: BinaryHeap::new(), upper_half: BinaryHeap::new(), } }
fn add(&mut self, num: i32) {
if self.lower_half.is_empty() || num <= self.median().unwrap() {
self.lower_half.push(ReverseOrd(num));
} else {
self.upper_half.push(num);
}
if self.lower_half.len() > self.upper_half.len() + 1 {
if let Some(ReverseOrd(max_lower)) = self.lower_half.pop() {
self.upper_half.push(max_lower);
}
} else if self.upper_half.len() > self.lower_half.len() {
if let Some(min_upper) = self.upper_half.pop() {
self.lower_half.push(ReverseOrd(min_upper));
}
}
}
fn median(&self) -> Option<f64> {
if self.lower_half.is_empty() && self.upper_half.is_empty() {
return None;
}
if self.lower_half.len() > self.upper_half.len() {
Some(self.lower_half.peek().unwrap().0 as f64)
} else if self.lower_half.len() < self.upper_half.len() {
Some(*self.upper_half.peek().unwrap() as f64)
} else {
Some((self.lower_half.peek().unwrap().0 as f64 + *self.upper_half.peek().unwrap() as f64) / 2.0)
}
}
}
#[cfg(test)] mod tests { use super::*;
#[test]
fn test_imt() {
let mut imt = IncrementalMedianTracker::new();
imt.add(1);
imt.add(3);
imt.add(2);
assert_eq!(imt.median().unwrap(), 2.0);
imt.add(5);
assert_eq!(imt.median().unwrap(), 2.5);
imt.add(4);
assert_eq!(imt.median().unwrap(), 3.0);
imt.add(6);
assert_eq!(imt.median().unwrap(), 3.5);
}
}