当前位置: 首页 > 工具软件 > photo4cat > 使用案例 >

CV系列学习——基于yolov4将检测出的目标裁剪出来并保存到本地

章锦
2023-12-01

1 前言

基于项目的功能需求,需要将yolov4检测出的目标逐个裁剪出来,并保存到本地。

2 使用yolov4完成目标检测

本项目使用的是yolov4的PyTorch实现版本,源码可参见https://github.com/Tianxiaomo/pytorch-YOLOv4

3 目标裁剪

基本思路为:将检测出的各目标的四个坐标(top、left、bottom、right)暂存,然后使用opencv根据坐标将各目标裁剪出来,最后保存到本地。

3.1 暂存各目标的坐标

参看源码yolo.py中的该行代码:

            boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)

并将其返回以供后续使用:

return image,boxes

具体参看:

#-------------------------------------#
#       创建YOLO类
#-------------------------------------#
import colorsys
import os

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
from torch.autograd import Variable

from nets.yolo4 import YoloBody
from utils.utils import (DecodeBox, bbox_iou, letterbox_image,
                         non_max_suppression, yolo_correct_boxes)

from IPython import embed

#--------------------------------------------#
#   使用自己训练好的模型预测需要修改2个参数
#   model_path和classes_path都需要修改!
#   如果出现shape不匹配,一定要注意
#   训练时的model_path和classes_path参数的修改
#--------------------------------------------#
class YOLO(object):
    _defaults = {
        "model_path"        : 'model_data/yolo4_weights.pth',
        "anchors_path"      : 'model_data/yolo_anchors.txt',
        "classes_path"      : 'model_data/coco_classes.txt',
        "model_image_size"  : (416, 416, 3),#这里的model_image_size是什么,不会跟图像size产生冲突吗,为什么不可以改????
        "confidence"        : 0.3,
        "iou"               : 0.3,
        "cuda"              : False
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化YOLO
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        self.class_names = self._get_class()
        self.anchors = self._get_anchors()
        self.generate()

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names
    
    #---------------------------------------------------#
    #   获得所有的先验框
    #---------------------------------------------------#
    def _get_anchors(self):
        anchors_path = os.path.expanduser(self.anchors_path)
        with open(anchors_path) as f:
            anchors = f.readline()
        anchors = [float(x) for x in anchors.split(',')]
        return np.array(anchors).reshape([-1, 3, 2])[::-1,:,:]

    #---------------------------------------------------#
    #   生成模型
    #---------------------------------------------------#
    def generate(self):
        #---------------------------------------------------#
        #   建立yolov4模型
        #---------------------------------------------------#
        self.net = YoloBody(len(self.anchors[0]), len(self.class_names)).eval()

        #---------------------------------------------------#
        #   载入yolov4模型的权重
        #---------------------------------------------------#
        print('Loading weights into state dict...')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        state_dict = torch.load(self.model_path, map_location=device)
        self.net.load_state_dict(state_dict)
        print('Finished!')
        
        if self.cuda:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

        #---------------------------------------------------#
        #   建立三个特征层解码用的工具
        #---------------------------------------------------#
        self.yolo_decodes = []# 创建数组,将三个解码器放到数组中
        for i in range(3):
            self.yolo_decodes.append(DecodeBox(self.anchors[i], len(self.class_names),  (self.model_image_size[1], self.model_image_size[0])))


        print('{} model, anchors, and classes loaded.'.format(self.model_path))
        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        # embed()
        image_shape = np.array(np.shape(image)[0:2])
        num_class = len(self.class_names)# 有80类
        # embed()

        #---------------------------------------------------------#
        #   给图像增加灰条(什么是灰条),实现不失真的resize
        #---------------------------------------------------------#
        # 复制image return new_image
        crop_img = np.array(letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])))
        photo = np.array(crop_img,dtype = np.float32) / 255.0# 归一化?
        photo = np.transpose(photo, (2, 0, 1))# 转置:将Image.open(img)得到的[H,W,C]格式转换permute为pytorch可以处理的[C,H,W]格式
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        images = [photo]# 将photo变为list类型

        with torch.no_grad():# disabled gradient calculation,reduce memory consumption for computations
            images = torch.from_numpy(np.asarray(images))# Creates a Tensor from a numpy.ndarray,此时images的shape为[1, 3, 416, 416]
            if self.cuda:
                images = images.cuda()

            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            # embed()
            # 从这里开始处理
            # 特征提取
            # 输出outputs为tuple,len=3,每个tensor的shape分别为 第一个特征层[1, 255, 13, 13],第二个特征层[1, 255, 26, 26],第三个特征层[1, 255, 52, 52]
            outputs = self.net(images)
            # embed()
            output_list = []
            for i in range(3):# 为什么是3
                # 有三个特征层,每个特征层对应自己的decode解码器
                output_list.append(self.yolo_decodes[i](outputs[i]))# 在这里打几个断点看看

            #---------------------------------------------------------#
            #   将预测框进行堆叠,然后进行非极大抑制
            #---------------------------------------------------------#
            # torch.cat()对矩阵按行进行拼接得到向量
            output = torch.cat(output_list, 1)# 这里也打几个断点
            # output就是predictions,格式为[batch_size, num_anchors, 85]
            batch_detections = non_max_suppression(output, len(self.class_names),
                                                    conf_thres=self.confidence,
                                                    nms_thres=self.iou)
            # embed()

            #---------------------------------------------------------#
            #   如果没有检测出物体,返回原图
            #---------------------------------------------------------#
            try:
                batch_detections = batch_detections[0].cpu().numpy()
            except:
                return image
            
            #---------------------------------------------------------#
            #   对预测框进行得分筛选
            #---------------------------------------------------------#
            # coordinates = []# bboxes的坐标

            top_index = batch_detections[:,4] * batch_detections[:,5] > self.confidence
            top_conf = batch_detections[top_index,4]*batch_detections[top_index,5]
            top_label = np.array(batch_detections[top_index,-1],np.int32)
            top_bboxes = np.array(batch_detections[top_index,:4])

            # 得到坐标点
            top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)

            # coordinates.append((top_xmin,top_xmax,top_ymin,top_ymax))# 把四个坐标点看做一个整体

            #-----------------------------------------------------------------#
            #   在图像传入网络预测前会进行letterbox_image给图像周围添加灰条
            #   因此生成的top_bboxes是相对于有灰条的图像的
            #   我们需要对其进行修改,去除灰条的部分。
            #-----------------------------------------------------------------#

            # boxes存放各目标的坐标
            boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)

        font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))

        thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0], 1)

        for i, c in enumerate(top_label):
            # embed()
            predicted_class = self.class_names[c]
            score = top_conf[i]

            top, left, bottom, right = boxes[i]
            top = top - 5
            left = left - 5
            bottom = bottom + 5
            right = right + 5

            # 左上角点的坐标
            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            # 右下角点的坐标
            bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
            right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))

            # 画框框
            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label, top, left, bottom, right)
            
            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle(
                    [left + i, top + i, right - i, bottom - i],
                    outline=self.colors[self.class_names.index(predicted_class)])
            draw.rectangle(# 画框框
                [tuple(text_origin), tuple(text_origin + label_size)],
                fill=self.colors[self.class_names.index(predicted_class)])
            draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
            del draw
        return image,boxes# 将boxes返回


3.2 使用opencv进行目标裁剪并保存各目标到本地

然后在处理目标裁剪的主程序中进行目标裁剪并保存到本地。
参考:

'''
predict.py有几个注意点
1、无法进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
2、如果想要保存,利用r_image.save("img.jpg")即可保存。
3、如果想要获得框的坐标,可以进入detect_image函数,读取top,left,bottom,right这四个值。
4、如果想要截取下目标,可以利用获取到的top,left,bottom,right这四个值在原图上利用矩阵的方式进行截取。
'''
import cv2
from PIL import Image
import numpy as np

from yolo import YOLO

from IPython import embed

yolo = YOLO()

image = Image.open('./img/view2.jpg')# 返回PIL.img对象
uncroped_image = cv2.imread("./img/view2.jpg")

r_image,boxes = yolo.detect_image(image)

# 进行裁剪并保存到本地
box = boxes

for i in range(boxes.shape[0]):
    # top, left, bottom, right = boxes[i]
    # 或者用下面这句等价
    top = boxes[i][0]
    left = boxes[i][1]
    bottom = boxes[i][2]
    right = boxes[i][3]

    top = top - 5
    left = left - 5
    bottom = bottom + 5
    right = right + 5

    # 左上角点的坐标
    top = int(max(0, np.floor(top + 0.5).astype('int32')))

    left = int(max(0, np.floor(left + 0.5).astype('int32')))
    # 右下角点的坐标
    bottom = int(min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')))
    right = int(min(np.shape(image)[1], np.floor(right + 0.5).astype('int32')))

    # embed()

    # 问题出在这里:不能用这个方法,看两个参数是长和宽,是从图像的原点开始裁剪的,这样肯定是不对的
    # 指定裁剪的目标范围
    croped_region = uncroped_image[top:bottom,left:right]# 先高后宽
    # 将裁剪好的目标保存到本地
    cv2.imwrite("./output/croped_view2_img_"+str(i)+".jpg",croped_region)

# embed()
r_image.show()

4 完整代码

完整的实现代码参考本项目的完整代码https://github.com/GaoZiqiang/crop_objects_based_yolov4_pytorch,主要改动在yolo.py和predict.py两个文件。

5 参考材料

https://blog.csdn.net/weixin_40119167/article/details/103685746

 类似资料: