VS中需要安装Eigen包,可自行百度下载,添加至属性——包含项目即可。
#ifndef KD_TREE_H
#define KD_TREE_H
#define lson (rt << 1)//左节点
#define rson (rt << 1 | 1)//右节点
#include
#include
#include <Eigen/Dense>
#include
using std::vector;
using Eigen::MatrixXf;
const int N = 50005;
const int k = 2; //2D-tree
struct Node {
int feature[2];//feature[0] = x, feature[1] = y
static int idx;
Node(int x0, int y0) {
feature[0] = x0;
feature[1] = y0;
}
bool operator < (const Node &u) const {
return feature[idx] < u.feature[idx];
}
//TOOD =hao
Node() {
feature[0] = 0;
feature[0] = 0;
}
};
class kdTree {
public:
kdTree();
~kdTree();
void clean();
void read_in(float* ary_x, float* ary_y, int len);
void read_in(vector points_array);
void build(int l, int r, int rt, int dept);
int find_nearest_point(float x, float y, Node &res, int& dist, int boundary);
int find_nearest_k_points(float x, float y, int k, vector &res, vector &dist, int boundary);
int distance(const Node& x, const Node& y);
private:
void query(const Node& p, Node& res, int& dist, int rt, int dept);
void query_k(const Node& p, int k, vector &res, vector &dist, int rt, int dept);
vector _data;//用vector模拟数组
vector _flag;//判断是否存在 1, -1, 0
int _idx;
vector _find_nth;
};
#endif
#include "kdTree.h"
int Node::idx = 0;
kdTree::kdTree()
{
_data.reserve(N * 4);
_flag.reserve(N * 4);//TODO init
}
kdTree::~kdTree()
{
}
void kdTree::read_in(float* ary_x, float* ary_y, int len)
{
_find_nth.reserve(N * 4);
for (int i = 0; i < len; i++) {
_find_nth.push_back(Node(ary_x[i], ary_y[i]));
}
for (int i = 0; i < N * 4; i++) {
Node tmp;
_data.push_back(tmp);
_flag.push_back(0);
}
build(0, len - 1, 1, 0);
}
void kdTree::read_in(vector<MatrixXf> points_array)
{
int len = points_array.size();
_find_nth.reserve(N * 4);
for (int i = 0; i < len; i++) {
_find_nth.push_back(Node(int(points_array[i](0, 0)), int(points_array[i](1, 0))));
}
for (int i = 0; i < N * 4; i++) {
Node tmp;
_data.push_back(tmp);
_flag.push_back(0);
}
build(0, len - 1, 1, 0);
}
void kdTree::clean()
{
_find_nth.clear();
_data.clear();
_flag.clear();
}
//建立kd-tree
void kdTree::build(int l, int r, int rt, int dept)
{
if (l > r) return;
_flag[rt] = 1; //表示标号为rt的节点存在
_flag[lson] = _flag[rson] = -1; //当前节点的孩子暂时标记不存在
int mid = (l + r + 1) >> 1;
Node::idx = dept % k; //按照编号为idx的属性进行划分
std::nth_element(_find_nth.begin() + l, _find_nth.begin() + mid, _find_nth.begin() + r + 1);
_data[rt] = _find_nth[mid];
build(l, mid - 1, lson, dept + 1); //递归左子树
build(mid + 1, r, rson, dept + 1);
}
int kdTree::find_nearest_point(float x, float y, Node &res, int& dist, int boundary)
{
Node p(x, y);
dist = boundary * boundary;
query(p, res, dist, 1, 0);
return 1;
}
int kdTree::find_nearest_k_points(float x, float y, int k, vector<Node>& res, vector<int>& dist, int boundary)
{
vector<int> idx;
idx.clear();
res.clear();
dist.clear();
Node p(x, y);
for (int i = 0; i < k; i++) {
idx.push_back(-1);
dist.push_back(boundary * boundary);
query_k(p, i, idx, dist, 1, 0);
if (idx[i] < 0) break;
_flag[idx[i]] = 0;
res.push_back(_data[idx[i]]);
}
int kfind = res.size();
for (int i = 0; i < kfind; i++) {
_flag[idx[i]] = 1;
}
return kfind;
}
//查找kd-tree距离p最近的点
void kdTree::query(const Node& p, Node& res, int& dist, int rt, int dept)
{
if (_flag[rt] == -1) return; //不存在的节点不遍历
int tmp_dist = distance(_data[rt], p);
bool fg = false; //用于标记是否需要遍历右子树
int dim = dept % k; //和建树一样, 保证相同节点的dim值不变
int x = lson;
int y = rson;
if (p.feature[dim] >= _data[rt].feature[dim]) std::swap(x, y); //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
if (x < _flag.size() && _flag[x] != -1) query(p, res, dist, x, dept + 1); //节点x存在, 则进入子树继续遍历
if (tmp_dist < dist) { //如果找到更小的距离, 则替换目前的结果dist
res = _data[rt];
dist = tmp_dist;
}
tmp_dist = (p.feature[dim] - _data[rt].feature[dim]) * (p.feature[dim] - _data[rt].feature[dim]);
if (tmp_dist < dist) fg = true; //还需要继续回溯
if (y < _flag.size() && _flag[y] != -1 && fg) query(p, res, dist, y, dept + 1);
}
void kdTree::query_k(const Node & p, int i, vector<int>& res, vector<int>& dist, int rt, int dept)
{
if (_flag[rt] == -1) return; //不存在的节点不遍历
int tmp_dist = distance(_data[rt], p);
bool fg = false; //用于标记是否需要遍历右子树
int dim = dept % k; //和建树一样, 保证相同节点的dim值不变
int x = lson;
int y = rson;
if (p.feature[dim] >= _data[rt].feature[dim]) std::swap(x, y); //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
if (x<_flag.size() && _flag[x]==1) query_k(p, i, res, dist, x, dept + 1); //节点x存在, 则进入子树继续遍历
if (tmp_dist < dist[i] && _flag[rt] == 1) { //如果找到更小的距离, 则替换目前的结果dist
res[i] = rt;
dist[i] = tmp_dist;
}
tmp_dist = (p.feature[dim] - _data[rt].feature[dim]) * (p.feature[dim] - _data[rt].feature[dim]);
if (tmp_dist < dist[i]) fg = true; //还需要继续回溯
if (y < _flag.size() && _flag[y] == 1 && fg) query_k(p, i, res, dist, y, dept + 1);
}
//计算两点间的距离的平方
int kdTree::distance(const Node& x, const Node& y)
{
int res = 0;
for (int i = 0; i < k; i++)
{
res += (x.feature[i] - y.feature[i]) * (x.feature[i] - y.feature[i]);
}
return res;
}