Segment tree

Segment tree data structure provides O(log n) time operations to process a range query and update and array value.

Segment trees support sum queries,minimum queries, and many other queries. more info here

Implementation

We will use bottom-up technique to construct the segment tree. It requires 2*N space. The steps are given below

  1. Decide which value each node in segment tree will store i.e. sum, max, min etc
  2. Create an array of size 2 * N where N is the number of items in original array
  3. Copy original array to the end of new array
  4. Construct the tree backwards, start at position N, pick left and right nodes for position and apply the function (sum, max etc)
  • The following diagram show segment tree for sum queries. Starting from bottom, nodes are picked and sum function is applied until we reach the root node which contains the sum of whole array.

Finding sum of a range

The sum of range can be found using following function, this can be modified to perform other operations like min or max.

int sum(int a, int b) {
    a += n; b += n;
    int s = 0;
    while (a <= b) {
        if (a%2 == 1) s += tree[a++];
        if (b%2 == 0) s += tree[b--];
        a /= 2; b /= 2;
    }
    return s;
}

The even and odd checks in the sum function are performing some adjustments for each node in the segment tree that corresponds to the range you want to sum.

If the left endpoint of the range is odd, then it means the corresponding node in the tree does not start at the beginning of the range, so you need to include the first value in the sum.

Similarly, if the right endpoint of the range is even, then it means the corresponding node in the tree does not end at the end of the range, so you need to exclude the last value from the sum.

This can be verified if you show the ranges that each node covers in the above tree. See the image below

Segment tree sample for example given above

import kotlin.math.min


fun main() {
    val arr = arrayOf(3, 0, 2, 4, 1, 5, 6, 7)
    val minTree = SegmentTree(arr) { a, b -> min(a, b) }
    val sumTree = SegmentTree(arr) { a, b -> a + b }

    printArrayWithIndices("Array", arr)
    printArrayWithIndices("Min Tree", minTree.tree)
    printArrayWithIndices("Sum Tree", sumTree.tree)
    println()

    val range = 1..5

    val minimum = minTree.computeValue(
        range = range,
        initialValue = Int.MAX_VALUE
    ) { previous, new ->
        min(previous, new)
    }

    val sum = sumTree.computeValue(
        range = range,
        initialValue = 0
    ) { previous, new ->
        previous + new
    }

    println("Min of range: $range = $minimum")
    println("Sum of range: $range = $sum")
    minTree.update(3, -5)
    println("update minTree index 3 to -5")
    minTree.computeValue(
        range = range,
        initialValue = Int.MAX_VALUE
    ) { previous, new ->
        min(previous, new)
    }.also {
        println("Min of range: $range = $it")
    }
}

//sampleStart
class SegmentTree(
    private val arr: Array<Int>,
    private val nodeOperation: (a: Int, b: Int) -> Int
) {
    // click + to see full script

    val tree: Array<Int>

    init {
        // construct segment tree
        val size = arr.size
        tree = Array(arr.size.times(2)) { 0 }
        System.arraycopy(arr, 0, tree, size, size)
        for (i in size.dec() downTo 1) {
            tree[i] = nodeOperation.invoke(
                tree[i.times(2)],
                tree[i.times(2).inc()]
            )
        }
    }

    fun update(index: Int, value: Int) {
        var idx = index.plus(arr.size)
        tree[idx] = value
        while (idx != 0) {
            idx = idx.shr(1)
            tree[idx] = nodeOperation(tree[idx.shl(1)], tree[idx.shl(1) + 1])
        }
    }

    fun computeValue(
        range: IntRange, initialValue: Int,
        merge: (previousValue: Int, newValue: Int) -> Int
    ): Int {
        var from = range.first.plus(arr.size)
        var to = range.last.plus(arr.size)
        var value = initialValue

        while (from <= to) {
            // if start index is odd, the node does not start at the
            // beginning of range. pick it
            if (from % 2 == 1) {
                value = merge(value, tree[from++])
            }

            // if end index is even, the node does not end at the
            // end of range. pick it
            if (to % 2 == 0) {
                value = merge(value, tree[to--])
            }
            // go to upper level
            from /= 2
            to /= 2
        }
        return value
    }
}
//sampleEnd

fun printArrayWithIndices(name: String, tree: Array<Int>) {
    println()
    System.out.printf("%-8s: ", name)
    tree.forEachIndexed { index, value ->
        val suffix = if (index < tree.size.dec()) ", " else ""
        System.out.printf("%2d$suffix", value)
    }
    println()
    System.out.printf("%-8s: ", "Index")
    tree.indices.forEachIndexed { index, value ->
        val suffix = if (index < tree.size.dec()) ", " else ""
        System.out.printf("%2d$suffix", value)
    }
    println()
}

top