Dice系数,是一种集合相似度度量函数,通常用于计算两个样本点的相似度(值范围为[0, 1])。用于分割问题,分割最好时为1,最差为0。及用于解决样本不均衡的问题,但不稳定,容易出现梯度爆炸(?)。
dice系数越大,DiceLoss(在Dice系数的基础上进行计算,用1去减Dice系数,即
D
i
c
e
L
o
s
s
=
1
−
D
i
c
e
DiceLoss = 1 - Dice
DiceLoss=1−Dice)越小,表明样本集合越相似。
其中pred为预测值的集合,true为真实值的集合,分子为pred和true之间的交集,乘以2是因为分母存在重复计算pred和true之间的共同元素的原因。分母为pred和true的并集。
D i c e = 2 ∗ ( p r e d ⋂ t r u e ) p r e d ⋃ t r u e = 2 ∗ i n t e r s e c t i o n u n i o n Dice = \frac{2*(pred \bigcap true)}{pred \bigcup true} = \frac {2*intersection} {union} Dice=pred⋃true2∗(pred⋂true)=union2∗intersection
即2倍交集除以并集,加smooth防止分母为0的情况。
#Dice系数
def dice_coeff(pred, target):
smooth = 1.
num = pred.size(0)
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten
intersection = (m1 * m2).sum()
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
#Dice损失函数
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
self.epsilon = 1e-5
def forward(self, predict, target):
assert predict.size() == target.size(), "the size of predict and target must be equal."
num = predict.size(0)
pre = torch.sigmoid(predict).view(num, -1)
tar = target.view(num, -1)
intersection = (pre * tar).sum(-1).sum() #利用预测值与标签相乘当作交集
union = (pre + tar).sum(-1).sum()
score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)
return score
loss = DiceLoss()
predict = torch.randn(3, 4, 4)
target = torch.randn(3, 4, 4)
score = loss(predict, target)
print(score)
#BiseNet中的DiceLoss代码
import torch.nn as nn
import torch
import torch.nn.functional as F
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
C = tensor.size(1) #获得图像的维数
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order) #将维数的数据转换到第一位
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.contiguous().view(C, -1)
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
self.epsilon = 1e-5
def forward(self, output, target):
assert output.size() == target.size(), "'input' and 'target' must have the same shape"
output = F.softmax(output, dim=1)
output = flatten(output)
target = flatten(target)
# intersect = (output * target).sum(-1).sum() + self.epsilon
# denominator = ((output + target).sum(-1)).sum() + self.epsilon
intersect = (output * target).sum(-1)
denominator = (output + target).sum(-1)
dice = intersect / denominator
dice = torch.mean(dice)
return 1 - dice
# return 1 - 2. * intersect / denominator
参考:
https://zhuanlan.zhihu.com/p/349046748
https://zhuanlan.zhihu.com/p/144582930