本文主要记录如何在python中使用opencv的flann模块进行图像点的最近邻搜索(仅作为作者自己的记录)。
FLANN 是一个用于高维空间数据的近似最近邻搜索的库,是一个经常会被用到的库。OpenCV flann 也实现了flann,其用途主要是用在特征匹配上面。这一部分可以参考 OpenCV Python tutorial Feature Matching 。下面介绍如何使用opencv-flann进行点的最近邻搜索。
这个是后面查资料时,在so上看到的。发现相比于 cv2.FlannBasedMatcher ,cv2.flann_Index 用于最近邻搜索要简单得多,也更合适。因此做一个补充,并移到前面来。
代码如下:
import cv2
import numpy as np
from matplotlib import pyplot as plt
if __name__ == '__main__':
# generate dataset
train_dataset_pts_nx2 = np.random.rand(1000, 2) * 1000
# create KD-Tree object
FLANN_INDEX_KDTREE = 1 # cv2 dost not contain FLANN_INDEX_KDTREE
index_param = {"algorithm": FLANN_INDEX_KDTREE, "trees": 1}
# generate dataset, the input must be float32
flann_index = cv2.flann_Index(np.float32(train_dataset_pts_nx2),
index_param)
# generate query points
query_pts_nx2 = np.random.rand(100, 2) * 1000
# query points must be float32 also
indices, dists = flann_index.knnSearch(np.float32(query_pts_nx2),
knn=1)
indices_nx1 = np.squeeze(indices)
dists_nx1 = np.squeeze(dists)
assert len(indices) == len(query_pts_nx2)
assert len(dists) == len(query_pts_nx2)
correspondent_pts_nx2 = train_dataset_pts_nx2[indices_nx1]
plt.scatter(train_dataset_pts_nx2[:, 0], train_dataset_pts_nx2[:, 1],
s=10, label='dataset')
plt.scatter(query_pts_nx2[:, 0], query_pts_nx2[:, 1],
s=10, label='query')
plt.scatter(correspondent_pts_nx2[:, 0], correspondent_pts_nx2[:, 1],
s=10, label='correspondent')
plt.legend(), plt.show()
关键点
点的坐标也可以当做是一种2维的描述子,因此使用flann进行点的最近邻搜索,本质上和使用flann进行特征匹配没有区别。只是要注意传入的数据的类型和结构。
先放代码(关于其中出现的一些参数的解释,参考博客 OpenCV学习笔记-FLANN匹配器):
import cv2
import numpy as np
from matplotlib import pyplot as plt
if __name__ == '__main__':
# generate dataset
train_dataset_pts_nx2 = np.random.rand(1000, 2) * 1000
# create KD-Tree object
FLANN_INDEX_KDTREE = 1
index_param = {"algorithm": FLANN_INDEX_KDTREE, "trees": 5}
search_param = {"checks": 50}
cv2_kdtree = cv2.FlannBasedMatcher(index_param, search_param)
# reshape points from nx2 --> 1xnx2, because in OpenCV-Flann,
# you can add different descriptors from different images at
# the same time. If the added array's shape is mxnxd, it indicates
# that theses descriptors computed from m images, and each image
# contain n descriptors, and each descriptor has d dimension.
# So, if you input a nx2 array, OpenCV will expand it to a
# nx1x2 array, which means the points come from n images, each
# image has only a single point.
# The input array's shape will impact the match results, so be
# careful about it.
# Another important point is that you must input a float32 array.
train_dataset_pts_1xnx2 = np.expand_dims(train_dataset_pts_nx2, 0)
# add dataset and train
cv2_kdtree.add(np.float32(train_dataset_pts_1xnx2))
cv2_kdtree.train()
# generate query points
query_pts_nx2 = np.random.rand(100, 2) * 1000
# query, also input a float32 array, and set k=1 since we want
# a NN search only.
match_results = cv2_kdtree.knnMatch(np.float32(query_pts_nx2), k=1)
# the return value is a 2-level nested list of cv2.DMatch
assert len(match_results) == len(query_pts_nx2)
assert len(match_results[0]) == 1
correspondent_pts_nx2 = []
for query_pt, nn_match_results in zip(query_pts_nx2, match_results):
for match_result in nn_match_results:
assert isinstance(match_result, cv2.DMatch)
# imgIdx is the index of first axis of the train dataset
# in shape mxnxd
# trainIdx is the index of second axis the of train dataset
# in shape mxnxd
# we can use these two indices to get the correspondent
# descriptor value(point) from train dataset.
print(match_result.imgIdx, match_result.trainIdx)
correspondent_pt = train_dataset_pts_1xnx2[
match_result.imgIdx, match_result.trainIdx]
print(match_result.distance)
print(np.linalg.norm(correspondent_pt - query_pt))
correspondent_pts_nx2.append(correspondent_pt)
correspondent_pts_nx2 = np.asarray(correspondent_pts_nx2)
plt.scatter(train_dataset_pts_nx2[:, 0], train_dataset_pts_nx2[:, 1],
s=10, label='dataset')
plt.scatter(query_pts_nx2[:, 0], query_pts_nx2[:, 1],
s=10, label='query')
plt.scatter(correspondent_pts_nx2[:, 0], correspondent_pts_nx2[:, 1],
s=10, label='correspondent')
plt.legend(), plt.show()
关键点: