当前位置: 首页 > 工具软件 > Torchnet > 使用案例 >

TorchNet ConfusionMeter 解析

南宫才艺
2023-12-01

TorchNet ConfusionMeter 解析

最近在看陈云大佬的simple-faster-rcnn代码,在阅读代码的过程中发现在衡量rpn,roi训练指标过程中有用到tnt的ConfusionMeter类模块,于是查看了源码,觉得可以记录一下实现方式。

ConfusionMeter类方法

  • init

    def __init__(self, k, normalized=False):
        super(ConfusionMeter, self).__init__()
        self.conf = np.ndarray((k, k), dtype=np.int32)
        self.normalized = normalized
        self.k = k
        self.reset()
    

    self.conf: 初始化了一个多类的混淆矩阵,不支持多标签多分类问题

    self.k: 分类问题中的类别数

    self.normalized: 用于在value方法输出混淆矩阵时判断是否做归一化(flag)

  • reset

    def reset(self):
    	self.conf.fill(0)
    

    reset方法用于将混淆矩阵重新赋值为全0

  • add

        def add(self, predicted, target):
            """Computes the confusion matrix of K x K size where K is no of classes
    
            Args:
                predicted (tensor): Can be an N x K tensor of predicted scores obtained from
                    the model for N examples and K classes or an N-tensor of
                    integer values between 0 and K-1.
                target (tensor): Can be a N-tensor of integer values assumed to be integer
                    values between 0 and K-1 or N x K tensor, where targets are
                    assumed to be provided as one-hot vectors
    
            """
            predicted = predicted.cpu().numpy()
            target = target.cpu().numpy()
    
            assert predicted.shape[0] == target.shape[0], \
                'number of targets and predicted outputs do not match'
    
            if np.ndim(predicted) != 1:
                assert predicted.shape[1] == self.k, \
                    'number of predictions does not match size of confusion matrix'
                predicted = np.argmax(predicted, 1)
            else:
                assert (predicted.max() < self.k) and (predicted.min() >= 0), \
                    'predicted values are not between 1 and k'
    
            onehot_target = np.ndim(target) != 1
            if onehot_target:
                assert target.shape[1] == self.k, \
                    'Onehot target does not match size of confusion matrix'
                assert (target >= 0).all() and (target <= 1).all(), \
                    'in one-hot encoding, target values should be 0 or 1'
                assert (target.sum(1) == 1).all(), \
                    'multi-label setting is not supported'
                target = np.argmax(target, 1)
            else:
                assert (predicted.max() < self.k) and (predicted.min() >= 0), \
                    'predicted values are not between 0 and k-1'
    
            # hack for bincounting 2 arrays together
            x = predicted + self.k * target
            bincount_2d = np.bincount(x.astype(np.int32),
                                      minlength=self.k ** 2)
            assert bincount_2d.size == self.k ** 2
            conf = bincount_2d.reshape((self.k, self.k))
    
            self.conf += conf
    

    add方法中大部分都是在判断输入的数组pred和target维度是否符合单标签one-hot编码,其中最重要的部分拆出来记录在下:

            predicted = np.argmax(predicted, 1)
            target = np.argmax(target, 1)
            
            x = predicted + self.k * target
            bincount_2d = np.bincount(x.astype(np.int32),
                                      minlength=self.k ** 2)
            assert bincount_2d.size == self.k ** 2
            conf = bincount_2d.reshape((self.k, self.k))
    
            self.conf += conf
    

    默认predicted和target为两个二维数组,且shape需要满足(1, self.k),通过argmax取得predicted和target在dim=1位置的最大值,即为一个一维向量,数组构成为各自dim1位置的最大值的索引值(索引值用于计算在混淆矩阵中的位置)。

    x = predicted + self.k * target中,self.k * target用于定位k*k矩阵对应类别行数,即在真实分类的情况下,分类在某一类中(对应混淆矩阵行数),如果分类是正确的,那么predicted + self.k * target对应位置就会分布在混淆矩阵对角线,否则就会偏离对角线位置。

    bincount_2d用于统计在 k ∗ k k*k kk的一维向量上相同数据出现的次数,并且定义数据长度为 k ∗ k k*k kk;之后将bincount_2d一维数组再reshape回 k ∗ k k*k kk的矩阵。

    最后将计算好的矩阵和原混淆矩阵相加。

    即完成一次混淆矩阵计算

  • value

    def value(self):
        """
        Returns:
            Confustion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        if self.normalized:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf

​ value方法将混淆矩阵可视化输出,并将混淆矩阵中的数据做归一化输出。

 类似资料: