HRNet-Semantic-Segmentation图像,视频推理

淳于星宇
2023-12-01

源码:https://github.com/HRNet/HRNet-Semantic-Segmentation/,我用的是pytorchv1.1分支。
这么好的项目居然没有inference代码,于是自己整理了一个简单的demo。

jit和onnx model导出

jit模型需要torch>=1.8

import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
from utils.utils import create_logger, FullModel, get_rank
from onnxruntime.datasets import get_example
import onnxruntime
from onnx import shape_inference
import os
from torch.nn import functional as F
import cv2
import numpy as np

def jit_export(model, pth_file):
    pretrained_dict = torch.load(pth_file, map_location="cpu")
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                        if k[6:] in model_dict.keys()}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    model.eval()
    
    dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
            )
    print(model)
    print(dump_input.shape)

    traced_script_module = torch.jit.trace(model, dump_input)
    traced_script_module.save("export_models/export_model.pt")
    
    new_model = torch.jit.load("export_models/export_model.pt")
    dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
            )
    out = new_model(dump_input)
    print(out.shape)

def onnx_export(model, pth_file):
    pretrained_dict = torch.load(pth_file, map_location="cpu")
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                        if k[6:] in model_dict.keys()}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    
    dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
            )

    export_onnx_file = os.path.join("export_models",os.path.basename(args.pth_file).replace("pth","onnx"))

    torch.onnx.export(model.cpu(), dump_input.cpu(), export_onnx_file, verbose=True)
    
    dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
            )

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    model.eval()
    x = torch.randn(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]).cpu()
    with torch.no_grad():
        torch_out = model(x)
    example_model = get_example(os.getcwd()+'/'+export_onnx_file)

    sess = onnxruntime.InferenceSession(example_model)
    onnx_out = sess.run(None, {sess.get_inputs()[0].name: to_numpy(x)})

    print(torch_out.shape,torch_out[0,0,0,0:10])
    print(onnx_out[0].shape,onnx_out[0][0,0,0,0:10])

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train segmentation network')
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)
    parser.add_argument('--pth_file',type=str)
    parser.add_argument('--image_path',type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)


    args = parser.parse_args()
    update_config(config, args)
    pth_file = args.pth_file
    image_path = args.image_path

    model = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)
    model.to("cpu")
    
	# onnx_export(model, pth_file)
    jit_export(model, pth_file)

jit、pth 模型图像、视频推理

jit模型需要torch>=1.8,pth模型随意,注意修改输入大小,mean和std。

import torch
import torchvision
import argparse
import _init_paths
from config import config
from config import update_config
import models
import os
from torch.nn import functional as F
import cv2
import numpy as np
import time

def preprocess(img, model, device):
    def input_transform(image): 
        image = image.astype(np.float32)[:, :, ::-1]
        image = image / 255.0
        image -= mean
        image /= std
        return image
    
    def image_resize(image, long_size, label=None):
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = np.int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = np.int(h * long_size / w + 0.5)
        
        image = cv2.resize(image, (new_w, new_h), 
                           interpolation = cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h), 
                           interpolation = cv2.INTER_NEAREST)
        else:
            return image
        
        return image, label
    
    def pad_image(image, h, w, size, padvalue):
        pad_image = image.copy()
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)
        if pad_h > 0 or pad_w > 0:
            pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT, 
                value=padvalue)
        
        return pad_image
    
    def multi_scale_aug(image, label=None, 
            rand_scale=1, rand_crop=True):
        long_size = 473
        if label is not None:
            image, label = image_resize(image, long_size, label)
            if rand_crop:
                image, label = rand_crop(image, label)
            return image, label
        else:
            image = image_resize(image, long_size)
            return image

    def infer(model, image):
        size = image.size()
        # start = time.time()
        pred = model(image)
        # print("inference time:",time.time()-start)
        pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')  

        return pred.exp()

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    padvalue = -1.0  * np.array(mean) / np.array(std)
    crop_size = (473, 473)

    
    new_img = multi_scale_aug(img)
    height, width = new_img.shape[:-1]
    new_img = input_transform(new_img)
    

    if max(height, width) <= np.min(crop_size):
        new_img = pad_image(new_img, height, width, crop_size, padvalue)
        new_img = new_img.transpose((2, 0, 1))
        new_img = np.expand_dims(new_img, axis=0)
        new_img = torch.from_numpy(new_img).to(device)
        preds = infer(model, new_img)
        preds = preds[:, :, 0:height, 0:width]

    else:
        if height < crop_size[0] or width < crop_size[1]:
            new_img = pad_image(new_img, height, width, crop_size, padvalue)
        new_h, new_w = new_img.shape[:-1]
        rows = np.int(np.ceil(1.0 * (new_h - 
                        crop_size[0]) / stride_h)) + 1
        cols = np.int(np.ceil(1.0 * (new_w - 
                        crop_size[1]) / stride_w)) + 1
        preds = torch.zeros([1, 2, new_h, new_w]).to(device)
        count = torch.zeros([1, 1, new_h, new_w]).to(device)

        for r in range(rows):
            for c in range(cols):
                h0 = r * stride_h
                w0 = c * stride_w
                h1 = min(h0 + crop_size[0], new_h)
                w1 = min(w0 + crop_size[1], new_w)
                crop_img = new_img[h0:h1, w0:w1, :]
                if h1 == new_h or w1 == new_w:
                    crop_img = ad_image(crop_img, 
                                                h1-h0, 
                                                w1-w0, 
                                                crop_size, 
                                                padvalue)
                crop_img = crop_img.transpose((2, 0, 1))
                crop_img = np.expand_dims(crop_img, axis=0)
                crop_img = torch.from_numpy(crop_img).to(device)
                preds = infer(model, crop_img)

                preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
                count[:,:,h0:h1,w0:w1] += 1
        preds = preds / count
        preds = preds[:,:,:height,:width]

    return preds
        
def inference(model, pth_file, image_path, device):
    image = cv2.imread(image_path)
    ori_height, ori_width = image.shape[0], image.shape[1]

    preds = preprocess(image.copy(), model, device)

    #preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

    preds = preds.detach().cpu().numpy().copy()
    preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

    preds = preds.astype(np.uint8).transpose((1,2,0))
    preds[preds==1] = 255
    preds[preds!=255] = 0

    preds = cv2.merge([preds,preds,preds])
    image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
    image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
    cv2.imshow(image_path, image_merge)
    cv2.waitKey(0)

def inference_dir(model, pth_file, image_path, device):
    for name in os.listdir(image_path):
        image = cv2.imread(os.path.join(image_path, name))
        ori_height, ori_width = image.shape[0], image.shape[1]

        preds = preprocess(image.copy(), model, device)

        # preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

        preds = preds.detach().cpu().numpy().copy()
        preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

        preds = preds.astype(np.uint8).transpose((1,2,0))
        preds[preds==1] = 255
        preds[preds!=255] = 0

        preds = cv2.merge([preds,preds,preds])
        image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
        image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
        cv2.imshow(name,image_merge)
        key = cv2.waitKey(0)
        cv2.destroyWindow(name)
        if key == 27:
            break

def inference_video(model, pth_file, video_path, device, save=False):
    if video_path.endswith(".mp4"):
        vc = cv2.VideoCapture(video_path)
    else:
        vc = cv2.VideoCapture(0)

    if vc.isOpened():
        rval, frame = vc.read()
    else:
        rval = False

    start_time = time.time()
    frame_count = 1

    if save:
        fps = vc.get(cv2.CAP_PROP_FPS) 
        width, height = frame.shape[1], frame.shape[0]
        resize_ratio = 1.0 * 473 / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4v
        outVideo = cv2.VideoWriter(save, fourcc,fps,target_size)

    while rval:
        rval, frame = vc.read()
        if rval == False:
            if save:
                outVideo.release()
            break
        
        ori_height, ori_width = frame.shape[0], frame.shape[1]
        # cv2.imshow("frame",frame)
        # cv2.waitKey(0)

        preds = preprocess(frame.copy(), model, device)
        #preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

        preds = preds.detach().cpu().numpy().copy()
        preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

        preds = preds.astype(np.uint8).transpose((1,2,0))
        preds[preds==1] = 255
        preds[preds!=255] = 0

        preds = cv2.merge([preds,preds,preds])
        frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
        image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
        cv2.imshow("merge", image_merge)
        key = cv2.waitKey(1)

        if save:
            image_merge = image_merge.astype(np.uint8)
            r = outVideo.write(image_merge)

        if key == 27:  # exit on ESC
            if save:
                outVideo.release()
            break
        if frame_count % 30 == 0:
            print("Frame Per second: {} fps.".format(
                (time.time() - start_time) / frame_count))
        frame_count = frame_count + 1
    cv2.destroyAllWindows()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train segmentation network')
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=False,
                        type=str)
    parser.add_argument('--pth_file',type=str)
    parser.add_argument('--image_path',type=str)
    parser.add_argument('--video_path',type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    device = "cuda:0"

    args = parser.parse_args()
    pth_file = args.pth_file
    image_path = args.image_path
    video_path = args.video_path

    # update_config(config, args)
    # model = eval('models.'+config.MODEL.NAME +'.get_seg_model')(config)
    # model.to(device)
    # pretrained_dict = torch.load(pth_file, map_location=device)
    # model_dict = model.state_dict()
    # pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
    #                     if k[6:] in model_dict.keys()}
    # model_dict.update(pretrained_dict)
    # model.load_state_dict(model_dict)
    # model.eval()

    # # onnx_export(model, pth_file)

    # if image_path:
    #     if os.path.isfile(image_path):
    #         inference(model, pth_file ,image_path, device)
    #     elif os.path.isdir(image_path):
    #         inference_dir(model, pth_file ,image_path, device)
    # elif video_path:
    #     inference_video(model, pth_file, video_path, device)

    model_jit = torch.jit.load("export_models/export_model.pt",map_location=device)
    # dump_input = torch.rand((1,3,473,473),device=device)
    # out = model_jit(dump_input)
    # print(out)
    # out2 = model(dump_input)
    # print(out2)

    if image_path:
        if os.path.isfile(image_path):
            inference(model_jit, pth_file ,image_path, device)
        elif os.path.isdir(image_path):
            inference_dir(model_jit, pth_file ,image_path, device)
    elif video_path:
        inference_video(model_jit, pth_file, video_path, device)

解决网络视频流阻塞问题

import torch
import torchvision
import argparse
from config import config
from config import update_config
from  seg_hrnet import get_seg_model
import os
from torch.nn import functional as F
import cv2
import numpy as np
import time
import threading

def preprocess(img, model, device):
    def input_transform(image): 
        image = image.astype(np.float32)[:, :, ::-1]
        image = image / 255.0
        image -= mean
        image /= std
        return image
    
    def image_resize(image, long_size, label=None):
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = np.int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = np.int(h * long_size / w + 0.5)
        
        image = cv2.resize(image, (new_w, new_h), 
                           interpolation = cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h), 
                           interpolation = cv2.INTER_NEAREST)
        else:
            return image
        
        return image, label
    
    def pad_image(image, h, w, size, padvalue):
        pad_image = image.copy()
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)
        if pad_h > 0 or pad_w > 0:
            pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, 
                pad_w, cv2.BORDER_CONSTANT, 
                value=padvalue)
        
        return pad_image
    
    def multi_scale_aug(image, label=None, 
            rand_scale=1, rand_crop=True):
        long_size = 473
        if label is not None:
            image, label = image_resize(image, long_size, label)
            if rand_crop:
                image, label = rand_crop(image, label)
            return image, label
        else:
            image = image_resize(image, long_size)
            return image

    def infer(model, image):
        size = image.size()
        # start = time.time()
        pred = model(image)
        # print("inference time:",time.time()-start)
        pred = F.upsample(input=pred, size=(size[-2], size[-1]), mode='bilinear')  

        return pred.exp()

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    padvalue = -1.0  * np.array(mean) / np.array(std)
    crop_size = (473, 473)

    
    new_img = multi_scale_aug(img)
    height, width = new_img.shape[:-1]
    new_img = input_transform(new_img)
    

    if max(height, width) <= np.min(crop_size):
        #new_img = pad_image(new_img, height, width, crop_size, padvalue)
        new_img = new_img.transpose((2, 0, 1))
        new_img = np.expand_dims(new_img, axis=0)
        new_img = torch.from_numpy(new_img).to(device)
        preds = infer(model, new_img)
        #preds = preds[:, :, 0:height, 0:width]

    else:
        if height < crop_size[0] or width < crop_size[1]:
            new_img = pad_image(new_img, height, width, crop_size, padvalue)
        new_h, new_w = new_img.shape[:-1]
        rows = np.int(np.ceil(1.0 * (new_h - 
                        crop_size[0]) / stride_h)) + 1
        cols = np.int(np.ceil(1.0 * (new_w - 
                        crop_size[1]) / stride_w)) + 1
        preds = torch.zeros([1, 2, new_h, new_w]).to(device)
        count = torch.zeros([1, 1, new_h, new_w]).to(device)

        for r in range(rows):
            for c in range(cols):
                h0 = r * stride_h
                w0 = c * stride_w
                h1 = min(h0 + crop_size[0], new_h)
                w1 = min(w0 + crop_size[1], new_w)
                crop_img = new_img[h0:h1, w0:w1, :]
                if h1 == new_h or w1 == new_w:
                    crop_img = ad_image(crop_img, 
                                                h1-h0, 
                                                w1-w0, 
                                                crop_size, 
                                                padvalue)
                crop_img = crop_img.transpose((2, 0, 1))
                crop_img = np.expand_dims(crop_img, axis=0)
                crop_img = torch.from_numpy(crop_img).to(device)
                preds = infer(model, crop_img)

                preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
                count[:,:,h0:h1,w0:w1] += 1
        preds = preds / count
        preds = preds[:,:,:height,:width]

    return preds
        
def inference(model, pth_file, image_path, device):
    image = cv2.imread(image_path)
    ori_height, ori_width = image.shape[0], image.shape[1]

    preds = preprocess(image.copy(), model, device)

    #preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

    preds = preds.detach().cpu().numpy().copy()
    preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

    preds = preds.astype(np.uint8).transpose((1,2,0))
    preds[preds==1] = 255
    preds[preds!=255] = 0

    preds = cv2.merge([preds,preds,preds])
    image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
    image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
    cv2.imshow(image_path, image_merge)
    cv2.waitKey(0)

def inference_dir(model, pth_file, image_path, device):
    for name in os.listdir(image_path):
        image = cv2.imread(os.path.join(image_path, name))
        ori_height, ori_width = image.shape[0], image.shape[1]

        preds = preprocess(image.copy(), model, device)

        # preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

        preds = preds.detach().cpu().numpy().copy()
        preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

        preds = preds.astype(np.uint8).transpose((1,2,0))
        preds[preds==1] = 255
        preds[preds!=255] = 0

        preds = cv2.merge([preds,preds,preds])
        image = cv2.resize(image,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
        image_merge = cv2.addWeighted(image, 0.5, preds, 0.5, 0)
        cv2.imshow(name,image_merge)
        key = cv2.waitKey(0)
        cv2.destroyWindow(name)
        if key == 27:
            break

def inference_video(model, pth_file, video_path, device, save=False):
    if video_path.endswith(".mp4"):
        vc = cv2.VideoCapture(video_path)
    else:
        vc = cv2.VideoCapture(0)

    if vc.isOpened():
        rval, frame = vc.read()
    else:
        rval = False

    start_time = time.time()
    frame_count = 1

    if save:
        fps = vc.get(cv2.CAP_PROP_FPS) 
        width, height = frame.shape[1], frame.shape[0]
        resize_ratio = 1.0 * 473 / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')#mp4v
        outVideo = cv2.VideoWriter(save, fourcc,fps,target_size)

    while rval:
        rval, frame = vc.read()
        if rval == False:
            if save:
                outVideo.release()
            break
        
        ori_height, ori_width = frame.shape[0], frame.shape[1]
        # cv2.imshow("frame",frame)
        # cv2.waitKey(0)

        preds = preprocess(frame.copy(), model, device)
        #preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

        preds = preds.detach().cpu().numpy().copy()
        preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

        preds = preds.astype(np.uint8).transpose((1,2,0))
        preds[preds==1] = 255
        preds[preds!=255] = 0

        preds = cv2.merge([preds,preds,preds])
        frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
        image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
        cv2.imshow("merge", image_merge)
        key = cv2.waitKey(1)

        if save:
            image_merge = image_merge.astype(np.uint8)
            r = outVideo.write(image_merge)

        if key == 27:  # exit on ESC
            if save:
                outVideo.release()
            break
        if frame_count % 30 == 0:
            print("Frame Per second: {} fps.".format(
                (time.time() - start_time) / frame_count))
        frame_count = frame_count + 1
    cv2.destroyAllWindows()

class Stack:
    def __init__(self, stack_size):
        self.items = []
        self.stack_size = stack_size
        self.flag = True
 
    def is_empty(self):
        return len(self.items) == 0
 
    def pop(self):
        return self.items.pop()
 
    def peek(self):
        if not self.isEmpty():
            return self.items[len(self.items) - 1]
 
    def size(self):
        return len(self.items)
 
    def push(self, item):
        if self.size() >= self.stack_size:
            for i in range(self.size() - self.stack_size + 1):
                self.items.remove(self.items[0])
        self.items.append(item)

    def end(self):
        self.flag = False

def capture_thread(video_path, frame_buffer, lock):
    print("capture_thread start")
    vid = cv2.VideoCapture(video_path)
    if not vid.isOpened():
        raise IOError("Couldn't open webcam or video")
    while True:
        return_value, frame = vid.read()
        if return_value is not True or frame_buffer.flag is not True:
            break
        lock.acquire()
        frame_buffer.push(frame)
        lock.release()

def play_thread(frame_buffer, lock, model):
    print("detect_thread start")
    print("detect_thread frame_buffer size is", frame_buffer.size())
 
    while True:
        if frame_buffer.size() > 0:
            lock.acquire()
            frame = frame_buffer.pop()
            lock.release()
            #  算法
            ori_height, ori_width = frame.shape[0], frame.shape[1]

            preds = preprocess(frame.copy(), model, device)
            #preds = F.upsample(preds, (ori_height, ori_width), mode='bilinear')

            preds = preds.detach().cpu().numpy().copy()
            preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8)

            preds = preds.astype(np.uint8).transpose((1,2,0))
            preds[preds==1] = 255
            preds[preds!=255] = 0

            preds = cv2.merge([preds,preds,preds])
            frame = cv2.resize(frame,(preds.shape[1],preds.shape[0]),interpolation=cv2.INTER_LINEAR)
            image_merge = cv2.addWeighted(frame, 0.5, preds, 0.5, 0)
            cv2.imshow("merge", image_merge)
            key = cv2.waitKey(1)


            key = cv2.waitKey(1)

            if key == 27:  # exit on ESC
                frame_buffer.end()
                break
            

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train segmentation network')
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default = "8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.yaml",
                        required=False,
                        type=str)
    parser.add_argument('--pth_file',type=str,default="8dataset_custom_seg_hrnetv1_w18_473x473_sgd_lr7e-3_wd5e-4_bs_32_epoch100.pth")
    parser.add_argument('--video_path',type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    device = "cuda:0"

    args = parser.parse_args()
    pth_file = args.pth_file
    video_path = args.video_path

    update_config(config, args)
    model = get_seg_model(config)
    model.to(device)
    pretrained_dict = torch.load(pth_file, map_location=device)
    model_dict = model.state_dict()
    pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                        if k[6:] in model_dict.keys()}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    model.eval()

    frame_buffer = Stack(3)
    lock = threading.RLock()
    t1 = threading.Thread(target=capture_thread, args=(video_path, frame_buffer, lock))
    t1.start()
    t2 = threading.Thread(target=play_thread, args=(frame_buffer, lock, model))
    t2.start()

imgviz可视化

def vis(lbl, img):
    if len(img.shape) == 2:
        img = cv2.merge([img,img,img])
    viz = imgviz.label2rgb(
                label=lbl,
                img=imgviz.rgb2gray(img),
                font_size=15,
                loc="rb",
            )
    return viz
 类似资料: