当前位置: 首页 > 工具软件 > FLANN > 使用案例 >

OpenCV Flann

都建树
2023-12-01

OpenCV Flann

简介

本文主要记录如何在python中使用opencv的flann模块进行图像点的最近邻搜索(仅作为作者自己的记录)。

FLANN 是一个用于高维空间数据的近似最近邻搜索的库,是一个经常会被用到的库。OpenCV flann 也实现了flann,其用途主要是用在特征匹配上面。这一部分可以参考 OpenCV Python tutorial Feature Matching 。下面介绍如何使用opencv-flann进行点的最近邻搜索。

使用 opencv flann_index 进行最近邻点搜索

这个是后面查资料时,在so上看到的。发现相比于 cv2.FlannBasedMatcher ,cv2.flann_Index 用于最近邻搜索要简单得多,也更合适。因此做一个补充,并移到前面来。

代码如下:

import cv2
import numpy as np
from matplotlib import pyplot as plt

if __name__ == '__main__':
    # generate dataset
    train_dataset_pts_nx2 = np.random.rand(1000, 2) * 1000

    # create KD-Tree object
    FLANN_INDEX_KDTREE = 1  # cv2 dost not contain FLANN_INDEX_KDTREE
    index_param = {"algorithm": FLANN_INDEX_KDTREE, "trees": 1}

    # generate dataset, the input must be float32
    flann_index = cv2.flann_Index(np.float32(train_dataset_pts_nx2), 
                                  index_param)

    # generate query points
    query_pts_nx2 = np.random.rand(100, 2) * 1000

    # query points must be float32 also
    indices, dists = flann_index.knnSearch(np.float32(query_pts_nx2), 
                                           knn=1)
    indices_nx1 = np.squeeze(indices)
    dists_nx1 = np.squeeze(dists)

    assert len(indices) == len(query_pts_nx2)
    assert len(dists) == len(query_pts_nx2)

    correspondent_pts_nx2 = train_dataset_pts_nx2[indices_nx1]

    plt.scatter(train_dataset_pts_nx2[:, 0], train_dataset_pts_nx2[:, 1],
                s=10, label='dataset')
    plt.scatter(query_pts_nx2[:, 0], query_pts_nx2[:, 1],
                s=10, label='query')
    plt.scatter(correspondent_pts_nx2[:, 0], correspondent_pts_nx2[:, 1],
                s=10, label='correspondent')
    plt.legend(), plt.show()

关键点

  1. 输入的dataset array 必须是float32类型的
  2. 查询时的query points 必须是2d array,类型也必须是float32

使用opencv FlannBasedMatcher进行点的最近邻搜索

点的坐标也可以当做是一种2维的描述子,因此使用flann进行点的最近邻搜索,本质上和使用flann进行特征匹配没有区别。只是要注意传入的数据的类型和结构。

先放代码(关于其中出现的一些参数的解释,参考博客 OpenCV学习笔记-FLANN匹配器):

import cv2
import numpy as np
from matplotlib import pyplot as plt

if __name__ == '__main__':
    # generate dataset
    train_dataset_pts_nx2 = np.random.rand(1000, 2) * 1000

    # create KD-Tree object
    FLANN_INDEX_KDTREE = 1
    index_param = {"algorithm": FLANN_INDEX_KDTREE, "trees": 5}
    search_param = {"checks": 50}
    cv2_kdtree = cv2.FlannBasedMatcher(index_param, search_param)

    # reshape points from nx2 --> 1xnx2, because in OpenCV-Flann,
    # you can add different descriptors from different images at 
    # the same time. If the added array's shape is mxnxd, it indicates
    # that theses descriptors computed from m images, and each image 
    # contain n descriptors, and each descriptor has d dimension.
    # So, if you input a nx2 array, OpenCV will expand it to a 
    # nx1x2 array, which means the points come from n images, each 
    # image has only a single point.
    # The input array's shape will impact the match results, so be 
    # careful about it.
    # Another important point is that you must input a float32 array.
    train_dataset_pts_1xnx2 = np.expand_dims(train_dataset_pts_nx2, 0)
    # add dataset and train
    cv2_kdtree.add(np.float32(train_dataset_pts_1xnx2))
    cv2_kdtree.train()

    # generate query points
    query_pts_nx2 = np.random.rand(100, 2) * 1000

    # query, also input a float32 array, and set k=1 since we want 
    # a NN search only.
    match_results = cv2_kdtree.knnMatch(np.float32(query_pts_nx2), k=1)

    # the return value is a 2-level nested list of cv2.DMatch
    assert len(match_results) == len(query_pts_nx2)
    assert len(match_results[0]) == 1
    correspondent_pts_nx2 = []
    for query_pt, nn_match_results in zip(query_pts_nx2, match_results):
        for match_result in nn_match_results:
            assert isinstance(match_result, cv2.DMatch)

            # imgIdx is the index of first axis of the train dataset 
            # in shape mxnxd
            # trainIdx is the index of second axis the of train dataset 
            # in shape mxnxd
            # we can use these two indices to get the correspondent
            # descriptor value(point) from train dataset.
            print(match_result.imgIdx, match_result.trainIdx)
            correspondent_pt = train_dataset_pts_1xnx2[
                match_result.imgIdx, match_result.trainIdx]
            print(match_result.distance)
            print(np.linalg.norm(correspondent_pt - query_pt))

            correspondent_pts_nx2.append(correspondent_pt)

    correspondent_pts_nx2 = np.asarray(correspondent_pts_nx2)

    plt.scatter(train_dataset_pts_nx2[:, 0], train_dataset_pts_nx2[:, 1],
                s=10, label='dataset')
    plt.scatter(query_pts_nx2[:, 0], query_pts_nx2[:, 1],
                s=10, label='query')
    plt.scatter(correspondent_pts_nx2[:, 0], correspondent_pts_nx2[:, 1],
                s=10, label='correspondent')
    plt.legend(), plt.show()

关键点:

  1. 输入的dataset的shape,最好是手动调整成为一个3d array,保证明确性
  2. 输入的dataset array 必须是float32类型的
  3. 查询时的query points 必须是2d array,类型也必须是float32
 类似资料:

相关阅读

相关文章

相关问答