kd-treeを実装してみた
はじめに
kd-treeを実装してみました
最近仕事でよく使うので勉強がてら
ソースコード
以下に公開してあります
github.com
kd-treeの構築と、以下の探索機能を実装してます
- 最近傍探索 (Nearest neighbor search)
- k-最近傍探索 (K-nearest neighbor search)
- 半径内に含まれる近傍の探索 (Radius search)
あと、ヘッダ1個includeするだけで使えるのでお手軽です
アルゴリズム
kd-treeの構築
wikipediaに載ってる疑似コードが分かりやすかったので引用させていただきます
function kdtree (list of points pointList, int depth) { if pointList is empty return nil; else { // 深さに応じて軸を選択し、軸が順次選択されるようにする var int axis := depth mod k; // 点のリストをソートし、中央値の点を選択する select median from pointList; // ノードを作成し、部分木を構築する var tree_node node; node.location := median; node.leftChild := kdtree(points in pointList before median, depth+1); node.rightChild := kdtree(points in pointList after median, depth+1); return node; } }
一方私の書いたコードはこんな↓感じです
Node* buildRecursive(int* indices, int npoints, int depth) { if (npoints <= 0) return nullptr; const int axis = depth % PointT::DIM; const int mid = (npoints - 1) / 2; std::nth_element(indices, indices + mid, indices + npoints, [&](int lhs, int rhs) { return points_[lhs][axis] < points_[rhs][axis]; }); Node* node = new Node(); node->idx = indices[mid]; node->axis = axis; node->next[0] = buildRecursive(indices, mid, depth + 1); node->next[1] = buildRecursive(indices + mid + 1, npoints - mid - 1, depth + 1); return node; }
疑似コードとほとんど同じです
私の場合は点そのものではなく、点へのインデックスをノードに保持しています
(中央値を取得するためにnth_elementを初めて使った…)
最近傍探索
k-d treeが構築できたらlet's最近傍探索
ググってヒットしたこちらの資料をもとに実装しました
(URLを見るとスタンフォード大の宿題らしい)
https://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf
原理は2分探索と同じで、「探している値は木の半分のこっち側にあるよね」
ってのを繰り返しながら左か右の子ノードに降りていきます
その過程で現在ノードとクエリとの距離を計算し最近傍の値を更新します
葉までたどり着いたら、自分の兄弟ノード(親から見て自分が左だったら、右のノード)の方に
最近傍の可能性がないかチェックして、ある場合は兄弟ノードの方も探索します
void nnSearchRecursive(const PointT& query, const Node* node, int *guess, double *minDist) const { if (node == nullptr) return; const PointT& train = points_[node->idx]; const double dist = distance(query, train); if (dist < *minDist) { *minDist = dist; *guess = node->idx; } const int axis = node->axis; const int dir = query[axis] < train[axis] ? 0 : 1; nnSearchRecursive(query, node->next[dir], guess, minDist); const double diff = fabs(query[axis] - train[axis]); if (diff < *minDist) nnSearchRecursive(query, node->next[!dir], guess, minDist); }
K-近傍探索や半径探索も探索の流れは同じです
デモ
最近傍探索/k-最近傍探索/半径探索を実行した結果です
赤い点がクエリ(画像の中心座標)、青い点が近傍点になります