ed.KLpq
ed.KLpq
Class KLpq
Inherits From: VariationalInference
Aliases:
- Class
ed.KLpq
- Class
ed.inferences.KLpq
Defined in edward/inferences/klpq.py
.
Variational inference with the KL divergence
$(\text{KL}( p(z \mid x) \| q(z) ).)$
To perform the optimization, this class uses a technique from adaptive importance sampling (Oh & Berger, 1992).
Notes
KLpq
also optimizes any model parameters $(p(z\mid x; \theta))$. It does this by variational EM, maximizing
$(\mathbb{E}_{p(z \mid x; \lambda)} [ \log p(x, z; \theta) ])$
with respect to $(\theta)$.
In conditional inference, we infer $z` in $(p(z, \beta \mid x))$ while fixing inference over $(\beta)$ using another distribution $(q(\beta))$. During gradient calculation, instead of using the model’s density
$(\log p(x, z^{(s)}), z^{(s)} \sim q(z; \lambda),)$
for each sample $(s=1,\ldots,S)$, KLpq
uses
$(\log p(x, z^{(s)}, \beta^{(s)}),)$
where $(z^{(s)} \sim q(z; \lambda))$ and$(\beta^{(s)} \sim q(\beta))$.
The objective function also adds to itself a summation over all tensors in the REGULARIZATION_LOSSES
collection.
Methods
init
__init__(
latent_vars=None,
data=None
)
Create an inference algorithm.
Args:
latent_vars
: list of RandomVariable or dict of RandomVariable to RandomVariable. Collection of random variables to perform inference on. If list, each random variable will be implictly optimized using aNormal
random variable that is defined internally with a free parameter per location and scale and is initialized using standard normal draws. The random variables to approximate must be continuous.
build_loss_and_gradients
build_loss_and_gradients(var_list)
Build loss function
$(\text{KL}( p(z \mid x) \| q(z) ) = \mathbb{E}_{p(z \mid x)} [ \log p(z \mid x) - \log q(z; \lambda) ])$
and stochastic gradients based on importance sampling.
The loss function can be estimated as
$(\sum_{s=1}^S [ w_{\text{norm}}(z^s; \lambda) (\log p(x, z^s) - \log q(z^s; \lambda) ],)$
where for $(z^s \sim q(z; \lambda))$,
$(w_{\text{norm}}(z^s; \lambda) = w(z^s; \lambda) / \sum_{s=1}^S w(z^s; \lambda))$
normalizes the importance weights, $(w(z^s; \lambda) = p(x, z^s) / q(z^s; \lambda))$.
This provides a gradient,
$(- \sum_{s=1}^S [ w_{\text{norm}}(z^s; \lambda) \nabla_{\lambda} \log q(z^s; \lambda) ].)$
finalize
finalize()
Function to call after convergence.
initialize
initialize(
n_samples=1,
*args,
**kwargs
)
Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm’s computation graph.
Args:
n_samples
: int. Number of samples from variational model for calculating stochastic gradients.
print_progress
print_progress(info_dict)
Print progress to output.
run
run(
variables=None,
use_coordinator=True,
*args,
**kwargs
)
A simple wrapper to run inference.
- Initialize algorithm via
initialize
. - (Optional) Build a TensorFlow summary writer for TensorBoard.
- (Optional) Initialize TensorFlow variables.
- (Optional) Start queue runners.
- Run
update
forself.n_iter
iterations. - While running,
print_progress
. - Finalize algorithm via
finalize
. - (Optional) Stop queue runners.
To customize the way inference is run, run these steps individually.
Args:
variables
: list. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.use_coordinator
: bool. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed intoinitialize
.
update
update(feed_dict=None)
Run one iteration of optimization.
Args:
feed_dict
: dict. Feed dictionary for a TensorFlow session run. It is used to feed placeholders that are not fed during initialization.
Returns:
dict. Dictionary of algorithm-specific information. In this case, the loss function value after one iteration.
Oh, M.-S., & Berger, J. O. (1992). Adaptive importance sampling in Monte Carlo integration. Journal of Statistical Computation and Simulation, 41(3-4), 143–168.