请参考
1. k-d tree算法的研究
2. Python手撸机器学习系列(十一):KNN之kd树实现
完整代码: https://github.com/nnzzll/NaiveKDTree
template <typename T>
struct Point3D
{
T x, y, z;
int index; // 在点的成员里记录该点索引,方便构造KDTree
Point3D() : x(0), y(0), z(0), index(-1){};
Point3D(T a, T b, T c) : x(a), y(b), z(c), index(-1){};
Point3D(T a, T b, T c, int idx) : x(a), y(b), z(c), index(idx){};
inline T &operator[](int i) { return i == 0 ? x : i == 1 ? y
: z; };
};
template <typename T>
struct Point2D
{
T x, y;
int index;
Point2D() : x(0), y(0), index(-1){};
Point2D(T a, T b) : x(a), y(b), index(-1){};
Point2D(T a, T b, int idx) : x(a), y(b), index(idx){};
inline T &operator[](int i) { return i == 0 ? x : y; };
};
struct KDNode
{
int index; // 记录该结点保存的点的索引
int axis; // 记录该结点二分的维度
KDNode *left;
KDNode *right;
KDNode(int index, int axis, KDNode *left = nullptr, KDNode *right = nullptr)
{
this->index = index;
this->axis = axis;
this->left = left;
this->right = right;
}
};
template <class T>
class KDTree
{
private:
int ndim;
KDNode *root;
KDNode *build(std::vector<T> &);
std::set<int> visited; // 用于搜索时回溯
std::stack<KDNode *> queueNode; // 记录搜索路径
std::vector<T> m_data;
void release(KDNode *);
void printNode(KDNode *);
int chooseAxis(std::vector<T> &);
void dfs(KDNode *, T);
// 点与点之间的距离
inline double distanceT(KDNode *, T);
inline double distanceT(int, T);
// 点与超平面的距离
inline double distanceP(KDNode *, T);
// 检查父节点超平面是否在超球体中
inline bool checkParent(KDNode *, T, double);
public:
KDTree(std::vector<T> &, int);
~KDTree();
void Print();
int findNearestPoint(T);
};
template <class T>
KDTree<T>::KDTree(std::vector<T> &data, int dim)
{
ndim = dim;
m_data = data; // 拷贝一份数据
root = build(data); // 递归地构造二叉树
}
template <class T>
KDNode *KDTree<T>::build(std::vector<T> &data)
{
if (data.empty())
return nullptr;
std::vector<T> temp = data;
int mid_index = static_cast<int>(data.size() / 2); // 二分的索引
int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1
std::sort(temp.begin(), temp.end(), [axis](T a, T b)
{ return a[axis] < b[axis]; });
std::vector<T> leftData, rightData;
leftData.assign(temp.begin(), temp.begin() + mid_index);
rightData.assign(temp.begin() + mid_index + 1, temp.end());
KDNode *leftNode = build(leftData);
KDNode *rightNode = build(rightData);
KDNode *rootNode;
rootNode = new KDNode(temp[mid_index].index, axis, leftNode, rightNode);
return rootNode;
}
参考[1]
template <class T>
int KDTree<T>::findNearestPoint(T pt)
{
while (!queueNode.empty())
queueNode.pop();
double min_dist = DBL_MAX;
int resNodeIdx = -1;
dfs(root, pt);
while (!queueNode.empty())
{
KDNode *curNode = queueNode.top();
queueNode.pop();
double dist = distanceT(curNode, pt);
if (dist < min_dist)
{
min_dist = dist;
resNodeIdx = curNode->index;
}
if (!queueNode.empty())
{
KDNode *parentNode = queueNode.top();
int parentAxis = parentNode->axis;
int parentIndex = parentNode->index;
if (checkParent(parentNode, pt, min_dist))
{
if (m_data[curNode->index][parentNode->axis] < m_data[parentNode->index][parentNode->axis])
dfs(parentNode->right, pt);
else
dfs(parentNode->left, pt);
}
}
}
return resNodeIdx;
}
template <class T>
void KDTree<T>::dfs(KDNode *node, T pt)
{
if (node)
{
if (visited.find(node->index) != visited.end())
return;
queueNode.push(node);
visited.insert(node->index);
if (pt[node->axis] <= m_data[node->index][node->axis] && node->left)
dfs(node->left, pt);
else if (pt[node->axis] >= m_data[node->index][node->axis] && node->right)
dfs(node->right, pt);
// 若子树只有一个叶子节点,则不能按照维度上的值来判断进入子树的左子空间还是右子空间,会漏掉可能的近邻点
else if ((node->left == nullptr) ^ (node->right == nullptr))
{
dfs(node->left, pt);
dfs(node->right, pt);
}
}
}
与VTK官方样例ClosestNPoints进行验证对比。
点云个数在1000以内,性能差不多和VTK的KdTree相当。
int main()
{
int N = 500;
// Create some random points
vtkNew<vtkPointSource> pointSource;
pointSource->SetNumberOfPoints(N);
pointSource->Update();
std::vector<Point3D<double>> datasets;
vtkPoints *randPts = pointSource->GetOutput()->GetPoints();
for (vtkIdType i = 0; i < N; i++)
{
double pts[3];
randPts->GetPoint(i, pts);
datasets.push_back(Point3D<double>(pts[0], pts[1], pts[2], i));
// std::cout << pts[0] << "," << pts[1] << "," << pts[2] << std::endl;
}
auto t1 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
// Create the tree
vtkNew<vtkKdTreePointLocator> pointTree;
pointTree->SetDataSet(pointSource->GetOutput());
pointTree->BuildLocator();
// Find the k closest points to (0,0,0)
unsigned int k = 1;
vtkNew<vtkPointSource> testSource;
testSource->SetNumberOfPoints(1);
testSource->Update();
double testPoint[3];
testSource->GetOutput()->GetPoints()->GetPoint(0, testPoint);
vtkNew<vtkIdList> result;
std::cout << "Test Point: " << testPoint[0] << "," << testPoint[1] << "," << testPoint[2] << std::endl;
pointTree->FindClosestNPoints(k, testPoint, result);
for (vtkIdType i = 0; i < k; i++)
{
vtkIdType point_ind = result->GetId(i);
double p[3];
pointSource->GetOutput()->GetPoint(point_ind, p);
std::cout << "Closest point " << i << ": Point " << point_ind << ": ("
<< p[0] << ", " << p[1] << ", " << p[2] << ")" << std::endl;
}
auto t2 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
// Should return:
// Closest point 0: Point 2: (-0.136162, -0.0276359, 0.0369441)
// std::vector<Point2D<double>> datasets = {Point2D<double>(7, 2, 0),
// Point2D<double>(5, 4, 1),
// Point2D<double>(9, 6, 2),
// Point2D<double>(2, 3, 3),
// Point2D<double>(4, 7, 4),
// Point2D<double>(8, 1, 5)};
KDTree<Point3D<double>> tree(datasets, 3);
// tree.Print();
std::cout << tree.findNearestPoint(Point3D<double>(testPoint[0], testPoint[1], testPoint[2])) << std::endl;
auto t3 = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
std::cout << "VTK Time:" << t2 - t1 << " ms" << std::endl;
std::cout << "MY Time:" << t3 - t2 << " ms" << std::endl;
return EXIT_SUCCESS;
}
Test Point: 0.117163,-0.205549,0.352397
Closest point 0: Point 474: (0.12327, -0.22358, 0.322906)
474
VTK Time:4 ms
MY Time:2 ms