tensorflow GradientTape

督建柏
2023-12-01

tensorflow 自动求导
官方API

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0

上面是是对方程 y = x 2 y=x^2 y=x2的一阶求导,即 ( y ′ = 2 x ∣ x = 3.0 ) (y' = 2x|_{x=3.0}) (y=2xx=3.0),故最后得到解为6.0。

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x
  dy_dx = gg.gradient(y, x)     # Will compute to 6.0
d2y_dx2 = g.gradient(dy_dx, x)  # Will compute to 2.0

上面包括一阶求导和二阶求导。

def f(x, y):
  output = 1.0
  for i in range(y):
    if i > 1 and i < 5:
      output = tf.multiply(output, x)
  return output

def grad(x, y):
  with tf.GradientTape() as t:
    t.watch(x)
    out = f(x, y)
  return t.gradient(out, x) 

x = tf.convert_to_tensor(2.0)

assert grad(x, 6).numpy() == 12.0
assert grad(x, 5).numpy() == 12.0
assert grad(x, 4).numpy() == 4.0

在控制流中的操作。

 类似资料:

相关阅读

相关文章

相关问答