首先来看下binary_cross_entropy的函数形式:
torch.nn.functional.binary_cross_entropy(input, target, weight=None, size_average=None, reduce=None, reduction='mean')
Parameters
reduction
). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average
is set to False
, the losses are instead summed for each minibatch. Ignored when reduce is False
. Default: True
reduction
). By default, the losses are averaged or summed over observations for each minibatch depending on size_average
. When reduce
is False
, returns a loss per batch element instead and ignores size_average
. Default: True
'none'
| 'mean'
| 'sum'
. 'none'
: no reduction will be applied, 'mean'
: the sum of the output will be divided by the number of elements in the output, 'sum'
: the output will be summed. Note: size_average
and reduce
are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction
. Default: 'mean'
size_average和reduce这两个参数正在被弃用,使用默认值即可
reduction为none时,二分类交叉熵损失函数公式如下:
ℓ ( x , y ) = L = { l 1 , … , l N } ⊤ , l n = − w n [ y n ⋅ log x n + ( 1 − y n ) ⋅ log ( 1 − x n ) ] \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ℓ(x,y)=L={l1,…,lN}⊤,ln=−wn[yn⋅logxn+(1−yn)⋅log(1−xn)]
其中N表示batch_size,若reduction不为none时,
ℓ
(
x
,
y
)
=
{
mean
(
L
)
,
if reduction
=
‘mean’;
sum
(
L
)
,
if reduction
=
‘sum’.
\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases}
ℓ(x,y)={mean(L),sum(L),if reduction=‘mean’;if reduction=‘sum’.
目标应该是0-1之间的值
Example1:
import torch
import torch.nn.functional as F
torch.manual_seed(0)
input = torch.randn((2, 2), requires_grad=True)
target = torch.rand((2, 2), requires_grad=False)
input = F.sigmoid(input)
# loss 0.8580
loss = F.binary_cross_entropy(input, target)
# loss_compare 0.8580
loss_compare = torch.sum(-(target * torch.log(input) + (1 - target) * torch.log(1 - input))) / 4
Example1:
import torch
import torch.nn.functional as F
torch.manual_seed(0)
input = torch.randn((2, 2), requires_grad=True)
target = torch.rand((2, 2), requires_grad=False)
input = F.sigmoid(input)
weight = torch.rand(2, 2)
# loss 0.2031
loss = F.binary_cross_entropy(input, target)
# loss_compare 0.2031
loss_compare = torch.sum(-(target * torch.log(input) + (1 - target) * torch.log(1 - input))) / 4