Segment Anything(1)

翟功
2023-12-01

Segment Anything(1)

项目地址:https://github.com/facebookresearch/segment-anything

最新的CV大模型SAM,将会是未来各种下游任务的基石。

本系列将会介绍SAM的安装、推理、延展,特别是推理时使用到的函数接口,以及SAM能够分别与tracking和classification相结合,以便开展一系列下游任务。

SAM+tracking:https://github.com/gaomingqi/Track-Anything、https://github.com/z-x-yang/Segment-and-Track-Anything

SAM+classification:https://github.com/facebookresearch/ov-seg

SAM的基本使用

安装

环境安装起来很简单,下面几行即可。

git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
pip install opencv-python pycocotools matplotlib onnxruntime onnx

下载权重参数文件

推理

SAM是基于prompt的分割模型,输入数据可以是point(pos/neg), bbox,甚至是mask和text。但是,目前无论是demo官网还是github中的案例中,仅仅给出了三种prompt方式,分别是point, bbox和everything(自动生成)。

github中的代码里有输入mask的接口,但是笔者仅输入mask的话是不能work的,等待官方的反馈ing。

1 导入库/构建可视化函数

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

2 初始化

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

# Init
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# 构建预测器,并设置图片,后续将利用predictor对象进行推理
predictor = SamPredictor(sam)
predictor.set_image(image)

3 Inference

3.1 point2mask

我们可以将若干个point作为prompt输入至SAM模型,包括pos和neg两种点,分别对应的label为1和0。

# 输入正/负样本点
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
# point2mask推理
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False,
)
# 绘制图像
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 
3.2 bbox2mask

我们可以将若干个bbox作为prompt输入至SAM模型,得到对应不同的mask。

# 输入若干个bbox
input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
# bbox2mask推理
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)
# 绘制图像
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
3.3 point_bbox2mask

只需结合3.1和3.2即可。

# 输入正/负样本点
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
# 输入若干个bbox
input_boxes = torch.tensor([
    [75, 275, 1725, 850],
    [425, 600, 700, 875],
    [1375, 550, 1650, 800],
    [1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
# point_bbox2mask推理
masks, _, _ = predictor.predict_torch(
    point_coords=input_point,
    point_labels=input_label,
    boxes=transformed_boxes,
    multimask_output=False,
)
# 绘制图像
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
 类似资料: