当前位置: 首页 > 工具软件 > Solo > 使用案例 >

SOLO代码阅读解析

姚航
2023-12-01

SOLO是一种直接预测instance mask的范式,摒弃了之前top-down和bottom-up两种主流的实例分割方法,从而pipeline更加简洁直观。这篇文章以官方代码中的demo为例,简单梳理一下SOLO在inference时的流程。整个代码基于mmdetection。

首先是demo.inference_demo.py

config_file = '../configs/solo/decoupled_solo_r50_fpn_8gpu_3x.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = '../checkpoints/DECOUPLED_SOLO_R50_3x.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')

# test a single image
img = 'demo.jpg'
result = inference_detector(model, img)

show_result_ins(img, result, model.CLASSES, score_thr=0.25, out_file="demo_out.jpg")

上述代码很简单,init_detector创建model,inference_detector做正向inference,并且show出最后的result。核心在于init_detector和inference_detector。这两个function存在于mmdet.apis中,下面看下这个模块:

mmdet.apis.inferece.py

def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        'but got {}'.format(type(config)))
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def inference_detector(model, img):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        If imgs is a str, a generator will be returned, otherwise return the
        detection results directly.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result

对于init_detector,其核心函数是build_detector,根据config文件信息创建模型,并将checkpoint加载进来;而inference_detector更简单了,首先做一系列augmentation,然后调用model做inference即可。

那么接下来仍然是两个分支,build_detector是如何创建模型的,以及该模型如何做inference,分开来说。

build_detector

build_detector方法存在于mmdet.model.builder.py

from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                       ROI_EXTRACTORS, SHARED_HEADS)


def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
        ]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)


def build_backbone(cfg):
    return build(cfg, BACKBONES)


def build_neck(cfg):
    return build(cfg, NECKS)


def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)


def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)


def build_head(cfg):
    return build(cfg, HEADS)


def build_loss(cfg):
    return build(cfg, LOSSES)


def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

build_detector方法又调用了build方法,而build方法中调用了build_from_cfg。注意:在调用build方法中传入了DETECTORS这个注册器(Registry,一个类,传入的参数该class的一个实例,每一个部分i.e. backbone,FPN etc. 都对应一个Registry实例),可以先理解为创建这些module以及分开进行管理。

接着看mmdet.utils.registry.py中的build_from_cfg:

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)

这里其实就是对注册器进行注册的部分,也就是说通过config中的字典来对模型进行搭建。obj_cls就是要创建的module,如SOLO,ResNet,FPN等等,只有某个注册器中有配置文件中存在的type时,才会对该注册器进行register,通过args中的dict得到相应的module。这里一开始obj_cls返回的是SOLO(可以refer下配置文件),所以我们要找到SOLO这个模型的文件:

mmdet.models.detectors.solo.py

@DETECTORS.register_module
class SOLO(SingleStageInsDetector):

    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SOLO, self).__init__(backbone, neck, bbox_head, None, train_cfg,
                                   test_cfg, pretrained)

可见第一行用了一个装饰器,也就是说在创建SOLO实例的时候,首先就自动调用装饰器中的方法,并且把SOLO这个类作为参数,注册到注册器DETECTORS里面。而SOLO又是继承自SingleStageInsDetector,所以接下来重点是SingleStageInsDetector类:

mmdet.models.detectors.single_stage_ins.py

@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 mask_feat_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageInsDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        if mask_feat_head is not None:
            self.mask_feat_head = builder.build_head(mask_feat_head)

        self.bbox_head = builder.build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)

上面是SingleStageInsDetector的核心代码,之前是将args作为参数传入作为这里的初始化。根据之前的config,依次创建模型的backbone,neck,bbox_head以及test_config(这里是inference),这些部分的创建又对应到builder中的函数,每一个module对应一个Registry,然后根据相应的config文件中的参数建立不同的module,最后都作为类内部变量,集中在这一个SingleStageInsDetector中。具体每一个module创建的代码就不贴了,无非是将args传递进去,根据现有的代码创建相应的模块。

至此模型的创建工作大致如此,下面来看Inference的过程。

Inference

SOLO类的forward继承自BaseDetector,其forward方法如下:

    def forward_test(self, imgs, img_metas, **kwargs):
        。。。。。。

        if num_augs == 1:
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(imgs, img_metas, **kwargs)

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_meta, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_meta, **kwargs)
        else:
            return self.forward_test(img, img_meta, **kwargs)

以单gpu为例,调用的是simple_test,这个函数在SingleStageInsDetector中被重写过,如下:

    def extract_feat(self, img):
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x
        
    def simple_test(self, img, img_meta, rescale=False):
        x = self.extract_feat(img)
        outs = self.bbox_head(x, eval=True)

        if self.with_mask_feat_head:
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
        else:
            seg_inputs = outs + (img_meta, self.test_cfg, rescale)
        seg_result = self.bbox_head.get_seg(*seg_inputs)
        return seg_result  

这里Inference的顺序依次是backbone->neck->bbox_head,backbone为ResNet50,neck为FPN,bbox_head为(decoupled)solo_head。所以前面特征提取部分的代码很简单,就不做过多赘述。主要来看下bbox_head:

mmdet.models.anchor_heads.decoupled_solo_head.py

@HEADS.register_module
class DecoupledSOLOHead(nn.Module):
    def __init__(self,
                 num_classes,
                 in_channels,
                 seg_feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 base_edge_list=(16, 32, 64, 128, 256),
                 scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
                 sigma=0.4,
                 num_grids=None,
                 cate_down_pos=0,
                 with_deform=False,
                 loss_ins=None,
                 loss_cate=None,
                 conv_cfg=None,
                 norm_cfg=None):
        super(DecoupledSOLOHead, self).__init__()
        self.num_classes = num_classes
        self.seg_num_grids = num_grids
        self.cate_out_channels = self.num_classes - 1
        self.in_channels = in_channels
        self.seg_feat_channels = seg_feat_channels
        self.stacked_convs = stacked_convs
        self.strides = strides
        self.sigma = sigma
        self.cate_down_pos = cate_down_pos
        self.base_edge_list = base_edge_list
        self.scale_ranges = scale_ranges
        self.with_deform = with_deform
        self.loss_cate = build_loss(loss_cate)
        self.ins_loss_weight = loss_ins['loss_weight']
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self._init_layers()

    def _init_layers(self):
        norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
        self.ins_convs_x = nn.ModuleList()
        self.ins_convs_y = nn.ModuleList()
        self.cate_convs = nn.ModuleList()

        for i in range(self.stacked_convs):
            #第一层+1表示采用coordconv concat上的position(如果非decouple则+2)
            chn = self.in_channels + 1 if i == 0 else self.seg_feat_channels
            # ins_x分支几个卷积+norm模块
            self.ins_convs_x.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))
            # ins_y分支几个卷积+norm模块
            self.ins_convs_y.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

            chn = self.in_channels if i == 0 else self.seg_feat_channels
            # cate分支几个卷积+norm模块
            self.cate_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

        self.dsolo_ins_list_x = nn.ModuleList()
        self.dsolo_ins_list_y = nn.ModuleList()
        #每一个level对应的num_grid不同,针对所有level的feature设计对应维度的卷积
        for seg_num_grid in self.seg_num_grids:
            self.dsolo_ins_list_x.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid, 3, padding=1))
            self.dsolo_ins_list_y.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid, 3, padding=1))
        self.dsolo_cate = nn.Conv2d(
            self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
            
     def forward(self, feats, eval=False):
#        for i in feats:
#            print(i.shape)
#        torch.Size([1, 256, 200, 304])
#        torch.Size([1, 256, 100, 152])
#        torch.Size([1, 256, 50, 76])
#        torch.Size([1, 256, 25, 38])
#        torch.Size([1, 256, 13, 19])

        new_feats = self.split_feats(feats)      
#        for i in new_feats:
#            print(i[0].shape)
#		torch.Size([256, 100, 152])
#		torch.Size([256, 100, 152])
#		torch.Size([256, 50, 76])
#		torch.Size([256, 25, 38])
#		torch.Size([256, 25, 38])

            
        featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
        upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
#        print(upsampled_size)   (200, 304)
        ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
                                                        list(range(len(self.seg_num_grids))),
                                                        eval=eval, upsampled_size=upsampled_size)
        return ins_pred_x, ins_pred_y, cate_pred

    def split_feats(self, feats):
        return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), 
                feats[1], 
                feats[2], 
                feats[3], 
                F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))

    def forward_single(self, x, idx, eval=False, upsampled_size=None):
        ins_feat = x
        cate_feat = x
        # ins branch
        # concat coord
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
        y, x = torch.meshgrid(y_range, x_range)
        y = y.expand([ins_feat.shape[0], 1, -1, -1])
        x = x.expand([ins_feat.shape[0], 1, -1, -1])
#        print(ins_feat.shape)
#        print(x.shape)
        ins_feat_x = torch.cat([ins_feat, x], 1)
        ins_feat_y = torch.cat([ins_feat, y], 1)
#        print(ins_feat_x.shape)  (1, 256 + 1, ?, ?)

        for ins_layer_x, ins_layer_y in zip(self.ins_convs_x, self.ins_convs_y):
            ins_feat_x = ins_layer_x(ins_feat_x)
            ins_feat_y = ins_layer_y(ins_feat_y)

        ins_feat_x = F.interpolate(ins_feat_x, scale_factor=2, mode='bilinear')
        ins_feat_y = F.interpolate(ins_feat_y, scale_factor=2, mode='bilinear')

        ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat_x)
        ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat_y)
#        print(ins_pred_x.shape)   对应到每个feat_map对应的grid (1,256,?,?)->(1,40/36/24/16/12,?,?)

        # cate branch
        for i, cate_layer in enumerate(self.cate_convs):
            if i == self.cate_down_pos:
                seg_num_grid = self.seg_num_grids[idx] 	# idx对应特征图的level,不同level的num_grid不同
                cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
            cate_feat = cate_layer(cate_feat)

        cate_pred = self.dsolo_cate(cate_feat)
#        print(cate_pred.shape)    (1, 80, num_grid, num_grid)

        if eval:
            ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode='bilinear')
            ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode='bilinear')
            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
        return ins_pred_x, ins_pred_y, cate_pred

上面的代码是solo_head正向传播以后得到的结果:ins_pred_x, ins_pred_y, cate_pred。但并不是完整的Inference,最终的maks生成还需要进行下面两个函数的操作:

    def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
        assert len(seg_preds_x) == len(cate_preds)
        num_levels = len(cate_preds)
#        print(num_levels)     5
        featmap_size = seg_preds_x[0].size()[-2:]
#        print(featmap_size)   [200, 304]

#        for i in range(5):
#            print(seg_preds_x[i].shape)
#            print(cate_preds[i].shape)
#       torch.Size([1, 40, 200, 304])
#		torch.Size([1, 40, 40, 80])
#		torch.Size([1, 36, 200, 304])
#		torch.Size([1, 36, 36, 80])
#		torch.Size([1, 24, 200, 304])
#		torch.Size([1, 24, 24, 80])
#		torch.Size([1, 16, 200, 304])
#		torch.Size([1, 16, 16, 80])
#		torch.Size([1, 12, 200, 304])
#		torch.Size([1, 12, 12, 80])

        result_list = []
        #由于是demo,这里只有一张img
        for img_id in range(len(img_metas)):
            cate_pred_list = [
                cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
            ]
#            print(cate_pred_list[0].shape)  (num_grid*num_grid, 80)
            seg_pred_list_x = [
                seg_preds_x[i][img_id].detach() for i in range(num_levels)
            ]
#            print(seg_pred_list_x[0].shape)    #(num_grid, 200, 304)
            seg_pred_list_y = [
                seg_preds_y[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            ori_shape = img_metas[img_id]['ori_shape']

            cate_pred_list = torch.cat(cate_pred_list, dim=0)    #(3872, 80) == (40^2+36^2+24^2+16^2+12^2, 80)
            seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0)    #(128, 200, 304) == (40+36+24+16+12, 200, 304)
#            print(seg_pred_list_x.shapes)
            seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)

            result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
                                         featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
            result_list.append(result)
        return result_list

    def get_seg_single(self,
                       cate_preds,
                       seg_preds_x,
                       seg_preds_y,
                       featmap_size,
                       img_shape,
                       ori_shape,
                       scale_factor,
                       cfg,
                       rescale=False, debug=False):


        # overall info.
        h, w, _ = img_shape
        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)   # 原图大小

        # trans trans_diff.
        trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long()    # [1600, 2896, 3472, 3728, 3872]
        trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
        seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
        strides = torch.ones(trans_size[-1].item(), device=cate_preds.device)	# [1, 1, ..., 1]

        n_stage = len(self.seg_num_grids)
        trans_diff[:trans_size[0]] *= 0
        seg_diff[:trans_size[0]] *= 0
        num_grids[:trans_size[0]] *= self.seg_num_grids[0]
#        print(self.strides)	[8, 8, 16, 32, 32]
        strides[:trans_size[0]] *= self.strides[0]

        for ind_ in range(1, n_stage):
            trans_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= trans_size[ind_ - 1]
            seg_diff[trans_size[ind_ - 1]:trans_size[ind_]] *= seg_size[ind_ - 1]
            num_grids[trans_size[ind_ - 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
            strides[trans_size[ind_ - 1]:trans_size[ind_]] *= self.strides[ind_]	# [0-1599:8, 1600-2895:8, 2896-3471: 16, 2372-3871:32]

        # process.
        inds = (cate_preds > cfg.score_thr)
#        print(inds.shape)    # [3872, 80]布尔矩阵  
        cate_scores = cate_preds[inds]
#        print(cate_scores)    # [3872, 80]

        inds = inds.nonzero()
#        print(inds.shape)  # (n, 2) n表示有多少个分数>thres
        trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
        seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
        num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
        strides = torch.index_select(strides, dim=0, index=inds[:, 0])

        y_inds = (inds[:, 0] - trans_diff) // num_grids
        x_inds = (inds[:, 0] - trans_diff) % num_grids
        y_inds += seg_diff
        x_inds += seg_diff

        cate_labels = inds[:, 1]
#        print(cate_labels)	# n维向量,表示类别num
        seg_masks_soft = seg_preds_x[x_inds, ...] * seg_preds_y[y_inds, ...]	# [n, 200, 304]
        seg_masks = seg_masks_soft > cfg.mask_thr
        sum_masks = seg_masks.sum((1, 2)).float()		# [n, 1]
        keep = sum_masks > strides		# 进一步筛除,总的mask之和小于stride就筛掉
#        print(keep)

        seg_masks_soft = seg_masks_soft[keep, ...]
        seg_masks = seg_masks[keep, ...]
        cate_scores = cate_scores[keep]
        sum_masks = sum_masks[keep]
        cate_labels = cate_labels[keep]
        # maskness
        seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
        cate_scores *= seg_score

        if len(cate_scores) == 0:
            return None

        # sort and keep top nms_pre
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.nms_pre:
            sort_inds = sort_inds[:cfg.nms_pre]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        seg_masks = seg_masks[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        sum_masks = sum_masks[sort_inds]
        cate_labels = cate_labels[sort_inds]
#        print(cate_scores)

        # Matrix NMS
        cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                 kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
#        print(cate_scores)		#维度并没变,只是将IOU高的部分的score降低,类似于soft-NMS

        keep = cate_scores >= cfg.update_thr
        seg_masks_soft = seg_masks_soft[keep, :, :]
        cate_scores = cate_scores[keep]
#        print(cate_scores.shape)		#筛掉一部分
        cate_labels = cate_labels[keep]
        # sort and keep top_k
        sort_inds = torch.argsort(cate_scores, descending=True)
        if len(sort_inds) > cfg.max_per_img:		# coco数据集最大一张img100个instance
            sort_inds = sort_inds[:cfg.max_per_img]
        seg_masks_soft = seg_masks_soft[sort_inds, :, :]
        cate_scores = cate_scores[sort_inds]
        cate_labels = cate_labels[sort_inds]

        # 将mask的resolution还原到original图像大小
        seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
                                    size=upsampled_size_out,
                                    mode='bilinear')[:, :, :h, :w]
        seg_masks = F.interpolate(seg_masks_soft,
                               size=ori_shape[:2],
                               mode='bilinear').squeeze(0)
        seg_masks = seg_masks > cfg.mask_thr

        return seg_masks, cate_labels, cate_scores

最后在demo中在Matrix NMS之后,选择的筛除阈值为0.05,这个值有点小导致很多有小目标的img筛出来100个,最后demo在展示结果的时候又采用了0.25的阈值,这里会不会有些矛盾。

清明过后补上训练部分的代码解读。

 类似资料: