源码:https://github.com/HRNet/HRNet-Semantic-Segmentation/,我用的是pytorchv1.1分支。
这么好的项目居然没有inference代码,于是自己整理了一个简单的demo。
jit模型需要torch>=1.8
import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
from utils.utils import create_logger, FullModel, get_rank
from onnxruntime.datasets import get_example
import onnxruntime
from onnx import shape_inference
import os
from torch.nn import functional as F
import cv2
import numpy as np
def jit_export(model, pth_file):
pretrained_dict = torch.load(pth_file, map_location="cpu")
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
print(model)
print(dump_input.shape)
traced_script_module = torch.jit.trace(model, dump_input)
traced_script_module.save("export_models/export_model.pt")
new_model = torch.jit.load("export_models/export_model.pt")
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
out = new_model(dump_input)
print(out.shape)
def onnx_export(model, pth_file):
pretrained_dict = torch.load(pth_file, map_location="cpu")
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
export_onnx_file = os.path.join("export_models",os.path.basename(args.pth_file).replace("pth","onnx"))
torch.onnx.export(model.cpu(), dump_input.cpu(), export_onnx_file, verbose=True)
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
model.eval()
x = torch.randn(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]).cpu()
with torch.no_grad():
torch_out = model(x)
example_model = get_example(os.getcwd()+'/'+export_onnx_file)
sess = onnxruntime.InferenceSession(example_model)
onnx_out = sess.run(None, {sess.get_inputs()[0].name: to_numpy(x)})
print(torch_out.shape,torch_out[0,0,0,0:10])
print(onnx_out[0].shape,onnx_out[0][0,0,0,0:10])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train segmentation network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('--pth_file',type=str)
parser.add_argument('--image_path',type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
pth_file = args.pth_file
image_path = args.image_path
model = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)
model.to("cpu")
# onnx_export(model, pth_file)
jit_export(model, pth_file)
jit模型需要torch>=1.8,pth模型随意,注意修改输入大小,mean和std。
import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
import os
from torch.nn import functional as F
import cv2
import numpy as np
import time
def preprocess(img, model, device):
def input_transform(image):
image = image.astype(np.float32)[:, :, ::-1]
image = image / 255.0
image -= mean
image /= std
return image
def image_resize(image, long_size, label=None):
h, w = image.shape[:2]
if h > w:
new_h = long_size
new_w = np.int(w * long_size / h + 0.5)
else:
new_w = long_size
new_h = np.int(h * long_size / w + 0.5)
image = cv2.resize(image, (new_w, new_h),
interpolation = cv2.INTER_LINEAR)
if label is not None:
label = cv2.resize(label, (new_w, new_h),
interpolation = cv2.INTER_NEAREST)
else:
return image
return image, label
def pad_image(image, h, w, size, padvalue):
pad_image = image.copy()
pad_h = max(size[0] - h, 0)
pad_w = max(size[1] - w, 0)
if pad_h > 0 or pad_w > 0:
pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
pad_w, cv2.BORDER_CONSTANT,
value=padvalue)
return pad_image
def multi_scale_aug(image, label=None,
rand_scale=1, rand_crop=True):
long_size = 473
if label is not None:
image, label = image_resize(image, long_size, label)
if rand_crop:
image, label = rand_crop(image, label)
return image, label
else:
image = image_resize(image, long_size)
return image
def infer(model, image):
size = image.size()
# start = time.time()
pred = model(image)
# print("inference time:",time.time()-start)
pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')
return pred.exp()
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
padvalue = -1.0 * np.array(mean) / np.array(std)
crop_size = (473, 473)
new_img = multi_scale_aug(img)
height, width = new_img.shape[:-1]
new_img = input_transform(new_img)
if max(height, width) <= np.min(crop_size):
new_img = pad_image(new_img, height, width, crop_size, padvalue)
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img).to(device)
preds = infer(model, new_img)
preds = preds[:, :, 0:height, 0:width]
else:
if height < crop_size[0] or width < crop_size[1]:
new_img = pad_image(new_img, height, width, crop_size, padvalue)
new_h, new_w = new_img.shape[:-1]
rows = np.int(np.ceil(1.0 * (new_h -
crop_size[0]) / stride_h)) + 1
cols = np.int(np.ceil(1.0 * (new_w -
crop_size[1]) / stride_w)) + 1
preds = torch.zeros([1, 2, new_h, new_w]).to(device)
count = torch.zeros([1, 1, new_h, new_w]).to(device)
for r in range(rows):
for c in range(cols):
h0 = r * stride_h
w0 = c * stride_w
h1 = min(h0 + crop_size[0], new_h)
w1 = min(w0 + crop_size[1], new_w)
crop_img = new_img[h0:h1, w0:w1, :]
if h1 == new_h or w1 == new_w:
crop_img = ad_image(crop_img,
h1-h0,
w1-w0,
crop_size,
padvalue)
crop_img = crop_img.transpose((2, 0, 1))
crop_img = np.expand_dims(crop_img, axis=0)
crop_img = torch.from_numpy(crop_img).to(device)
preds = infer(model, crop_img)
preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
count[:,:,h0:h1,w0:w1] += 1
preds = preds / count
preds = preds[:,:,:height,:width]
return preds
def inference(model, pth_file, image_path, device):
image = cv2.imread(image_path)
ori_height, ori_width = image.shape[0], image.shape[1]
preds = preprocess(image.copy(), model, device)
#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
cv2.imshow(image_path, image_merge)
cv2.waitKey(0)
def inference_dir(model, pth_file, image_path, device):
for name in os.listdir(image_path):
image = cv2.imread(os.path.join(image_path, name))
ori_height, ori_width = image.shape[0], image.shape[1]
preds = preprocess(image.copy(), model, device)
# preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
cv2.imshow(name,image_merge)
key = cv2.waitKey(0)
cv2.destroyWindow(name)
if key == 27:
break
def inference_video(model, pth_file, video_path, device, save=False):
if video_path.endswith(".mp4"):
vc = cv2.VideoCapture(video_path)
else:
vc = cv2.VideoCapture(0)
if vc.isOpened():
rval, frame = vc.read()
else:
rval = False
start_time = time.time()
frame_count = 1
if save:
fps = vc.get(cv2.CAP_PROP_FPS)
width, height = frame.shape[1], frame.shape[0]
resize_ratio = 1.0 * 473 / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4v
outVideo = cv2.VideoWriter(save, fourcc,fps,target_size)
while rval:
rval, frame = vc.read()
if rval == False:
if save:
outVideo.release()
break
ori_height, ori_width = frame.shape[0], frame.shape[1]
# cv2.imshow("frame",frame)
# cv2.waitKey(0)
preds = preprocess(frame.copy(), model, device)
#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
cv2.imshow("merge", image_merge)
key = cv2.waitKey(1)
if save:
image_merge = image_merge.astype(np.uint8)
r = outVideo.write(image_merge)
if key == 27: # exit on ESC
if save:
outVideo.release()
break
if frame_count % 30 == 0:
print("Frame Per second: {} fps.".format(
(time.time() - start_time) / frame_count))
frame_count = frame_count + 1
cv2.destroyAllWindows()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train segmentation network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=False,
type=str)
parser.add_argument('--pth_file',type=str)
parser.add_argument('--image_path',type=str)
parser.add_argument('--video_path',type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
device = "cuda:0"
args = parser.parse_args()
pth_file = args.pth_file
image_path = args.image_path
video_path = args.video_path
# update_config(config, args)
# model = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)
# model.to(device)
# pretrained_dict = torch.load(pth_file, map_location=device)
# model_dict = model.state_dict()
# pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
# if k[6:] in model_dict.keys()}
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)
# model.eval()
# # onnx_export(model, pth_file)
# if image_path:
# if os.path.isfile(image_path):
# inference(model, pth_file ,image_path, device)
# elif os.path.isdir(image_path):
# inference_dir(model, pth_file ,image_path, device)
# elif video_path:
# inference_video(model, pth_file, video_path, device)
model_jit = torch.jit.load("export_models/export_model.pt",map_location=device)
# dump_input = torch.rand((1,3,473,473),device=device)
# out = model_jit(dump_input)
# print(out)
# out2 = model(dump_input)
# print(out2)
if image_path:
if os.path.isfile(image_path):
inference(model_jit, pth_file ,image_path, device)
elif os.path.isdir(image_path):
inference_dir(model_jit, pth_file ,image_path, device)
elif video_path:
inference_video(model_jit, pth_file, video_path, device)
import torch
import torchvision
import argparse
from config import config
from config import update_config
from seg_hrnet import get_seg_model
import os
from torch.nn import functional as F
import cv2
import numpy as np
import time
import threading
def preprocess(img, model, device):
def input_transform(image):
image = image.astype(np.float32)[:, :, ::-1]
image = image / 255.0
image -= mean
image /= std
return image
def image_resize(image, long_size, label=None):
h, w = image.shape[:2]
if h > w:
new_h = long_size
new_w = np.int(w * long_size / h + 0.5)
else:
new_w = long_size
new_h = np.int(h * long_size / w + 0.5)
image = cv2.resize(image, (new_w, new_h),
interpolation = cv2.INTER_LINEAR)
if label is not None:
label = cv2.resize(label, (new_w, new_h),
interpolation = cv2.INTER_NEAREST)
else:
return image
return image, label
def pad_image(image, h, w, size, padvalue):
pad_image = image.copy()
pad_h = max(size[0] - h, 0)
pad_w = max(size[1] - w, 0)
if pad_h > 0 or pad_w > 0:
pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
pad_w, cv2.BORDER_CONSTANT,
value=padvalue)
return pad_image
def multi_scale_aug(image, label=None,
rand_scale=1, rand_crop=True):
long_size = 473
if label is not None:
image, label = image_resize(image, long_size, label)
if rand_crop:
image, label = rand_crop(image, label)
return image, label
else:
image = image_resize(image, long_size)
return image
def infer(model, image):
size = image.size()
# start = time.time()
pred = model(image)
# print("inference time:",time.time()-start)
pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')
return pred.exp()
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
padvalue = -1.0 * np.array(mean) / np.array(std)
crop_size = (473, 473)
new_img = multi_scale_aug(img)
height, width = new_img.shape[:-1]
new_img = input_transform(new_img)
if max(height, width) <= np.min(crop_size):
#new_img = pad_image(new_img, height, width, crop_size, padvalue)
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img).to(device)
preds = infer(model, new_img)
#preds = preds[:, :, 0:height, 0:width]
else:
if height < crop_size[0] or width < crop_size[1]:
new_img = pad_image(new_img, height, width, crop_size, padvalue)
new_h, new_w = new_img.shape[:-1]
rows = np.int(np.ceil(1.0 * (new_h -
crop_size[0]) / stride_h)) + 1
cols = np.int(np.ceil(1.0 * (new_w -
crop_size[1]) / stride_w)) + 1
preds = torch.zeros([1, 2, new_h, new_w]).to(device)
count = torch.zeros([1, 1, new_h, new_w]).to(device)
for r in range(rows):
for c in range(cols):
h0 = r * stride_h
w0 = c * stride_w
h1 = min(h0 + crop_size[0], new_h)
w1 = min(w0 + crop_size[1], new_w)
crop_img = new_img[h0:h1, w0:w1, :]
if h1 == new_h or w1 == new_w:
crop_img = ad_image(crop_img,
h1-h0,
w1-w0,
crop_size,
padvalue)
crop_img = crop_img.transpose((2, 0, 1))
crop_img = np.expand_dims(crop_img, axis=0)
crop_img = torch.from_numpy(crop_img).to(device)
preds = infer(model, crop_img)
preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
count[:,:,h0:h1,w0:w1] += 1
preds = preds / count
preds = preds[:,:,:height,:width]
return preds
def inference(model, pth_file, image_path, device):
image = cv2.imread(image_path)
ori_height, ori_width = image.shape[0], image.shape[1]
preds = preprocess(image.copy(), model, device)
#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
cv2.imshow(image_path, image_merge)
cv2.waitKey(0)
def inference_dir(model, pth_file, image_path, device):
for name in os.listdir(image_path):
image = cv2.imread(os.path.join(image_path, name))
ori_height, ori_width = image.shape[0], image.shape[1]
preds = preprocess(image.copy(), model, device)
# preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
cv2.imshow(name,image_merge)
key = cv2.waitKey(0)
cv2.destroyWindow(name)
if key == 27:
break
def inference_video(model, pth_file, video_path, device, save=False):
if video_path.endswith(".mp4"):
vc = cv2.VideoCapture(video_path)
else:
vc = cv2.VideoCapture(0)
if vc.isOpened():
rval, frame = vc.read()
else:
rval = False
start_time = time.time()
frame_count = 1
if save:
fps = vc.get(cv2.CAP_PROP_FPS)
width, height = frame.shape[1], frame.shape[0]
resize_ratio = 1.0 * 473 / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4v
outVideo = cv2.VideoWriter(save, fourcc,fps,target_size)
while rval:
rval, frame = vc.read()
if rval == False:
if save:
outVideo.release()
break
ori_height, ori_width = frame.shape[0], frame.shape[1]
# cv2.imshow("frame",frame)
# cv2.waitKey(0)
preds = preprocess(frame.copy(), model, device)
#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
cv2.imshow("merge", image_merge)
key = cv2.waitKey(1)
if save:
image_merge = image_merge.astype(np.uint8)
r = outVideo.write(image_merge)
if key == 27: # exit on ESC
if save:
outVideo.release()
break
if frame_count % 30 == 0:
print("Frame Per second: {} fps.".format(
(time.time() - start_time) / frame_count))
frame_count = frame_count + 1
cv2.destroyAllWindows()
class Stack:
def __init__(self, stack_size):
self.items = []
self.stack_size = stack_size
self.flag = True
def is_empty(self):
return len(self.items) == 0
def pop(self):
return self.items.pop()
def peek(self):
if not self.isEmpty():
return self.items[len(self.items) - 1]
def size(self):
return len(self.items)
def push(self, item):
if self.size() >= self.stack_size:
for i in range(self.size() - self.stack_size + 1):
self.items.remove(self.items[0])
self.items.append(item)
def end(self):
self.flag = False
def capture_thread(video_path, frame_buffer, lock):
print("capture_thread start")
vid = cv2.VideoCapture(video_path)
if not vid.isOpened():
raise IOError("Couldn't open webcam or video")
while True:
return_value, frame = vid.read()
if return_value is not True or frame_buffer.flag is not True:
break
lock.acquire()
frame_buffer.push(frame)
lock.release()
def play_thread(frame_buffer, lock, model):
print("detect_thread start")
print("detect_thread frame_buffer size is", frame_buffer.size())
while True:
if frame_buffer.size() > 0:
lock.acquire()
frame = frame_buffer.pop()
lock.release()
# 算法
ori_height, ori_width = frame.shape[0], frame.shape[1]
preds = preprocess(frame.copy(), model, device)
#preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')
preds = preds.detach().cpu().numpy().copy()
preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)
preds = preds.astype(np.uint8).transpose((1,2,0))
preds[preds==1] = 255
preds[preds!=255] = 0
preds = cv2.merge([preds,preds,preds])
frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
cv2.imshow("merge", image_merge)
key = cv2.waitKey(1)
key = cv2.waitKey(1)
if key == 27: # exit on ESC
frame_buffer.end()
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train segmentation network')
parser.add_argument('--cfg',
help='experiment configure file name',
default = "8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.yaml",
required=False,
type=str)
parser.add_argument('--pth_file',type=str,default="8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.pth")
parser.add_argument('--video_path',type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
device = "cuda:0"
args = parser.parse_args()
pth_file = args.pth_file
video_path = args.video_path
update_config(config, args)
model = get_seg_model(config)
model.to(device)
pretrained_dict = torch.load(pth_file, map_location=device)
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
frame_buffer = Stack(3)
lock = threading.RLock()
t1 = threading.Thread(target=capture_thread, args=(video_path, frame_buffer, lock))
t1.start()
t2 = threading.Thread(target=play_thread, args=(frame_buffer, lock, model))
t2.start()
def vis(lbl, img):
if len(img.shape) == 2:
img = cv2.merge([img,img,img])
viz = imgviz.label2rgb(
label=lbl,
img=imgviz.rgb2gray(img),
font_size=15,
loc="rb",
)
return viz