The Algorithms logo
The Algorithms
AboutDonate

Prim's Algorithm (Adjacency Matrix)

using System;

namespace Algorithms.Graph.MinimumSpanningTree
{
    /// <summary>
    ///     Class that uses Prim's (Jarnik's algorithm) to determine the minimum
    ///     spanning tree (MST) of a given graph. Prim's algorithm is a greedy
    ///     algorithm that can determine the MST of a weighted undirected graph
    ///     in O(V^2) time where V is the number of nodes/vertices when using an
    ///     adjacency matrix representation.
    ///     More information: https://en.wikipedia.org/wiki/Prim%27s_algorithm
    ///     Pseudocode and runtime analysis: https://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/GraphAlgor/primAlgor.htm .
    /// </summary>
    public static class PrimMatrix
    {
        /// <summary>
        ///     Determine the minimum spanning tree for a given weighted undirected graph.
        /// </summary>
        /// <param name="adjacencyMatrix">Adjacency matrix for graph to find MST of.</param>
        /// <param name="start">Node to start search from.</param>
        /// <returns>Adjacency matrix of the found MST.</returns>
        public static float[,] Solve(float[,] adjacencyMatrix, int start)
        {
            ValidateMatrix(adjacencyMatrix);

            var numNodes = adjacencyMatrix.GetLength(0);

            // Create array to represent minimum spanning tree
            var mst = new float[numNodes, numNodes];

            // Create array to keep track of which nodes are in the MST already
            var added = new bool[numNodes];

            // Create array to keep track of smallest edge weight for node
            var key = new float[numNodes];

            // Create array to store parent of node
            var parent = new int[numNodes];

            for (var i = 0; i < numNodes; i++)
            {
                mst[i, i] = float.PositiveInfinity;
                key[i] = float.PositiveInfinity;

                for (var j = i + 1; j < numNodes; j++)
                {
                    mst[i, j] = float.PositiveInfinity;
                    mst[j, i] = float.PositiveInfinity;
                }
            }

            // Ensures that the starting node is added first
            key[start] = 0;

            // Keep looping until all nodes are in tree
            for (var i = 0; i < numNodes - 1; i++)
            {
                GetNextNode(adjacencyMatrix, key, added, parent);
            }

            // Build adjacency matrix for tree
            for (var i = 0; i < numNodes; i++)
            {
                if (i == start)
                {
                    continue;
                }

                mst[i, parent[i]] = adjacencyMatrix[i, parent[i]];
                mst[parent[i], i] = adjacencyMatrix[i, parent[i]];
            }

            return mst;
        }

        /// <summary>
        ///     Ensure that the given adjacency matrix represents a weighted undirected graph.
        /// </summary>
        /// <param name="adjacencyMatrix">Adjacency matric to check.</param>
        private static void ValidateMatrix(float[,] adjacencyMatrix)
        {
            // Matrix should be square
            if (adjacencyMatrix.GetLength(0) != adjacencyMatrix.GetLength(1))
            {
                throw new ArgumentException("Adjacency matrix must be square!");
            }

            // Graph needs to be undirected and connected
            for (var i = 0; i < adjacencyMatrix.GetLength(0); i++)
            {
                var connection = false;
                for (var j = 0; j < adjacencyMatrix.GetLength(0); j++)
                {
                    if (Math.Abs(adjacencyMatrix[i, j] - adjacencyMatrix[j, i]) > 1e-6)
                    {
                        throw new ArgumentException("Adjacency matrix must be symmetric!");
                    }

                    if (!connection && float.IsFinite(adjacencyMatrix[i, j]))
                    {
                        connection = true;
                    }
                }

                if (!connection)
                {
                    throw new ArgumentException("Graph must be connected!");
                }
            }
        }

        /// <summary>
        ///     Determine which node should be added next to the MST.
        /// </summary>
        /// <param name="adjacencyMatrix">Adjacency matrix of graph.</param>
        /// <param name="key">Currently known minimum edge weight connected to each node.</param>
        /// <param name="added">Whether or not a node has been added to the MST.</param>
        /// <param name="parent">The node that added the node to the MST. Used for building MST adjacency matrix.</param>
        private static void GetNextNode(float[,] adjacencyMatrix, float[] key, bool[] added, int[] parent)
        {
            var numNodes = adjacencyMatrix.GetLength(0);
            var minWeight = float.PositiveInfinity;

            var node = -1;

            // Find node with smallest node with known edge weight not in tree. Will always start with starting node
            for (var i = 0; i < numNodes; i++)
            {
                if (!added[i] && key[i] < minWeight)
                {
                    minWeight = key[i];
                    node = i;
                }
            }

            // Add node to mst
            added[node] = true;

            // Update smallest found edge weights and parent for adjacent nodes
            for (var i = 0; i < numNodes; i++)
            {
                if (!added[i] && adjacencyMatrix[node, i] < key[i])
                {
                    key[i] = adjacencyMatrix[node, i];
                    parent[i] = node;
                }
            }
        }
    }
}