调试mmsegmentation

唐信瑞
2023-12-01

mmsegmentation的调试我花了不少时间,我用的天池篡改检测数据,原始数据是jpg格式,label是0-255之间像素的一个png格式,这里有几个问题:1.0-255像素之间是多个值,而不是只有两类,实际上应该是两类任务。2.0-255像素应该是要0-1归一化的。

1.dataset上我选择的是cityscapes,因为看到数据集是标准的jpg+mask,但是cityscapes有171类,一个像素值其实对应了一类,天池篡改数据集不是这样的,而应该是只有0和255两类。要修改config/models下面的num_classes类别。

2.要在configs/datasets下选择cityscapes这个类,并且改数据集路径,一般1080Ti的话,samplers_per_gpu只能选到2,works选2.其次有个reduce_zero_label的配置,一般选False,这里是选择是否忽略背景类,不用做修改。

3.在core下evaluation中改class_names这个文件,这个文件容易忘,在mmclas中已经没有了,选择cityscapes改cityscapes_classes和cityscapes_palette。

4.在datsets的cityscapes中改一下classes和palette。

5.在loading的LoadAnnotations中改一下,加上二值化。

 def __call__(self, results):
        """Call function to load multiple types annotations.

        Args:
            results (dict): Result dict from :obj:`mmseg.CustomDataset`.

        Returns:
            dict: The dict contains loaded semantic segmentation annotations.
        """

        if self.file_client is None:
            self.file_client = mmcv.FileClient(**self.file_client_args)

        if results.get('seg_prefix', None) is not None:
            filename = osp.join(results['seg_prefix'],
                                results['ann_info']['seg_map'])
        else:
            filename = results['ann_info']['seg_map']
        img_bytes = self.file_client.get(filename)
        gt_semantic_seg = mmcv.imfrombytes(
            img_bytes, flag='unchanged',
            backend=self.imdecode_backend).squeeze().astype(np.uint8)
#         import pdb;pdb.set_trace()
        # modify if custom classes
        if results.get('label_map', None) is not None:
            # Add deep copy to solve bug of repeatedly
            # replace `gt_semantic_seg`, which is reported in
            # https://github.com/open-mmlab/mmsegmentation/pull/1445/
            gt_semantic_seg_copy = gt_semantic_seg.copy()
            for old_id, new_id in results['label_map'].items():
                gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
        
        self.binarize = True
        if self.binarize:
            gt_semantic_seg = (gt_semantic_seg != 0).astype(np.uint8)
            
        # reduce zero_label
        if self.reduce_zero_label:
            # avoid using underflow conversion
            gt_semantic_seg[gt_semantic_seg == 0] = 255
            gt_semantic_seg = gt_semantic_seg - 1
            gt_semantic_seg[gt_semantic_seg == 254] = 255
        results['gt_semantic_seg'] = gt_semantic_seg
        results['seg_fields'].append('gt_semantic_seg')
        return results

6.mmseg中海油一个ignore_index,默认是忽略255这个像素值,做了二值化就不用管这个了。

 类似资料: