Skip to content

Instantly share code, notes, and snippets.

@lfloeer
Created May 4, 2015 15:34
Show Gist options
  • Select an option

  • Save lfloeer/d8a2d3ccc8898eeb02af to your computer and use it in GitHub Desktop.

Select an option

Save lfloeer/d8a2d3ccc8898eeb02af to your computer and use it in GitHub Desktop.
Discrete variables in emcee
import emcee
from scipy import stats
from scipy.special import binom
import numpy as np
import pylab as pl
def ln_prior(p):
"Uniform prior on trials and beta prior on probability per trial"
if 5 <= p[0] <= 15 and 0. < p[1] < 1.:
return -0.5 * np.log(p[1]) - 0.5 * np.log(1 - p[1])
else:
return -np.inf
def ln_like(p, data):
"""Log-likelihood for binomial distribution. The number of trials in the
binomial distribution is calculated by truncating the first parameter."""
trials = int(p[0])
ln_value = np.log(binom(trials, data))
ln_value += (trials - data) * np.log(1 - p[1])
ln_value += data * np.log(p[1])
return np.sum(ln_value)
def ln_posterior(p, data):
lnp = ln_prior(p)
if np.isfinite(lnp):
return lnp + ln_like(p, data)
return lnp
def plot_chain(chain, thin_walkers=10, burn=50):
pl.figure()
pl.subplot(221)
pl.plot(chain[::thin_walkers, :, 0].T,
alpha=0.1, color='k')
pl.ylabel('Trials')
pl.subplot(222)
pl.hist(chain[:, -1, 0],
bins=np.arange(31))
pl.subplot(223)
pl.plot(chain[::thin_walkers, :, 1].T,
alpha=0.1, color='k')
pl.ylabel('Probability')
pl.subplot(224)
pl.hist(chain[:, -1, 1],
bins=50)
pl.figure()
for walker in chain[::thin_walkers]:
pl.plot(walker[burn:, 0], walker[burn:, 1], ls='None', marker='.', markeredgecolor='None')
pl.xlabel('Trials')
pl.ylabel('Probability')
pl.show()
def main():
n_walkers = 500
n_dim = 2
data = stats.binom(10, 0.3).rvs(2000)
sampler = emcee.EnsembleSampler(n_walkers, n_dim, ln_posterior, args=[data])
scales = np.array([[10., 0.2]])
offsets = np.array([[5, 0.2]])
pos0 = np.random.uniform(0, 1, (n_walkers, n_dim)) * scales + offsets
sampler.run_mcmc(pos0, 500)
print np.percentile(sampler.acceptance_fraction, [16, 50, 84])
print sampler.acor
plot_chain(sampler.chain)
if __name__ == '__main__':
np.random.seed(42)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment