当前位置: 首页 > 工具软件 > kdtree > 使用案例 >

KDTree的C++实现

姚信鸥
2023-12-01

KDTree原理:

请参考
1. k-d tree算法的研究
2. Python手撸机器学习系列(十一):KNN之kd树实现
完整代码: https://github.com/nnzzll/NaiveKDTree

C++实现

点的结构

template <typename T>
struct Point3D
{
    T x, y, z; 
    int index; // 在点的成员里记录该点索引,方便构造KDTree
    Point3D() : x(0), y(0), z(0), index(-1){};
    Point3D(T a, T b, T c) : x(a), y(b), z(c), index(-1){};
    Point3D(T a, T b, T c, int idx) : x(a), y(b), z(c), index(idx){};
    inline T &operator[](int i) { return i == 0 ? x : i == 1 ? y
                                                             : z; };
};

template <typename T>
struct Point2D
{
    T x, y;
    int index;
    Point2D() : x(0), y(0), index(-1){};
    Point2D(T a, T b) : x(a), y(b), index(-1){};
    Point2D(T a, T b, int idx) : x(a), y(b), index(idx){};

    inline T &operator[](int i) { return i == 0 ? x : y; };
};

KDTree结点的结构

struct KDNode
{
	int index; // 记录该结点保存的点的索引
	int axis; // 记录该结点二分的维度
	KDNode *left;
	KDNode *right;
	KDNode(int index, int axis, KDNode *left = nullptr, KDNode *right = nullptr)
	{
		this->index = index;
		this->axis = axis;
		this->left = left;
		this->right = right;
	}
};

KDTree的结构

template <class T>
class KDTree
{
private:
	int ndim;
	KDNode *root;
	KDNode *build(std::vector<T> &);
	std::set<int> visited; // 用于搜索时回溯
	std::stack<KDNode *> queueNode; // 记录搜索路径
	std::vector<T> m_data;

	void release(KDNode *);
	void printNode(KDNode *);
	int chooseAxis(std::vector<T> &);
	void dfs(KDNode *, T);
	// 点与点之间的距离
	inline double distanceT(KDNode *, T);
	inline double distanceT(int, T);
	// 点与超平面的距离
	inline double distanceP(KDNode *, T);
	// 检查父节点超平面是否在超球体中
	inline bool checkParent(KDNode *, T, double);

public:
	KDTree(std::vector<T> &, int);
	~KDTree();
	void Print();
	int findNearestPoint(T);
};

KDTree的构造函数

template <class T>
KDTree<T>::KDTree(std::vector<T> &data, int dim)
{
	ndim = dim;
	m_data = data; // 拷贝一份数据
	root = build(data); // 递归地构造二叉树
}

template <class T>
KDNode *KDTree<T>::build(std::vector<T> &data)
{
	if (data.empty())
		return nullptr;
	std::vector<T> temp = data;
	int mid_index = static_cast<int>(data.size() / 2); // 二分的索引
	int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1
	std::sort(temp.begin(), temp.end(), [axis](T a, T b)
			  { return a[axis] < b[axis]; });
			  
	std::vector<T> leftData, rightData;
	leftData.assign(temp.begin(), temp.begin() + mid_index);
	rightData.assign(temp.begin() + mid_index + 1, temp.end());
	
	KDNode *leftNode = build(leftData);
	KDNode *rightNode = build(rightData);
	KDNode *rootNode;
	rootNode = new KDNode(temp[mid_index].index, axis, leftNode, rightNode);
	return rootNode;
}

最近邻搜索

参考[1]

template <class T>
int KDTree<T>::findNearestPoint(T pt)
{
	while (!queueNode.empty())
		queueNode.pop();
	double min_dist = DBL_MAX;
	int resNodeIdx = -1;
	dfs(root, pt);
	while (!queueNode.empty())
	{
		KDNode *curNode = queueNode.top();
		queueNode.pop();
		double dist = distanceT(curNode, pt);
		if (dist < min_dist)
		{
			min_dist = dist;
			resNodeIdx = curNode->index;
		}

		if (!queueNode.empty())
		{
			KDNode *parentNode = queueNode.top();
			int parentAxis = parentNode->axis;
			int parentIndex = parentNode->index;
			if (checkParent(parentNode, pt, min_dist))
			{
				if (m_data[curNode->index][parentNode->axis] < m_data[parentNode->index][parentNode->axis])
					dfs(parentNode->right, pt);
				else
					dfs(parentNode->left, pt);
			}
		}
	}
	return resNodeIdx;
}

template <class T>
void KDTree<T>::dfs(KDNode *node, T pt)
{
	if (node)
	{
		if (visited.find(node->index) != visited.end())
			return;
		queueNode.push(node);
		visited.insert(node->index);
		if (pt[node->axis] <= m_data[node->index][node->axis] && node->left)
			dfs(node->left, pt);
		else if (pt[node->axis] >= m_data[node->index][node->axis] && node->right)
			dfs(node->right, pt);
		// 若子树只有一个叶子节点,则不能按照维度上的值来判断进入子树的左子空间还是右子空间,会漏掉可能的近邻点
		else if ((node->left == nullptr) ^ (node->right == nullptr))
		{
			dfs(node->left, pt);
			dfs(node->right, pt);
		}
	}
}

测试

与VTK官方样例ClosestNPoints进行验证对比。
点云个数在1000以内,性能差不多和VTK的KdTree相当。

int main()
{
    int N = 500;
    // Create some random points
    vtkNew<vtkPointSource> pointSource;
    pointSource->SetNumberOfPoints(N);
    pointSource->Update();

    std::vector<Point3D<double>> datasets;
    vtkPoints *randPts = pointSource->GetOutput()->GetPoints();
    for (vtkIdType i = 0; i < N; i++)
    {
        double pts[3];
        randPts->GetPoint(i, pts);
        datasets.push_back(Point3D<double>(pts[0], pts[1], pts[2], i));
        // std::cout << pts[0] << "," << pts[1] << "," << pts[2] << std::endl;
    }

    auto t1 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    
    // Create the tree
    vtkNew<vtkKdTreePointLocator> pointTree;
    pointTree->SetDataSet(pointSource->GetOutput());
    pointTree->BuildLocator();

    // Find the k closest points to (0,0,0)
    unsigned int k = 1;
    vtkNew<vtkPointSource> testSource;
    testSource->SetNumberOfPoints(1);
    testSource->Update();
    double testPoint[3];
    testSource->GetOutput()->GetPoints()->GetPoint(0, testPoint);
    vtkNew<vtkIdList> result;
    std::cout << "Test Point: " << testPoint[0] << "," << testPoint[1] << "," << testPoint[2] << std::endl;

    pointTree->FindClosestNPoints(k, testPoint, result);

    for (vtkIdType i = 0; i < k; i++)
    {
        vtkIdType point_ind = result->GetId(i);
        double p[3];
        pointSource->GetOutput()->GetPoint(point_ind, p);
        std::cout << "Closest point " << i << ": Point " << point_ind << ": ("
                  << p[0] << ", " << p[1] << ", " << p[2] << ")" << std::endl;
    }
    auto t2 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();

    // Should return:
    // Closest point 0: Point 2: (-0.136162, -0.0276359, 0.0369441)

    // std::vector<Point2D<double>> datasets = {Point2D<double>(7, 2, 0),
    //                                       Point2D<double>(5, 4, 1),
    //                                       Point2D<double>(9, 6, 2),
    //                                       Point2D<double>(2, 3, 3),
    //                                       Point2D<double>(4, 7, 4),
    //                                       Point2D<double>(8, 1, 5)};
    KDTree<Point3D<double>> tree(datasets, 3);
    // tree.Print();
    std::cout << tree.findNearestPoint(Point3D<double>(testPoint[0], testPoint[1], testPoint[2])) << std::endl;
    auto t3 = std::chrono::duration_cast<std::chrono::milliseconds>(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    std::cout << "VTK Time:" << t2 - t1 << " ms" << std::endl;
    std::cout << "MY Time:" << t3 - t2 << " ms" << std::endl;
    return EXIT_SUCCESS;
}
Test Point: 0.117163,-0.205549,0.352397
Closest point 0: Point 474: (0.12327, -0.22358, 0.322906)
474
VTK Time:4 ms
MY Time:2 ms
 类似资料: