当前位置: 首页 > 面试题库 >

如何为由tf操作组成的操作注册自定义渐变

公孙霖
2023-03-14
问题内容

更具体地说,我有一个简单的fprop,它是tf操作的组成部分。我想使用RegisterGradient用我自己的梯度方法覆盖tensorflow梯度计算。

此代码有什么问题?

import tensorflow as tf
from tensorflow.python.framework import ops

@ops.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0]
    return 0 * x  # zero out to see the difference:

def fprop(x):
    x = tf.sqrt(x)
    out = tf.maximum(x, .2)
    return out

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32))
h = fprop(a)
h = tf.identity(h, name="Myop")
grad = tf.gradients(h, a)

g = tf.get_default_graph()
with g.gradient_override_map({'Myop': 'MyopGrad'}):
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        result = sess.run(grad)

print(result[0])

我想查看打印中的所有零,但我得到的是:

[ 0.2236068   0.25000003  0.28867513  0.35355341  0.5       ]

问题答案:

您需要在以下范围内定义操作 with g.gradient_override_map({'Myop': 'MyopGrad'})

另外,您需要映射Identity而不是名称Myop到新渐变。

这是完整的代码:

import tensorflow as tf
from tensorflow.python.framework import ops

@ops.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0]
    return 0 * x  # zero out to see the difference:

def fprop(x):
    x = tf.sqrt(x)
    out = tf.maximum(x, .2)
    return out

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32))
h = fprop(a)

g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyopGrad'}):
    h = tf.identity(h, name="Myop")
    grad = tf.gradients(h, a)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    result = sess.run(grad)

print(result[0])

输出:

[ 0.  0.  0.  0.  0.]


 类似资料:
  • TensorFlow GraphDef based models (typically created via the Python API) may be saved in one of following formats: TensorFlow SavedModel Frozen Model Session Bundle Tensorflow Hub module All of above f

  • 本文向大家介绍使用Spring组合自定义的注释 mscharhag操作,包括了使用Spring组合自定义的注释 mscharhag操作的使用技巧和注意事项,需要的朋友参考一下 在本文中,我们将介绍一个非常有用的Spring功能,该功能允许我们基于一个或多个Spring注释创建自己的注释。 假设我们有一组经常一起使用的Spring注释。一个常见的示例是@Service和@Transactional的

  • 问题内容: 我一直在一些存储库中使用自定义操作。到目前为止,我只需要指定url和方法。 例如: 但是随后,我不得不编写一个自定义操作,该操作不包含一个,而是两个路径参数: 所以我首先将其编码为: 但这是行不通的。参数未传递。 经过几次尝试,我发现在自定义操作定义之前添加一些参数定义可以正常工作。 它必须像: 请注意以下情况: 当时我的理解是,在$ resource定义中,具有多个路径参数的自定义操

  • 本文向大家介绍JavaScript jQuery 中定义数组与操作及jquery数组操作,包括了JavaScript jQuery 中定义数组与操作及jquery数组操作的使用技巧和注意事项,需要的朋友参考一下 首先给大家介绍javascript jquery中定义数组与操作的相关知识,具体内容如下所示: 1.认识数组 数组就是某类数据的集合,数据类型可以是整型、字符串、甚至是对象 Javascr

  • 问题内容: 以下工作正常,但是我认为这会全局修改$ httpProvider,这不是我想要的。 反正有这样做吗? “标题”参数似乎被忽略了。请求仍然 我的标头值可以吗? 问题答案: 我已经确认1.1.3确实支持这一点。但是,您需要确保您还获得了资源服务的1.1.3版本。快速测试: 这将发出一个标头设置为(使用Chrome确认)的请求: 快速说明,我无法找到angular- resource.js的

  • 操作组合注记 除了send()和sendAsync之外,所有JSON-RPC方法在web3j中都实现了支持observable()方法来创建可观察的异步执行请求。这使得将JSON-RPC调用组合成新的函数是非常容易和直接的。 例如, blockObservable本身由许多单独的JSON-RPC调用组成: public Observable<EthBlock> blockObservable(