Skip to content

Instantly share code, notes, and snippets.

@teryror
Last active August 22, 2021 09:24
Show Gist options
  • Select an option

  • Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.

Select an option

Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.

Revisions

  1. teryror revised this gist Aug 22, 2021. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions constexpr_alias_method.rs
    Original file line number Diff line number Diff line change
    @@ -166,6 +166,8 @@ impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone
    }
    }

    // TODO(macro_metavar_expr): The expansion of this macro will contain the weight array
    // twice to automatically determine its length, which is literally redundant work.
    macro_rules! population_table {
    ($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => {
    $v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new(
  2. teryror revised this gist Aug 22, 2021. 1 changed file with 11 additions and 11 deletions.
    22 changes: 11 additions & 11 deletions constexpr_alias_method.rs
    Original file line number Diff line number Diff line change
    @@ -5,7 +5,7 @@
    /// finite probability distribution in O(1) time, by first simulating a fair
    /// n-sided die, followed by a biased coin.
    ///
    /// Because floating point arithmetic cannot be used in const context, this is
    /// Because floating point arithmetic cannot be used in const functions, this is
    /// built to operate on integer weights, rather than precomputed probabilities.
    ///
    /// Where the standard Alias Method scales the probabilities by a factor of n
    @@ -64,9 +64,8 @@ impl<const N: usize> AliasTable<N> {
    let mut total_weight = 0;
    let mut i = 0;
    while i < N {
    // TODO(const_panic): assert_ne!(weights[i], 0);
    // Such an assertion is not strictly necessary, but keeping items
    // with probability 0 unnecessarily wastes memory.
    // TODO(const_panic): assert_ne!(weights[i], 0, "Weight at position {} is zero!", i);
    let _ = 1 / weights[i];

    total_weight += weights[i];
    i += 1;
    @@ -76,7 +75,7 @@ impl<const N: usize> AliasTable<N> {
    let weight_factor = rescaled_total / total_weight;
    let mut i = 0;
    while i < N {
    weights[i] = weights[i] * weight_factor;
    weights[i] *= weight_factor;
    i += 1;
    }

    @@ -167,11 +166,9 @@ impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone
    }
    }

    // TODO(macro_metavar_expr): The expansion of this macro will contain the weight array
    // twice to automatically determine its length, which is literally redundant work.
    macro_rules! population_table {
    ($name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ ] ) => {
    const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new(
    ($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => {
    $v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new(
    [$($item),*], [$($weight),*]
    );
    }
    @@ -181,11 +178,11 @@ population_table! {
    NAME_TABLE: &'static str = [
    2 => "Alice",
    1 => "Bob",
    3 => "Charlie"
    3 => "Charlie",
    ]
    }

    fn main() {
    pub fn main() {
    assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]);
    assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]);

    @@ -194,7 +191,10 @@ fn main() {
    println!("Hello, {}!", name);
    }

    #[cfg(test)]
    mod test {
    use super::*;

    #[test]
    fn greatest_common_divisor() {
    assert_eq!(gcd(2, 4), 2);
  3. teryror revised this gist Aug 22, 2021. 1 changed file with 48 additions and 16 deletions.
    64 changes: 48 additions & 16 deletions constexpr_alias_method.rs
    Original file line number Diff line number Diff line change
    @@ -26,9 +26,9 @@ use rand::distributions::Distribution;

    const fn gcd(mut a: u32, mut b: u32) -> u32 {
    while b != 0 {
    let tmp = b;
    let t = b;
    b = a % b;
    a = tmp;
    a = t;
    }

    a
    @@ -38,14 +38,13 @@ const fn lcm(a: u32, b: u32) -> u32 {
    (a * b) / gcd(a, b)
    }

    pub struct PopulationTable<T, const N: usize> {
    items: [T; N],
    pub struct AliasTable<const N: usize> {
    prob: [u32; N],
    alias: [usize; N],
    }

    impl<T, const N: usize> PopulationTable<T, N> {
    pub const fn new(items: [T; N], mut weights: [u32; N]) -> Self {
    impl<const N: usize> AliasTable<N> {
    pub const fn new(mut weights: [u32; N]) -> Self {
    let mut prob = [0; N];
    let mut alias = [0; N];

    @@ -134,28 +133,61 @@ impl<T, const N: usize> PopulationTable<T, N> {
    alias[l] = l;
    }

    PopulationTable { items, prob, alias }
    AliasTable { prob, alias }
    }
    }

    impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
    impl<const N: usize> Distribution<usize> for AliasTable<N> {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
    let i = rng.gen_range(0..N);
    let x = rng.gen::<u32>();
    if self.prob[i] < x {
    self.items[i].clone()
    if x < self.prob[i] {
    i
    } else {
    let a = self.alias[i];
    self.items[a].clone()
    self.alias[i]
    }
    }
    }

    const NAME_TABLE: PopulationTable<&'static str, 3> = PopulationTable::new(["Alice", "Bob", "Charlie"], [2, 1, 3]);
    pub struct PopulationTable<T, const N: usize> {
    items: [T; N],
    distr: AliasTable<N>,
    }

    impl<T, const N: usize> PopulationTable<T, N> {
    pub const fn new(items: [T; N], weights: [u32; N]) -> Self {
    PopulationTable { items, distr: AliasTable::new(weights) }
    }
    }

    impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
    let idx = rng.sample(&self.distr);
    self.items[idx].clone()
    }
    }

    // TODO(macro_metavar_expr): The expansion of this macro will contain the weight array
    // twice to automatically determine its length, which is literally redundant work.
    macro_rules! population_table {
    ($name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ ] ) => {
    const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new(
    [$($item),*], [$($weight),*]
    );
    }
    }

    population_table! {
    NAME_TABLE: &'static str = [
    2 => "Alice",
    1 => "Bob",
    3 => "Charlie"
    ]
    }

    fn main() {
    assert_eq!(NAME_TABLE.alias, [0, 2, 2]);
    assert_eq!(NAME_TABLE.prob, [u32::MAX, 1 << 31, u32::MAX]);
    assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]);
    assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]);

    let mut rng = thread_rng();
    let name = rng.sample(&NAME_TABLE);
  4. teryror created this gist Aug 21, 2021.
    179 changes: 179 additions & 0 deletions constexpr_alias_method.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,179 @@
    /// Const evaluatable Rust implementation of Vose's Alias Method, as described
    /// by Keith Schwarz at https://www.keithschwarz.com/darts-dice-coins/
    ///
    /// In brief, this is an O(n) precomputation, which allows sampling an arbitrary
    /// finite probability distribution in O(1) time, by first simulating a fair
    /// n-sided die, followed by a biased coin.
    ///
    /// Because floating point arithmetic cannot be used in const context, this is
    /// built to operate on integer weights, rather than precomputed probabilities.
    ///
    /// Where the standard Alias Method scales the probabilities by a factor of n
    /// and uses 1 as a cutoff to partition them into large and small probabilites,
    /// this finds the least common multiple of n and the total weight, scales up
    /// the weights to match it, and uses the LCM divided by n as the threshold.
    ///
    /// Unlike the original method, this approach is perfectly exact and numerically
    /// stable; I only switch to fixed point arithmetic for the final probability
    /// calculation, which introduces negligible rounding errors.
    ///
    /// The implementation could be made much more elegant as more language features
    /// become available in const fns, most notably for loops, panics, and arguments
    /// of mutable reference types.
    use rand::{Rng, thread_rng};
    use rand::distributions::Distribution;

    const fn gcd(mut a: u32, mut b: u32) -> u32 {
    while b != 0 {
    let tmp = b;
    b = a % b;
    a = tmp;
    }

    a
    }

    const fn lcm(a: u32, b: u32) -> u32 {
    (a * b) / gcd(a, b)
    }

    pub struct PopulationTable<T, const N: usize> {
    items: [T; N],
    prob: [u32; N],
    alias: [usize; N],
    }

    impl<T, const N: usize> PopulationTable<T, N> {
    pub const fn new(items: [T; N], mut weights: [u32; N]) -> Self {
    let mut prob = [0; N];
    let mut alias = [0; N];

    // Vec and similar data structures cannot be used in const fns, because
    // only other const fns may be called, which may not take &mut arguments.
    // So we have to use an ad-hoc, inline implementation for the work lists.
    //
    // These could have capacity N - 1, except the current state of const
    // generics doesn't allow that.

    let mut small = [0; N];
    let mut small_count = 0;

    let mut large = [0; N];
    let mut large_count = 0;

    let mut total_weight = 0;
    let mut i = 0;
    while i < N {
    // TODO(const_panic): assert_ne!(weights[i], 0);
    // Such an assertion is not strictly necessary, but keeping items
    // with probability 0 unnecessarily wastes memory.

    total_weight += weights[i];
    i += 1;
    }

    let rescaled_total = lcm(total_weight, N as u32);
    let weight_factor = rescaled_total / total_weight;
    let mut i = 0;
    while i < N {
    weights[i] = weights[i] * weight_factor;
    i += 1;
    }

    let weight_threshold = rescaled_total / (N as u32);

    let mut i = 0;
    while i < N {
    if weights[i] < weight_threshold {
    small[small_count] = i;
    small_count += 1;
    } else {
    large[large_count] = i;
    large_count += 1;
    }
    i += 1;
    }

    while small_count > 0 && large_count > 0 {
    small_count -= 1;
    let l = small[small_count];

    large_count -= 1;
    let g = large[large_count];

    prob[l] = (((weights[l] as u64) << 32) / (weight_threshold as u64)) as u32;
    alias[l] = g;

    weights[g] -= weight_threshold - weights[l];
    if weights[g] < weight_threshold {
    small[small_count] = g;
    small_count += 1;
    } else {
    large[large_count] = g;
    large_count += 1;
    }
    }

    while large_count > 0 {
    large_count -= 1;
    let g = large[large_count];

    prob[g] = u32::MAX;
    alias[g] = g;
    }

    // TODO(const_panic): assert_eq!(small_count, 0);
    // This should only be possible with floating point arithmetic
    // due to numerical instability; we should be fine without this loop:
    while small_count > 0 {
    small_count -= 1;
    let l = small[small_count];

    prob[l] = u32::MAX;
    alias[l] = l;
    }

    PopulationTable { items, prob, alias }
    }
    }

    impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
    let i = rng.gen_range(0..N);
    let x = rng.gen::<u32>();
    if self.prob[i] < x {
    self.items[i].clone()
    } else {
    let a = self.alias[i];
    self.items[a].clone()
    }
    }
    }

    const NAME_TABLE: PopulationTable<&'static str, 3> = PopulationTable::new(["Alice", "Bob", "Charlie"], [2, 1, 3]);

    fn main() {
    assert_eq!(NAME_TABLE.alias, [0, 2, 2]);
    assert_eq!(NAME_TABLE.prob, [u32::MAX, 1 << 31, u32::MAX]);

    let mut rng = thread_rng();
    let name = rng.sample(&NAME_TABLE);
    println!("Hello, {}!", name);
    }

    mod test {
    #[test]
    fn greatest_common_divisor() {
    assert_eq!(gcd(2, 4), 2);
    assert_eq!(gcd(2, 5), 1);
    assert_eq!(gcd(252, 105), 21);
    }

    #[test]
    fn least_common_multiple() {
    assert_eq!(lcm(2, 4), 4);
    assert_eq!(lcm(2, 5), 10);
    assert_eq!(lcm(18, 12), 36);
    }
    }