Skip to content

Instantly share code, notes, and snippets.

@BaxterEaves
Last active February 18, 2016 21:06
Show Gist options
  • Select an option

  • Save BaxterEaves/687ef5ed49a3f26d3196 to your computer and use it in GitHub Desktop.

Select an option

Save BaxterEaves/687ef5ed49a3f26d3196 to your computer and use it in GitHub Desktop.
Infinite Gaussian mixture model using the collapsed Gibbs sampler
"""
Copyright (C) 2016 Baxter Eaves
License: Do what the fuck you want to public license (WTFPL) V2
Infinite Gaussian Mixture Model
Requires: numpy, scipy, matplotlib, seaborn, pandas
"""
import numpy as np
import sys
from random import shuffle
from scipy.special import gammaln
from scipy.misc import logsumexp
from math import log
from math import pi
LOG2 = log(2)
LOG2PI = log(2*pi)
LOGSQRTPI = .5*log(pi)
LOGSQRT2PI = .5*log(2*pi)
class ConjugateMixtureComponent(object):
def __init__(self, x, *params):
raise NotImplementedError
def posterior_predictive(self, x):
raise NotImplementedError
def marginal_likelihood(self):
raise NotImplementedError
def insert_datum(x):
raise NotImplementedError
def remove_datum(x):
raise NotImplementedError
@property
def params(self):
raise NotImplementedError
class NIG(ConjugateMixtureComponent):
""" Normal inverse-Gamma component model """
def __init__(self, x, m, r, s, nu):
self.m = m
self.r = r
self.s = s
self.nu = nu
self.logz = self._psi(r, s, nu)
self.n = 0.
self.sum_x = 0.
self.sum_x_sq = 0.
# TODO: put this in superclass constructor?
for x_i in x:
self.insert_datum(x_i)
def posterior_predictive(self, x):
(mn, rn, sn, nun) = self._update_params(
self.n, self.sum_x, self.sum_x_sq, self.m, self.r, self.s, self.nu)
(_, rm, sm, num) = self._update_params(1, x, x*x, mn, rn, sn, nun)
return -LOGSQRT2PI + self._psi(rm, sm, num) - self._psi(rn, sn, nun)
def marginal_likelihood(self):
(mn, rn, sn, nun) = self._update_params(
self.n, self.sum_x, self.sum_x_sq, self.m, self.r, self.s, self.nu)
return -(self.n/2.)*LOG2PI + self._psi(rn, sn, nun) - self.logz
def insert_datum(self, x):
self.n += 1
self.sum_x += x
self.sum_x_sq += x*x
def remove_datum(self, x):
self.n -= 1
self.sum_x -= x
self.sum_x_sq -= x*x
@property
def params(self):
return self.m, self.r, self.s, self.nu
@staticmethod
def _psi(r, s, nu):
logz = (nu+1.)/2.*LOG2 + LOGSQRTPI - .5*log(r) - nu/2.*log(s)
logz += gammaln(nu/2.)
return logz
@staticmethod
def _update_params(n, sum_x, sum_x_sq, m, r, s, nu):
rn = r + n
nun = nu + n
mn = (r*m + sum_x)/rn
sn = s + sum_x_sq + r*m*m - rn*mn*mn
return mn, rn, sn, nun
class IGMM(object):
""" Infinite Gaussian Mixture Model """
def __init__(self, x, alpha, component_model, params, seqinit=False):
self.x = x
self.n = len(x)
self.alpha = alpha
self.params = params
self.model = component_model
if seqinit:
self.z = np.zeros(self.n, dtype=int)
self.nk = [1]
self.k = 1
self.components = [self.model([self.x[0]], *params)]
for i in range(1, self.n):
self.step(i, init_mode=True)
else:
self.z, self.nk, self.k = crp_gen(self.n, alpha)
self.components = []
for j in range(self.k):
x_j = self.x[self.z == j]
self.components.append(self.model(x_j, *params))
def _remove_datum(self, idx):
""" Remove a datum from a component """
assert idx < self.n
j = self.z[idx]
is_singleton = self.nk[j] == 1
self.z[idx] = -1
if is_singleton:
self.k -= 1
del self.nk[j]
self.z[self.z > j] -= 1
del self.components[j]
else:
self.components[j].remove_datum(self.x[idx])
self.nk[j] -= 1
assert len(self.nk) == self.k
assert len(self.components) == self.k
# assert sum(self.nk) == self.n-1
assert max(self.z) == self.k-1
assert min(self.z) == -1
assert sum(self.z < 0) == 1
def _insert_datum(self, idx, j):
""" Insert a datum into a component """
assert idx < self.n
assert j <= self.k
self.z[idx] = j
if j == self.k:
self.nk.append(1)
self.k += 1
self.components.append(self.model([self.x[idx]], *self.params))
else:
self.nk[j] += 1
self.components[j].insert_datum(self.x[idx])
assert len(self.nk) == self.k
assert len(self.components) == self.k
# assert sum(self.nk) == self.n
assert max(self.z) == self.k-1
assert min(self.z) == 0
assert sum(self.z < 0) == 0
def step(self, idx, init_mode=False, return_prob=False):
""" Resample the component assignment of x_idx """
assert idx < self.n
if not init_mode:
self._remove_datum(idx)
logps = np.log(np.array(self.nk + [self.alpha]))
x_i = self.x[idx]
for j, component in enumerate(self.components):
logps[j] += component.posterior_predictive(x_i)
temp_component = self.model([x_i], *self.params)
logps[-1] += temp_component.marginal_likelihood()
z_i = logpflip(logps)
self._insert_datum(idx, z_i)
if return_prob:
logps -= logsumexp(logps)
return logps[z_i]
def infer(self, n_sweeps):
""" Run the Gibbs sampler for n_sweeps. Each sweep resamples each
datum.
"""
for i in range(n_sweeps):
sys.stdout.write("<<<SWEEP %d OF %d>>> " % (i+1, n_sweeps,))
sys.stdout.flush()
for idx in range(self.n):
self.step(idx)
sys.stdout.write(".")
sys.stdout.flush()
sys.stdout.write("Done\n")
def pflip(w):
""" Normalize the weights, w, and do a categorical draw """
p = np.cumsum(w)
p /= p[-1]
return np.digitize([np.random.rand()], p)[0]
def logpflip(lw):
""" Normalize log weights, lw, and do a categorical draw """
p = np.cumsum(np.exp(lw-logsumexp(lw)))
return np.digitize([np.random.rand()], p)[0]
def crp_gen(n, alpha):
""" Draw an n-length partition from CRP(alpha) """
if alpha <= 0.0:
raise ValueError('alpha must be greater than 0.')
z = np.zeros(n, dtype=int)
nk = [1]
k = 1
for i in range(1, n):
j = pflip(np.array(nk + [alpha]))
assert j <= k
z[i] = j
if j == k:
k += 1
nk.append(1)
else:
nk[j] += 1
shuffle(z)
assert len(z) == n
assert len(nk) == k
assert sum(nk) == n
assert min(z) == 0
assert max(z) == k-1
return z, nk, k
if __name__ == '__main__':
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set_context('paper')
n = 50
std = .75
x = np.append(np.random.randn(n)*std,
(np.random.randn(n)*std+4., np.random.randn(n)*std-4.,))
igmm = IGMM(x, .5, NIG, (0., 1., 1., 1.,), seqinit=True)
plt.figure(figsize=(7.5, 3.5), dpi=150)
igmm.infer(100)
z = igmm.z
k = igmm.k
df = pd.DataFrame([{'x': xi, 'j': j} for j, xi in zip(z, x)])
plt.clf()
plt.subplot(1, 2, 1)
sns.distplot(df['x'], bins=35, hist=True, rug=True, kde=False)
plt.title('Raw Data')
plt.subplot(1, 2, 2)
for j in range(k):
xj = df['x'][df['j'] == j]
if len(xj) == 1:
continue
sns.distplot(xj, hist=False, rug=True, kde=True, norm_hist=False)
plt.title('Categorized Data')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment