Segment tree
Data-structures #data-structuresSegment tree data structure provides O(log n)
time operations to process a range query
and update and array value.
Segment trees support sum queries,minimum queries, and many other queries. more info here
Implementation
We will use bottom-up technique to construct the segment tree. It requires 2*N
space. The steps are
given below
- Decide which value each node in segment tree will store i.e. sum, max, min etc
- Create an array of size
2 * N
whereN
is the number of items in original array - Copy original array to the end of new array
- Construct the tree backwards, start at position
N
, pick left and right nodes for position and apply the function (sum, max etc)
- The following diagram show segment tree for sum queries. Starting from bottom, nodes are picked and sum function is applied until we reach the root node which contains the sum of whole array.
Finding sum of a range
The sum of range can be found using following function, this can be modified to perform other operations like min or max.
int sum(int a, int b) {
a += n; b += n;
int s = 0;
while (a <= b) {
if (a%2 == 1) s += tree[a++];
if (b%2 == 0) s += tree[b--];
a /= 2; b /= 2;
}
return s;
}
The even and odd checks in the sum function are performing some adjustments for each node in the segment tree that corresponds to the range you want to sum.
If the left endpoint of the range is odd, then it means the corresponding node in the tree does not start at the beginning of the range, so you need to include the first value in the sum.
Similarly, if the right endpoint of the range is even, then it means the corresponding node in the tree does not end at the end of the range, so you need to exclude the last value from the sum.
This can be verified if you show the ranges that each node covers in the above tree. See the image below
Segment tree sample for example given above
import kotlin.math.min
fun main() {
val arr = arrayOf(3, 0, 2, 4, 1, 5, 6, 7)
val minTree = SegmentTree(arr) { a, b -> min(a, b) }
val sumTree = SegmentTree(arr) { a, b -> a + b }
printArrayWithIndices("Array", arr)
printArrayWithIndices("Min Tree", minTree.tree)
printArrayWithIndices("Sum Tree", sumTree.tree)
println()
val range = 1..5
val minimum = minTree.computeValue(
range = range,
initialValue = Int.MAX_VALUE
) { previous, new ->
min(previous, new)
}
val sum = sumTree.computeValue(
range = range,
initialValue = 0
) { previous, new ->
previous + new
}
println("Min of range: $range = $minimum")
println("Sum of range: $range = $sum")
minTree.update(3, -5)
println("update minTree index 3 to -5")
minTree.computeValue(
range = range,
initialValue = Int.MAX_VALUE
) { previous, new ->
min(previous, new)
}.also {
println("Min of range: $range = $it")
}
}
//sampleStart
class SegmentTree(
private val arr: Array<Int>,
private val nodeOperation: (a: Int, b: Int) -> Int
) {
// click + to see full script
val tree: Array<Int>
init {
// construct segment tree
val size = arr.size
tree = Array(arr.size.times(2)) { 0 }
System.arraycopy(arr, 0, tree, size, size)
for (i in size.dec() downTo 1) {
tree[i] = nodeOperation.invoke(
tree[i.times(2)],
tree[i.times(2).inc()]
)
}
}
fun update(index: Int, value: Int) {
var idx = index.plus(arr.size)
tree[idx] = value
while (idx != 0) {
idx = idx.shr(1)
tree[idx] = nodeOperation(tree[idx.shl(1)], tree[idx.shl(1) + 1])
}
}
fun computeValue(
range: IntRange, initialValue: Int,
merge: (previousValue: Int, newValue: Int) -> Int
): Int {
var from = range.first.plus(arr.size)
var to = range.last.plus(arr.size)
var value = initialValue
while (from <= to) {
// if start index is odd, the node does not start at the
// beginning of range. pick it
if (from % 2 == 1) {
value = merge(value, tree[from++])
}
// if end index is even, the node does not end at the
// end of range. pick it
if (to % 2 == 0) {
value = merge(value, tree[to--])
}
// go to upper level
from /= 2
to /= 2
}
return value
}
}
//sampleEnd
fun printArrayWithIndices(name: String, tree: Array<Int>) {
println()
System.out.printf("%-8s: ", name)
tree.forEachIndexed { index, value ->
val suffix = if (index < tree.size.dec()) ", " else ""
System.out.printf("%2d$suffix", value)
}
println()
System.out.printf("%-8s: ", "Index")
tree.indices.forEachIndexed { index, value ->
val suffix = if (index < tree.size.dec()) ", " else ""
System.out.printf("%2d$suffix", value)
}
println()
}