KL(p||q) Minimization
$(\text{KL}(p\|q))$ Minimization
One form of variational inference minimizes the Kullback-Leibler divergence from $(p(\mathbf{z} \mid \mathbf{x}))$ to $(q(\mathbf{z}\;;\;\lambda))$, \[\begin{aligned} \lambda^* &= \arg\min_\lambda \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) )\\ &= \arg\min_\lambda\; \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \log p(\mathbf{z} \mid \mathbf{x}) - \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] The KL divergence is a non-symmetric, information theoretic measure of similarity between two probability distributions.
Minimizing an intractable objective function
The $(\text{KL}(p\|q))$ objective we seek to minimize is intractable; it directly involves the posterior $(p(\mathbf{z} \mid \mathbf{x}))$. Ignoring this for the moment, consider its gradient \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big].\end{aligned}\] Both $(\text{KL}(p\|q))$ and its gradient are intractable because of the posterior expectation. We can use importance sampling to both estimate the objective and calculate stochastic gradients (Oh & Berger, 1992).
Adaptive Importance sampling
First rewrite the expectation to be with respect to the variational distribution, \[\begin{aligned} - \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x})} \big[ \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \big] &= - \mathbb{E}_{q(\mathbf{z}\;;\;\lambda)} \Bigg[ \frac{p(\mathbf{z} \mid \mathbf{x})}{q(\mathbf{z}\;;\;\lambda)} \nabla_\lambda\; \log q(\mathbf{z}\;;\;\lambda) \Bigg].\end{aligned}\]
We then use importance sampling to obtain a noisy estimate of this gradient. The basic procedure follows these steps:
- draw $(S)$ samples $(\{\mathbf{z}_s\}_1^S \sim q(\mathbf{z}\;;\;\lambda))$,
- evaluate $(\nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda))$,
- compute the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \end{aligned}\]
- compute the weighted sum.
The key insight is that we can use the joint $(p(\mathbf{x},\mathbf{z}))$ instead of the posterior when estimating the normalized importance weights \[\begin{aligned} w_s &= \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{z}_s \mid \mathbf{x})}{q(\mathbf{z}_s\;;\;\lambda)} \\ &= \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)} \Bigg/ \sum_{s=1}^{S} \frac{p(\mathbf{x}, \mathbf{z}_s)}{q(\mathbf{z}_s\;;\;\lambda)}.\end{aligned}\] This follows from Bayes’ rule \[\begin{aligned} p(\mathbf{z} \mid \mathbf{x}) &= p(\mathbf{x}, \mathbf{z}) / p(\mathbf{x})\\ &= p(\mathbf{x}, \mathbf{z}) / \text{a constant function of }\mathbf{z}.\end{aligned}\]
Importance sampling thus gives the following biased yet consistent gradient estimate \[\begin{aligned} \nabla_\lambda\; \text{KL}( p(\mathbf{z} \mid \mathbf{x}) \;\|\; q(\mathbf{z}\;;\;\lambda) ) &= - \sum_{s=1}^S w_s \nabla_\lambda\; \log q(\mathbf{z}_s\;;\;\lambda).\end{aligned}\] The objective $(\text{KL}(p\|q))$ can be calculated in a similar fashion. The only new ingredient for its gradient is the score function $(\nabla_\lambda \log q(\mathbf{z}\;;\;\lambda))$. Edward uses automatic differentiation, specifically with TensorFlow’s computational graphs, making this gradient computation both simple and efficient to distribute.
Adaptive importance sampling follows this gradient to a local optimum using stochastic optimization. It is adaptive because the variational distribution $(q(\mathbf{z}\;;\;\lambda))$ iteratively gets closer to the posterior $(p(\mathbf{z} \mid \mathbf{x}))$.
For more details, see the API as well as its implementation in Edward’s code base.
References
Oh, M.-S., & Berger, J. O. (1992). Adaptive importance sampling in Monte Carlo integration. Journal of Statistical Computation and Simulation, 41(3-4), 143–168.