ed.WGANInference
ed.WGANInference
Class WGANInference
Inherits From: GANInference
Aliases:
- Class
ed.WGANInference
- Class
ed.inferences.WGANInference
Defined in edward/inferences/wgan_inference.py
.
Parameter estimation with GAN-style training (Goodfellow et al., 2014), using the Wasserstein distance (Arjovsky, Chintala, & Bottou, 2017).
Works for the class of implicit (and differentiable) probabilistic models. These models do not require a tractable density and assume only a program that generates samples.
Notes
Argument-wise, the only difference from GANInference
is conceptual: the discriminator
is better described as a test function or critic. WGANInference
continues to use discriminator
only to share methods and attributes with GANInference
.
The objective function also adds to itself a summation over all tensors in the REGULARIZATION_LOSSES
collection.
Examples
z = Normal(loc=tf.zeros([100, 10]), scale=tf.ones([100, 10]))
x = generative_network(z)
inference = ed.WGANInference({x: x_data}, discriminator)
Methods
init
__init__(
*args,
**kwargs
)
build_loss_and_gradients
build_loss_and_gradients(var_list)
finalize
finalize()
Function to call after convergence.
initialize
initialize(
penalty=10.0,
clip=None,
*args,
**kwargs
)
Initialize inference algorithm. It initializes hyperparameters and builds ops for the algorithm’s computation graph.
Args:
penalty
: float. Scalar value to enforce gradient penalty that ensures the gradients have norm equal to 1 (Gulrajani, Ahmed, Arjovsky, Dumoulin, & Courville, 2017). Set to None (or 0.0) if using no penalty.clip
: float. Value to clip weights by. Default is no clipping.
print_progress
print_progress(info_dict)
Print progress to output.
run
run(
variables=None,
use_coordinator=True,
*args,
**kwargs
)
A simple wrapper to run inference.
- Initialize algorithm via
initialize
. - (Optional) Build a TensorFlow summary writer for TensorBoard.
- (Optional) Initialize TensorFlow variables.
- (Optional) Start queue runners.
- Run
update
forself.n_iter
iterations. - While running,
print_progress
. - Finalize algorithm via
finalize
. - (Optional) Stop queue runners.
To customize the way inference is run, run these steps individually.
Args:
variables
: list. A list of TensorFlow variables to initialize during inference. Default is to initialize all variables (this includes reinitializing variables that were already initialized). To avoid initializing any variables, pass in an empty list.use_coordinator
: bool. Whether to start and stop queue runners during inference using a TensorFlow coordinator. For example, queue runners are necessary for batch training with file readers. *args, **kwargs: Passed intoinitialize
.
update
update(
feed_dict=None,
variables=None
)
Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. In International conference on machine learning.
Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … Bengio, Y. (2014). Generative adversarial nets. In Neural information processing systems.
Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. (2017). Improved Training of Wasserstein GANs. arXiv.org. Retrieved from http://arxiv.org/abs/1704.00028v1