安装pydot:
pip install pydot
安装GraphViz:
sudo apt-get install graphviz
在安装完以上两步后进行绘制遇到的错误:
./python/draw_net.py ./examples/cifar10/cifar10_quick_train_test.prototxt ./net.png
Drawing net to ./net.png
Traceback (most recent call last):
File "./python/draw_net.py", line 45, in <module>
main()
File "./python/draw_net.py", line 41, in main
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
File "/home/andre/caffe/python/caffe/draw.py", line 222, in draw_net_to_file
fid.write(draw_net(caffe_net, rankdir, ext))
File "/home/andre/caffe/python/caffe/draw.py", line 204, in draw_net
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
File "/home/andre/caffe/python/caffe/draw.py", line 151, in get_pydot_graph
node_label = get_layer_label(layer, rankdir)
File "/home/andre/caffe/python/caffe/draw.py", line 94, in get_layer_label
layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
AttributeError: 'google.protobuf.pyext._message.RepeatedScalarConta' object has no attribute
解决方法:
对应修改:
1.python/caffe/draw.py
def get_layer_label(layer, rankdir):
"""Define node label based on layer type.
Parameters
----------
layer : ?
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
Returns
-------
string :
A label for the current layer
"""
if rankdir in ('TB', 'BT'):
# If graph orientation is vertical, horizontal space is free and
# vertical space is not; separate words with spaces
separator = ' '
else:
# If graph orientation is horizontal, vertical space is free and
# horizontal space is not; separate words with newlines
separator = '\\n'
if layer.type == 'Convolution' or layer.type == 'Deconvolution':
# Outer double quotes needed or else colon characters don't parse
# properly
param = layer.convolution_param #增加的内容
node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
(layer.name,
separator,
layer.type,
separator,
layer.pooling_param.kernel_size,
separator,
layer.pooling_param.stride,
separator,
layer.pooling_param.pad)
elif layer.type == 'Pooling':
pooling_types_dict = get_pooling_types_dict()
node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
(layer.name,
separator,
pooling_types_dict[layer.pooling_param.pool],
layer.type,
separator,
layer.pooling_param.kernel_size,
separator,
layer.pooling_param.stride,
separator,
layer.pooling_param.pad)
else:
node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
return node_label
2.在python/caffe/test/ 增加文件 test_draw.py
import os
import unittest
from google import protobuf
import caffe.draw
from caffe.proto import caffe_pb2
def getFilenames():
"""Yields files in the source tree which are Net prototxts."""
result = []
root_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', '..', '..'))
assert os.path.exists(root_dir)
for dirname in ('models', 'examples'):
dirname = os.path.join(root_dir, dirname)
assert os.path.exists(dirname)
for cwd, _, filenames in os.walk(dirname):
for filename in filenames:
filename = os.path.join(cwd, filename)
if filename.endswith('.prototxt') and 'solver' not in filename:
yield os.path.join(dirname, filename)
class TestDraw(unittest.TestCase):
def test_draw_net(self):
for filename in getFilenames():
net = caffe_pb2.NetParameter()
with open(filename) as infile:
protobuf.text_format.Merge(infile.read(), net)
caffe.draw.draw_net(net, 'LR')
参考地址:点击打开链接