当前位置: 首页 > 工具软件 > F.I.S > 使用案例 >

一文搞懂F.binary_cross_entropy以及weight参数

景安翔
2023-12-01

相信有很多人在用pytorch做深度学习的时候,可能只是知道模型中用的是F.binary_cross_entropy或者F.cross_entropy,但是从来没有想过这两者的区别,即使知道这两者是分别在什么情况下使用的,也没有想过它们在pytorch中是如何具体实现的。在另一篇文章中介绍了F.cross_entropy()的具体实现,所以本文将介绍F.binary_cross_entropy的具体实现。
当你分别了解了它们在pytorch中的具体实现,也就自然知道它们的区别以及应用场景了。

1、pytorch对BCELoss的官方解释
在自己实现F.binary_cross_entropy之前,我们首先得看一下pytorch的官方实现,下面是pytorch官方对BCELoss类的描述:
在目标和输出之间创建一个衡量二进制交叉熵的标准。the unreduced loss(如:reduction属性被设置为none) 的数学表达式为:
l n = − w n [ y n ⋅ log ⁡ x n + ( 1 − y n ) ⋅ log ⁡ ( 1 − x n ) ] \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ln=wn[ynlogxn+(1yn)log(1xn)]
其中,N表示batch size,如果reduction is not none(reduction的默认是‘mean’)时的表达式为:
ℓ ( x , y ) = { mean ⁡ ( L ) , if reduction = ’mean’; sum ⁡ ( L ) , if reduction = ’sum’. \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases} (x,y)={mean(L),sum(L),if reduction=’mean’;if reduction=’sum’.


补充:targets也就是表达式中的y应该是0-1之间的数,Xn不能为0或1,如果Xn是0或者1,也就意味着log(Xn)或者log(1-Xn)中的一项没有意义,pytorch中对log(0)作出的定义如下,也是数学上对log(0)的定义:
log ⁡ ( 0 ) = − ∞ , lim ⁡ x → 0 log ⁡ ( x ) = − ∞ \log (0) = -\infty,\lim_{x\to 0} \log (x) = -\infty log(0)=x0limlog(x)=
然而,由于一些原因,无穷项在在损失函数中无法表述。举个例子:如果Yn=0或者1-Yn=0,我们就会用0乘上无穷。而且如果我们有一个无穷的损失值,我们在计算梯度的时候也会是一个无穷,也是因为数学上的定义:
lim ⁡ x → 0 d d x log ⁡ ( x ) = ∞ \lim_{x\to 0} \frac{d}{dx} \log (x) = \infty x0limdxdlog(x)=
而且会导致BECLoss的反向传播方法非线性。对于上述可能会出现的问题,pytorch官方给出的解决方案是限制log函数的输出大于等于-100,这样的话就可以得到一个有限的损失值,以及线性的反向传播方法。下面写个代码测试一下pytorch限制log函数输出的机制:

print(np.log(1e-50))
input = torch.tensor([1e-50])
target = torch.tensor([1.0])
print(F.binary_cross_entropy(input, target))
print(torch.log(input))
# 输出
-115.12925464970229
tensor(100.)
tensor([-inf])

首先,我们取一个数让其log运算后的值小于-100,发现F.binary_cross_entropy中的计算结果为100,而torch.log()的计算结果为负无穷,原因在于pytorch官方实现的F.binary_cross_entropy对log输出做了限制。大家不要对100感到疑惑呀,为什么不是-100,那是因为损失函数计算的时候前面有个负号。


2、pytorch的官方实现
input的维度(N,*),其中*表示可以是任何维度。target和input的维度需一致。OK,其实最关键的还是上面的数学表达式,知道了表达式也就可以简单实现二值交叉熵了。

input = torch.rand(1, 3, 3)
target = torch.rand(1, 3, 3).random_(2)
print(input)
print(target)
input = torch.sigmoid(input)
output = torch.nn.functional.binary_cross_entropy(input, target)
print(output)

输出:

# input
tensor([[[0.7266, 0.9478, 0.3987],
         [0.4134, 0.1654, 0.0298],
         [0.1266, 0.1153, 0.0549]]])
# target
tensor([[[0., 1., 1.],
         [1., 0., 0.],
         [0., 0., 0.]]])
# output
tensor(0.6877)

3、根据公式自己实现

class binary_ce_loss(torch.nn.Module):
    def __init__(self):
        super(binary_ce_loss, self).__init__()

    def forward(self, input, target):
        input = input.view(input.shape[0], -1)
        target = target.view(target.shape[0], -1)
        loss = 0.0
        for i in range(input.shape[0]):
            for j in range(input.shape[1]):
            	loss += -(target[i][j] * torch.log(input[i][j]) + (1 - target[i][j]) * torch.log(1 - input[i][j]))
        return loss/(input.shape[0]*input.shape[1]) # 默认取均值

input和target的维度需相同,上述的例子中,它们的维度均是[1,3,3],我们可以把1看作batchsize的大小,3*3看作是图片的大小。首先将shape变成[1,3*3],然后按照公式计算每一个batchsize的损失,再求和,最后按照pytorch官方默认的方式求平均,即可大功告成。
4、weight参数含义
在写代码的过程中,我们会发现F.binary_cross_entropy中还有一个参数weight,它的默认值是None,估计很多人不知道weight参数怎么作用的,下面简单的分析一下:
首先,看一下pytorch官方对weight给出的解释,if provided it’s repeated to match input tensor shape,就是给出weight参数后,会将其shape和input的shape相匹配。回忆公式:
l n = − w n [ y n ⋅ log ⁡ x n + ( 1 − y n ) ⋅ log ⁡ ( 1 − x n ) ] \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ln=wn[ynlogxn+(1yn)log(1xn)]
默认情况,也就是weight=None时,上述公式中的Wn=1;当weight!=None时,也就意味着我们需要为每一个样本赋予权重Wi,这样weight的shape和input一致就很好理解了。
首先看pytorch中weight参数作用后的结果:

input = torch.rand(3, 3)  
target = torch.rand(3, 3).random_(2)
print(input)
print(target)
w = [0.1, 0.9] # 标签0和标签1的权重
weight = torch.zeros(target.shape)  # 权重矩阵
for i in range(target.shape[0]):
    for j in range(target.shape[1]):
        weight[i][j] = w[int(target[i][j])]
print(weight)
loss = F.binary_cross_entropy(input, target, weight=weight)
print(loss)
"""
# input
tensor([[0.1531, 0.3302, 0.7537],
        [0.2200, 0.6875, 0.2268],
        [0.5109, 0.5873, 0.9275]])
# target
tensor([[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]])
# weight
tensor([[0.9000, 0.1000, 0.1000],
        [0.1000, 0.1000, 0.9000],
        [0.1000, 0.9000, 0.1000]])
# loss
tensor(0.4621)
"""

通过下面的代码再次验证weight是如何作用的,weight就是为每一个样本加权

class binary_ce_loss(torch.nn.Module):
    def __init__(self):
        super(binary_ce_loss, self).__init__()

    def forward(self, input, target, weight=None):
        input = input.view(input.shape[0], -1)
        target = target.view(target.shape[0], -1)
        loss = 0.0
        for i in range(input.shape[0]):
            for j in range(input.shape[1]):
        	    loss += -weight[i][j] * (target[i][j] * torch.log(input[i][j]) + (1 - target[i][j]) * torch.log(1 - input[i][j]))
        return loss/(input.shape[0]*input.shape[1]) # 默认取均值
myloss = binary_ce_loss()
print(myloss(input, target, weight=weight))
"""
# myloss
tensor(0.4621)
"""

pytorch官方的代码和自己实现的计算出的损失一致,再次说明binary_cross_entropy的weight权重会分别对应的作用在每一个样本上。
5、总结
看源码是最直接有效的手段。 留个彩蛋,下篇文章讲balanced_cross_entropy,解决样本之间的不平衡问题。

注:如有错误还请指出!

 类似资料: