当前位置: 首页 > 文档资料 > Edward 中文文档 >

ed.GANInference

优质
小牛编辑
145浏览
2023-12-01

ed.GANInference

Class GANInference

Inherits From: VariationalInference

Aliases:

  • Class ed.GANInference
  • Class ed.inferences.GANInference

Defined in edward/inferences/gan_inference.py.

Parameter estimation with GAN-style training (Goodfellow et al., 2014).

Works for the class of implicit (and differentiable) probabilistic models. These models do not require a tractable density and assume only a program that generates samples.

Notes

GANInference does not support latent variable inference. Note that GAN-style training also samples from the prior: this does not work well for latent variables that are shared across many data points (global variables).

In building the computation graph for inference, the discriminator’s parameters can be accessed with the variable scope “Disc”.

GANs also only work for one observed random variable in data.

The objective function also adds to itself a summation over all tensors in the REGULARIZATION_LOSSES collection.

Examples

z = Normal(loc=tf.zeros([100, 10]), scale=tf.ones([100, 10]))
x = generative_network(z)
inference = ed.GANInference({x: x_data}, discriminator)

Methods

init

__init__(
    data,
    discriminator
)

Create an inference algorithm.

Args:

  • data: dict. Data dictionary which binds observed variables (of type RandomVariable or tf.Tensor) to their realizations (of type tf.Tensor). It can also bind placeholders (of type tf.Tensor) used in the model to their realizations.
  • discriminator: function. Function (with parameters) to discriminate samples. It should output logit probabilities (real-valued) and not probabilities in $([0, 1])$.

build_loss_and_gradients

build_loss_and_gradients(var_list)

finalize

finalize()

Function to call after convergence.

initialize

initialize(
    optimizer=None,
    optimizer_d=None,
    global_step=None,
    global_step_d=None,
    var_list=None,
    *args,
    **kwargs
)

Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm’s computation graph.

Args:

  • optimizer: str or tf.train.Optimizer. A TensorFlow optimizer, to use for optimizing the generator objective. Alternatively, one can pass in the name of a TensorFlow optimizer, and default parameters for the optimizer will be used.
  • optimizer_d: str or tf.train.Optimizer. A TensorFlow optimizer, to use for optimizing the discriminator objective. Alternatively, one can pass in the name of a TensorFlow optimizer, and default parameters for the optimizer will be used.
  • global_step: tf.Variable. Optional Variable to increment by one after the variables for the generator have been updated. See tf.train.Optimizer.apply_gradients.
  • global_step_d: tf.Variable. Optional Variable to increment by one after the variables for the discriminator have been updated. See tf.train.Optimizer.apply_gradients.
  • var_list: list of tf.Variable. List of TensorFlow variables to optimize over (in the generative model). Default is all trainable variables that latent_vars and data depend on.

print_progress

print_progress(info_dict)

Print progress to output.

run

run(
    variables=None,
    use_coordinator=True,
    *args,
    **kwargs
)

A simple wrapper to run inference.

  1. Initialize algorithm via initialize.
  2. (Optional) Build a TensorFlow summary writer for TensorBoard.
  3. (Optional) Initialize TensorFlow variables.
  4. (Optional) Start queue runners.
  5. Run update for self.n_iter iterations.
  6. While running, print_progress.
  7. Finalize algorithm via finalize.
  8. (Optional) Stop queue runners.

To customize the way inference is run, run these steps individually.

Args:

  • variables: list. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.
  • use_coordinator: bool. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed into initialize.

update

update(
    feed_dict=None,
    variables=None
)

Run one iteration of optimization.

Args:

  • feed_dict: dict. Feed dictionary for a TensorFlow session run. It is used to feed placeholders that are not fed during initialization.
  • variables: str. Which set of variables to update. Either “Disc” or “Gen”. Default is both.

Returns:

dict. Dictionary of algorithm-specific information. In this case, the iteration number and generative and discriminative losses.

Notes

The outputted iteration number is the total number of calls to update. Each update may include updating only a subset of parameters.

Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … Bengio, Y. (2014). Generative adversarial nets. In Neural information processing systems.