Source code for cytomulate.cell_graph_general

# Math computation
import numpy as np

# Graph package and functions
import networkx as nx
import matplotlib.pyplot as plt
from cytomulate.utilities import trajectories

# Typing
from typing import Union, Tuple

[docs]class GeneralCellGraph: def __init__(self) -> None: """Initialize the GeneralCellGraph object """ self.graph = None self.n_markers = -1 # trajectories would be a dictionary whose keys are edges of a # directed graph and whose values would be a list of functions # that describes the actual differentiation path self.trajectories = {}
[docs] def generate_trajectories(self, cell_types: dict, **kwargs) -> None: """Generate the actual differential paths Parameters ---------- cell_types: dict A dictionary of CellType objects kwargs: Extra parameters needed for non-default path generation algorithms, which are passed to ``cytomulate.utilities.trajectories``. """ edges = self.graph.edges for e in edges: from_label = e[0] to_label = e[1] end_values = cell_types[to_label].cell_mean - cell_types[from_label].cell_mean if self.n_markers <= 0: self.n_markers = len(end_values) self.trajectories[e] = trajectories(end_values=end_values, **kwargs)
[docs] def sample_graph(self, n_samples: int, cell_label: Union[str, int], beta_alpha: Union[float, int] = 0.4, beta_beta: Union[float, int] = 1.0) -> Tuple[np.ndarray, np.ndarray, list]: """Draw random samples of a cell type from the cell differentiation graph Parameters ---------- n_samples: int Number of samples cell_label: str or int The label of the cell needed beta_alpha: float The alpha parameter of the beta distribution beta_beta: float The beta parameter of the beta distribution Returns ------- G: np.ndarray The additive values of the path pseudo_time: np.ndarray The pseudo times labels: list The cell types to which the cell is differentiating """ if len(self.trajectories) < 1: return 0, 0, ["None"] * n_samples children_cell_labels = list(self.graph.successors(cell_label)) n_children = len(children_cell_labels) labels = ["None"] * n_samples G = np.zeros((n_samples, self.n_markers)) pseudo_time = np.zeros((n_samples, self.n_markers)) if n_children >= 1: n_per_child = np.random.multinomial(n_samples, np.ones(n_children)/n_children) labels = [item for item, count in zip(children_cell_labels, n_per_child) for i in range(count)] start_n = 0 end_n = 0 counter = 0 for c_label in children_cell_labels: n = n_per_child[counter] counter += 1 if n == 0: continue end_n += n for m in range(self.n_markers): p_time = np.random.beta(beta_alpha, beta_beta, n) G[start_n: end_n, m] = self.trajectories[(cell_label, c_label)][m](p_time) pseudo_time[start_n: end_n, m] = p_time start_n += n return G, pseudo_time, labels
[docs] def visualize_graph(self) -> None: """Visualize the cell graph """ connected_components = list(nx.connected_components(self.graph.to_undirected())) n_plt = 0 figs = [] for nodes in connected_components: G = self.graph.subgraph(list(nodes)) pos = nx.planar_layout(G, scale=20) d = dict( nodelist = list(d.keys()) nodesize = [] max_degree = np.max(list(d.values())) for k in d: predecessors = list(G.predecessors(k)) if len(predecessors) == 0: nodesize.append(max_degree * 150) else: nodesize.append(d[k] * 100) colors = [] for node in G.nodes(): successors = list(G.successors(node)) predecessors = list(G.predecessors(node)) if len(successors) == 0: colors.append("springgreen") elif len(predecessors) == 0: colors.append("royalblue") else: colors.append("magenta") figs.append(plt.figure(n_plt)) ax = figs[n_plt].add_subplot(1,1,1) ax.set_title("Tree " + str(n_plt)) nx.draw(G, pos, with_labels=True, nodelist=nodelist, node_size=nodesize, node_color=colors) n_plt += 1