Last active
August 22, 2021 09:24
-
-
Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.
Revisions
-
teryror revised this gist
Aug 22, 2021 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -166,6 +166,8 @@ impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone } } // TODO(macro_metavar_expr): The expansion of this macro will contain the weight array // twice to automatically determine its length, which is literally redundant work. macro_rules! population_table { ($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => { $v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new( -
teryror revised this gist
Aug 22, 2021 . 1 changed file with 11 additions and 11 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,7 +5,7 @@ /// finite probability distribution in O(1) time, by first simulating a fair /// n-sided die, followed by a biased coin. /// /// Because floating point arithmetic cannot be used in const functions, this is /// built to operate on integer weights, rather than precomputed probabilities. /// /// Where the standard Alias Method scales the probabilities by a factor of n @@ -64,9 +64,8 @@ impl<const N: usize> AliasTable<N> { let mut total_weight = 0; let mut i = 0; while i < N { // TODO(const_panic): assert_ne!(weights[i], 0, "Weight at position {} is zero!", i); let _ = 1 / weights[i]; total_weight += weights[i]; i += 1; @@ -76,7 +75,7 @@ impl<const N: usize> AliasTable<N> { let weight_factor = rescaled_total / total_weight; let mut i = 0; while i < N { weights[i] *= weight_factor; i += 1; } @@ -167,11 +166,9 @@ impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone } } macro_rules! population_table { ($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => { $v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new( [$($item),*], [$($weight),*] ); } @@ -181,11 +178,11 @@ population_table! { NAME_TABLE: &'static str = [ 2 => "Alice", 1 => "Bob", 3 => "Charlie", ] } pub fn main() { assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]); assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]); @@ -194,7 +191,10 @@ fn main() { println!("Hello, {}!", name); } #[cfg(test)] mod test { use super::*; #[test] fn greatest_common_divisor() { assert_eq!(gcd(2, 4), 2); -
teryror revised this gist
Aug 22, 2021 . 1 changed file with 48 additions and 16 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -26,9 +26,9 @@ use rand::distributions::Distribution; const fn gcd(mut a: u32, mut b: u32) -> u32 { while b != 0 { let t = b; b = a % b; a = t; } a @@ -38,14 +38,13 @@ const fn lcm(a: u32, b: u32) -> u32 { (a * b) / gcd(a, b) } pub struct AliasTable<const N: usize> { prob: [u32; N], alias: [usize; N], } impl<const N: usize> AliasTable<N> { pub const fn new(mut weights: [u32; N]) -> Self { let mut prob = [0; N]; let mut alias = [0; N]; @@ -134,28 +133,61 @@ impl<T, const N: usize> PopulationTable<T, N> { alias[l] = l; } AliasTable { prob, alias } } } impl<const N: usize> Distribution<usize> for AliasTable<N> { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { let i = rng.gen_range(0..N); let x = rng.gen::<u32>(); if x < self.prob[i] { i } else { self.alias[i] } } } pub struct PopulationTable<T, const N: usize> { items: [T; N], distr: AliasTable<N>, } impl<T, const N: usize> PopulationTable<T, N> { pub const fn new(items: [T; N], weights: [u32; N]) -> Self { PopulationTable { items, distr: AliasTable::new(weights) } } } impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { let idx = rng.sample(&self.distr); self.items[idx].clone() } } // TODO(macro_metavar_expr): The expansion of this macro will contain the weight array // twice to automatically determine its length, which is literally redundant work. macro_rules! population_table { ($name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ ] ) => { const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new( [$($item),*], [$($weight),*] ); } } population_table! { NAME_TABLE: &'static str = [ 2 => "Alice", 1 => "Bob", 3 => "Charlie" ] } fn main() { assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]); assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]); let mut rng = thread_rng(); let name = rng.sample(&NAME_TABLE); -
teryror created this gist
Aug 21, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,179 @@ /// Const evaluatable Rust implementation of Vose's Alias Method, as described /// by Keith Schwarz at https://www.keithschwarz.com/darts-dice-coins/ /// /// In brief, this is an O(n) precomputation, which allows sampling an arbitrary /// finite probability distribution in O(1) time, by first simulating a fair /// n-sided die, followed by a biased coin. /// /// Because floating point arithmetic cannot be used in const context, this is /// built to operate on integer weights, rather than precomputed probabilities. /// /// Where the standard Alias Method scales the probabilities by a factor of n /// and uses 1 as a cutoff to partition them into large and small probabilites, /// this finds the least common multiple of n and the total weight, scales up /// the weights to match it, and uses the LCM divided by n as the threshold. /// /// Unlike the original method, this approach is perfectly exact and numerically /// stable; I only switch to fixed point arithmetic for the final probability /// calculation, which introduces negligible rounding errors. /// /// The implementation could be made much more elegant as more language features /// become available in const fns, most notably for loops, panics, and arguments /// of mutable reference types. use rand::{Rng, thread_rng}; use rand::distributions::Distribution; const fn gcd(mut a: u32, mut b: u32) -> u32 { while b != 0 { let tmp = b; b = a % b; a = tmp; } a } const fn lcm(a: u32, b: u32) -> u32 { (a * b) / gcd(a, b) } pub struct PopulationTable<T, const N: usize> { items: [T; N], prob: [u32; N], alias: [usize; N], } impl<T, const N: usize> PopulationTable<T, N> { pub const fn new(items: [T; N], mut weights: [u32; N]) -> Self { let mut prob = [0; N]; let mut alias = [0; N]; // Vec and similar data structures cannot be used in const fns, because // only other const fns may be called, which may not take &mut arguments. // So we have to use an ad-hoc, inline implementation for the work lists. // // These could have capacity N - 1, except the current state of const // generics doesn't allow that. let mut small = [0; N]; let mut small_count = 0; let mut large = [0; N]; let mut large_count = 0; let mut total_weight = 0; let mut i = 0; while i < N { // TODO(const_panic): assert_ne!(weights[i], 0); // Such an assertion is not strictly necessary, but keeping items // with probability 0 unnecessarily wastes memory. total_weight += weights[i]; i += 1; } let rescaled_total = lcm(total_weight, N as u32); let weight_factor = rescaled_total / total_weight; let mut i = 0; while i < N { weights[i] = weights[i] * weight_factor; i += 1; } let weight_threshold = rescaled_total / (N as u32); let mut i = 0; while i < N { if weights[i] < weight_threshold { small[small_count] = i; small_count += 1; } else { large[large_count] = i; large_count += 1; } i += 1; } while small_count > 0 && large_count > 0 { small_count -= 1; let l = small[small_count]; large_count -= 1; let g = large[large_count]; prob[l] = (((weights[l] as u64) << 32) / (weight_threshold as u64)) as u32; alias[l] = g; weights[g] -= weight_threshold - weights[l]; if weights[g] < weight_threshold { small[small_count] = g; small_count += 1; } else { large[large_count] = g; large_count += 1; } } while large_count > 0 { large_count -= 1; let g = large[large_count]; prob[g] = u32::MAX; alias[g] = g; } // TODO(const_panic): assert_eq!(small_count, 0); // This should only be possible with floating point arithmetic // due to numerical instability; we should be fine without this loop: while small_count > 0 { small_count -= 1; let l = small[small_count]; prob[l] = u32::MAX; alias[l] = l; } PopulationTable { items, prob, alias } } } impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { let i = rng.gen_range(0..N); let x = rng.gen::<u32>(); if self.prob[i] < x { self.items[i].clone() } else { let a = self.alias[i]; self.items[a].clone() } } } const NAME_TABLE: PopulationTable<&'static str, 3> = PopulationTable::new(["Alice", "Bob", "Charlie"], [2, 1, 3]); fn main() { assert_eq!(NAME_TABLE.alias, [0, 2, 2]); assert_eq!(NAME_TABLE.prob, [u32::MAX, 1 << 31, u32::MAX]); let mut rng = thread_rng(); let name = rng.sample(&NAME_TABLE); println!("Hello, {}!", name); } mod test { #[test] fn greatest_common_divisor() { assert_eq!(gcd(2, 4), 2); assert_eq!(gcd(2, 5), 1); assert_eq!(gcd(252, 105), 21); } #[test] fn least_common_multiple() { assert_eq!(lcm(2, 4), 4); assert_eq!(lcm(2, 5), 10); assert_eq!(lcm(18, 12), 36); } }