当前位置: 首页 > 工具软件 > json-mask > 使用案例 >

批量分割mask转json

傅越
2023-12-01

代码

import os
import cv2
import sys
import PIL
import copy
import json
import yaml
import base64
import numpy as np
import skimage.io as io
from glob import glob
try:
    from labelme import __version__ as labelme_version
except:
    labelme_version = '4.2.9'
sys.path.append('..')
currentCV_version = cv2.__version__


def rm(filepath):
    p = open(filepath, 'r+')
    lines = p.readlines()
    d = ""
    for line in lines:
        c = line.replace('"group_id": "null",', '"group_id": null,')
        d += c
    p.seek(0)
    p.truncate()
    p.write(d)
    p.close()


def imgEncode(img_or_path):
    if isinstance(img_or_path, np.ndarray):
        """
        copy from labelme image.py    
        """
        img_pil = PIL.Image.fromarray(img_or_path)
        f = io.BytesIO()
        img_pil.save(f, format='PNG')
        img_bin = f.getvalue()
        if hasattr(base64, 'encodebytes'):
            img_b64 = base64.encodebytes(img_bin)
        else:
            img_b64 = base64.encodestring(img_bin)
        return img_b64
    else:
        if isinstance(img_or_path, str):
            i = open(img_or_path, 'rb')
        elif isinstance(img_or_path, io.BufferedReader):
            i = img_or_path
        else:
            raise TypeError('Input type error!')
        base64_data = base64.b64encode(i.read())
        return base64_data.decode()


def rs(st: str):
    s = st.replace('\n', '').strip()
    return s


def readYmal(filepath, labeledImg=None):
    if os.path.exists(filepath):
        if filepath.endswith('.yaml'):
            f = open(filepath)
            y = yaml.load(f, Loader=yaml.FullLoader)
            f.close()
            # print(y)
            tmp = y['label_names']
            # print(tmp["tag1"])
            objs = zip(tmp.keys(), tmp.values())
            return sorted(objs)
        elif filepath.endswith('.txt'):
            f = open(filepath, 'r', encoding='utf-8')
            classList = f.readlines()
            f.close()
            l3 = [rs(i) for i in classList]
            l = list(range(1, len(classList)+1))
            objs = zip(l3, l)
            return sorted(objs)
    elif labeledImg is not None and filepath == "":
        """
        should make sure your label is correct!!!
        """
        labeledImg = np.array(labeledImg, dtype=np.uint8)

        labeledImg[labeledImg > 0] = 255
        labeledImg[labeledImg != 255] = 0
        # print(labeledImg)
        _, labels, stats, centroids = cv2.connectedComponentsWithStats(
            labeledImg)

        labels = np.max(labels) + 1
        labels = [x for x in range(1, labels)]

        classes = []
        for i in range(0, len(labels)):
            classes.append("class{}".format(i))

        return zip(classes, labels)
    else:
        raise FileExistsError('file not found')


def get_approx(img, contour, length_p=0.005):
    """获取逼近多边形
    :param img: 处理图片
    :param contour: 连通域
    :param length_p: 逼近长度百分比
    """
    img_adp = img.copy()
    # 逼近长度计算
    epsilon = length_p * cv2.arcLength(contour, True)
    # 获取逼近多边形
    approx = cv2.approxPolyDP(contour, epsilon, True)
    return approx


def getBinary(img_or_path, minConnectedArea=1):
    if isinstance(img_or_path, str):
        i = cv2.imread(img_or_path)
    elif isinstance(img_or_path, np.ndarray):
        i = img_or_path
    else:
        raise TypeError('Input type error')

    if len(i.shape) == 3:
        img_gray = cv2.cvtColor(i, cv2.COLOR_BGR2GRAY)

    else:
        img_gray = i

    ret, img_bin = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY)

    _, labels, stats, centroids = cv2.connectedComponentsWithStats(img_bin, connectivity=4)
    # labels:图像上每一像素的标记,用数字1、2、3…表示(不同的数字表示不同的连通域)
    # stats:每一个标记的统计信息,是一个5列的矩阵,每一行对应每个连通区域的外接矩形的x、y、width、height和面积,示例如下: 0 0 720 720 291805
    # centroids:连通域的中心点
    # print(stats.shape)  (19,5)
    # 删除区域小的图片
    for index in range(1, stats.shape[0]):
        if stats[index][4] < minConnectedArea or stats[index][4] < 0.0001 * (
                stats[index][2] * stats[index][3]):
            labels[labels == index] = 0

    labels[labels != 0] = 1

    img_bin = np.array(img_bin * labels).astype(np.uint8)
    return i, img_bin


def getMultiRegion(img, img_bin):
    """
    for multiple objs in same class
    """
    if float(currentCV_version[0:3]) < 3.5:
        img_bin, contours, hierarchy = cv2.findContours(
            img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    else:
        contours, hierarchy = cv2.findContours(img_bin, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    regions = []
    if len(contours) >= 1:
        for i in range(0, len(contours)):
            if i:
                # print(len(contours[i]))
                region = get_approx(img, contours[i], 0.0001)
                # print(region)
                if region.shape[0] > 3:
                    regions.append(region)

        return regions
    else:
        return []


def process(oriImg):
    img, img_bin = getBinary(oriImg)
    return getMultiRegion(img, img_bin)


def getMultiShapes(oriImgPath, labelPath, savePath='', labelYamlPath='', flag=False):
    """
    oriImgPath : for change img to base64  \n
    labelPath : after fcn/unet or other machine learning objects outlining , the generated label img
                or labelme labeled imgs(after json files converted to mask files)  \n
    savePath : json file save path  \n
    labelYamlPath : after json files converted to mask files. if doesn't have this file,should have a labeled img.
                    but the classes should change by yourself(labelme 4.2.9 has a bug,when change the label there will be an error.
                    )   \n
    """
    if isinstance(labelPath, str):
        if os.path.exists(labelPath):
            label_img = io.imread(labelPath)
        else:
            raise FileNotFoundError('mask/labeled image not found')
    else:
        label_img = labelPath

    # print(np.max(label_img))

    if np.max(label_img) > 127:
        # print('too many classes! \n maybe binary?')
        label_img[label_img > 127] = 255
        label_img[label_img != 255] = 0
        label_img = label_img / 255

    labelShape = label_img.shape

    labels = readYmal(labelYamlPath, label_img)
    # print(list(labels))
    shapes = []
    obj = dict()
    obj['version'] = labelme_version
    obj['flags'] = {}
    for la in list(labels):

        if la[1] > 0:
            # print(la[0])
            img = copy.deepcopy(label_img)  # img = label_img.copy()
            img = img.astype(np.uint8)
            img[img == la[1]] = 255
            img[img != 255] = 0

            region = process(img.astype(np.uint8))

            if isinstance(region, np.ndarray):

                points = []
                for i in range(0, region.shape[0]):
                    print(len(region[i][0]))
                    points.append(region[i][0].tolist())
                shape = dict()
                shape['label'] = la[0]
                shape['points'] = points
                shape['group_id'] = 'null'
                shape['shape_type'] = 'polygon'
                shape['flags'] = {}
                shapes.append(shape)

            elif isinstance(region, list):
                # print(len(region))
                for subregion in region:
                    points = []
                    for i in range(0, subregion.shape[0]):
                        points.append(subregion[i][0].tolist())
                    shape = dict()
                    shape['label'] = la[0]
                    shape['points'] = points
                    shape['group_id'] = 'null'
                    shape['shape_type'] = 'polygon'
                    shape['flags'] = {}
                    shapes.append(shape)

    # print(len(shapes))
    obj['shapes'] = shapes
    # print(shapes)
    (_, imgname) = os.path.split(oriImgPath)
    obj['imagePath'] = imgname
    # print(obj['imagePath'])
    obj['imageData'] = str(imgEncode(oriImgPath))

    obj['imageHeight'] = labelShape[0]
    obj['imageWidth'] = labelShape[1]

    j = json.dumps(obj, sort_keys=True, indent=4)
    # print(j)

    if not flag:
        saveJsonPath = savePath + os.sep + obj['imagePath'][:-4] + '.json'
        # print(saveJsonPath)
        with open(saveJsonPath, 'w') as f:
            f.write(j)

        rm(saveJsonPath)
    else:
        return j


if __name__ == "__main__":
    path = ''
    init_path = '%s/image' % path
    mask_path = '%s/mask' % path

    yaml_file = '%s/label_names.yaml' % path
    save_json = '%s/json' % path

    mask_images_list = glob(os.path.join(mask_path, "*.png"))
    init_images_list = glob(os.path.join(init_path, "*.png"))

    if not os.path.exists(save_json):
        os.mkdir(save_json)

    for mask_image, init_image in zip(mask_images_list, init_images_list):
        print(mask_image)
        getMultiShapes(init_image, mask_image, save_json, yaml_file)

label_name.yaml格式

label_names:
    Tag1: 1
    Tag2: 2
    Tag3: 3
    类别: 掩码像素值
    ....

参考

https://github.com/guchengxi1994/mask2json

 类似资料: