Created
August 22, 2023 14:10
-
-
Save ericwen229/bdcd1c7ac93cd10bf1296913f95d8b2f to your computer and use it in GitHub Desktop.
Distributed kth selection algorithm implemented in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use rand::Rng; | |
| use std::fmt::{Display, Formatter, Result}; | |
| pub struct Worker<T> { | |
| arr: Vec<T>, | |
| begin: usize, | |
| end: usize, | |
| } | |
| impl<T> Display for Worker<T> | |
| where T: Display { | |
| fn fmt(&self, f: &mut Formatter<'_>) -> Result { | |
| write!(f, "[")?; | |
| for i in self.begin..self.end { | |
| if i != self.begin { | |
| write!(f, ", ")?; | |
| } | |
| write!(f, "{}", self.arr[i])?; | |
| } | |
| write!(f, "]") | |
| } | |
| } | |
| impl<T> Worker<T> | |
| where T: Copy + Default + Ord + Display { | |
| pub fn new(arr: Vec<T>) -> Worker<T> { | |
| let n = arr.len(); | |
| Worker { | |
| arr, | |
| begin: 0, | |
| end: n, | |
| } | |
| } | |
| fn empty(&self) -> bool { | |
| return self.begin >= self.end | |
| } | |
| fn rand(&self) -> Option<T> { | |
| if self.empty() { | |
| None | |
| } else { | |
| Some(self.arr[rand::thread_rng().gen_range(self.begin..self.end)]) | |
| } | |
| } | |
| fn set_interval(&mut self, begin: usize, end: usize) { | |
| (self.begin, self.end) = (begin, end); | |
| } | |
| pub fn reset_interval(&mut self) { | |
| self.begin = 0; | |
| self.end = self.arr.len(); | |
| } | |
| fn swap(&mut self, i: usize, j: usize) { | |
| if i != j { | |
| (self.arr[i], self.arr[j]) = (self.arr[j], self.arr[i]); | |
| } | |
| } | |
| fn split(&mut self, pivot: T) -> ((usize, usize), (usize, usize)) { | |
| if self.empty() { | |
| return ((self.begin, self.end), (self.begin, self.end)) | |
| } | |
| // le&eq&i | |
| // [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] | |
| // le eq i | |
| // [-] [-] [=] [+] [+] [ ] [ ] [ ] | |
| // le eq i | |
| // [-] [-] [-] [=] [=] [=] [+] [+] | |
| let mut le_end = self.begin; | |
| let mut eq_end: usize = self.begin; | |
| for i in self.begin..self.end { | |
| if self.arr[i] < pivot { | |
| self.swap(i, eq_end); | |
| self.swap(eq_end, le_end); | |
| eq_end += 1; | |
| le_end += 1; | |
| } else if self.arr[i] == pivot { | |
| self.swap(i, eq_end); | |
| eq_end += 1; | |
| } | |
| } | |
| ((self.begin, le_end), (eq_end, self.end)) | |
| } | |
| pub fn find_kth(workers: &mut Vec<Worker<T>>, k: usize) -> Option<T> { | |
| let mut k = k; | |
| loop { | |
| println!("k: {k}"); | |
| println!("workers:"); | |
| workers.iter().for_each(|w| println!("{w}")); | |
| let pivots: Vec<_> = workers | |
| .iter() | |
| .map(|w| w.rand()) | |
| .filter(|v| v.is_some()) | |
| .map(|v| v.unwrap()) | |
| .collect(); | |
| if pivots.is_empty() { | |
| return None; | |
| } | |
| let pivot = pivots[rand::thread_rng().gen_range(0..pivots.len())]; | |
| println!("pivot: {pivot}"); | |
| let splits: Vec<_> = workers | |
| .iter_mut() | |
| .map(|w| w.split(pivot)) | |
| .collect(); | |
| let less_count: usize = splits | |
| .iter() | |
| .map(|s| s.0.1 - s.0.0) | |
| .sum(); | |
| let equal_count: usize = splits | |
| .iter() | |
| .map(|s| s.1.0 - s.0.1) | |
| .sum(); | |
| if k < less_count { | |
| workers | |
| .iter_mut() | |
| .zip(splits.iter()) | |
| .for_each(|(w, ((begin, end), _))| w.set_interval(*begin, *end)) | |
| } else if k < less_count + equal_count { | |
| return Some(pivot) | |
| } else { | |
| k -= less_count + equal_count; | |
| workers | |
| .iter_mut() | |
| .zip(splits.iter()) | |
| .for_each(|(w, (_, (begin, end)))| w.set_interval(*begin, *end)) | |
| } | |
| } | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use rand::{thread_rng, Rng}; | |
| use crate::Worker; | |
| #[test] | |
| fn test_split() { | |
| let mut worker = Worker::new(vec![5, 4, 3, 2, 1]); | |
| // 5 4 3 2 1 | |
| // 3 4 5 2 1 | |
| // 2 3 5 4 1 | |
| // 2 1 3 4 5 | |
| assert_eq!(worker.split(3), ((0, 2), (3, 5))); | |
| assert_eq!(worker.arr, vec![2, 1, 3, 4, 5]); | |
| } | |
| #[test] | |
| fn test_single_worker() { | |
| let mut workers = vec![Worker::new(vec![1, 2, 3])]; | |
| assert_eq!(Worker::find_kth(&mut workers, 1), Some(2)); | |
| } | |
| #[test] | |
| fn test_multi_worker() { | |
| let k = 49; | |
| let mut vecs: Vec<Option<Vec<i32>>> = vec![Some(vec![]); 10]; | |
| for i in 1..=100 { | |
| vecs[thread_rng().gen_range(0..10)].as_mut().unwrap().push(i); | |
| } | |
| let mut workers: Vec<_> = vecs.iter_mut().map(|v| Worker::new(v.take().unwrap())).collect(); | |
| assert_eq!(Worker::find_kth(&mut workers, k), Some(50)); | |
| } | |
| #[test] | |
| fn test_all_same() { | |
| let mut workers = vec![Worker::new(vec![1, 1, 1, 1, 1])]; | |
| assert_eq!(Worker::find_kth(&mut workers, 2), Some(1)); | |
| } | |
| #[test] | |
| fn test_clusters() { | |
| let mut vecs: Vec<Option<Vec<i32>>> = vec![Some(vec![]); 10]; | |
| for i in 1..=10 { | |
| for _ in 1..=10 { | |
| vecs[thread_rng().gen_range(0..10)].as_mut().unwrap().push(i); | |
| } | |
| } | |
| let mut workers: Vec<_> = vecs.iter_mut().map(|v| Worker::new(v.take().unwrap())).collect(); | |
| assert_eq!(Worker::find_kth(&mut workers, 49), Some(5)); | |
| workers.iter_mut().for_each(|w| w.reset_interval()); | |
| assert_eq!(Worker::find_kth(&mut workers, 50), Some(6)); | |
| } | |
| #[test] | |
| fn test_singletons() { | |
| let mut vecs: Vec<Option<Vec<i32>>> = vec![ | |
| Some(vec![1, 1, 1, 1, 1]), | |
| Some(vec![1, 1, 1, 1, 1]), | |
| Some(vec![1, 1, 1, 1, 1]), | |
| Some(vec![1, 1, 1, 1, 1]), | |
| Some(vec![1, 1, 1, 1, 1]), | |
| ]; | |
| let mut workers: Vec<_> = vecs.iter_mut().map(|v| Worker::new(v.take().unwrap())).collect(); | |
| assert_eq!(Worker::find_kth(&mut workers, 12), Some(1)); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment