Created
February 13, 2015 13:21
-
-
Save lmc2179/48feb47a8ecc94bb0aa6 to your computer and use it in GitHub Desktop.
Revisions
-
lmc2179 created this gist
Feb 13, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,22 @@ import random import partition_tree class Multinomial_Sampler(object): def __init__(self, probabilities, event_names): intervals = self._build_intervals_from_probabilities(probabilities) self.tree = partition_tree.PartitionTree(intervals, event_names) def _build_intervals_from_probabilities(self, probabilities): if sum(probabilities) != 1.0: raise Exception intervals = [] left_side = 0.0 for p in probabilities: intervals.append((left_side, left_side+p)) left_side += p return intervals def sample(self): random_0_1 = random.random() return self.tree.get_label(random_0_1) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,38 @@ class PartitionTreeNode(object): def __init__(self, left=None, right=None, interval=None): self.left = left self.right = right self.interval = interval class PartitionTree(object): def __init__(self, intervals, labels): self.mapping = {} self.root = PartitionTreeNode() for interval, label in zip(intervals, labels): self._add_interval(interval, self.root) self.mapping[interval] = label def _add_interval(self, interval, node): if not node.interval: node.interval = interval node.left = PartitionTreeNode() node.right = PartitionTreeNode() elif interval[1] <= node.interval[0]: self._add_interval(interval, node.left) elif interval[0] >= node.interval[1]: self._add_interval(interval, node.right) else: raise Exception def get_label(self, number): interval = self._get_interval(number, self.root) return self.mapping[interval] def _get_interval(self, number, node): left_bound, right_bound = node.interval if number < left_bound: return self._get_interval(number, node.left) elif number > right_bound: return self._get_interval(number, node.right) else: return node.interval This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,27 @@ import unittest import partition_tree import sampler class PartitionTreeTrest(unittest.TestCase): def test_partition_tree(self): tree = partition_tree.PartitionTree([(0.0,0.5),(0.5,1.0)],['A', 'B']) values = [0.0, 0.3, 0.5, 0.7, 1.0] labels = [tree.get_label(v) for v in values] correct_labels = ['A', 'A', 'A', 'B', 'B'] assert labels == correct_labels class MultinomialSampleTest(unittest.TestCase): def test_biased_coin_flip(self): true_heads, true_tails = 0.3, 0.7 P = [true_heads, true_tails] event_names = ['Heads', 'Tails'] s = sampler.Multinomial_Sampler(P, event_names) from collections import Counter total_samples = 400000 sample_counter = Counter([s.sample() for i in range(total_samples)]) allowed_error = 0.001 head_frequency = 1.0*sample_counter['Heads']/total_samples tail_frequency = 1.0*sample_counter['Tails']/total_samples print(head_frequency, tail_frequency) assert true_heads - allowed_error <= head_frequency <= true_heads + allowed_error assert true_tails - allowed_error <= tail_frequency <= true_tails + allowed_error