merge with the main branch

metadata
Wenzel Jakob 2011-02-10 11:45:01 +01:00
commit 2867546133
3 changed files with 427 additions and 58 deletions

View File

@ -20,6 +20,7 @@
#define __KDTREE_H
#include <mitsuba/core/aabb.h>
#include <boost/foreach.hpp>
MTS_NAMESPACE_BEGIN
@ -30,33 +31,37 @@ MTS_NAMESPACE_BEGIN
* \tparam DataRecord Custom payload to be attached to each node
*/
template <typename PointType, typename DataRecord> struct BasicKDNode {
typedef PointType point_type;
PointType position;
BasicKDNode *left, *right;
DataRecord data;
uint8_t axis;
uint32_t right;
uint16_t flags;
uint16_t axis;
DataRecord value;
/// Initialize a KD-tree node
inline BasicKDNode() : position((Float) 0),
right(0), flags(0), axis(0), value() { }
/// Initialize a KD-tree node with the given data record
inline BasicKDNode(const DataRecord &data) : position(0),
left(NULL), right(NULL), data(data), axis(0) { }
inline BasicKDNode(const DataRecord &value) : position((Float) 0),
right(0), flags(0), axis(0), value(value) { }
/// Return a pointer to the left child of this node
inline BasicKDNode *getLeft() { return left; }
/// Return a pointer to the left child of this node (const version)
inline const BasicKDNode *getLeft() const { return left; }
/// Set the left child of this node
inline void setLeft(BasicKDNode *node) { left = node; }
/// Return the index of the right child of this node
inline uint32_t getRightIndex() { return right; }
/// Return the index of the right child of this node (const version)
inline const uint32_t getRightIndex() const { return right; }
/// Set the right child index of this node
inline void setRightIndex(uint32_t node) { right = node; }
/// Return a pointer to the right child of this node
inline BasicKDNode *getRight() { return right; }
/// Return a pointer to the right child of this node (const version)
inline const BasicKDNode *getRight() const { return right; }
/// Set the right child of this node
inline void setRight(BasicKDNode *node) { right = node; }
/// Check whether this is a leaf node
inline bool isLeaf() const { return flags & 1; }
/// Specify whether this is a leaf node
inline void setLeaf(bool value) { if (value) flags |= 1; else flags &= ~1; }
/// Return the split axis associated with this node
inline uint8_t getAxis() const { return axis; }
inline uint16_t getAxis() const { return axis; }
/// Set the split axis associated with this node
inline void setAxis(uint8_t value) { axis = value; }
inline void setAxis(uint16_t value) { axis = value; }
/// Return the position associated with this node
inline const PointType &getPosition() const { return position; }
@ -64,44 +69,102 @@ template <typename PointType, typename DataRecord> struct BasicKDNode {
inline void setPosition(const PointType &value) { position = value; }
/// Return the data record associated with this node
inline DataRecord &getData() { return data; }
inline DataRecord &getValue() { return value; }
/// Return the data record associated with this node (const version)
inline const DataRecord &getData() const { return data; }
inline const DataRecord &getValue() const { return value; }
/// Set the data record associated with this node
inline void setData(const DataRecord &value) { data = value; }
inline void setValue(const DataRecord &val) { value = val; }
};
/**
* \brief Generic multi-dimensional kd-tree data structure for point data
* using the sliding midpoint tree construction rule. This ensures that
* cells do not become overly elongated.
*
* Organizes a list of point data in a hierarchical manner. For data
* with spatial extents, \ref GenericKDTree and \ref ShapeKDTree will be
* more appropriate.
*
* \tparam PointType Underlying point data type (e.g. \ref TPoint3<float>)
*
* \tparam KDNode Underlying node data structure. See \ref BasicKDNode as
* an example for the required public interface
*/
template <typename PointType, typename KDNode> class TKDTree {
typedef typename PointType::value_type value_type;
typedef typename PointType::vector_type vector_type;
typedef TAABB<PointType> aabb_type;
template <typename KDNode> class TKDTree {
public:
typedef typename KDNode::point_type point_type;
typedef typename point_type::value_type value_type;
typedef typename point_type::vector_type vector_type;
typedef TAABB<point_type> aabb_type;
/// Supported tree construction heuristics
enum EHeuristic {
/// Create a balanced tree by splitting along the median
EBalanced = 0,
/// Create a left-balanced tree
ELeftBalanced,
/**
* \brief Use the sliding midpoint tree construction rule. This
* ensures that cells do not become overly elongated.
*/
ESlidingMidpoint,
/**
* \brief Choose the split plane by optimizing a cost heuristic
* based on the ratio of voxel volumes. Note that the implementation
* here is not particularly optimized, and it furthermore runs in
* time O(n (log n)^2) instead of O(n log n)
*/
EVoxelVolume
};
/// Result data type for k-nn queries
struct SearchResult {
Float distSquared;
uint32_t index;
inline SearchResult(Float distSquared, uint32_t index)
: distSquared(distSquared), index(index) { }
std::string toString() const {
std::ostringstream oss;
oss << "SearchResult[distance=" << std::sqrt(distSquared)
<< ", index=" << index << "]";
return oss.str();
}
inline bool operator==(const SearchResult &r) const {
return distSquared == r.distSquared &&
index == r.index;
}
};
/// Comparison functor for nearest-neighbor search queries
struct SearchResultComparator : public
std::binary_function<SearchResult, SearchResult, bool> {
public:
inline bool operator()(const SearchResult &a, const SearchResult &b) const {
return a.distSquared < b.distSquared;
}
};
public:
/**
* \brief Create an empty KD-tree that can hold the specified
* number of points
*/
inline TKDTree(size_t nodes) : m_nodes(nodes) {}
inline TKDTree(size_t nodes, EHeuristic heuristic = ESlidingMidpoint)
: m_nodes(nodes), m_heuristic(heuristic), m_depth(0) { }
/// Return one of the KD-tree nodes by index (const version)
inline const KDNode &operator[](size_t idx) const { return m_nodes[idx]; }
/// Return one of the KD-tree nodes by index
inline KDNode &operator[](size_t idx) { return m_nodes[idx]; }
/// Return one of the KD-tree nodes by index (const version)
inline const KDNode &operator[](size_t idx) const { return m_nodes[idx]; }
/// Return the AABB of the underlying point data
inline const aabb_type &getAABB() const { return m_aabb; }
/// Return the depth of the constructed KD-tree
inline size_t getDepth() const { return m_depth; }
/// Construct the KD-tree hierarchy
void build() {
m_aabb.reset();
@ -109,8 +172,161 @@ template <typename PointType, typename KDNode> class TKDTree {
m_aabb.expandBy(node.getPosition());
}
build(m_nodes.begin(), m_nodes.end());
m_depth = 0;
build(1, m_nodes.begin(), m_nodes.end());
}
/**
* \brief Run a k-nearest-neighbor search query
*
* \param p Search position
* \param k Maximum number of search results
* \param result Index list of search results
* \param searchRadius Maximum search radius (this can be used to
* restrict the knn query to a subset of the data)
* \return The number of used traversal steps
*/
size_t nnSearch(const point_type &p, size_t k, std::vector<SearchResult> &results,
Float searchRadius = std::numeric_limits<Float>::infinity()) const {
uint32_t *stack = (uint32_t *) alloca((m_depth+1) * sizeof(uint32_t));
size_t index = 0, stackPos = 1, traversalSteps = 0;
bool isHeap = false;
Float distSquared = searchRadius*searchRadius;
stack[0] = 0;
results.clear();
results.reserve(k+1);
while (stackPos > 0) {
const KDNode &node = m_nodes[index];
++traversalSteps;
int nextIndex;
/* Recurse on inner nodes */
if (!node.isLeaf()) {
Float distToPlane = p[node.getAxis()]
- node.getPosition()[node.getAxis()];
uint32_t first, second;
bool searchBoth = distToPlane*distToPlane <= distSquared;
if (distToPlane > 0) {
first = node.getRightIndex();
second = searchBoth ? index+1 : 0;
} else {
first = index+1;
second = searchBoth ? node.getRightIndex() : 0;
}
if (first != 0 && second != 0) {
nextIndex = first;
stack[stackPos++] = second;
} else if (first != 0) {
nextIndex = first;
} else if (second != 0) {
nextIndex = second;
} else {
nextIndex = stack[--stackPos];
}
} else {
nextIndex = stack[--stackPos];
}
/* Check if the current point is within the query's search radius */
const Float pointDistSquared = (node.getPosition() - p).lengthSquared();
if (pointDistSquared < distSquared) {
/* Switch to a max-heap when the available search
result space is exhausted */
if (results.size() < k) {
/* There is still room, just add the point to
the search result list */
results.push_back(SearchResult(pointDistSquared, index));
} else {
if (!isHeap) {
/* Establish the max-heap property */
std::make_heap(results.begin(), results.end(),
SearchResultComparator());
isHeap = true;
}
/* Add the new point, remove the one that is farthest away */
results.push_back(SearchResult(pointDistSquared, index));
std::push_heap(results.begin(), results.end(), SearchResultComparator());
std::pop_heap(results.begin(), results.end(), SearchResultComparator());
results.pop_back();
/* Reduce the search radius accordingly */
distSquared = results[0].distSquared;
}
}
index = nextIndex;
}
return traversalSteps;
}
/**
* \brief Run a search query
*
* \param p Search position
* \param result Index list of search results
* \param searchRadius Search radius
* \return The number of used traversal steps
*/
size_t search(const point_type &p, Float searchRadius, std::vector<SearchResult> &results) const {
uint32_t *stack = (uint32_t *) alloca((m_depth+1) * sizeof(uint32_t));
size_t index = 0, stackPos = 1, traversalSteps = 0;
Float distSquared = searchRadius*searchRadius;
stack[0] = 0;
results.clear();
while (stackPos > 0) {
const KDNode &node = m_nodes[index];
++traversalSteps;
int nextIndex;
/* Recurse on inner nodes */
if (!node.isLeaf()) {
Float distToPlane = p[node.getAxis()]
- node.getPosition()[node.getAxis()];
uint32_t first, second;
bool searchBoth = distToPlane*distToPlane <= distSquared;
if (distToPlane > 0) {
first = node.getRightIndex();
second = searchBoth ? index+1 : 0;
} else {
first = index+1;
second = searchBoth ? node.getRightIndex() : 0;
}
if (first != 0 && second != 0) {
nextIndex = first;
stack[stackPos++] = second;
} else if (first != 0) {
nextIndex = first;
} else if (second != 0) {
nextIndex = second;
} else {
nextIndex = stack[--stackPos];
}
} else {
nextIndex = stack[--stackPos];
}
/* Check if the current point is within the query's search radius */
const Float pointDistSquared = (node.getPosition() - p).lengthSquared();
if (pointDistSquared < distSquared)
results.push_back(SearchResult(pointDistSquared, index));
index = nextIndex;
}
return traversalSteps;
}
protected:
struct CoordinateOrdering : public std::binary_function<KDNode, KDNode, bool> {
public:
@ -133,46 +349,137 @@ protected:
value_type m_value;
};
KDNode *build(typename std::vector<KDNode>::iterator rangeStart,
typename std::vector<KDNode>::iterator rangeEnd) {
SAssert(rangeEnd > rangeStart);
if (rangeEnd-rangeStart <= 1) {
void build(size_t depth,
typename std::vector<KDNode>::iterator rangeStart,
typename std::vector<KDNode>::iterator rangeEnd) {
m_depth = std::max(depth, m_depth);
if (rangeEnd-rangeStart <= 0) {
SLog(EError, "Internal error!");
} else if (rangeEnd-rangeStart == 1) {
/* Create a leaf node */
rangeStart->setLeaf(true);
return;
}
int axis = 0;
typename std::vector<KDNode>::iterator split;
switch (m_heuristic) {
case EBalanced: {
/* Split along the median */
split = rangeStart + (rangeEnd-rangeStart)/2;
axis = m_aabb.getLargestAxis();
std::nth_element(rangeStart, split, rangeEnd, CoordinateOrdering(axis));
};
break;
case ELeftBalanced: {
size_t treeSize = rangeEnd-rangeStart;
/* Layer 0 contains one node */
size_t p = 1;
/* Traverse downwards until the first incompletely
filled tree level is encountered */
while (2*p <= treeSize)
p *= 2;
/* Calculate the number of filled slots in the last level */
size_t remaining = treeSize - p + 1;
if (2*remaining < p) {
/* Case 2: The last level contains too few nodes. Remove
overestimate from the left subtree node count and add
the remaining nodes */
p = (p >> 1) + remaining;
}
axis = m_aabb.getLargestAxis();
split = rangeStart + (p - 1);
std::nth_element(rangeStart, split, rangeEnd,
CoordinateOrdering(axis));
};
break;
case ESlidingMidpoint: {
/* Sliding midpoint rule: find a split that is close to the spatial median */
axis = m_aabb.getLargestAxis();
value_type midpoint = (value_type) 0.5f
* (m_aabb.max[axis]+m_aabb.min[axis]);
size_t nLT = std::count_if(rangeStart, rangeEnd,
LessThanOrEqual(axis, midpoint));
/* Re-adjust the split to pass through a nearby point */
split = rangeStart + nLT;
if (split == rangeStart)
++split;
else if (split == rangeEnd)
--split;
std::nth_element(rangeStart, split, rangeEnd,
CoordinateOrdering(axis));
};
break;
/* Find a split that is close to the spatial median */
int axis = m_aabb.getLargestAxis();
value_type midpoint = (value_type) 0.5f * (m_aabb.max[axis]+m_aabb.min[axis]);
case EVoxelVolume: {
Float bestCost = std::numeric_limits<Float>::infinity();
size_t nLT = std::count_if(rangeStart, rangeEnd,
LessThanOrEqual(axis, midpoint));
for (int dim=0; dim<point_type::dim; ++dim) {
std::sort(rangeStart, rangeEnd, CoordinateOrdering(dim));
size_t numLeft = 1, numRight = rangeEnd-rangeStart-2;
aabb_type leftAABB(m_aabb), rightAABB(m_aabb);
Float invVolume = 1.0f / m_aabb.getVolume();
for (typename std::vector<KDNode>::iterator it = rangeStart+1; it != rangeEnd; ++it) {
++numLeft; --numRight;
leftAABB.max[dim] = it->getPosition()[dim];
rightAABB.min[dim] = it->getPosition()[dim];
Float cost = (numLeft * leftAABB.getVolume()
+ numRight * rightAABB.getVolume()) * invVolume;
if (cost < bestCost) {
bestCost = cost;
axis = dim;
split = it;
}
}
}
std::nth_element(rangeStart, split, rangeEnd,
CoordinateOrdering(axis));
};
break;
}
/* Re-adjust the split to pass through a nearby photon */
typename std::vector<KDNode>::iterator split = rangeStart + nLT;
std::nth_element(rangeStart, split, rangeEnd,
CoordinateOrdering(axis));
value_type splitPos = split->getPosition()[axis];
split->setAxis(axis);
if (split+1 != rangeEnd)
split->setRightIndex((uint32_t) (split + 1 - m_nodes.begin()));
else
split->setRightIndex(0);
split->setLeaf(false);
std::iter_swap(rangeStart, split);
/* Recursively build the children */
value_type temp = m_aabb.max[axis];
m_aabb.max[axis] = splitPos;
split->setLeft(build(rangeStart, split));
build(depth+1, rangeStart+1, split+1);
m_aabb.max[axis] = temp;
temp = m_aabb.min[axis];
m_aabb.min[axis] = splitPos;
split->setRight(build(split+1, rangeEnd));
m_aabb.min[axis] = temp;
return split;
if (split+1 != rangeEnd) {
temp = m_aabb.min[axis];
m_aabb.min[axis] = splitPos;
build(depth+1, split+1, rangeEnd);
m_aabb.min[axis] = temp;
}
}
protected:
std::vector<KDNode> m_nodes;
aabb_type m_aabb;
EHeuristic m_heuristic;
size_t m_depth;
};
MTS_NAMESPACE_END

View File

@ -397,7 +397,7 @@ protected:
std::binary_function<search_result, search_result, bool> {
public:
inline bool operator()(search_result &a, search_result &b) const {
return a.first <= b.first;
return a.first < b.first;
}
};

View File

@ -17,8 +17,9 @@
*/
#include <mitsuba/core/plugin.h>
#include <mitsuba/core/kdtree.h>
#include <mitsuba/render/testcase.h>
#include <mitsuba/render/gkdtree.h>
#include <mitsuba/render/skdtree.h>
MTS_NAMESPACE_BEGIN
@ -27,6 +28,7 @@ public:
MTS_BEGIN_TESTCASE()
MTS_DECLARE_TEST(test01_sutherlandHodgman)
MTS_DECLARE_TEST(test02_bunnyBenchmark)
MTS_DECLARE_TEST(test03_pointKDTree)
MTS_END_TESTCASE()
void test01_sutherlandHodgman() {
@ -125,6 +127,66 @@ public:
Log(EInfo, "");
}
}
void test03_pointKDTree() {
typedef TKDTree< BasicKDNode<Point2, Float> > KDTree2;
size_t nPoints = 50000, nTries = 20;
ref<Random> random = new Random();
for (int heuristic=0; heuristic<4; ++heuristic) {
KDTree2 kdtree(nPoints, (KDTree2::EHeuristic) heuristic);
for (size_t i=0; i<nPoints; ++i) {
kdtree[i].setPosition(Point2(random->nextFloat(), random->nextFloat()));
kdtree[i].setValue(random->nextFloat());
}
std::vector<KDTree2::SearchResult> results, resultsBF;
if (heuristic == 0) {
Log(EInfo, "Testing the balanced kd-tree construction heuristic");
} else if (heuristic == 1) {
Log(EInfo, "Testing the left-balanced kd-tree construction heuristic");
} else if (heuristic == 2) {
Log(EInfo, "Testing the sliding midpoint kd-tree construction heuristic");
} else if (heuristic == 3) {
Log(EInfo, "Testing the voxel volume kd-tree construction heuristic");
}
ref<Timer> timer = new Timer();
kdtree.build();
Log(EInfo, "Construction time = %i ms, depth = %i", timer->getMilliseconds(), kdtree.getDepth());
for (int k=1; k<=10; ++k) {
size_t nTraversals = 0;
for (size_t it = 0; it < nTries; ++it) {
Point2 p(random->nextFloat(), random->nextFloat());
nTraversals += kdtree.nnSearch(p, k, results);
resultsBF.clear();
for (size_t j=0; j<nPoints; ++j)
resultsBF.push_back(KDTree2::SearchResult((kdtree[j].getPosition()-p).lengthSquared(), j));
std::sort(results.begin(), results.end(), KDTree2::SearchResultComparator());
std::sort(resultsBF.begin(), resultsBF.end(), KDTree2::SearchResultComparator());
for (int j=0; j<k; ++j)
assertTrue(results[j] == resultsBF[j]);
}
Log(EInfo, "Average number of traversals for a %i-nn query = " SIZE_T_FMT, k, nTraversals / nTries);
}
size_t nTraversals = 0;
for (size_t it = 0; it < nTries; ++it) {
Point2 p(random->nextFloat(), random->nextFloat());
nTraversals += kdtree.search(p, 0.05, results);
resultsBF.clear();
for (size_t j=0; j<nPoints; ++j)
resultsBF.push_back(KDTree2::SearchResult((kdtree[j].getPosition()-p).lengthSquared(), j));
std::sort(results.begin(), results.end(), KDTree2::SearchResultComparator());
std::sort(resultsBF.begin(), resultsBF.end(), KDTree2::SearchResultComparator());
for (size_t j=0; j<results.size(); ++j)
assertTrue(results[j] == resultsBF[j]);
}
Log(EInfo, "Average number of traversals for a radius=0.05 search query = " SIZE_T_FMT, nTraversals / nTries);
}
}
};
MTS_EXPORT_TESTCASE(TestKDTree, "Testcase for kd-tree related code")