Finding LCA of two nodes in a tree

There are many ways to find LCA of two nodes in a tree. We will discuss parent pointers and stack based solution in this article

Parent Pointers

There are three steps to find LCA using parent pointers.

  1. Find both the nodes
  2. Traverse backward from the node which has more depth until both nodes are on same level
  3. Traverse backward from both nodes until a common parent is found.

Suppose, we want to find LCA of node 4 & 6.

  1. Start with finding 4 and 6 in the tree
  2. As node 6 is on upper level, bring it to same level as node 4
  3. Traverse backward from both nodes, in this example after 1st iteration we found a common node 3, which is LCA

These steps are demonstrated in following diagram

//sampleStart
fun main() {
    // click + to see full script
    val adjList = mutableListOf<List<Int>>().apply {
        add(0, listOf(1, 2, 3))
        add(1, listOf())
        add(2, listOf())
        add(3, listOf(4, 5))
        add(4, listOf())
        add(5, listOf(6, 7))
        add(6, listOf())
        add(7, listOf())
    }
    val lc = LcaTwoPointers(adjList)
    listOf(
        Pair(6, 5),
        Pair(4, 7),
        Pair(2, 3),
        Pair(0, 0),
    ).forEach {
        println("lca of $it = ${lc.findLca(it.first, it.second)}")
    }
}
//sampleEnd

class LcaTwoPointers(
    private val adjList: List<List<Int>>
) {

    private val tree: Node

    init {
        // construct tree rooted at node 0 from graph
        fun dfs(node: Node) {
            adjList[node.value].forEach { neighbour ->
                if (neighbour != node.value) {
                    val child = Node(neighbour, node)
                    node.children.add(child)
                    dfs(child)
                }
            }
        }
        tree = Node(0, null).apply {
            dfs(this)
        }
    }

    fun findLca(first: Int, second: Int): Int {
        fun findNode(
            depth: Int,
            node: Node,
            findValue: Int
        ): Pair<Int, Node>? {

            if (findValue == node.value) {
                return Pair(depth, node)
            }

            for (i in 0 until node.children.size) {
                val found = findNode(
                    depth = depth + 1,
                    node = node.children[i],
                    findValue = findValue
                )
                if (found != null) {
                    return found
                }
            }
            return null
        }

        var (depth1, node1) = findNode(0, tree, first) ?: return -1
        var (depth2, node2) = findNode(0, tree, second) ?: return -1

        return if (depth1 > depth2) {
            while (depth1 != depth2 && node1.parent != null) {
                node1 = node1.parent!!
                depth1--
            }
            findCommonParent(node1, node2)
        } else {
            while (depth1 != depth2 && node2.parent != null) {
                node2 = node2.parent!!
                depth2--
            }
            findCommonParent(node1, node2)
        }
    }

    private fun findCommonParent(first: Node, second: Node): Int {
        var a = first
        var b = second
        while (true) {
            if (a == b) {
                return a.value
            } else {
                a = a.parent ?: return a.value
                b = b.parent ?: return b.value
            }
        }
    }

    class Node(
        val value: Int,
        var parent: Node? = null,
        val children: MutableList<Node> = mutableListOf()
    )
}

Using stacks

Same LCA method can be implemented using stacks. Steps are a bit different in this technique.

  1. Find both the nodes and save the path to separate stacks
  2. Start with the stack having greater size and pop the values until both stacks have same size
  3. Pop the values until a common node is found

For example, finding LCA of node 4 & 6 will create following stacks and popping values from both stacks until a common node is found, which results in node 3 as LCA.

 fun findLca(a: Int, b: Int): Int {

        val stackA = LinkedList<Int>()
        val stackB = LinkedList<Int>()

        fun findNode(
            stack: LinkedList<Int>,
            node: Node,
            findValue: Int
        ): Int {

            if (findValue == node.value) {
                return 0
            }

            for (i in 0 until node.children.size) {
                val child = node.children[i]
                stack.push(child.value)
                val found = findNode(
                    stack,
                    node = child,
                    findValue = findValue
                )
                if (found == 0) {
                    return found
                } else {
                    stack.pop()
                }
            }

            return -1
        }

        stackA.push(tree.value)
        if (findNode(stackA, tree, a) == -1) {
            return -1
        }
        stackB.push(tree.value)

        if (findNode(stackB, tree, b) == -1) {
            return -1
        }

        while (stackA.isNotEmpty() && stackB.isNotEmpty()) {
            if (stackA.peek() == stackB.peek()) {
                return stackA.peek()
            } else if (stackA.size > stackB.size) {
                stackA.pop()
            } else if (stackB.size > stackA.size) {
                stackB.pop()
            } else {
                stackA.pop()
                stackB.pop()
            }
        }

        return -1
    }

top