有时,需要对某些 label 做 mask
#!/usr/bin/env python
# coding=utf-8
"""
tf version: 1.15.0
"""
import tensorflow as tf
# 维度 [batch_size, 1]
label1 = tf.constant([[0.0],
[1.0],
[1.0]])
label2 = tf.constant([[1.0],
[0.0],
[1.0]])
loss1 = tf.constant([[0.1],
[0.2],
[0.3]])
loss2 = tf.constant([[0.4],
[0.5],
[0.6]])
stack_loss = tf.stack([loss1, loss2], axis=1)
"""
# 如果使用 logical_or,即 [label1 or label2] 的关系
label_or = tf.cast(tf.logical_or(tf.cast(label1, dtype=tf.bool), tf.cast(label2, dtype=tf.bool)), dtype=tf.float32)
stack_label = tf.stack([label_or, label_or], axis=1)
最终结果
('loss_1: ', 0.6) # 1.0 * 0.1 + 1.0 * 0.2 + 1.0 * 0.3 = 0.6
('loss_2: ', 1.5) # 1.0 * 0.4 + 1.0 * 0.5 + 1.0 * 0.6 = 1.5
"""
stack_label = tf.stack([label1, label2], axis=1)
losses_masked = tf.multiply(stack_loss, stack_label)
loss_1 = tf.reduce_sum(losses_masked[:, 0])
loss_2 = tf.reduce_sum(losses_masked[:, 1])
sess = tf.Session()
print("loss_1: ", sess.run(loss_1)) # 1.0 * 0.2 + 1.0 * 0.3 = 0.5
print("loss_2: ", sess.run(loss_2)) # 1.0 * 0.4 + 1.0 * 0.6 = 1.0
输出:
('loss_1: ', 0.5)
('loss_2: ', 1.0)