ed.copy
优质
小牛编辑
137浏览
2023-12-01
Aliases:
ed.copy
ed.util.copy
copy(
org_instance,
dict_swap=None,
scope='copied',
replace_itself=False,
copy_q=False,
copy_parent_rvs=True
)
Defined in edward/util/random_variables.py
.
Build a new node in the TensorFlow graph from org_instance
, where any of its ancestors existing in dict_swap
are replaced with dict_swap
’s corresponding value.
Copying is done recursively. Any Operation
whose output is required to copy org_instance
is also copied (if it isn’t already copied within the new scope).
tf.Variable
s, tf.placeholder
s, and nodes of type Queue
are always reused and not copied. In addition, tf.Operation
s with operation-level seeds are copied with a new operation-level seed.
Args:
org_instance
: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable. Node to add in graph with replaced ancestors.dict_swap
: dict. Random variables, variables, tensors, or operations to swap with. Its keys are whatorg_instance
may depend on, and its values are the corresponding object (not necessarily of the same class instance, but must have the same type, e.g., float32) that is used in exchange.scope
: str. A scope for the new node(s). This is used to avoid name conflicts with the original node(s).replace_itself
: bool. Whether to replaceorg_instance
itself if it exists indict_swap
. (This is used for the recursion.)copy_q
: bool. Whether to copy the replaced tensors too (if not already copied within the new scope). Otherwise will reuse them.copy_parent_rvs
: Whether to copy parent random variablesorg_instance
depends on. Otherwise will copy only the sample tensors and not the random variable class itself.
Returns:
RandomVariable, tf.Variable, tf.Tensor, or tf.Operation. The copied node.
Raises:
TypeError. If org_instance
is not one of the above types.
Examples
x = tf.constant(2.0)
y = tf.constant(3.0)
z = x * y
qx = tf.constant(4.0)
# The TensorFlow graph is currently
# `x` -> `z` <- y`, `qx`
# This adds a subgraph with newly copied nodes,
# `qx` -> `copied/z` <- `copied/y`
z_new = ed.copy(z, {x: qx})
sess = tf.Session()
sess.run(z)
6.0
sess.run(z_new)
12.0