#!/usr/bin/env python
# -*- encoding:utf8 -*-
# protoc-gen-erl
# Google's Protocol Buffers project, ported to lua.
# https://code.google.com/p/protoc-gen-lua/
#
# Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
# All rights reserved.
#
# Use, modification and distribution are subject to the "New BSD License"
# as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
import sys
import os.path as path
from cStringIO import StringIO
import plugin_pb2
import google.protobuf.descriptor_pb2 as descriptor_pb2
_packages = {}
_files = {}
_message = {}
FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
if sys.platform == "win32":
import msvcrt, os
msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
class CppType:
CPPTYPE_INT32 = 1
CPPTYPE_INT64 = 2
CPPTYPE_UINT32 = 3
CPPTYPE_UINT64 = 4
CPPTYPE_DOUBLE = 5
CPPTYPE_FLOAT = 6
CPPTYPE_BOOL = 7
CPPTYPE_ENUM = 8
CPPTYPE_STRING = 9
CPPTYPE_MESSAGE = 10
CPP_TYPE ={
FDP.TYPE_DOUBLE : CppType.CPPTYPE_DOUBLE,
FDP.TYPE_FLOAT : CppType.CPPTYPE_FLOAT,
FDP.TYPE_INT64 : CppType.CPPTYPE_INT64,
FDP.TYPE_UINT64 : CppType.CPPTYPE_UINT64,
FDP.TYPE_INT32 : CppType.CPPTYPE_INT32,
FDP.TYPE_FIXED64 : CppType.CPPTYPE_UINT64,
FDP.TYPE_FIXED32 : CppType.CPPTYPE_UINT32,
FDP.TYPE_BOOL : CppType.CPPTYPE_BOOL,
FDP.TYPE_STRING : CppType.CPPTYPE_STRING,
FDP.TYPE_MESSAGE : CppType.CPPTYPE_MESSAGE,
FDP.TYPE_BYTES : CppType.CPPTYPE_STRING,
FDP.TYPE_UINT32 : CppType.CPPTYPE_UINT32,
FDP.TYPE_ENUM : CppType.CPPTYPE_ENUM,
FDP.TYPE_SFIXED32 : CppType.CPPTYPE_INT32,
FDP.TYPE_SFIXED64 : CppType.CPPTYPE_INT64,
FDP.TYPE_SINT32 : CppType.CPPTYPE_INT32,
FDP.TYPE_SINT64 : CppType.CPPTYPE_INT64
}
def printerr(*args):
sys.stderr.write(" ".join(args))
sys.stderr.write("\n")
sys.stderr.flush()
class TreeNode(object):
def __init__(self, name, parent=None, filename=None, package=None):
super(TreeNode, self).__init__()
self.child = []
self.parent = parent
self.filename = filename
self.package = package
if parent:
self.parent.add_child(self)
self.name = name
def add_child(self, child):
self.child.append(child)
def find_child(self, child_names):
if child_names:
for i in self.child:
if i.name == child_names[0]:
return i.find_child(child_names[1:])
raise StandardError
else:
return self
def get_child(self, child_name):
for i in self.child:
if i.name == child_name:
return i
return None
def get_path(self, end = None):
pos = self
out = []
while pos and pos != end:
out.append(pos.name)
pos = pos.parent
out.reverse()
return '.'.join(out)
def get_global_name(self):
return self.get_path()
def get_local_name(self):
pos = self
while pos.parent:
pos = pos.parent
if self.package and pos.name == self.package[-1]:
break
return self.get_path(pos)
def __str__(self):
return self.to_string(0)
def __repr__(self):
return str(self)
def to_string(self, indent = 0):
return ' '*indent + '<TreeNode ' + self.name + '(\n' + \
','.join([i.to_string(indent + 4) for i in self.child]) + \
' '*indent +')>\n'
class Env(object):
filename = None
package = None
extend = None
descriptor = None
message = None
context = None
register = None
def __init__(self):
self.message_tree = TreeNode('')
self.scope = self.message_tree
def get_global_name(self):
return self.scope.get_global_name()
def get_local_name(self):
return self.scope.get_local_name()
def get_ref_name(self, type_name):
try:
node = self.lookup_name(type_name)
except:
# if the child doesn't be founded, it must be in this file
return type_name[len('.'.join(self.package)) + 2:]
if node.filename != self.filename:
return node.filename + '_pb.' + node.get_local_name()
return node.get_local_name()
def lookup_name(self, name):
names = name.split('.')
if names[0] == '':
return self.message_tree.find_child(names[1:])
else:
return self.scope.parent.find_child(names)
def enter_package(self, package):
if not package:
return self.message_tree
names = package.split('.')
pos = self.message_tree
for i, name in enumerate(names):
new_pos = pos.get_child(name)
if new_pos:
pos = new_pos
else:
return self._build_nodes(pos, names[i:])
return pos
def enter_file(self, filename, package):
self.filename = filename
self.package = package.split('.')
self._init_field()
self.scope = self.enter_package(package)
def exit_file(self):
self._init_field()
self.filename = None
self.package = []
self.scope = self.scope.parent
def enter(self, message_name):
self.scope = TreeNode(message_name, self.scope, self.filename,
self.package)
def exit(self):
self.scope = self.scope.parent
def _init_field(self):
self.descriptor = []
self.context = []
self.message = []
self.register = []
def _build_nodes(self, node, names):
parent = node
for i in names:
parent = TreeNode(i, parent, self.filename, self.package)
return parent
class Writer(object):
def __init__(self, prefix=None):
self.io = StringIO()
self.__indent = ''
self.__prefix = prefix
def getvalue(self):
return self.io.getvalue()
def __enter__(self):
self.__indent += ' '
return self
def __exit__(self, type, value, trackback):
self.__indent = self.__indent[:-4]
def __call__(self, data):
self.io.write(self.__indent)
if self.__prefix:
self.io.write(self.__prefix)
self.io.write(data)
DEFAULT_VALUE = {
FDP.TYPE_DOUBLE : '0.0',
FDP.TYPE_FLOAT : '0.0',
FDP.TYPE_INT64 : '0',
FDP.TYPE_UINT64 : '0',
FDP.TYPE_INT32 : '0',
FDP.TYPE_FIXED64 : '0',
FDP.TYPE_FIXED32 : '0',
FDP.TYPE_BOOL : 'false',
FDP.TYPE_STRING : '""',
FDP.TYPE_MESSAGE : 'nil',
FDP.TYPE_BYTES : '""',
FDP.TYPE_UINT32 : '0',
FDP.TYPE_ENUM : '1',
FDP.TYPE_SFIXED32 : '0',
FDP.TYPE_SFIXED64 : '0',
FDP.TYPE_SINT32 : '0',
FDP.TYPE_SINT64 : '0',
}
def code_gen_enum_item(index, enum_value, env):
full_name = env.get_local_name() + '.' + enum_value.name
obj_name = full_name.upper().replace('.', '_') + '_ENUM'
env.descriptor.append(
"local %s = protobuf.EnumValueDescriptor();\n"% obj_name
)
context = Writer(obj_name)
context('.name = "%s"\n' % enum_value.name)
context('.index = %d\n' % index)
context('.number = %d\n' % enum_value.number)
env.context.append(context.getvalue())
return obj_name
def code_gen_enum(enum_desc, env):
env.enter(enum_desc.name)
full_name = env.get_local_name()
obj_name = full_name.upper().replace('.', '_')
env.descriptor.append(
"local %s = protobuf.EnumDescriptor();\n"% obj_name
)
context = Writer(obj_name)
context('.name = "%s"\n' % enum_desc.name)
context('.full_name = "%s"\n' % env.get_global_name())
values = []
for i, enum_value in enumerate(enum_desc.value):
values.append(code_gen_enum_item(i, enum_value, env))
context('.values = {%s}\n' % ','.join(values))
env.context.append(context.getvalue())
env.exit()
return obj_name
def code_gen_field(index, field_desc, env):
full_name = env.get_local_name() + '.' + field_desc.name
obj_name = full_name.upper().replace('.', '_') + '_FIELD'
env.descriptor.append(
"local %s = protobuf.FieldDescriptor();\n"% obj_name
)
context = Writer(obj_name)
context('.name = "%s"\n' % field_desc.name)
context('.full_name = "%s"\n' % (
env.get_global_name() + '.' + field_desc.name))
context('.number = %d\n' % field_desc.number)
context('.index = %d\n' % index)
context('.label = %d\n' % field_desc.label)
if field_desc.HasField("default_value"):
context('.has_default_value = true\n')
value = field_desc.default_value
if field_desc.type == FDP.TYPE_STRING:
context('.default_value = "%s"\n'%value)
else:
context('.default_value = %s\n'%value)
else:
context('.has_default_value = false\n')
if field_desc.label == FDP.LABEL_REPEATED:
default_value = "{}"
elif field_desc.HasField('type_name'):
default_value = "nil"
else:
default_value = DEFAULT_VALUE[field_desc.type]
context('.default_value = %s\n' % default_value)
if field_desc.HasField('type_name'):
type_name = env.get_ref_name(field_desc.type_name).upper()
if field_desc.type == FDP.TYPE_MESSAGE:
context('.message_type = %s\n' % type_name)
else:
context('.enum_type = %s\n' % type_name)
if field_desc.HasField('extendee'):
type_name = env.get_ref_name(field_desc.extendee)
env.register.append(
"%s.RegisterExtension(%s)\n" % (type_name, obj_name)
)
context('.type = %d\n' % field_desc.type)
context('.cpp_type = %d\n\n' % CPP_TYPE[field_desc.type])
env.context.append(context.getvalue())
return obj_name
def code_gen_message(message_descriptor, env, containing_type = None):
env.enter(message_descriptor.name)
full_name = env.get_local_name()
obj_name = full_name.upper().replace('.', '_')
env.descriptor.append(
"%s = protobuf.Descriptor();\n"% obj_name
)
context = Writer(obj_name)
context('.name = "%s"\n' % message_descriptor.name)
context('.full_name = "%s"\n' % env.get_global_name())
nested_types = []
for msg_desc in message_descriptor.nested_type:
msg_name = code_gen_message(msg_desc, env, obj_name)
nested_types.append(msg_name)
context('.nested_types = {%s}\n' % ', '.join(nested_types))
enums = []
for enum_desc in message_descriptor.enum_type:
enums.append(code_gen_enum(enum_desc, env))
context('.enum_types = {%s}\n' % ', '.join(enums))
fields = []
for i, field_desc in enumerate(message_descriptor.field):
fields.append(code_gen_field(i, field_desc, env))
context('.fields = {%s}\n' % ', '.join(fields))
if len(message_descriptor.extension_range) > 0:
context('.is_extendable = true\n')
else:
context('.is_extendable = false\n')
extensions = []
for i, field_desc in enumerate(message_descriptor.extension):
extensions.append(code_gen_field(i, field_desc, env))
context('.extensions = {%s}\n' % ', '.join(extensions))
if containing_type:
context('.containing_type = %s\n' % containing_type)
env.message.append('%s = protobuf.Message(%s)\n' % (full_name,
obj_name))
env.context.append(context.getvalue())
env.exit()
return obj_name
def write_header(writer):
writer("""-- Generated By protoc-gen-lua Do not Edit
""")
def code_gen_file(proto_file, env, is_gen):
filename = path.splitext(proto_file.name)[0]
env.enter_file(filename, proto_file.package)
includes = []
for f in proto_file.dependency:
inc_file = path.splitext(f)[0]
includes.append(inc_file)
# for field_desc in proto_file.extension:
# code_gen_extensions(field_desc, field_desc.name, env)
for enum_desc in proto_file.enum_type:
code_gen_enum(enum_desc, env)
for enum_value in enum_desc.value:
env.message.append('%s = %d\n' % (enum_value.name,
enum_value.number))
for msg_desc in proto_file.message_type:
code_gen_message(msg_desc, env)
if is_gen:
lua = Writer()
write_header(lua)
lua('local protobuf = require "protobuf"\n')
for i in includes:
lua('local %s_PB = require("%s_pb")\n' % (i.upper(), i))
lua("module('%s_pb')\n" % env.filename)
lua('\n\n')
map(lua, env.descriptor)
lua('\n')
map(lua, env.context)
lua('\n')
env.message.sort()
map(lua, env.message)
lua('\n')
map(lua, env.register)
_files[env.filename+ '_pb.lua'] = lua.getvalue()
env.exit_file()
def main():
plugin_require_bin = sys.stdin.read()
code_gen_req = plugin_pb2.CodeGeneratorRequest()
code_gen_req.ParseFromString(plugin_require_bin)
env = Env()
for proto_file in code_gen_req.proto_file:
code_gen_file(proto_file, env,
proto_file.name in code_gen_req.file_to_generate)
code_generated = plugin_pb2.CodeGeneratorResponse()
for k in _files:
file_desc = code_generated.file.add()
file_desc.name = k
file_desc.content = _files[k]
sys.stdout.write(code_generated.SerializeToString())
if __name__ == "__main__":
main()
修改protoc-gen-lua文件的内容为以上,即可。
lua里使用的时候,复合字段是有值的,直接取即可
local person = person_pb.Person()
person.company.name = "xxx" --其中company为复合字段,也就是另一个类型,比如 company_pb.Company()
直接用就可以了