Source code for cytomulate.emulation.cell_graph

# Math computation
import numpy as np
import itertools

# Graph package and functions
import networkx as nx
from networkx.algorithms import tree
from networkx.algorithms.community import greedy_modularity_communities

# Superclass
from cytomulate.cell_graph_general import GeneralCellGraph

# Typing
from typing import Union, Optional


[docs] class EmulationCellGraph(GeneralCellGraph): def __init__(self) -> None: """Initialize the EmulationCellGraph object """ super().__init__() self.complete_undirected_graph = None self.bead_label = None
[docs] def initialize_graph(self, cell_types: dict, bead_label: Optional[Union[str, int]] = None) -> None: """Build a complete graph among all cell types Parameters ---------- cell_types: dict A dictionary of CellType objects bead_label: str or int The label for beads """ self.complete_undirected_graph = nx.Graph() self.bead_label = bead_label for label in cell_types: self.complete_undirected_graph.add_node(label) # The edges consist of n choose 2 elements for labels in itertools.combinations(cell_types.keys(), 2): if (labels[0] == bead_label) or (labels[1] == bead_label): # If any of the "cell types" is bead, we skip # since beads should not differentiate continue else: # The weights are the l2 distance between two means mean1 = cell_types[labels[0]].cell_mean mean2 = cell_types[labels[1]].cell_mean self.complete_undirected_graph.add_edge(labels[0], labels[1], weight=np.linalg.norm(mean1 - mean2))
[docs] def prune_graph(self, graph_topology: str = "tree") -> None: """Construct a tree or a forest from the complete graph Parameters ---------- graph_topology: str The type of the graph desired. Should be either tree or forest """ if graph_topology in ["tree", "forest"]: # We will first construct a minimum spanning tree mst = tree.minimum_spanning_edges(self.complete_undirected_graph, algorithm="kruskal", weight="weight", data=False) mst_G = nx.Graph() mst_G.add_edges_from(list(mst)) if graph_topology == "tree": # If a single tree is desired, we randomly select a node as the root # and construct a directed graph root = set(mst_G.nodes).pop() self.graph = nx.dfs_tree(mst_G, root) else: # If a forest is desired, we use community detection algorithm # to divide the trees forest = list(greedy_modularity_communities(mst_G)) tree_list = [] for t in range(len(forest)): nodes = list(forest[t]) root = np.random.choice(nodes) tree_list.append(nx.dfs_tree(mst_G.subgraph(nodes), root)) self.graph = nx.compose_all(tree_list) # Finally, if bead is present, we add bead to the resulting graph if self.bead_label is not None: self.graph.add_node(self.bead_label) else: raise ValueError('Unknown graph type')