Kruskal's minimum spanning tree

Kruskal’s algorithm finds the minimum spanning tree by picking the minimum cost edge from the set of edges. The edge is ignored if it forms a cycle (disjoint-sets are used to detect cycles).

Steps

  1. Create disjoint sets data structure
  2. Pick minimum cost edge and add to disjoint set
  3. If the edge is forming a cycle, ignore it
  4. Pick the next edge and repeat the process until all edges are visited

Example

  • The following script finds MST in the graph given above

The script uses disjoint-sets to detect if edge is forming cycle and stores edges in a queue (min heap) where edges are queue sorted on the minimum weight

Click here for more information about disjoint-sets (union-find)

import java.util.*
import kotlin.math.absoluteValue

fun main() {
    val edgeList = listOf(
            Edge(0, 3, 1),
            Edge(0, 4, 2),
            Edge(1, 3, 1),
            Edge(1, 5, 4),
            Edge(1, 8, 4),
            Edge(2, 6, 3),
            Edge(3, 0, 1),
            Edge(3, 1, 1),
            Edge(4, 0, 2),
            Edge(4, 5, 3),
            Edge(4, 6, 3),
            Edge(5, 1, 4),
            Edge(5, 4, 3),
            Edge(5, 7, 1),
            Edge(6, 4, 3),
            Edge(6, 7, 1),
            Edge(6, 2, 3),
            Edge(7, 5, 1),
            Edge(7, 6, 1),
            Edge(7, 8, 2),
            Edge(8, 1, 4),
            Edge(8, 7, 2),
    )
    kruskalMinimumSpanningTree(noOfVertices = 9,
            edgeList = edgeList).also {
        println("MST: total weight -> ${it.first}")
        it.second.forEach(::println)
    }
}

//sampleStart

data class Edge(
    val from: Int,
    val to: Int,
    val weight: Int
)

/**
 * Finds MST using Kruskal's algorithm. Retuns a pair of MST and
 * total edge cost
 */
fun kruskalMinimumSpanningTree(
        noOfVertices: Int,
        edgeList: List<Edge>
): Pair<Int, List<Edge>> {
    // click + to see full script
    val tree = mutableListOf<Edge>()
    var totalWeight = 0

    /**
     * Create min heap of edges
     */
    val queue = PriorityQueue<Edge>(Comparator { firstEdge, secondEdge ->
        firstEdge.weight - secondEdge.weight
    })
    edgeList.forEach { edge -> queue.offer(edge) }

    val unionFind = UnionFind(noOfVertices = noOfVertices)

    // total edges will be (no. of vertices  - 1)
    val totalEdgeCount = noOfVertices.dec()
    val edgeCount = 0

    while (queue.isNotEmpty() && edgeCount != totalEdgeCount) {
        val currentEdge = queue.poll()
        if (unionFind.addEdge(currentEdge.from, currentEdge.to)) {
            // edge is not forming a cycle, add it to tree
            // and update totalWeight
            tree.add(currentEdge)
            totalWeight += currentEdge.weight
        }
    }
    return Pair(totalWeight, tree)
}
//sampleEnd

/**
 * Weighted union with path compression
 */
class UnionFind(noOfVertices: Int) {
    // click + to see full script
    /**
     * Set of all the vertices. Index represents the vertex
     * and value store the parent of vertex in the tree.
     *
     * This is weighted union, initially all vertices have parent
     * -1, which means that every vertex is in its own set. The rank
     * is stored as -ive value. This won't work if vertices have -ive
     * values. Other data structures can be used in that case.
     *
     * Later on when different vertices are combine to one set i.e.
     * union is performed, the vertex value will contain the no. of
     * child vertices to which this vertex is parent i.e. all vertices
     * all in same set.
     */
    private val vertexSets = Array(noOfVertices) { -1 }

    /**
     * Return false if both vertices belong to
     * same set, true otherwise
     */
    fun addEdge(from: Int, to: Int): Boolean {
        val parentFrom = find(from)
        val parentTo = find(to)
        if (parentFrom == parentTo) {
            return false
        }
        union(parentFrom, parentTo)
        return true
    }

    private fun find(vertex: Int): Int {
        if (vertexSets[vertex] <= -1) {
            return vertex
        }
        return find(vertexSets[vertex])
    }

    /**
     * Based on ranking, merge parents. This also assigns
     * root parent to other vertices instead of direct parent which
     * is also called path compression.
     */
    private fun union(parentA: Int, parentB: Int) {
        val rankA = vertexSets[parentA].absoluteValue
        val rankB = vertexSets[parentB].absoluteValue
        /**
         * Compare absolute value to pick greater rank parent
         * as update the rank again as -ive value
         */
        if (rankA >= rankB) {
            vertexSets[parentB] = parentA
            vertexSets[parentA] = rankA.plus(rankB).times(-1)
        } else {
            vertexSets[parentA] = parentB
            vertexSets[parentB] = rankA.plus(rankB).times(-1)
        }
    }
   
    override fun toString(): String {
       return "Parents: ${vertexSets.contentToString()}\nVertices: ${vertexSets.indices.toList()}"
    }
}

top