#ifndef __INTERVAL_TREE_H #define __INTERVAL_TREE_H #include #include #include #include #include #include #ifdef USE_INTERVAL_TREE_NAMESPACE namespace interval_tree { #endif template class Interval { public: Scalar start; Scalar stop; Value value; Interval(const Scalar& s, const Scalar& e, const Value& v) : start(std::min(s, e)) , stop(std::max(s, e)) , value(v) {} }; template Value intervalStart(const Interval& i) { return i.start; } template Value intervalStop(const Interval& i) { return i.stop; } template std::ostream& operator<<(std::ostream& out, const Interval& i) { out << "Interval(" << i.start << ", " << i.stop << "): " << i.value; return out; } template class IntervalTree { public: typedef Interval interval; typedef std::vector interval_vector; struct IntervalStartCmp { bool operator()(const interval& a, const interval& b) { return a.start < b.start; } }; struct IntervalStopCmp { bool operator()(const interval& a, const interval& b) { return a.stop < b.stop; } }; IntervalTree() : left(nullptr) , right(nullptr) , center(0) {} ~IntervalTree() = default; std::unique_ptr clone() const { return std::unique_ptr(new IntervalTree(*this)); } IntervalTree(const IntervalTree& other) : intervals(other.intervals), left(other.left ? other.left->clone() : nullptr), right(other.right ? other.right->clone() : nullptr), center(other.center) {} IntervalTree& operator=(IntervalTree&&) = default; IntervalTree(IntervalTree&&) = default; IntervalTree& operator=(const IntervalTree& other) { center = other.center; intervals = other.intervals; left = other.left ? other.left->clone() : nullptr; right = other.right ? other.right->clone() : nullptr; return *this; } IntervalTree( interval_vector&& ivals, std::size_t depth = 16, std::size_t minbucket = 64, std::size_t maxbucket = 512, Scalar leftextent = 0, Scalar rightextent = 0) : left(nullptr) , right(nullptr) { --depth; const auto minmaxStop = std::minmax_element(ivals.begin(), ivals.end(), IntervalStopCmp()); const auto minmaxStart = std::minmax_element(ivals.begin(), ivals.end(), IntervalStartCmp()); if (!ivals.empty()) { center = (minmaxStart.first->start + minmaxStop.second->stop) / 2; } if (leftextent == 0 && rightextent == 0) { // sort intervals by start std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); } else { assert(std::is_sorted(ivals.begin(), ivals.end(), IntervalStartCmp())); } if (depth == 0 || (ivals.size() < minbucket && ivals.size() < maxbucket)) { std::sort(ivals.begin(), ivals.end(), IntervalStartCmp()); intervals = std::move(ivals); assert(is_valid().first); return; } else { Scalar leftp = 0; Scalar rightp = 0; if (leftextent || rightextent) { leftp = leftextent; rightp = rightextent; } else { leftp = ivals.front().start; rightp = std::max_element(ivals.begin(), ivals.end(), IntervalStopCmp())->stop; } interval_vector lefts; interval_vector rights; for (typename interval_vector::const_iterator i = ivals.begin(); i != ivals.end(); ++i) { const interval& interval = *i; if (interval.stop < center) { lefts.push_back(interval); } else if (interval.start > center) { rights.push_back(interval); } else { assert(interval.start <= center); assert(center <= interval.stop); intervals.push_back(interval); } } if (!lefts.empty()) { left.reset(new IntervalTree(std::move(lefts), depth, minbucket, maxbucket, leftp, center)); } if (!rights.empty()) { right.reset(new IntervalTree(std::move(rights), depth, minbucket, maxbucket, center, rightp)); } } assert(is_valid().first); } // Call f on all intervals near the range [start, stop]: template void visit_near(const Scalar& start, const Scalar& stop, UnaryFunction f) const { if (!intervals.empty() && ! (stop < intervals.front().start)) { for (auto & i : intervals) { f(i); } } if (left && start <= center) { left->visit_near(start, stop, f); } if (right && stop >= center) { right->visit_near(start, stop, f); } } // Call f on all intervals crossing pos template void visit_overlapping(const Scalar& pos, UnaryFunction f) const { visit_overlapping(pos, pos, f); } // Call f on all intervals overlapping [start, stop] template void visit_overlapping(const Scalar& start, const Scalar& stop, UnaryFunction f) const { auto filterF = [&](const interval& interval) { if (interval.stop >= start && interval.start <= stop) { // Only apply f if overlapping f(interval); } }; visit_near(start, stop, filterF); } // Call f on all intervals contained within [start, stop] template void visit_contained(const Scalar& start, const Scalar& stop, UnaryFunction f) const { auto filterF = [&](const interval& interval) { if (start <= interval.start && interval.stop <= stop) { f(interval); } }; visit_near(start, stop, filterF); } interval_vector findOverlapping(const Scalar& start, const Scalar& stop) const { interval_vector result; visit_overlapping(start, stop, [&](const interval& interval) { result.emplace_back(interval); }); return result; } interval_vector findContained(const Scalar& start, const Scalar& stop) const { interval_vector result; visit_contained(start, stop, [&](const interval& interval) { result.push_back(interval); }); return result; } bool empty() const { if (left && !left->empty()) { return false; } if (!intervals.empty()) { return false; } if (right && !right->empty()) { return false; } return true; } template void visit_all(UnaryFunction f) const { if (left) { left->visit_all(f); } std::for_each(intervals.begin(), intervals.end(), f); if (right) { right->visit_all(f); } } std::pair extentBruitForce() const { struct Extent { std::pair x = {std::numeric_limits::max(), std::numeric_limits::min() }; void operator()(const interval & interval) { x.first = std::min(x.first, interval.start); x.second = std::max(x.second, interval.stop); } }; Extent extent; visit_all([&](const interval & interval) { extent(interval); }); return extent.x; } // Check all constraints. // If first is false, second is invalid. std::pair> is_valid() const { const auto minmaxStop = std::minmax_element(intervals.begin(), intervals.end(), IntervalStopCmp()); const auto minmaxStart = std::minmax_element(intervals.begin(), intervals.end(), IntervalStartCmp()); std::pair> result = {true, { std::numeric_limits::max(), std::numeric_limits::min() }}; if (!intervals.empty()) { result.second.first = std::min(result.second.first, minmaxStart.first->start); result.second.second = std::min(result.second.second, minmaxStop.second->stop); } if (left) { auto valid = left->is_valid(); result.first &= valid.first; result.second.first = std::min(result.second.first, valid.second.first); result.second.second = std::min(result.second.second, valid.second.second); if (!result.first) { return result; } if (valid.second.second >= center) { result.first = false; return result; } } if (right) { auto valid = right->is_valid(); result.first &= valid.first; result.second.first = std::min(result.second.first, valid.second.first); result.second.second = std::min(result.second.second, valid.second.second); if (!result.first) { return result; } if (valid.second.first <= center) { result.first = false; return result; } } if (!std::is_sorted(intervals.begin(), intervals.end(), IntervalStartCmp())) { result.first = false; } return result; } void clear() { left.reset(); right.reset(); intervals.clear(); center = 0; } private: interval_vector intervals; std::unique_ptr left; std::unique_ptr right; Scalar center; }; #ifdef USE_INTERVAL_TREE_NAMESPACE } #endif #endif