Notes on Variational Inference Part I

1. Introduction

Consider the following statistical problem setting: we are given \mathbf X and \mathbf Z, where \mathbf X is a set of observations, and \mathbf Z is a set of latent random variables, the goal is to compute the conditional posterior distribution of \mathbf Z given \mathbf X, i.e. P(\mathbf Z|\mathbf X). This is a very general setting, because \mathbf Z here can be anything that we haven’t observed, which includes naturally aroused latent variables, as well as those parameters been considered as random variables.

2. Mean-field Approximation

2.1. Evidence Lower Bound

In many complex models, the posterior P(\mathbf Z|\mathbf X) is intractable (but P(\mathbf Z, \mathbf X) is tractable), which means we cannot compute directly given data. So the idea is to find a approximation distribution over \mathbf Z, i.e. q(\mathbf Z). Under the mean-field approximation, there are two ways to introduce term q(\mathbf Z) by approximation.

KL-divergence. We want to quantify the approximation by using KL-divergence between true distribution and proposed distribution, that is

(1)   \begin{equation*}  D_{KL}(q(\mathbf Z)||P(\mathbf Z|\mathbf X)) = \int q(\mathbf Z) \log \frac{q(\mathbf Z)}{P(\mathbf Z|\mathbf X)} d\mathbf Z \end{equation*}

To make the approximation accurate, we want to minimize the KL-divergence over q. However, once again we encounter the intractable distribution P(\mathbf Z|\mathbf X) in the KL-divergence. To solve this problem, we expend E.q.1:

    \[ D_{KL}(q||p) = \int \log \frac{q(\mathbf Z)}{P(\mathbf Z,\mathbf X)} d\mathbf Z + \log P(\mathbf X) \]

Here we abbreviate distribution q(Z) and P(\mathbf Z|\mathbf X) as q and p. Now since P(X) has nothing to do with how we determine q. So minimizing KL-divergence is equivalent to maximizing the evidence lower bound (this name would make more sense when we come to the other way of deriving it later), i.e. \mathcal L(q):

(2)   \begin{equation*}  \mathcal L(q) = - \int \log \frac{q(\mathbf Z)}{P(\mathbf Z,\mathbf X)} d\mathbf Z = \mathbb E_q [\log P(\mathbf Z, \mathbf X)] - \mathbb E_q [\log q(\mathbf Z)] \end{equation*}

Here term \mathbb E_q [\log P(\mathbf Z, \mathbf X)] is called variational free energy, \mathbb E_q [\log q(\mathbf Z)] is the entropy of q. Thus to minimize the approximation error is to maximize the lower bound over q:

    \[ \max_{q} \mathcal L(q) \]

Jensen Inequality. Another way of deriving approximation distribution q is by considering the estimation of data log-likelihood:

    \begin{equation*} \begin{split} \log P(\mathbf X) &= \log \int P(\mathbf Z, \mathbf X) d\mathbf Z\\ &= \log \int q(\mathbf Z) \frac{P(\mathbf Z, \mathbf X)}{q(\mathbf Z)} d\mathbf Z\\ &\ge \int q(\mathbf Z) \log \frac{P(\mathbf Z, \mathbf X)}{q(\mathbf Z)} d\mathbf Z\\ &= \mathbb E_q [\log P(\mathbf Z, \mathbf X)] - \mathbb E_q [\log q(\mathbf Z)]\\ &= \mathcal L(q) \end{split} \end{equation*}

It is also easy to see that the difference between data likelihood \log P(\mathbf Z) and lower bound \mathcal L(q) is the KL-divergence between q and p.

2.2. Deriving q(\mathbf Z)

Now that we have evidence lower bound containing a tractable distribution P(\mathbf Z, \mathbf X) (good) and a unknown distribution over all latent variables q(\mathbf Z)(not so good), we still need a way to quantify q(\mathbf Z). Under mean-fieldĀ variational Bayes, we will make an assumption: latent variables can be factorized into several independent sets \{\mathbf Z_i\} (specified by users), i.e.,

(3)   \begin{equation*}  q(\mathbf Z) = \prod_i q_i(\mathbf Z_i| \mathbf X) \end{equation*}

Plug E.q. 3 into lower bound E.q. 2, we can derive the optimum solution of q_i while others fixed is (expressed in logarithm):

    \[ \log q_i^*(\mathbf Z_i) = \mathbb E_{-i} [\log P(\mathbf Z, \mathbf X)] + \text{const} \]

or

    \begin{equation*} \begin{split} q_i^*(\mathbf Z_i) &\propto \exp \{ \mathbb E_{-i} [\log P(\mathbf Z, \mathbf X)] \} \\ &\propto \exp \{ \mathbb E_{-i} [\log P(\mathbf Z_i|\mathbf Z_{-i}, \mathbf X)] \} \end{split} \end{equation*}

Where

    \[ \mathbb E_{-i}[\log P(z_i|\mathbf Z_i, \mathbf X)] = \int \log P(\mathbf Z_i |\mathbf Z_{-i}, \mathbf X) \prod_{j \neq i} q_j(\mathbf Z_j) d\mathbf Z_j \]

Optimal q_i(z_i) can be derived from here, although might be difficult to work with for some model (for most models it is not).

Once the optimal \{q_i\} for all \mathbf Z are found, we can alternatively update each latent variable until convergence (which is guaranteed due to the convexity of ELBO). Noted the convergent point is local optimal.

2.2.1. q(\mathbf Z) for Exponential Family Conditionals

If P(\mathbf Z_i |\mathbf Z_{-i}, \mathbf X) is in exponential family, then optimal q_i(\mathbf Z_i) is in the same family of distribution as P(\mathbf Z_i |\mathbf Z_{-i}, \mathbf X). This provides a shortcut for doing mean-field variance inference on graphical models with conditional exponential families (which is common for many graphical models): using the theory mentioned here, we can simply write down the optimal variation distributions and their undetermined variational parameters, then setting derivative of ELBO to zero w.r.t. those parameters (or using gradient ascent) can form a coordinate ascent learning procedure. See David Blei’s note.

3. Expectation Propagation – A glance

Same as mean-field approximation, EP also tries to find the approximation distribution q by minimizing the KL-divergence. However, different from mean-field which minimizes the KL-divergence by maximizing the ELBO, EP tries directly maximize KL-divergence, which might be difficult for any distributions, but practical for some distributions, such as exponential families. The details might be contained in future posts.