We will use bottom-up technique to construct the segment tree. It requires 2*N space. The steps are
given below
Decide which value each node in segment tree will store i.e. sum, max, min etc
Create an array of size 2 * N where N is the number of items in original array
Copy original array to the end of new array
Construct the tree backwards, start at position N, pick left and right nodes for position and apply
the function (sum, max etc)
The following diagram show segment tree for sum queries. Starting from bottom, nodes are picked and
sum function is applied until we reach the root node which contains the sum of whole array.
The even and odd checks in the sum function are performing some adjustments for
each node in the segment tree that corresponds to the range you want to sum.
If the left endpoint of the range is odd, then it means the corresponding node in the tree
does not start at the beginning of the range, so you need to include the first value in the sum.
Similarly, if the right endpoint of the range is even, then it means the corresponding node
in the tree does not end at the end of the range, so you need to exclude the last value
from the sum.
This can be verified if you show the ranges that each node covers in the above tree. See the image below
import kotlin.math.min
fun main() {
val arr = arrayOf(3, 0, 2, 4, 1, 5, 6, 7)
val minTree = SegmentTree(arr) { a, b -> min(a, b) }
val sumTree = SegmentTree(arr) { a, b -> a + b }
printArrayWithIndices("Array", arr)
printArrayWithIndices("Min Tree", minTree.tree)
printArrayWithIndices("Sum Tree", sumTree.tree)
println()
val range = 1..5
val minimum = minTree.computeValue(
range = range,
initialValue = Int.MAX_VALUE
) { previous, new ->
min(previous, new)
}
val sum = sumTree.computeValue(
range = range,
initialValue = 0
) { previous, new ->
previous + new
}
println("Min of range: $range = $minimum")
println("Sum of range: $range = $sum")
minTree.update(3, -5)
println("update minTree index 3 to -5")
minTree.computeValue(
range = range,
initialValue = Int.MAX_VALUE
) { previous, new ->
min(previous, new)
}.also {
println("Min of range: $range = $it")
}
}
//sampleStart
class SegmentTree(
private val arr: Array<Int>,
private val nodeOperation: (a: Int, b: Int) -> Int
) {
// click + to see full script
val tree: Array<Int>
init {
// construct segment tree
val size = arr.size
tree = Array(arr.size.times(2)) { 0 }
System.arraycopy(arr, 0, tree, size, size)
for (i in size.dec() downTo 1) {
tree[i] = nodeOperation.invoke(
tree[i.times(2)],
tree[i.times(2).inc()]
)
}
}
fun update(index: Int, value: Int) {
var idx = index.plus(arr.size)
tree[idx] = value
while (idx != 0) {
idx = idx.shr(1)
tree[idx] = nodeOperation(tree[idx.shl(1)], tree[idx.shl(1) + 1])
}
}
fun computeValue(
range: IntRange, initialValue: Int,
merge: (previousValue: Int, newValue: Int) -> Int
): Int {
var from = range.first.plus(arr.size)
var to = range.last.plus(arr.size)
var value = initialValue
while (from <= to) {
// if start index is odd, the node does not start at the
// beginning of range. pick it
if (from % 2 == 1) {
value = merge(value, tree[from++])
}
// if end index is even, the node does not end at the
// end of range. pick it
if (to % 2 == 0) {
value = merge(value, tree[to--])
}
// go to upper level
from /= 2
to /= 2
}
return value
}
}
//sampleEnd
fun printArrayWithIndices(name: String, tree: Array<Int>) {
println()
System.out.printf("%-8s: ", name)
tree.forEachIndexed { index, value ->
val suffix = if (index < tree.size.dec()) ", " else ""
System.out.printf("%2d$suffix", value)
}
println()
System.out.printf("%-8s: ", "Index")
tree.indices.forEachIndexed { index, value ->
val suffix = if (index < tree.size.dec()) ", " else ""
System.out.printf("%2d$suffix", value)
}
println()
}