当前位置: 首页 > 工具软件 > binary_log > 使用案例 >

pytorch中loss_functions——binary_cross_entropy

曹茂材
2023-12-01

首先来看下binary_cross_entropy的函数形式:

torch.nn.functional.binary_cross_entropy(input, target, weight=None, size_average=None, reduce=None, reduction='mean')

Parameters

  • input – Tensor of arbitrary shape
  • target – Tensor of the same shape as input
  • weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape
  • size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True
  • reduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True
  • reduction (string*,* optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

size_average和reduce这两个参数正在被弃用,使用默认值即可

理论

reduction为none时,二分类交叉熵损失函数公式如下:

ℓ ( x , y ) = L = { l 1 , … , l N } ⊤ , l n = − w n [ y n ⋅ log ⁡ x n + ( 1 − y n ) ⋅ log ⁡ ( 1 − x n ) ] \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] (x,y)=L={l1,,lN},ln=wn[ynlogxn+(1yn)log(1xn)]

其中N表示batch_size,若reduction不为none时,
ℓ ( x , y ) = { mean ⁡ ( L ) , if reduction = ‘mean’; sum ⁡ ( L ) , if reduction = ‘sum’. \ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} \end{cases} (x,y)={mean(L),sum(L),if reduction=‘mean’;if reduction=‘sum’.

目标应该是0-1之间的值

Example1:

import torch
import torch.nn.functional as F

torch.manual_seed(0)
input = torch.randn((2, 2), requires_grad=True)
target = torch.rand((2, 2), requires_grad=False)
input = F.sigmoid(input)
# loss 0.8580
loss = F.binary_cross_entropy(input, target)
# loss_compare 0.8580
loss_compare =  torch.sum(-(target * torch.log(input) + (1 - target) * torch.log(1 - input))) / 4

Example1:

import torch
import torch.nn.functional as F

torch.manual_seed(0)
input = torch.randn((2, 2), requires_grad=True)
target = torch.rand((2, 2), requires_grad=False)
input = F.sigmoid(input)
weight = torch.rand(2, 2)
# loss 0.2031
loss = F.binary_cross_entropy(input, target)
# loss_compare 0.2031
loss_compare =  torch.sum(-(target * torch.log(input) + (1 - target) * torch.log(1 - input))) / 4
 类似资料: