Exploring the EM Algorithm

Exploring the EM Algorithm

tl;dr: I explain, visualize, and code the EM Algorithm for Gaussian Mixture Models and explain how it is basically just an example of Variational Inference.

In my AMLED class we recently talked about how Maximum Likelihood can help you find the best statistical model:

  1. Write down the formula for the Likelihood of the data $P(\mathbf{X})$. This is how likely the data are with respect to your probabilistic models parameters.
  2. Take that likelihood, and take the natural logarithm of it. This won't change the optimal parameters that maximize the likelihood, since the logarithm is a monotonically increasing function.
  3. Take that log-likelihood, and optimize it, typically by setting the partial derivatives of the log-likelihood w.r.t. each parameter equal to 0 and solving for which parameters get you that.

A typical example of this is finding the $\mu$ and $\sigma$ of a Gaussian Distribution.

This works fine for certain convex likelihoods (provided you are careful to take care of things like singularities, ill-conditioned matrices, etc.). However, for many types of models this optimization is not straight-forward: the optimal parameters depend on one another in complex, non-linear ways, and you can't just "plug-in" the right data to get the optimal answer. One case where this occurs is in Latent Variable models, such as Gaussian Mixture Models, that is, where you have data that are generated from multiple gaussian distributions, but you don't know which distribution generated which data point. While this might seem similar to the MLE of the Gaussian Distribution above, it is actually much more complex.

One way to solve this problem is by first finding a "simpler" distribution, for which analytical closed form updates are possible, and then matching that simpler distribution as closely as possible to the "harder" distribution. This is the basic idea behind Variational Inference. Often, one can iteratively update simpler models to get closer and closer to the "harder" one. That iteration between simplier models is the basic idea of the Expectation-Maximization Algorithm. And in some cases, with certain models and distributions, that iterative process of inching closer and close with simpler distributions can actually get you the exact solution to the harder distribution that you originally wanted. Below, we will see that Gaussian Mixture Models are one case where that happens. Most of the time, however, you will not be so lucky and the optimized simple models will be a little different than what an optimized hard model would be; in these cases you just hope that the diffence doesn't affect actual performance much.

In [1]:
# First, let's just setup the libraries we'll need
from scipy.stats import norm
from scipy.stats import entropy
from bokeh.plotting import figure, output_notebook, show
from bokeh.models import ColumnDataSource
from ipywidgets import interact
import numpy as np
plot_height = 400
plot_width = 800
# output to the IPython Notebook