Skip to content

Instantly share code, notes, and snippets.

@lmc2179
Created February 13, 2015 13:21
Show Gist options
  • Select an option

  • Save lmc2179/48feb47a8ecc94bb0aa6 to your computer and use it in GitHub Desktop.

Select an option

Save lmc2179/48feb47a8ecc94bb0aa6 to your computer and use it in GitHub Desktop.

Revisions

  1. lmc2179 created this gist Feb 13, 2015.
    22 changes: 22 additions & 0 deletions multinomial_sample.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,22 @@
    import random
    import partition_tree

    class Multinomial_Sampler(object):
    def __init__(self, probabilities, event_names):
    intervals = self._build_intervals_from_probabilities(probabilities)
    self.tree = partition_tree.PartitionTree(intervals, event_names)

    def _build_intervals_from_probabilities(self, probabilities):
    if sum(probabilities) != 1.0:
    raise Exception
    intervals = []
    left_side = 0.0
    for p in probabilities:
    intervals.append((left_side, left_side+p))
    left_side += p
    return intervals


    def sample(self):
    random_0_1 = random.random()
    return self.tree.get_label(random_0_1)
    38 changes: 38 additions & 0 deletions partition_tree.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,38 @@
    class PartitionTreeNode(object):
    def __init__(self, left=None, right=None, interval=None):
    self.left = left
    self.right = right
    self.interval = interval

    class PartitionTree(object):
    def __init__(self, intervals, labels):
    self.mapping = {}
    self.root = PartitionTreeNode()
    for interval, label in zip(intervals, labels):
    self._add_interval(interval, self.root)
    self.mapping[interval] = label

    def _add_interval(self, interval, node):
    if not node.interval:
    node.interval = interval
    node.left = PartitionTreeNode()
    node.right = PartitionTreeNode()
    elif interval[1] <= node.interval[0]:
    self._add_interval(interval, node.left)
    elif interval[0] >= node.interval[1]:
    self._add_interval(interval, node.right)
    else:
    raise Exception

    def get_label(self, number):
    interval = self._get_interval(number, self.root)
    return self.mapping[interval]

    def _get_interval(self, number, node):
    left_bound, right_bound = node.interval
    if number < left_bound:
    return self._get_interval(number, node.left)
    elif number > right_bound:
    return self._get_interval(number, node.right)
    else:
    return node.interval
    27 changes: 27 additions & 0 deletions test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@
    import unittest
    import partition_tree
    import sampler

    class PartitionTreeTrest(unittest.TestCase):
    def test_partition_tree(self):
    tree = partition_tree.PartitionTree([(0.0,0.5),(0.5,1.0)],['A', 'B'])
    values = [0.0, 0.3, 0.5, 0.7, 1.0]
    labels = [tree.get_label(v) for v in values]
    correct_labels = ['A', 'A', 'A', 'B', 'B']
    assert labels == correct_labels

    class MultinomialSampleTest(unittest.TestCase):
    def test_biased_coin_flip(self):
    true_heads, true_tails = 0.3, 0.7
    P = [true_heads, true_tails]
    event_names = ['Heads', 'Tails']
    s = sampler.Multinomial_Sampler(P, event_names)
    from collections import Counter
    total_samples = 400000
    sample_counter = Counter([s.sample() for i in range(total_samples)])
    allowed_error = 0.001
    head_frequency = 1.0*sample_counter['Heads']/total_samples
    tail_frequency = 1.0*sample_counter['Tails']/total_samples
    print(head_frequency, tail_frequency)
    assert true_heads - allowed_error <= head_frequency <= true_heads + allowed_error
    assert true_tails - allowed_error <= tail_frequency <= true_tails + allowed_error