draw_net.py绘制网络结构图遇到的问题

寿阳华
2023-12-01

安装pydot:

pip install pydot

安装GraphViz:

sudo apt-get install graphviz


在安装完以上两步后进行绘制遇到的错误:

./python/draw_net.py ./examples/cifar10/cifar10_quick_train_test.prototxt ./net.png
Drawing net to ./net.png
Traceback (most recent call last):
  File "./python/draw_net.py", line 45, in <module>
    main()
  File "./python/draw_net.py", line 41, in main
    caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
  File "/home/andre/caffe/python/caffe/draw.py", line 222, in draw_net_to_file
    fid.write(draw_net(caffe_net, rankdir, ext))
  File "/home/andre/caffe/python/caffe/draw.py", line 204, in draw_net
    return get_pydot_graph(caffe_net, rankdir).create(format=ext)
  File "/home/andre/caffe/python/caffe/draw.py", line 151, in get_pydot_graph
    node_label = get_layer_label(layer, rankdir)
  File "/home/andre/caffe/python/caffe/draw.py", line 94, in get_layer_label
    layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
AttributeError: 'google.protobuf.pyext._message.RepeatedScalarConta' object has no attribute

解决方法:

对应修改:

1.python/caffe/draw.py

def get_layer_label(layer, rankdir):
    """Define node label based on layer type.

    Parameters
    ----------
    layer : ?
    rankdir : {'LR', 'TB', 'BT'}
        Direction of graph layout.

    Returns
    -------
    string :
        A label for the current layer
    """

    if rankdir in ('TB', 'BT'):
        # If graph orientation is vertical, horizontal space is free and
        # vertical space is not; separate words with spaces
        separator = ' '
    else:
        # If graph orientation is horizontal, vertical space is free and
        # horizontal space is not; separate words with newlines
        separator = '\\n'

    if layer.type == 'Convolution' or layer.type == 'Deconvolution':
        # Outer double quotes needed or else colon characters don't parse
        # properly
        param = layer.convolution_param                            #增加的内容
        node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
                    (layer.name,
                     separator,
                     layer.type,
                     separator,
                     
                     layer.pooling_param.kernel_size,
                     separator,

                     layer.pooling_param.stride,
                     separator,
                     layer.pooling_param.pad)
    elif layer.type == 'Pooling':
        pooling_types_dict = get_pooling_types_dict()
        node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
                     (layer.name,
                      separator,
                      pooling_types_dict[layer.pooling_param.pool],
                      layer.type,
                      separator,
                      layer.pooling_param.kernel_size,
                      separator,
                      layer.pooling_param.stride,
                      separator,
                      layer.pooling_param.pad)
    else:
        node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
    return node_label


2.在python/caffe/test/ 增加文件  test_draw.py

import os
import unittest

from google import protobuf

import caffe.draw
from caffe.proto import caffe_pb2

def getFilenames():
    """Yields files in the source tree which are Net prototxts."""
    result = []

    root_dir = os.path.abspath(os.path.join(
        os.path.dirname(__file__), '..', '..', '..'))
    assert os.path.exists(root_dir)

    for dirname in ('models', 'examples'):
        dirname = os.path.join(root_dir, dirname)
        assert os.path.exists(dirname)
        for cwd, _, filenames in os.walk(dirname):
            for filename in filenames:
                filename = os.path.join(cwd, filename)
                if filename.endswith('.prototxt') and 'solver' not in filename:
                    yield os.path.join(dirname, filename)


class TestDraw(unittest.TestCase):
    def test_draw_net(self):
        for filename in getFilenames():
            net = caffe_pb2.NetParameter()
            with open(filename) as infile:
                protobuf.text_format.Merge(infile.read(), net)
            caffe.draw.draw_net(net, 'LR')


参考地址:点击打开链接

 类似资料: