background matting v3 inference webcam

笪智志
2023-12-01

代码


import torch
from model import MattingNetwork
from torchvision import transforms
from PIL import Image
import cv2
from inference_utils import OneFrameReader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))


from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

import numpy as np
import cv2 as cv
from torchvision.transforms import ToTensor

totensor = ToTensor()
cap = cv.VideoCapture('videos/qihang/qihang.mp4')
while cap.isOpened():
    ret, frame = cap.read()
    # if frame is read correctly ret is True
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break
    # gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)

    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = totensor(frame)
    frame = torch.unsqueeze(frame, 0)

    fgr, pha, *rec = model(frame.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
    com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
    com = com.mul(255).byte()
    com = torch.squeeze(com, 0) # c h w
    com = torch.permute(com, (1,2,0)) # h w c
    com = com.detach().cpu().numpy()
    com = com[..., ::-1] # bgr2rgb

    cv.imshow('com', com)
    if cv.waitKey(1) == ord('q'):
        break
cap.release()
cv.destroyAllWindows()

备注

1,如果是webcam, cv.VideoCapture(‘videos/qihang/qihang.mp4’) 的参数改为’0’。

2, opencv 的cv2.imread, cv2.imwrite, cv2.imshow 的图片的通道模式都是bgr。

3,刚接触一个代码模块时,要先用样例跑,而不是自己先敲, 可能会敲漏而跑不出结果,到处找原因很麻烦。

 类似资料:

相关阅读

相关文章

相关问答