ed.ImplicitKLqp
ed.ImplicitKLqp
Class ImplicitKLqp
Inherits From: GANInference
Aliases:
- Class
ed.ImplicitKLqp
- Class
ed.inferences.ImplicitKLqp
Defined in edward/inferences/implicit_klqp.py
.
Variational inference with implicit probabilistic models (Tran et al., 2017).
It minimizes the KL divergence
$(\text{KL}( q(z, \beta; \lambda) \| p(z, \beta \mid x) ),)$
where $(z)$ are local variables associated to a data point and $(\beta)$ are global variables shared across data points.
Global latent variables require log_prob()
and need to return a random sample when fetched from the graph. Local latent variables and observed variables require only a random sample when fetched from the graph. (This is true for both $(p)$ and $(q)$.)
All variational factors must be reparameterizable: each of the random variables (rv
) satisfies rv.is_reparameterized
and rv.is_continuous
.
Notes
Unlike GANInference
, discriminator
takes dict’s as input, and must subset to the appropriate values through lexical scoping from the previously defined model and latent variables. This is necessary as the discriminator can take an arbitrary set of data, latent, and global variables.
Note the type for discriminator
’s output changes when one passes in the scale
argument to initialize()
.
- If
scale
has at most one item, thendiscriminator
outputs a tensor whose multiplication with that element is broadcastable. (For example, the output is a tensor and the single scale factor is a scalar.) - If
scale
has more than one item, then in order to scale its corresponding output,discriminator
must output a dictionary of same size and keys asscale
.
The objective function also adds to itself a summation over all tensors in the REGULARIZATION_LOSSES
collection.
Methods
init
__init__(
latent_vars,
data=None,
discriminator=None,
global_vars=None
)
Create an inference algorithm.
Args:
discriminator
: function. Function (with parameters). UnlikeGANInference
, it is interpreted as a ratio estimator rather than a discriminator. It takes three arguments: a data dict, local latent variable dict, and global latent variable dict. As with GAN discriminators, it can take a batch of data points and local variables, of size $(M)$, and output a vector of length $(M)$.global_vars
: dict of RandomVariable to RandomVariable. Identifying which variables inlatent_vars
are global variables, shared across data points. These will not be encompassed in the ratio estimation problem, and will be estimated with tractable variational approximations.
build_loss_and_gradients
build_loss_and_gradients(var_list)
Build loss function
$(-\Big(\mathbb{E}_{q(\beta)} [\log p(\beta) - \log q(\beta) ] + \sum_{n=1}^N \mathbb{E}_{q(\beta)q(z_n\mid\beta)} [ r^*(x_n, z_n, \beta) ] \Big).)$
We minimize it with respect to parameterized variational families $(q(z, \beta; \lambda))$.
$(r^*(x_n, z_n, \beta))$ is a function of a single data point $(x_n)$, single local variable $(z_n)$, and all global variables $(\beta)$. It is equal to the log-ratio
$(\log p(x_n, z_n\mid \beta) - \log q(x_n, z_n\mid \beta),)$
where $(q(x_n))$ is the empirical data distribution. Rather than explicit calculation, $(r^*(x, z, \beta))$ is the solution to a ratio estimation problem, minimizing the specified ratio_loss
.
Gradients are taken using the reparameterization trick (Kingma & Welling, 2014).
Notes
This also includes model parameters $(p(x, z, \beta; \theta))$ and variational distributions with inference networks $(q(z\mid x))$.
There are a bunch of extensions we could easily do in this implementation:
- further factorizations can be used to better leverage the graph structure for more complicated models;
- score function gradients for global variables;
- use more samples; this would require the
copy()
utility function for q’s as well, and an additional loop. we opt not to because it complicates the code; - analytic KL/swapping out the penalty term for the globals.
finalize
finalize()
Function to call after convergence.
initialize
initialize(
ratio_loss='log',
*args,
**kwargs
)
Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm’s computation graph.
Args:
ratio_loss
: str or fn. Loss function minimized to get the ratio estimator. ‘log’ or ‘hinge’. Alternatively, one can pass in a function of two inputs,psamples
andqsamples
, and output a point-wise value with shape matching the shapes of the two inputs.
print_progress
print_progress(info_dict)
Print progress to output.
run
run(
variables=None,
use_coordinator=True,
*args,
**kwargs
)
A simple wrapper to run inference.
- Initialize algorithm via
initialize
. - (Optional) Build a TensorFlow summary writer for TensorBoard.
- (Optional) Initialize TensorFlow variables.
- (Optional) Start queue runners.
- Run
update
forself.n_iter
iterations. - While running,
print_progress
. - Finalize algorithm via
finalize
. - (Optional) Stop queue runners.
To customize the way inference is run, run these steps individually.
Args:
variables
: list. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.use_coordinator
: bool. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed intoinitialize
.
update
update(
feed_dict=None,
variables=None
)
Run one iteration of optimization.
Args:
feed_dict
: dict. Feed dictionary for a TensorFlow session run. It is used to feed placeholders that are not fed during initialization.variables
: str. Which set of variables to update. Either “Disc” or “Gen”. Default is both.
Returns:
dict. Dictionary of algorithm-specific information. In this case, the iteration number and generative and discriminative losses.
Notes
The outputted iteration number is the total number of calls to update
. Each update may include updating only a subset of parameters.
Kingma, D., & Welling, M. (2014). Auto-encoding variational Bayes. In International conference on learning representations.
Tran, D., Hoffman, M. D., Saurous, R. A., Brevdo, E., Murphy, K., & Blei, D. M. (2017). Deep probabilistic programming. In International conference on learning representations.