当前位置: 首页 > 工具软件 > ​DICE > 使用案例 >

损失函数:DiceLoss与Dice系数

西门良才
2023-12-01

Dice系数

  Dice系数,是一种集合相似度度量函数,通常用于计算两个样本点的相似度(值范围为[0, 1])。用于分割问题,分割最好时为1,最差为0。及用于解决样本不均衡的问题,但不稳定,容易出现梯度爆炸(?)。
dice系数越大,DiceLoss(在Dice系数的基础上进行计算,用1去减Dice系数,即 D i c e L o s s = 1 − D i c e DiceLoss = 1 - Dice DiceLoss=1Dice)越小,表明样本集合越相似。

  1. dice系数计算公式:

其中pred为预测值的集合,true为真实值的集合,分子为pred和true之间的交集,乘以2是因为分母存在重复计算pred和true之间的共同元素的原因。分母为pred和true的并集。

D i c e = 2 ∗ ( p r e d ⋂ t r u e ) p r e d ⋃ t r u e = 2 ∗ i n t e r s e c t i o n u n i o n Dice = \frac{2*(pred \bigcap true)}{pred \bigcup true} = \frac {2*intersection} {union} Dice=predtrue2(predtrue)=union2intersection

即2倍交集除以并集,加smooth防止分母为0的情况。

  • intersection近似为pred与label之间的点乘,并将结果元素相加。
  • union有人直接用简单的相加近似代替,也有用平方求和来近似代替。
  1. dice系数另一种计算形式:
    D i c e = 2 ∗ T P F P + 2 ∗ T P + F N Dice = \frac{2*TP}{FP+2*TP+FN} Dice=FP+2TP+FN2TP

DiceLoss代码实现

#Dice系数
def dice_coeff(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
 
    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

#Dice损失函数
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.epsilon = 1e-5
    
    def forward(self, predict, target):
        assert predict.size() == target.size(), "the size of predict and target must be equal."
        num = predict.size(0)
        
        pre = torch.sigmoid(predict).view(num, -1)
        tar = target.view(num, -1)
        
        intersection = (pre * tar).sum(-1).sum()  #利用预测值与标签相乘当作交集
        union = (pre + tar).sum(-1).sum()
        
        score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)
        
        return score

loss = DiceLoss()
predict = torch.randn(3, 4, 4)
target = torch.randn(3, 4, 4)

score = loss(predict, target)
print(score)

#BiseNet中的DiceLoss代码
import torch.nn as nn
import torch
import torch.nn.functional as F

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    
    C = tensor.size(1)        #获得图像的维数
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))     
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)                 #将维数的数据转换到第一位
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)              


class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-5

    def forward(self, output, target):
        assert output.size() == target.size(), "'input' and 'target' must have the same shape"
        output = F.softmax(output, dim=1)
        output = flatten(output)
        target = flatten(target)
        # intersect = (output * target).sum(-1).sum() + self.epsilon
        # denominator = ((output + target).sum(-1)).sum() + self.epsilon

        intersect = (output * target).sum(-1)
        denominator = (output + target).sum(-1)
        dice = intersect / denominator
        dice = torch.mean(dice)
        return 1 - dice
        # return 1 - 2. * intersect / denominator

参考:
https://zhuanlan.zhihu.com/p/349046748
https://zhuanlan.zhihu.com/p/144582930

 类似资料: