当前位置: 首页 > 知识库问答 >
问题:

为什么我不能使用TensorArray.gather()在@tf.function?

尹昀
2023-03-14

从TensorArray读取:

def __init__(self, size):
    self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)

def get_sample(self, batch_size):
        idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)
        tf.print(idxs)
        return self.obs_buf.gather(indices=idxs),     \     # HERE IS THE ISSUE
               self.act_buf.gather(indices=idxs),     \
               self.rew_buf.gather(indices=idxs),     \
               self.obs2_buf.gather(indices=idxs),    \
               self.done_buf.gather(indices=idxs)

使用:

@tf.function
def train(self, rpm, batch_size, gradient_steps):
    for gradient_step in tf.range(1, gradient_steps + 1):
        obs, act, rew, next_obs, done = rpm.get_sample(batch_size)

        with tf.GradientTape() as tape:
        ...

问题:

回溯(最近一次调用last):RLU培训中第130行的文件“\main.py”。train()文件“C:\Users\user\Documents\Projects\rl toolkit\rl_training.py”,第129行,在train self中_rpm,赛尔夫。批量大小,自行确定。梯度步数,记录步数b=self。在call result=self中记录文件“C:\Users\user\AppData\Local\Programs\Python\36\lib\site packages\tensorflow\Python\eager\def\u function.py”,第828行_调用(*args,**kwds)文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site packages\tensorflow\Python\eager\def_function.py”,第871行,在调用self中_初始化(args,kwds,add_initializers_to=initializers)文件“C:\Users\user\AppData\Local\Programs\Python36\lib\site packages\tensorflow\Python\eager\def\function.py”,第726行,在_initialize*args,**kwds)文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site packages\tensorflow\Python\eager\function.py”,第2969行,在"获取"具体"函数"内部"垃圾"收集图"函数中,"自我"_可能定义函数(args,kwargs)文件“C:\Users\user\AppData\Local\Programs\Python36\lib\site packages\tensorflow\Python\eager\function.py”,第3361行,在可能定义函数图\u function=self中_创建图函数(args,kwargs)文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site packages\tensorflow\Python\eager\function.py”,第3206行,在创建图函数捕获中,按值=self_按值捕获),文件“C:\Users\user\AppData\Local\Programs\Python\36\lib\site packages\tensorflow\Python\framework\func\graph.py”,第990行,func-graph\u-from\u-py-func-func\u-outputs=Python\u-func(*-func参数,**func\u-kwargs)文件“C:\Users\user\AppData\Local\Programs\Python\Python\Python36\lib\site packages\site-function\tensorflow\Python\Python\def\def\func\func\fu,第634行,in-wrapped_fn out=弱_wrapped_fn()。wrapped(*args,**kwds)文件“C:\Users\user\AppData\Local\Programs\Python36\lib\site packages\tensorflow\Python\eager\function.py”,第3887行,在绑定方法包装返回wrapped\u fn(*args,**kwargs)文件“C:\Users\user\AppData\Local\Programs\Python36\lib\site packages\tensorflow\Python\framework\func\graph.py”,第977行,在包装器中生成e.ag\u错误\u元数据。例外情况(e)tensorflow。python框架错误。OperatorNotAllowedInGraphError:在用户代码中:

C:\Users\user\Documents\Projects\rl-toolkit\policy\sac\sac.py:183 update  *
    obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
C:\Users\user\Documents\Projects\rl-toolkit\utils\replay_buffer.py:39 __call__  *
    return self.obs_buf.gather(indices=idxs),                    self.act_buf.gather(indices=idxs),                    self.rew_buf.gather(indices=idxs),                    self.obs2_buf.gather(indices=idxs),                   self.done_buf.gather(indices=idxs)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:1190 gather  **
    return self._implementation.gather(indices, name=name)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:861 gather
    return array_ops.stack([self._maybe_zero(i) for i in indices])
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:505 __iter__
    self._disallow_iteration()
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:498 _disallow_iteration
    self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:476 _disallow_when_autograph_enabled
    " indicate you are trying to use an unsupported feature.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

为什么我不能在这种情况下使用Tensorary?我还有什么选择?

共有1个答案

公冶智刚
2023-03-14

在这里解决了。必须使用tf。变量而不是tf。TensorArray.

 类似资料:
  • 我正在尝试使用文件系统。我的< code>CMakeLists.txt中有< code>-std=c 11 -std=c 1y。GCC版本为4.9.2。然而,我得到了一个错误: 使用的正确方法是什么?

  • 我试图使用Java8Javadoc工具,但它抱怨是一个未知标记: 我看到有一些方法可以禁用doclint,但我真的想知道哪些标签列表被支持(或者为什么这个不支持)。 更多信息在这个问题,这个问题和从这个博文。

  • 问题内容: 如果html文件是本地文件(在我的C驱动器上),则可以使用,但是如果html文件在服务器上并且图像文件是本地文件,则无法使用。这是为什么? 任何可能的解决方法? 问题答案: 如果客户端可以请求本地文件系统文件,然后使用JavaScript找出其中的内容,则将是一个安全漏洞。 解决此问题的唯一方法是在浏览器中构建扩展。Firefox扩展和IE扩展可以访问本地资源。Chrome的限制更为严

  • 我想在spring boot项目中使用“truncate table”语句而不是“delete”语句,因为我需要重置mysql中的自动增量id。这是我的代码: 但有一个例外,就像这样: 其他操作已经工作,如插入、更新或选择,原因是什么,我应该修改什么?

  • 我无法理解为什么以下操作不起作用? 我一直收到错误: 我尝试使用通用参数: 我好像想不出来。 我完全错过了什么? 我正在使用OpenJDK 11。

  • 我试图隐藏Actionbar并改用工具栏,但如果我更改Theme.appcompat.light.NoActionBar,它确实会隐藏,但应用程序在行setContentView(r.layout.activity_main)处崩溃; 尽管它确实使用getSupportActionBar()隐藏。hide();但我不能用这个代码 Toolbar Toolbar=(Toolbar)findViewB