原始仓库代码使用resnet系列作为骨干网络,网络输入分辨率为288*800,在板卡上运行7-8FPS,有点低,个人对原始仓库添加了些东西,具体如下:
1.配置文件
dataset:
name: CULane
data_root: '/opt/sda5/BL01_Data/Lane_Data/CULane'
num_lanes: 4
w: 512
h: 256
input_size: [256, 512]
batch_size: 128
griding_num: 200
use_aux: False
# row_anchor: [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287] # 288
row_anchor: [108, 116, 125, 134, 142, 151, 160, 168, 177, 186, 194, 203, 212, 220, 229, 238, 246, 255] #256
num_per_lane: 18
num_workers: 8
train:
epoch: 100
optimizer: 'SGD' #['SGD','Adam']
learning_rate: 0.1
weight_decay: 1.0e-4
momentum: 0.9
scheduler: 'multi' #['multi', 'cos']
steps: [25,38]
gamma: 0.1
warmup: 'linear'
warmup_iters: 2
network:
backbone: 'mobilenetv2'
pretrained: NULL
out_channel: [32, 96, 320] #[128,256,1024]
sim_loss_w: 0.0
shp_loss_w: 0.0
test:
test_model: 'weights/20210322_094501_lr_0.1/ep016.pth'
test_work_dir: 'weights'
val_intervals: 1
# EXP
note: ''
log_path: 'runs'
view: True
# FINETUNE or RESUME MODEL PATH
finetune: NULL
resume: NULL
所有参数统一放在配置文件中,原始仓库在两个py文件中
2.骨干网络和分辨率
import copy
from .mobilenetv2 import MobileNetV2
from .resnet import resnet
def build_backbone(name):
if name == 'resnet':
layer = name.split('_')[1]
return resnet(layer)
elif name == 'mobilenetv2':
return MobileNetV2()
else:
raise NotImplementedError
import math
lane_num = 18
for i in range(1, lane_num + 1):
anchors = (590-(i-1)*20)-1
anchors = math.floor((256 / 590) * anchors)
print(anchors)
通过配置文件中的name直接选取不同网络和不同分辨率,注意改变分辨率同时,需要首先生成anchor,将配置文件中的anchors修改为对应分辨率。
3.在训练脚本中添加test功能,选取最佳模型
def test(net, data_loader, dataset, work_dir, logger, use_aux=True):
output_path = os.path.join(work_dir, 'culane_eval_tmp')
if not os.path.exists(output_path):
os.mkdir(output_path)
net.eval()
if dataset['name'] == 'CULane':
for i, data in enumerate(dist_tqdm(data_loader)):
imgs, names = data
imgs = imgs.cuda()
with torch.no_grad():
out = net(imgs)
if len(out) == 2 and use_aux:
out, seg_out = out
generate_lines(out,imgs[0,0].shape,names,output_path,dataset['griding_num'],localization_type = 'rel',flip_updown = True)
res = call_culane_eval(dataset['data_root'], 'culane_eval_tmp', work_dir)
TP,FP,FN = 0,0,0
for k, v in res.items():
val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
val_tp,val_fp,val_fn = int(v['tp']),int(v['fp']),int(v['fn'])
TP += val_tp
FP += val_fp
FN += val_fn
logger.log('k:{} val{}'.format(k,val))
P = TP * 1.0/(TP + FP)
R = TP * 1.0/(TP + FN)
F = 2*P*R/(P + R)
logger.log('F:{}'.format(F))
return F
4.onnx转换脚本
import torch, os, cv2
from model.model import parsingNet
import torch
import scipy.special, tqdm
import numpy as np
import argparse
import yaml
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--params', default = 'configs/culane.yaml', type = str)
parser.add_argument('--batch_size', default = 1, type = int)
parser.add_argument('--weights', default = 'model_last.pth', type = str)
parser.add_argument('--img-size', nargs='+', type=int, default=[256, 512], help='image size') # height, width
return parser
if __name__ == '__main__':
args = get_args().parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
with open(args.params) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader) # data dict
net = parsingNet(network=cfg['network'],datasets=cfg['dataset']).cuda()
state_dict = torch.load(args.weights, map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
print(k)
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
exit()
net.load_state_dict(compatible_state_dict, strict=False)
net.eval()
print('val done!!!')
img = torch.zeros(args.batch_size, 3, *args.img_size) # image size(1,3,320,192) iDetection
img = img.cuda()
with torch.no_grad():
out = net(img)
# ONNX export
try:
import onnx
from onnxsim import simplify
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = args.weights.replace('.pth', '.onnx') # filename
torch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],
output_names=['output'])
# Checks
onnx_model = onnx.load(f) # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, f)
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)