Prim's minimum spanning tree

Prim’s algorithm is a method for finding a minimum spanning tree (MST) in a connected, weighted undirected graph.

Minimum spanning tree

A MST is a subset of the edges in the graph that connects all vertices and has the minimum total weight (i.e., the sum of the weights of all edges in the MST is as small as possible).

For a MST to form, all vertices must be connected. Graphs with disconnected components cannot form a MST.

A spanning tree will have |V| - 1 edges, where |V| is the number of vertices.

A graph can have multiple spanning trees. The formula for no. of spanning tree is

|e| C (|V| - 1) - no. of cycles

Algorithm

Prim’s algorithm starts at an arbitrary vertex and grows the MST by iteratively adding the edge with the minimum weight that connects a vertex in the MST to a vertex outside the MST. The algorithm stops when all vertices are included in the MST. The resulting MST is a tree, i.e., it has no cycles and is connected, and its total weight is minimized.

Steps

  1. Create a min heap (priority queue) to store edges with weights
  2. Visit the start vertex, mark it as visited and add all adjacent edges to the queue
  3. Pick next best edge based on the weight from queue
  4. If the vertex (edge is pointing to) is not visited, mark it as visited and put all adjacent edges to the queue
  5. Repeat the steps until a min spanning tree is formed or queue is empty
  • The following script finds MST in the graph given above
import java.util.*

fun main() {
    val vertexToAdjacentEdges = mutableListOf<List<Edge>>().apply {
        add(
            0, listOf(
                Edge(0, 3, 1),
                Edge(0, 4, 2),
            )
        )

        add(
            1, listOf(
                Edge(1, 3, 1),
                Edge(1, 5, 4),
                Edge(1, 8, 4),
            )
        )

        add(
            2, listOf(
                Edge(2, 6, 3),
            )
        )

        add(
            3, listOf(
                Edge(3, 0, 1),
                Edge(3, 1, 1),
            )
        )


        add(
            4, listOf(
                Edge(4, 0, 2),
                Edge(4, 5, 3),
                Edge(4, 6, 3),
            )
        )


        add(
            5, listOf(
                Edge(5, 1, 4),
                Edge(5, 4, 3),
                Edge(5, 7, 1),
            )
        )


        add(
            6, listOf(
                Edge(6, 4, 3),
                Edge(6, 7, 1),
                Edge(6, 2, 3)The resulting MST
        )

        add(
            7, listOf(
                Edge(7, 5, 1),
                Edge(7, 6, 1),
                Edge(7, 8, 2),
            )
        )

        add(
            8, listOf(
                Edge(8, 1, 4),
                Edge(8, 7, 2),
            )
        )
    }

    minimumSpanningTree(0, vertexToAdjacentEdges).also {
        println("MST: total weight -> ${it.first}")
        it.second.forEach(::println)
    }

}

//sampleStart
data class Edge(
    val from: Int,
    val to: Int,
    val weight: Int
)

fun minimumSpanningTree(
    startVertex: Int,
    vertexToAdjacentEdges: List<List<Edge>>
): Pair<Int, List<Edge>> {
    // click + to see full script
    if (vertexToAdjacentEdges.isEmpty()) return Pair(0, listOf())

    val tree = mutableListOf<Edge>()
    val visited = BooleanArray(vertexToAdjacentEdges.size)

    // create min heap to store edges
    val queue = PriorityQueue<Edge>(Comparator { o1, o2 ->
        o1.weight - o2.weight
    })

    /**
     * Mark start vertex as visited and all adjacent edges
     * to the queue
     */
    visited[startVertex] = true
    vertexToAdjacentEdges[startVertex].forEach { edge -> queue.add(edge) }

    var edgeCount = 0
    var totalWeight = 0

    // total edges in spanning tree = no. of nodes - 1
    val totalEdges = vertexToAdjacentEdges.size.dec()

    while (queue.isNotEmpty() && edgeCount != totalEdges) {
        val (vertex, adjacentVertex, weight) = queue.poll()
        if (visited[adjacentVertex].not()) {
            totalWeight += weight
            vertexToAdjacentEdges[adjacentVertex].forEach { edge ->
                queue.add(edge)
            }
            visited[adjacentVertex] = true
            tree.add(Edge(vertex, adjacentVertex, weight))
            edgeCount++
        }
    }
    return Pair(totalWeight, tree)
}
//sampleEnd
  • The resulting MST

Time complexity

The complexity of Prim’s MST is O(E*log(E)) in the above script. It can be improved to O(E*log(V)) by replacing queue with Indexed Priority Queue. In that way, only vertices can be stored updated in the queue.

Usage

  • One of the main applications of Prim’s algorithm is in designing and optimizing computer networks. Given a network represented as a graph, where nodes represent devices and edges represent connections between them, an MST can help optimize the network by identifying the minimum set of connections required to ensure all devices are connected. By minimizing the number of connections, the algorithm can reduce the overall cost of the network while still maintaining its functionality.

  • Another common application is in geographic information systems (GIS), where MSTs can be used to identify the shortest path or the optimal route between multiple locations. This can help in various applications such as delivery optimization, routing of emergency services, and traffic management.

  • Overall, Prim’s algorithm for finding an MST is widely used in many real-world applications that require optimization of networks or paths.

top