当前位置: 首页 > 工具软件 > binary_log > 使用案例 >

【pytorch】Softmax,LogSoftmax,CrossEntropyLoss,NLLLoss,F.cross_entropy, F.binary_cross_entropy傻傻分不清楚?

牟波
2023-12-01

一句话:

Softmax 后接 CrossEntropyLoss,
LogSoftmax 后接 NLLLoss
F.cross_entropy 内含 Softmax
F.binary_cross_entropy 不含 Softmax

理由

Softmax 之后,得到预测概率分布 q i \red{q_i} qi,根据交叉熵公式可计算得到和真实分布 p i \blue{p_i} pi 之间的损失:
L C E ( p , q ) = − ∑ i p i log ⁡ q i L_{CE}(\blue{p},\red{q})=-\sum_i{\blue{p_i} \log{\red{q_i}}} LCE(p,q)=ipilogqi

而 LogSoftmax 之后,得到预测概率分布的对数 log ⁡ q i \red{\log{q_i}} logqi,负对数似然损失就是将两个分布按位相乘取反,当我们输入的是 log ⁡ q i \red{\log{q_i}} logqi 时,得到:
L N L ( p , log ⁡ q ) = − ∑ i p i log ⁡ q i L_{NL}(\blue{p},\red{\log{q}})=-\sum_i{\blue{p_i} \red{\log{q_i}}} LNL(p,logq)=ipilogqi

可以看到,这两种搭配方法计算的结果是一致的。


Softmax 后接 CrossEntropyLoss,LogSoftmax 后接 NLLLoss 的区别

使用 LogSoftmax + NLLLoss 的优点有:

数值稳定性

在输入数值较大时,使用 Softmax 会导致溢出,而 LogSoftmax 则不会。

加速模型训练

对数运算时求导更容易,加快了反向传播的速度。

在实际应用中,如果发现模型训练过程中出现数值稳定性问题,或者需要加速训练过程,使用 LogSoftmax + NLLLoss 可能是一种较优的选择。

具体的解决溢出的原理:https://zhuanlan.zhihu.com/p/570141358


F.cross_entropy 和 F.binary_cross_entropy

F.cross_entropy 和 F.binary_cross_entropy 都是 PyTorch 中的交叉熵损失函数。

F.cross_entropy 用于多分类问题,它计算了预测值与真实值之间的交叉熵。

F.binary_cross_entropy 则用于二分类问题,它计算了预测值与真实值之间的二进制交叉熵。

如果输入的第一个参数是介于 0,1 之间的概率值****(经过softmax),建议使用 F.binary_cross_entropy

如果第一个参数没有经过softmax,也就是说这是一个未经过归一化的概率分布,那么还是建议使用 F.cross_entropy,这个函数会在内部进行 softmax 归一化,将未归一化的概率分布转换为归一化的概率分布。

如果使用F.binary_cross_entropy,那么需要在输入之前进行 softmax 归一化, 以便能够计算二进制交叉熵。

 类似资料: