当前位置: 首页 > 工具软件 > Faster Dom > 使用案例 >

使用faster rcnn训练自己的模型

殷学
2023-12-01

转载请注明:http://blog.csdn.net/c602273091/article/details/53945485

安装caffe

可以看我之前的博客:
安装caffe
安装faster rcnn:
faster rcnn

数据预处理

进行数据标注:
https://github.com/saicoco/object_labelImg

我这里主要是使用python对xml进行处理。

生产xml的代码:

# -*- coding:utf-8 -*-

__author__ = "Yu Chen"

import xml.dom
import xml.dom.minidom
import os
import json
import scipy
import numpy
import matplotlib
from PIL import Image

_ANNOTATION_SAVE_FOLDER_ = 'Annotations'
# xml文件规范定义
_INDENT = '\t'#' ' * 4
_NEW_LINE = '\n'
_FOLDER_NODE = 'VOC2007'
_ROOT_NODE = 'annotation'
_DATABASE_NAME = 'INRIA'
_ANNOTATION = 'PASCAL VOC2007'
_AUTHOR = 'Yu Chen'

_SEGMENTED = '0'
_DIFFICULT = '0'
_TRUNCATED = '0'
_POSE = 'Unspecified'

_IMAGE_HEIGHT = 360
_IMAGE_WIDTH = 640

_IMAGE_CHANNEL = 3


# 封装创建节点的过程
def createElementNode(doc, tag, attr):
    # 创建一个元素节点
    element_node = doc.createElement(tag)

    # 创建一个文本节点
    text_node = doc.createTextNode(attr)
   # 将文本节点作为元素节点的子节点
    element_node.appendChild(text_node)

    return element_node


# 封装添加一个子节点的过程
def createChildNode(doc, tag, attr, parent_node):

    child_node = createElementNode(doc, tag, attr)
    parent_node.appendChild(child_node)

# object节点比较特殊
def createObjectNode(doc, attrs):
    object_node = doc.createElement('object')
    createChildNode(doc, 'name', attrs['classification'], object_node)
    createChildNode(doc, 'pose', _POSE, object_node)
    createChildNode(doc, 'truncated', _TRUNCATED, object_node)
    createChildNode(doc, 'difficult', _DIFFICULT, object_node)

    bndbox_node = doc.createElement('bndbox')
    createChildNode(doc, 'xmin', attrs['xmin'], bndbox_node)
    createChildNode(doc, 'ymin', attrs['ymin'], bndbox_node)
    createChildNode(doc, 'xmax', attrs['xmax'], bndbox_node)
    createChildNode(doc, 'ymax', attrs['ymax'], bndbox_node)
    object_node.appendChild(bndbox_node)

    return object_node

# 将documentElement写入XML文件中
def writeXMLFile(doc, filename):
    tmpfile = open('tmp.xml', 'w')
    doc.writexml(tmpfile, addindent=_INDENT, newl='\n', encoding='utf-8')
    tmpfile.close()


    # 删除第一行默认添加的标记
    fin = open('tmp.xml')
    fout = open(filename, 'w')

   fout = open(filename, 'w')
    lines = fin.readlines()

    for line in lines[1:]:
        if line.split():
            fout.writelines(line)

    #new_lines = ''.join(lines[1:])
    #fout.write(new_lines)
    fin.close()
    fout.close()

# 创建XML文档并写入节点信息
def createXMLFile(attrs, width, height, filename):

    # 创建文档对象, 文档对象用于创建各种节点
    my_dom = xml.dom.getDOMImplementation()
    doc = my_dom.createDocument(None, _ROOT_NODE, None)

    # 获得根节点
    root_node = doc.documentElement

    # folder节点
    createChildNode(doc, 'folder', _FOLDER_NODE, root_node)

    # filename节点
    createChildNode(doc, 'filename', attrs['name'], root_node)

    # source节点
    source_node = doc.createElement('source')
    # source的子节点
    createChildNode(doc, 'database', _DATABASE_NAME, source_node)
    createChildNode(doc, 'annotation', _ANNOTATION, source_node)
    createChildNode(doc, 'image', 'flickr', source_node)
    createChildNode(doc, 'flickrid', 'NULL', source_node)
    root_node.appendChild(source_node)

    # owner节点
    owner_node = doc.createElement('owner')
    # owner的子节点
    createChildNode(doc, 'flickrid', 'NULL', owner_node)
    createChildNode(doc, 'name', _AUTHOR, owner_node)
       root_node.appendChild(owner_node)

    # size节点
    size_node = doc.createElement('size')
    createChildNode(doc, 'width', str(width), size_node)
    createChildNode(doc, 'height', str(height), size_node)

    createChildNode(doc, 'depth', str(_IMAGE_CHANNEL), size_node)
    root_node.appendChild(size_node)

    # segmented节点
    createChildNode(doc, 'segmented', _SEGMENTED, root_node)

    # object节点
    object_node = createObjectNode(doc, attrs)
    root_node.appendChild(object_node)

    # 写入文件
    writeXMLFile(doc, filename)

if __name__ == "__main__":
    # open label
    fid = open('training/label.idl', 'r')

    # storage path
    if not os.path.exists('Annotations/'):
        os.mkdir('Annotations')

    while True:
        line = fid.readline()
        if line:
            data = json.loads(line)
            for ite_key in data.keys():
                #print ite_key 
                attrs = dict()
                attrs['name'] = str(ite_key)
                xml_file_name = os.path.join(_ANNOTATION_SAVE_FOLDER_, (attrs['name'].split('.'))[0] + '.xml')
                print xml_file_name

                if data[ite_key]:
                    for bbx in data[ite_key]:
                        attrs['xmin'] = str(bbx[0])
                        attrs['ymin'] = str(bbx[1])
                        attrs['xmax'] = str(bbx[2])
                        attrs['ymax'] = str(bbx[3])
                        attrs['classification'] = str(bbx[4])

                        if os.path.exists(xml_file_name):
                            # print('do exists')
                            existed_doc = xml.dom.minidom.parse(xml_file_name)
                            root_node = existed_doc.documentElement

                            # 如果XML存在了, 添加object节点信息即可
                            object_node = createObjectNode(existed_doc, attrs)
                            root_node.appendChild(object_node)

                            # 写入文件
                            writeXMLFile(existed_doc, xml_file_name)
                    #       print bbx[0], bbx[1], bbx[2], bbx[3], bbx[4]
                        else:
                          # print('not exists')
                          # 如果XML文件不存在, 创建文件并写入节点信息

                          # 创建XML文件
                          createXMLFile(attrs, _IMAGE_WIDTH, _IMAGE_HEIGHT, xml_file_name)
                else:
               #     createEmptyXMLFile(attrs, _IMAGE_WIDTH, _IMAGE_HEIGHT, xml_file_name)
                    print "Empty List"
        else:
            break

    fid.close()

生产Main的txt代码:

# -*- coding:utf-8 -*-

import os
import random

__author__ = 'Yu Chen'

'''

设置trainval和test数据集包含的图片

'''

# ImageSets文件夹

_IMAGE_SETS_PATH = 'ImageSets'
_MAin_PATH = 'ImageSets/Main'
_XML_FILE_PATH = 'Annotations'

# Train数据集编号
_TRAIN_NUMBER = 6000
_TEST_NUM = 70091 # 72090

if __name__ == '__main__':

    resul = range(60091, 70091)
    random.shuffle(resul)

    # 创建ImageSets数据集
    if os.path.exists(_IMAGE_SETS_PATH):
        print('ImageSets dir is already exists')
        if os.path.exists(_MAin_PATH):
            print('Main dir is already in ImageSets')
        else:
            os.mkdir(_MAin_PATH)
    else:
        os.mkdir(_IMAGE_SETS_PATH)
        os.mkdir(_MAin_PATH)

    f_test = open(os.path.join(_MAin_PATH, 'test.txt'), 'w')
    f_trainval = open(os.path.join(_MAin_PATH, 'trainval.txt'), 'w')
    f_train = open(os.path.join(_MAin_PATH, 'train.txt'), 'w')
    f_val = open(os.path.join(_MAin_PATH, 'val.txt'), 'w')

    num = 0
    for root, dirs, files in os.walk(_XML_FILE_PATH):
        print len(files)
        for f in files:
            element = f.split('.')[0]
            f_trainval.write(str(element)+'\n')
            if num > _TRAIN_NUMBER:
                f_val.write(str(element) + '\n')
            else:
                f_train.write(str(element) + '\n')
            num += 1


    for i in range(_TEST_NUM, 72091):
        f_test.write(str(i) + '\n')

    f_test.close()
    f_trainval.close()
    f_train.close()
    f_val.close()

主要参考了:
1、http://blog.csdn.net/sinat_30071459/article/details/50723212
2、http://blog.csdn.net/gvfdbdf/article/details/52214008
3、https://github.com/Parlefan/create-voc2007-dataset/blob/master/create_ImageSets.py
4、https://github.com/Parlefan/create-voc2007-dataset/blob/master/create_JPEGImages.py
5、https://saicoco.github.io/object-detection-4/
6、http://www.cnblogs.com/louyihang-loves-baiyan/p/4885659.html
7、http://www.cnblogs.com/louyihang-loves-baiyan/p/4903231.html

对于训练代码的修改

主要是参考了:
http://blog.csdn.net/sinat_30071459/article/details/51332084

1、http://www.voidcn.com/blog/sinat_30071459/article/p-5957360.html
2、http://www.cnblogs.com/CarryPotMan/p/5390336.html

遇到问题

1、error 1:assert (boxes[:, 2] >= boxes[:, 0]).all()
将py-faster-rcnn/lib/datasets/imdb.py中的相应代码改成如下代码即可:

def append_flipped_images(self):
        num_images = self.num_images
        widths = [PIL.Image.open(self.image_path_at(i)).size[0]
                  for i in xrange(num_images)]
        for i in xrange(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1

            for b in range(len(boxes)):
                if boxes[b][2] < boxes[b][0]:
                   boxes[b][0] = 0

            assert (boxes[:, 2] >= boxes[:, 0]).all()

2、IndexError: list index out of range

删除fast-rcnn-master/data/cache/ 文件夹下的.pkl文件,或者改名备份,重新训练即可。

3、image_num aeert divide 0.
这是因为在做xml的时候,没有目标的图片不能记录。

参考了:
1、https://github.com/rbgirshick/py-faster-rcnn/issues
2、https://github.com/rbgirshick/fast-rcnn/issues/
3、http://blog.csdn.net/marshwb/article/details/50451548
4、http://blog.csdn.net/sinat_30071459/article/details/51332084
5、http://blog.csdn.net/xzzppp/article/details/52036794

参考性很强

有自己的数据集,很实用的。
http://www.cnblogs.com/louyihang-loves-baiyan/p/4906690.html
http://blog.csdn.net/sinat_30071459/article/details/50723212
http://download.csdn.net/detail/sinat_30071459/9531172
http://download.csdn.net/detail/sinat_30071459/9532108
https://saicoco.github.io/object-detection-4/
http://blog.csdn.net/sinat_30071459/article/details/51332084

 类似资料: