Algorithm Comparison Tutorial¶
Learn how to compare different network algorithms and choose the best one for your data.
Overview¶
We'll compare MST, MSN, TCS, and MJN algorithms on the same dataset to understand their differences.
Setup¶
from pypopart import Alignment
from pypopart.algorithms import MSTAlgorithm, MSNAlgorithm, TCSAlgorithm, MJNAlgorithm
from pypopart.visualization import StaticPlot
from pypopart.stats import NetworkStatistics
import matplotlib.pyplot as plt
# Load data
alignment = Alignment.from_fasta("sequences.fasta")
Build Networks with All Algorithms¶
algorithms = {
"MST": MSTAlgorithm(),
"MSN": MSNAlgorithm(),
"TCS": TCSAlgorithm(epsilon=0.95),
"MJN": MJNAlgorithm(),
}
networks = {}
for name, algorithm in algorithms.items():
print(f"Building {name} network...")
networks[name] = algorithm.build_network(alignment)
n_nodes = networks[name].number_of_nodes()
n_edges = networks[name].number_of_edges()
print(f" {name}: {n_nodes} nodes, {n_edges} edges")
Compare Network Properties¶
import pandas as pd
# Collect statistics
results = []
for name, network in networks.items():
stats = NetworkStatistics(network)
results.append({
"Algorithm": name,
"Nodes": stats.number_of_nodes(),
"Edges": stats.number_of_edges(),
"Diameter": stats.diameter(),
"Avg Path Length": stats.average_path_length(),
"Clustering": stats.clustering_coefficient(),
})
# Create comparison table
df = pd.DataFrame(results)
print("\nNetwork Comparison:")
print(df.to_string(index=False))
# Save table
df.to_csv("algorithm_comparison.csv", index=False)
Visualize Side-by-Side¶
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()
for idx, (name, network) in enumerate(networks.items()):
plot = StaticPlot(network, ax=axes[idx], layout="spring")
axes[idx].set_title(f"{name} Network", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig("algorithm_comparison.png", dpi=300)
print("\nComparison figure saved!")
Analyze Differences¶
# Compare node sets
mst_nodes = set(networks["MST"].nodes())
mjn_nodes = set(networks["MJN"].nodes())
# Median vectors (inferred nodes) in MJN
inferred = mjn_nodes - mst_nodes
print(f"\nMJN inferred {len(inferred)} median vectors")
# Compare connectivity
for name, network in networks.items():
density = NetworkStatistics(network).density()
print(f"{name} density: {density:.3f}")
Decision Guide¶
def recommend_algorithm(alignment):
"""Suggest best algorithm based on data characteristics."""
n_seqs = len(alignment)
diversity = alignment.pairwise_diversity()
if n_seqs < 20:
return "MST - Small dataset, start simple"
elif diversity < 0.01:
return "TCS - Low diversity, within-species"
elif n_seqs < 100:
return "MJN - Medium dataset, comprehensive analysis"
else:
return "MSN - Large dataset, balance speed and information"
recommendation = recommend_algorithm(alignment)
print(f"\nRecommendation: {recommendation}")
Computational Performance¶
import time
times = {}
for name, algorithm in algorithms.items():
start = time.time()
algorithm.build_network(alignment)
elapsed = time.time() - start
times[name] = elapsed
print(f"{name}: {elapsed:.3f} seconds")
# Plot timing
plt.figure(figsize=(8, 6))
plt.bar(times.keys(), times.values())
plt.xlabel("Algorithm")
plt.ylabel("Time (seconds)")
plt.title("Computational Performance")
plt.savefig("algorithm_timing.png", dpi=300)
Complete Comparison Script¶
from pypopart import Alignment
from pypopart.algorithms import MSTAlgorithm, MSNAlgorithm, TCSAlgorithm, MJNAlgorithm
from pypopart.visualization import StaticPlot
from pypopart.stats import NetworkStatistics
import matplotlib.pyplot as plt
import pandas as pd
import time
# Load data
alignment = Alignment.from_fasta("sequences.fasta")
# Define algorithms
algorithms = {
"MST": MSTAlgorithm(),
"MSN": MSNAlgorithm(),
"TCS": TCSAlgorithm(),
"MJN": MJNAlgorithm(),
}
# Build and time networks
networks = {}
times = {}
results = []
for name, algorithm in algorithms.items():
start = time.time()
networks[name] = algorithm.build_network(alignment)
times[name] = time.time() - start
stats = NetworkStatistics(networks[name])
results.append({
"Algorithm": name,
"Nodes": stats.number_of_nodes(),
"Edges": stats.number_of_edges(),
"Time (s)": f"{times[name]:.3f}",
})
# Print comparison
df = pd.DataFrame(results)
print(df.to_string(index=False))
# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()
for idx, (name, network) in enumerate(networks.items()):
StaticPlot(network, ax=axes[idx], layout="spring")
axes[idx].set_title(f"{name} - {times[name]:.2f}s")
plt.tight_layout()
plt.savefig("algorithm_comparison.png", dpi=300)
print("Comparison complete!")
When to Use Each Algorithm¶
MST¶
- ✅ Quick exploration
- ✅ Small datasets
- ✅ Simple relationships
- ❌ Ignores alternative paths
MSN¶
- ✅ Alternative connections
- ✅ Medium datasets
- ✅ Ambiguous relationships
- ❌ Less complete than MJN
TCS¶
- ✅ Within-species data
- ✅ Statistical justification
- ✅ Population studies
- ❌ May be disconnected
MJN¶
- ✅ Comprehensive analysis
- ✅ Ancestral inference
- ✅ Publication figures
- ❌ Slower computation
- ❌ Complex interpretation
Next Steps¶
- Basic Workflow: Complete analysis tutorial
- Visualization Tutorial: Advanced plotting
- Algorithm Guide: Detailed documentation