修改yolov4网络结构中keras的导包命令,tensorflow=2.2.0中自带的有keras
# 将所有的
from keras import *
from keras.* import *
# 修改过为:
from tensorflow.keras import *
from tensorflow.keras.* import *
加载权重和网络结构
from tensorflow.keras.models import Model
import tensorflow as tf
class YOLO(object):
_defaults = {
...
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化yolo
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
self.generate()
#---------------------------------------------------#
# 载入模型
#---------------------------------------------------#
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
self.model = yolo_body([640, 640, 3], ...)
self.model.load_weights(self.model_path)
outputs = Lambda(
DecodeBox,
output_shape = (1,),
name = 'yolo_eval',
arguments = {...}
)(self.model.output) # self.model.output输出是三个特征层
self.yolo_model = Model(self.model.input, outputs)
tf.saved_model.save(self.yolo_model, "yolo_tflite/yolov4")
converter = tf.lite.TFLiteConverter.from_saved_model('yolo_tflite/yolov4')
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops = True
tflite_model = converter.convert()
open('yolo_tflite/yolo_fp32.tflite', 'wb').write(tflite_model)