Skip to content

MST Algorithm

Minimum Spanning Tree network construction.

Overview

The MST (Minimum Spanning Tree) algorithm creates the simplest tree connecting all haplotypes with minimum total distance.

Characteristics: - Always produces a tree (no cycles) - Deterministic result - Fast computation - Best for initial exploration

Classes

pypopart.algorithms.mst

Minimum Spanning Tree (MST) algorithm for haplotype network construction.

This module implements both Prim's and Kruskal's algorithms for constructing minimum spanning trees from genetic sequence data. The MST forms the foundation for more complex network algorithms including MSN and median-joining networks.

References

.. [1] Excoffier, L. & Smouse, P. E. (1994). Using allele frequencies and geographic subdivision to reconstruct gene trees within a species: molecular variance parsimony. Genetics, 136(1), 343-359.

MinimumSpanningTree

Bases: NetworkAlgorithm

Construct a Minimum Spanning Tree from haplotype data.

A minimum spanning tree (MST) connects all haplotypes with the minimum total genetic distance. This is the simplest haplotype network algorithm and forms the basis for more complex methods like MSN (Minimum Spanning Network).

The MST is guaranteed to be a tree (no cycles) and provides the most parsimonious representation of relationships between haplotypes.

Supports both Prim's and Kruskal's algorithms for MST construction:

  • Prim's algorithm: Grows the tree from a single starting node, always adding the minimum-weight edge that connects a new node. Time complexity: O(E log V) with binary heap.

  • Kruskal's algorithm: Sorts all edges and adds them in order of increasing weight, skipping edges that would create cycles. Time complexity: O(E log E) with union-find.

Parameters:

Name Type Description Default
distance_method str

str, default='hamming'. Method for calculating pairwise distances between sequences.

'hamming'
Options

'hamming', 'jukes_cantor', 'kimura_2p', 'tamura_nei'.

required
algorithm str

str, default='prim'. MST construction algorithm to use: 'prim' or 'kruskal'.

'prim'
**kwargs

dict. Additional parameters passed to base NetworkAlgorithm.

{}

Attributes:

Name Type Description
algorithm str

The selected MST algorithm

_distance_matrix DistanceMatrix

Cached distance matrix from last construction

Examples:

>>> from pypopart.algorithms import MinimumSpanningTree
>>> from pypopart.io import load_alignment
>>>
>>> # Load alignment
>>> alignment = load_alignment('sequences.fasta')
>>>
>>> # Construct MST using Prim's algorithm
>>> mst = MinimumSpanningTree(algorithm='prim')
>>> network = mst.build_network(alignment)
>>>
>>> # Construct using Kruskal's algorithm
>>> mst = MinimumSpanningTree(algorithm='kruskal')
>>> network = mst.build_network(alignment)
Notes

For most applications, Prim's algorithm is preferred as it's typically faster and uses less memory. Kruskal's algorithm can be advantageous when the graph is sparse or when edges are already sorted.

See Also

MinimumSpanningNetwork : Extension of MST allowing alternative connections TCS : Statistical parsimony network construction

Source code in src/pypopart/algorithms/mst.py
class MinimumSpanningTree(NetworkAlgorithm):
    """
    Construct a Minimum Spanning Tree from haplotype data.

    A minimum spanning tree (MST) connects all haplotypes with the minimum total
    genetic distance. This is the simplest haplotype network algorithm and
    forms the basis for more complex methods like MSN (Minimum Spanning Network).

    The MST is guaranteed to be a tree (no cycles) and provides the most
    parsimonious representation of relationships between haplotypes.

    Supports both Prim's and Kruskal's algorithms for MST construction:

    - **Prim's algorithm**: Grows the tree from a single starting node,
      always adding the minimum-weight edge that connects a new node.
      Time complexity: O(E log V) with binary heap.

    - **Kruskal's algorithm**: Sorts all edges and adds them in order of
      increasing weight, skipping edges that would create cycles.
      Time complexity: O(E log E) with union-find.

    Parameters
    ----------
    distance_method :
        str, default='hamming'.
        Method for calculating pairwise distances between sequences.
    Options :
        'hamming', 'jukes_cantor', 'kimura_2p', 'tamura_nei'.
    algorithm :
        str, default='prim'.
        MST construction algorithm to use: 'prim' or 'kruskal'.
    **kwargs :
        dict.
        Additional parameters passed to base NetworkAlgorithm.

    Attributes
    ----------
    algorithm : str
        The selected MST algorithm
    _distance_matrix : DistanceMatrix
        Cached distance matrix from last construction

    Examples
    --------
    >>> from pypopart.algorithms import MinimumSpanningTree
    >>> from pypopart.io import load_alignment
    >>>
    >>> # Load alignment
    >>> alignment = load_alignment('sequences.fasta')
    >>>
    >>> # Construct MST using Prim's algorithm
    >>> mst = MinimumSpanningTree(algorithm='prim')
    >>> network = mst.build_network(alignment)
    >>>
    >>> # Construct using Kruskal's algorithm
    >>> mst = MinimumSpanningTree(algorithm='kruskal')
    >>> network = mst.build_network(alignment)

    Notes
    -----
    For most applications, Prim's algorithm is preferred as it's typically
    faster and uses less memory. Kruskal's algorithm can be advantageous
    when the graph is sparse or when edges are already sorted.

    See Also
    --------
    MinimumSpanningNetwork : Extension of MST allowing alternative connections
    TCS : Statistical parsimony network construction
    """

    def __init__(
        self, distance_method: str = 'hamming', algorithm: str = 'prim', **kwargs
    ):
        """
        Initialize MST algorithm.

        Parameters
        ----------
        distance_method :
            Method for calculating distances.
        algorithm :
            MST algorithm to use ('prim' or 'kruskal').
        **kwargs :
            Additional parameters.
        """
        super().__init__(distance_method, **kwargs)
        self.algorithm = algorithm.lower()
        if self.algorithm not in ('prim', 'kruskal'):
            raise ValueError(
                f"Unknown MST algorithm: {algorithm}. Use 'prim' or 'kruskal'"
            )

    def construct_network(
        self, alignment: Alignment, distance_matrix: Optional[DistanceMatrix] = None
    ) -> HaplotypeNetwork:
        """
            Construct MST from sequence alignment.

        Parameters
        ----------
            alignment :
                Multiple sequence alignment.
            distance_matrix :
                Optional pre-computed distance matrix.

        Returns
        -------
            Haplotype network representing the MST.
        """
        # Identify unique haplotypes
        haplotypes = identify_haplotypes(alignment)

        if len(haplotypes) == 0:
            return HaplotypeNetwork()

        if len(haplotypes) == 1:
            # Single haplotype - just add it to network
            network = HaplotypeNetwork()
            network.add_haplotype(haplotypes[0])
            return network

        # Calculate distances between haplotypes
        haplotype_dist_matrix = self._calculate_haplotype_distances(haplotypes)
        self._distance_matrix = haplotype_dist_matrix

        # Build MST using selected algorithm
        if self.algorithm == 'prim':
            edges = self._prim_mst(haplotypes, haplotype_dist_matrix)
        else:  # kruskal
            edges = self._kruskal_mst(haplotypes, haplotype_dist_matrix)

        # Construct network from MST edges
        network = self._build_network(haplotypes, edges)

        return network

    def _prim_mst(
        self, haplotypes: List, distance_matrix: DistanceMatrix
    ) -> List[Tuple[str, str, float]]:
        """
            Construct MST using Prim's algorithm.

        Parameters
        ----------
            haplotypes :
                List of Haplotype objects.
            distance_matrix :
                Distance matrix between haplotypes.

        Returns
        -------
            List of edges (id1, id2, distance).
        """
        if len(haplotypes) == 0:
            return []

        # Map haplotype IDs to haplotype objects for quick lookup
        hap_dict = {h.id: h for h in haplotypes}
        hap_ids = list(hap_dict.keys())

        # Track which nodes are in the tree
        in_tree = {hap_ids[0]}  # Start with first haplotype
        not_in_tree = set(hap_ids[1:])

        edges = []

        # Priority queue: (distance, from_id, to_id)
        pq = []

        # Add all edges from first haplotype
        for other_id in not_in_tree:
            dist = distance_matrix.get_distance(hap_ids[0], other_id)
            heapq.heappush(pq, (dist, hap_ids[0], other_id))

        # Build MST
        while not_in_tree and pq:
            dist, from_id, to_id = heapq.heappop(pq)

            # Skip if to_id already in tree (edge is outdated)
            if to_id not in not_in_tree:
                continue

            # Add edge to MST
            edges.append((from_id, to_id, dist))
            in_tree.add(to_id)
            not_in_tree.remove(to_id)

            # Add new edges from newly added node
            for other_id in not_in_tree:
                new_dist = distance_matrix.get_distance(to_id, other_id)
                heapq.heappush(pq, (new_dist, to_id, other_id))

        return edges

    def _kruskal_mst(
        self, haplotypes: List, distance_matrix: DistanceMatrix
    ) -> List[Tuple[str, str, float]]:
        """
            Construct MST using Kruskal's algorithm with Union-Find.

        Parameters
        ----------
            haplotypes :
                List of Haplotype objects.
            distance_matrix :
                Distance matrix between haplotypes.

        Returns
        -------
            List of edges (id1, id2, distance).
        """
        if len(haplotypes) == 0:
            return []

        hap_ids = [h.id for h in haplotypes]

        # Get all edges sorted by distance
        all_edges = []
        for i, id1 in enumerate(hap_ids):
            for id2 in hap_ids[i + 1 :]:
                dist = distance_matrix.get_distance(id1, id2)
                all_edges.append((dist, id1, id2))

        all_edges.sort()  # Sort by distance

        # Union-Find data structure
        parent = {hap_id: hap_id for hap_id in hap_ids}
        rank = dict.fromkeys(hap_ids, 0)

        def find(x):
            """Find root of x with path compression."""
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]

        def union(x, y):
            """Union sets containing x and y."""
            root_x = find(x)
            root_y = find(y)

            if root_x == root_y:
                return False

            # Union by rank
            if rank[root_x] < rank[root_y]:
                parent[root_x] = root_y
            elif rank[root_x] > rank[root_y]:
                parent[root_y] = root_x
            else:
                parent[root_y] = root_x
                rank[root_x] += 1

            return True

        # Build MST
        mst_edges = []
        for dist, id1, id2 in all_edges:
            if union(id1, id2):
                mst_edges.append((id1, id2, dist))
                # Stop when we have n-1 edges
                if len(mst_edges) == len(hap_ids) - 1:
                    break

        return mst_edges

    def _build_network(
        self, haplotypes: List, edges: List[Tuple[str, str, float]]
    ) -> HaplotypeNetwork:
        """
            Build HaplotypeNetwork from haplotypes and MST edges.

        Parameters
        ----------
            haplotypes :
                List of Haplotype objects.
            edges :
                List of edges (id1, id2, distance).

        Returns
        -------
            Constructed haplotype network.
        """
        network = HaplotypeNetwork()

        # Add all haplotypes
        for haplotype in haplotypes:
            network.add_haplotype(haplotype)

        # Add MST edges
        for id1, id2, dist in edges:
            network.add_edge(id1, id2, distance=int(round(dist)))

        return network

    def _calculate_haplotype_distances(self, haplotypes: List) -> DistanceMatrix:
        """
            Calculate pairwise distances between haplotypes.

        Parameters
        ----------
            haplotypes :
                List of Haplotype objects.

        Returns
        -------
            DistanceMatrix with distances between haplotypes.
        """
        import numpy as np

        from ..core.distance import hamming_distance

        n = len(haplotypes)
        labels = [h.id for h in haplotypes]
        matrix = np.zeros((n, n))

        for i in range(n):
            for j in range(i + 1, n):
                dist = hamming_distance(
                    haplotypes[i].sequence,
                    haplotypes[j].sequence,
                    ignore_gaps=self.params.get('ignore_gaps', True),
                )
                matrix[i, j] = matrix[j, i] = dist

        return DistanceMatrix(labels, matrix)

    def get_parameters(self) -> dict:
        """Get algorithm parameters including MST algorithm type."""
        params = super().get_parameters()
        params['algorithm'] = self.algorithm
        return params
__init__
__init__(
    distance_method: str = "hamming",
    algorithm: str = "prim",
    **kwargs
)

Initialize MST algorithm.

Parameters:

Name Type Description Default
distance_method str

Method for calculating distances.

'hamming'
algorithm str

MST algorithm to use ('prim' or 'kruskal').

'prim'
**kwargs

Additional parameters.

{}
Source code in src/pypopart/algorithms/mst.py
def __init__(
    self, distance_method: str = 'hamming', algorithm: str = 'prim', **kwargs
):
    """
    Initialize MST algorithm.

    Parameters
    ----------
    distance_method :
        Method for calculating distances.
    algorithm :
        MST algorithm to use ('prim' or 'kruskal').
    **kwargs :
        Additional parameters.
    """
    super().__init__(distance_method, **kwargs)
    self.algorithm = algorithm.lower()
    if self.algorithm not in ('prim', 'kruskal'):
        raise ValueError(
            f"Unknown MST algorithm: {algorithm}. Use 'prim' or 'kruskal'"
        )
construct_network
construct_network(
    alignment: Alignment,
    distance_matrix: Optional[DistanceMatrix] = None,
) -> HaplotypeNetwork
Construct MST from sequence alignment.

Returns:

Type Description
Haplotype network representing the MST.
Source code in src/pypopart/algorithms/mst.py
def construct_network(
    self, alignment: Alignment, distance_matrix: Optional[DistanceMatrix] = None
) -> HaplotypeNetwork:
    """
        Construct MST from sequence alignment.

    Parameters
    ----------
        alignment :
            Multiple sequence alignment.
        distance_matrix :
            Optional pre-computed distance matrix.

    Returns
    -------
        Haplotype network representing the MST.
    """
    # Identify unique haplotypes
    haplotypes = identify_haplotypes(alignment)

    if len(haplotypes) == 0:
        return HaplotypeNetwork()

    if len(haplotypes) == 1:
        # Single haplotype - just add it to network
        network = HaplotypeNetwork()
        network.add_haplotype(haplotypes[0])
        return network

    # Calculate distances between haplotypes
    haplotype_dist_matrix = self._calculate_haplotype_distances(haplotypes)
    self._distance_matrix = haplotype_dist_matrix

    # Build MST using selected algorithm
    if self.algorithm == 'prim':
        edges = self._prim_mst(haplotypes, haplotype_dist_matrix)
    else:  # kruskal
        edges = self._kruskal_mst(haplotypes, haplotype_dist_matrix)

    # Construct network from MST edges
    network = self._build_network(haplotypes, edges)

    return network
get_parameters
get_parameters() -> dict

Get algorithm parameters including MST algorithm type.

Source code in src/pypopart/algorithms/mst.py
def get_parameters(self) -> dict:
    """Get algorithm parameters including MST algorithm type."""
    params = super().get_parameters()
    params['algorithm'] = self.algorithm
    return params