diff --git a/celloracle/oracle_utility/development_analysis.py b/celloracle/oracle_utility/development_analysis.py index 9cff86a..aa96706 100644 --- a/celloracle/oracle_utility/development_analysis.py +++ b/celloracle/oracle_utility/development_analysis.py @@ -45,12 +45,20 @@ def subset_oracle_for_development_analysiis(oracle_object, cluster_column_name, def get_stat_for_inner_product(oracle_object, n_bins=10): # Prepare data - inner_product_stats = pd.DataFrame({"score": oracle_object.inner_product, - "pseudotime": oracle_object.new_pseudotime}) + inner_product_stats = pd.DataFrame({"score": oracle_object.inner_product[~oracle_object.mass_filter], + "pseudotime": oracle_object.new_pseudotime[~oracle_object.mass_filter]}) + bins = _get_bins(inner_product_stats.pseudotime, n_bins) inner_product_stats["pseudotime_id"] = np.digitize(inner_product_stats.pseudotime, bins) - 1 + try: + inner_product_stats["stage"] = oracle_object.stage_grid[~oracle_object.mass_filter] + + except: + print("stage_grid not calculated") + + # stat test ps = [] for i in np.sort(inner_product_stats["pseudotime_id"].unique()): @@ -88,11 +96,17 @@ def extract_data_from_oracle(self, oracle_object, min_mass): setattr(self.oracle_dev, i, getattr(oracle_object, i)) # 1.2. mass_filter for grid matrix - self.oracle_dev.mass_filter = mass_filter = oracle_object.total_p_mass < min_mass + self.oracle_dev.mass_filter = (oracle_object.total_p_mass < min_mass) ## 2. Extract pseudotime data self.oracle_dev.pseudotime = oracle_object.adata.obs["pseudotime"].values + try: + self.oracle_dev.stage = np.array(oracle_object.adata.obs["Stage"].values) + except: + print("Stage not in data") + + def transfer_data_into_grid(self, args={}): @@ -105,15 +119,29 @@ def transfer_data_into_grid(self, args={}): grid=self.oracle_dev.flow_grid, value=self.oracle_dev.pseudotime, **args) + try: + self.oracle_dev.stage_grid = scatter_value_to_grid_value(embedding=self.oracle_dev.embedding, + grid=self.oracle_dev.flow_grid, + value=self.oracle_dev.stage, + **{"method": "knn_class", + "n_knn": 30}) + except: + print("Stage not in data") - def calculate_gradient_and_inner_product(self, scale_factor=60): + def calculate_gradient_and_inner_product(self, scale_factor="l2_norm_mean", normalization=None): # Gradient calculation - new_pseudotime = self.oracle_dev.new_pseudotime - n = int(np.sqrt(new_pseudotime.shape[0])) - new_pseudotime_as_grid = new_pseudotime.reshape(n, n) - dy, dx = np.gradient(new_pseudotime_as_grid) - self.oracle_dev.gradient = np.stack([dx.flatten(), dy.flatten()], axis=1) * scale_factor + gradient = get_gradient(value_on_grid=self.oracle_dev.new_pseudotime.copy()) + + if normalization == "sqrt": + gradient = normalize_gradient(gradient, method="sqrt") + + if scale_factor == "l2_norm_mean": + # divide gradient by the mean of l2 norm. + l2_norm = np.linalg.norm(gradient, ord=2, axis=1) + scale_factor = 1 / l2_norm.mean() + + self.oracle_dev.gradient = gradient * scale_factor # Calculate inner product between the pseudotime-gradient and the perturb-gradient self.oracle_dev.inner_product = np.array([np.dot(i, j) for i, j in zip(self.oracle_dev.flow, self.oracle_dev.gradient)]) @@ -125,3 +153,282 @@ def calculate_stats(self, n_bins=10): self.oracle_dev.inner_product_stats = inner_product_stats self.oracle_dev.inner_product_stats_grouped = inner_product_stats_grouped + + + +class Gradient_based_trajecory(): + def __init__(self, adata=None, obsm_key=None, pseudotime_key="pseudotime", cluster_column_name=None, cluster=None, gt=None): + + if adata is not None: + self.load_adata(adata=adata, obsm_key=obsm_key, + pseudotime_key=pseudotime_key,cluster_column_name=cluster_column_name, + cluster=cluster) + elif gt is not None: + self.embedding = gt.embedding_whole.copy() + self.embedding_whole = gt.embedding_whole.copy() + self.mass_filter = gt.mass_filter_whole.copy() + self.mass_filter_whole = gt.mass_filter_whole.copy() + self.gridpoints_coordinates = gt.gridpoints_coordinates.copy() + self.pseudotime = gt.pseudotime_whole.copy() + + def load_adata(self, adata, obsm_key, pseudotime_key, cluster_column_name=None, cluster=None): + + self.embedding = adata.obsm[obsm_key] + self.pseudotime = adata.obs[pseudotime_key].values + self.embedding_whole = self.embedding.copy() + self.pseudotime_whole = self.pseudotime.copy() + + if (cluster_column_name is not None) & (cluster is not None): + cells_ix = np.where(adata.obs[cluster_column_name] == cluster)[0] + self.embedding = self.embedding[cells_ix, :] + self.pseudotime = self.pseudotime[cells_ix] + + + + + + def calculate_mass_filter(self, min_mass=0.01, smooth=0.8, steps=(40, 40), n_neighbors=200, n_jobs=4): + + x_min, y_min = self.embedding_whole.min(axis=0) + x_max, y_max = self.embedding_whole.max(axis=0) + xylim = ((x_min, x_max), (y_min, y_max)) + + total_p_mass, gridpoints_coordinates = calculate_p_mass(self.embedding, smooth=smooth, steps=steps, + n_neighbors=n_neighbors, n_jobs=n_jobs, xylim=xylim) + + total_p_mass_whole, _ = calculate_p_mass(self.embedding_whole, smooth=smooth, steps=steps, + n_neighbors=n_neighbors, n_jobs=n_jobs, xylim=xylim) + + self.total_p_mass = total_p_mass + self.mass_filter = (total_p_mass < min_mass) + self.mass_filter_whole = (total_p_mass_whole < min_mass) + self.gridpoints_coordinates = gridpoints_coordinates + + def transfer_data_into_grid(self, args={}): + + if not args: + args = {"method": "knn", + "n_knn": 30} + + self.pseudotime_on_grid = scatter_value_to_grid_value(embedding=self.embedding, + grid=self.gridpoints_coordinates, + value=self.pseudotime, + **args) + def calculate_gradient(self, scale_factor=60, normalization=None): + + # Gradient calculation + gradient = get_gradient(value_on_grid=self.pseudotime_on_grid.copy()) + + if normalization == "sqrt": + gradient = normalize_gradient(gradient, method="sqrt") + + if scale_factor == "l2_norm_mean": + # divide gradient by the mean of l2 norm. + l2_norm = np.linalg.norm(gradient, ord=2, axis=1) + scale_factor = 1 / l2_norm.mean() + + self.gradient = gradient * scale_factor + + def visualize_dev_flow(self, scale_for_pseudotime=30, s=10, s_grid=30): + visualize_dev_flow(self, scale_for_pseudotime=scale_for_pseudotime, s=s, s_grid=s_grid) + +def aggregate_GT_object(list_GT_object, base_gt=None): + + pseudotime_stack = [i.pseudotime_on_grid for i in list_GT_object] + gradient_stack = [i.gradient for i in list_GT_object] + mass_filter_stack = [i.mass_filter for i in list_GT_object] + + new_pseudotime, new_gradient, new_mass_filter = _aggregate_gradients(pseudotime_stack=pseudotime_stack, + gradient_stack=gradient_stack, + mass_filter_stack=mass_filter_stack) + + if base_gt is None: + gt = Gradient_based_trajecory(gt=list_GT_object[0]) + gt.pseudotime_on_grid = new_pseudotime + gt.gradient = new_gradient + + else: + gt = base_gt + gt.pseudotime_on_grid[~new_mass_filter] = new_pseudotime[~new_mass_filter] + gt.gradient[~new_mass_filter, :] = new_gradient[~new_mass_filter, :] + + return gt + +def _aggregate_gradients(pseudotime_stack, gradient_stack, mass_filter_stack): + + new_pseudotime = np.zeros_like(pseudotime_stack[0]) + new_pseudotime_count = np.zeros_like(pseudotime_stack[0]) + new_gradient = np.zeros_like(gradient_stack[0]) + gradient_count = np.zeros_like(gradient_stack[0]) + for fil, pt, gra in zip(mass_filter_stack, pseudotime_stack, gradient_stack): + new_pseudotime[~fil] += pt[~fil] + new_pseudotime_count[~fil] +=1 + new_gradient[~fil, :] += gra[~fil, :] + gradient_count[~fil, :] += 1 + + new_pseudotime[new_pseudotime_count != 0] /= new_pseudotime_count[new_pseudotime_count != 0] + new_gradient[gradient_count != 0] /= gradient_count[gradient_count != 0] + new_mass_filter = (gradient_count.sum(axis=1) == 0) + + return new_pseudotime, new_gradient, new_mass_filter + + +def normalize_gradient(gradient, method="sqrt"): + """ + Normalize length of 2D vector + """ + + if method == "sqrt": + + size = np.sqrt(np.power(gradient, 2).sum(axis=1)) + size_sq = np.sqrt(size) + size_sq[size_sq == 0] = 1 + factor = np.repeat(np.expand_dims(size_sq, axis=1), 2, axis=1) + + return gradient / factor + +from scipy.stats import norm as normal +from sklearn.neighbors import NearestNeighbors + + +def calculate_p_mass(embedding, smooth=0.5, steps=(40, 40), + n_neighbors=100, n_jobs=4, xylim=((None, None), (None, None))): + """Calculate the velocity using a points on a regular grid and a gaussian kernel + + Note: the function should work also for n-dimensional grid + + Arguments + --------- + embedding: + + smooth: float, smooth=0.5 + Higher value correspond to taking in consideration further points + the standard deviation of the gaussian kernel is smooth * stepsize + steps: tuple, default + the number of steps in the grid for each axis + n_neighbors: + number of neighbors to use in the calculation, bigger number should not change too much the results.. + ...as soon as smooth is small + Higher value correspond to slower execution time + n_jobs: + number of processes for parallel computing + xymin: + ((xmin, xmax), (ymin, ymax)) + + Returns + ------- + total_p_mass: np.ndarray + density at each point of the grid + + """ + + # Prepare the grid + grs = [] + for dim_i in range(embedding.shape[1]): + m, M = np.min(embedding[:, dim_i]), np.max(embedding[:, dim_i]) + + if xylim[dim_i][0] is not None: + m = xylim[dim_i][0] + if xylim[dim_i][1] is not None: + M = xylim[dim_i][1] + + m = m - 0.025 * np.abs(M - m) + M = M + 0.025 * np.abs(M - m) + gr = np.linspace(m, M, steps[dim_i]) + grs.append(gr) + + meshes_tuple = np.meshgrid(*grs) + gridpoints_coordinates = np.vstack([i.flat for i in meshes_tuple]).T + + nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs) + nn.fit(embedding) + dists, neighs = nn.kneighbors(gridpoints_coordinates) + + std = np.mean([(g[1] - g[0]) for g in grs]) + # isotropic gaussian kernel + gaussian_w = normal.pdf(loc=0, scale=smooth * std, x=dists) + total_p_mass = gaussian_w.sum(1) + gridpoints_coordinates + + return total_p_mass, gridpoints_coordinates + +def get_gradient(value_on_grid): + # Gradient calculation + n = int(np.sqrt(value_on_grid.shape[0])) + value_on_grid_as_matrix = value_on_grid.reshape(n, n) + dy, dx = np.gradient(value_on_grid_as_matrix) + gradient = np.stack([dx.flatten(), dy.flatten()], axis=1) + + return gradient + + +def visualize_dev_flow(self, scale_for_pseudotime=30, s=10, s_grid=30): + + embedding_whole = self.embedding_whole + embedding_of_interest= self.embedding + mass_filter = self.mass_filter + mass_filter_whole = self.mass_filter_whole + gridpoints_coordinates=self.gridpoints_coordinates + + pseudotime_raw = self.pseudotime + pseudotime_on_grid=self.pseudotime_on_grid + + gradient_pseudotime=self.gradient + + + fig, ax = plt.subplots(1, 5, figsize=[25,5]) + + ## + ax_ = ax[0] + ax_.scatter(embedding_whole[:, 0], embedding_whole[:, 1], c="lightgray", s=s) + ax_.scatter(embedding_of_interest[:, 0], embedding_of_interest[:, 1], c=pseudotime_raw, cmap="rainbow", s=s) + ax_.set_title("Pseudotime") + ax_.axis("off") + + #### + ax_ = ax[1] + ax_.scatter(gridpoints_coordinates[mass_filter, 0], gridpoints_coordinates[mass_filter, 1], s=0) + ax_.scatter(gridpoints_coordinates[~mass_filter_whole, 0], gridpoints_coordinates[~mass_filter_whole, 1], + c="lightgray", s=s_grid) + ax_.scatter(gridpoints_coordinates[~mass_filter, 0], gridpoints_coordinates[~mass_filter, 1], + c=pseudotime_on_grid[~mass_filter], cmap="rainbow", s=s_grid) + ax_.set_title("Pseudotime on grid") + ax_.axis("off") + + + + ### + ax_ = ax[2] + #ax_.scatter(gridpoints_coordinates[mass_filter, 0], gridpoints_coordinates[mass_filter, 1], s=0) + ax_.scatter(gridpoints_coordinates[mass_filter, 0], gridpoints_coordinates[mass_filter, 1], s=0) + ax_.scatter(gridpoints_coordinates[~mass_filter_whole, 0], gridpoints_coordinates[~mass_filter_whole, 1], + c="lightgray", s=s_grid) + ax_.scatter(gridpoints_coordinates[~mass_filter, 0], gridpoints_coordinates[~mass_filter, 1], + c=pseudotime_on_grid[~mass_filter], cmap="rainbow", s=s_grid) + + ax_.quiver(gridpoints_coordinates[~mass_filter, 0], gridpoints_coordinates[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + ### + ax_ = ax[3] + #ax_.scatter(gridpoints_coordinates[mass_filter, 0], gridpoints_coordinates[mass_filter, 1], s=0) + ax_.scatter(embedding_whole[:, 0], embedding_whole[:, 1], c="lightgray", s=s) + ax_.quiver(gridpoints_coordinates[~mass_filter, 0], gridpoints_coordinates[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + #### + ax_ = ax[4] + ax_.scatter(embedding_whole[:, 0], embedding_whole[:, 1], c="lightgray", s=s) + ax_.scatter(embedding_of_interest[:, 0], embedding_of_interest[:, 1], c=pseudotime_raw, cmap="rainbow", s=s) + + ax_.quiver(gridpoints_coordinates[~mass_filter, 0], gridpoints_coordinates[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Pseudotime + \nDevelopment flow") + ax_.axis("off") diff --git a/celloracle/oracle_utility/interactive_simulation_and_plot.py b/celloracle/oracle_utility/interactive_simulation_and_plot.py index 984308d..3ad4883 100644 --- a/celloracle/oracle_utility/interactive_simulation_and_plot.py +++ b/celloracle/oracle_utility/interactive_simulation_and_plot.py @@ -42,12 +42,13 @@ class Oracle_extended(Oracle_data_strage, Oracle_development_module): - def __init__(self, oracle, hdf_path, mode): + def __init__(self, oracle, hdf_path, mode, obsm_key="X_umap"): self.oracle = oracle self.gene = None self.n_neighbors = None self.n_grid = None self.names = [] + self.obsm_key = obsm_key self.set_hdf_path(path=hdf_path, create_if_not_exist=True) @@ -144,10 +145,10 @@ def interactive_plot_grid(self, plt.figure(None,(6,6)) if self.n_grid is None: - self.oracle.calculate_grid_arrows(smooth=0.8, steps=(n_grid, n_grid), n_neighbors=300) + self.oracle.calculate_grid_arrows(smooth=0.8, steps=(n_grid, n_grid), n_neighbors=200) else: if n_grid != self.n_grid: - self.oracle.calculate_grid_arrows(smooth=0.8, steps=(n_grid, n_grid), n_neighbors=300) + self.oracle.calculate_grid_arrows(smooth=0.8, steps=(n_grid, n_grid), n_neighbors=200) plt.title(f"Perturb simulation: {self.gene}") @@ -211,7 +212,7 @@ def save_development_analysis_results(self, gene, cluster_column_name, cluster): place=f"dev_analysis/{cluster_column_name}/{cluster}/{gene}", attributes=["embedding", "pseudotime", "mass_filter", "flow_grid", "flow", "flow_norm_rndm", - "new_pseudotime", "gradient", "inner_product"]) + "new_pseudotime", "gradient", "inner_product", "stage", "stage_grid"]) self.save_dfs(oracle=self.oracle_dev, place=f"inner_product/{cluster_column_name}/{cluster}/{gene}", @@ -231,7 +232,7 @@ def load_development_analysis_results(self, gene, cluster_column_name, cluster): attributes=["embedding", "pseudotime", "mass_filter", "flow_grid", "flow", "flow_norm_rndm", "new_pseudotime", "gradient", - "inner_product"]) + "inner_product", "stage", "stage_grid"]) self.load_dfs(oracle=self.oracle_dev, place=f"inner_product/{cluster_column_name}/{cluster}/{gene}", diff --git a/celloracle/oracle_utility/make_figure.py b/celloracle/oracle_utility/make_figure.py new file mode 100644 index 0000000..72b957b --- /dev/null +++ b/celloracle/oracle_utility/make_figure.py @@ -0,0 +1,285 @@ +# -*- coding: utf-8 -*- + + +import os, sys + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from ..trajectory.oracle_utility import _adata_to_color_dict + + +def _get_ix_for_a_cluster(oracle, cluster_column_name, cluster): + ix = np.arange(oracle.adata.shape[0])[oracle.adata.obs[cluster_column_name] == cluster] + return ix + +def _plot_quiver_for_a_cluster(oracle, cluster_column_name, cluster, quiver_scale, color=None, plot_whole_cells=True, args={}): + + if cluster == "whole": + ix_choice = ix = np.arange(oracle.adata.shape[0]) + else: + ix_choice = _get_ix_for_a_cluster(oracle, cluster_column_name, cluster) + + + if plot_whole_cells: + + plt.scatter(oracle.embedding[:, 0], oracle.embedding[:, 1], + c="lightgray", alpha=1, lw=0.3, rasterized=True, **args) + + plt.scatter(oracle.embedding[ix_choice, 0], oracle.embedding[ix_choice, 1], + c="lightgray", alpha=0.2, edgecolor=(0,0,0,1), lw=0.3, rasterized=True, **args) + + + + if color is None: + color=oracle.colorandum[ix_choice] + + quiver_kwargs=dict(headaxislength=7, headlength=11, headwidth=8, + linewidths=0.25, width=0.0045,edgecolors="k", + color=color, alpha=1) + + plt.quiver(oracle.embedding[ix_choice, 0], oracle.embedding[ix_choice, 1], + oracle.delta_embedding[ix_choice, 0], + oracle.delta_embedding[ix_choice, 1], + scale=quiver_scale, **quiver_kwargs) + + plt.axis("off") + +def plot_scatter_with_anndata(adata, obsm_key, cluster_column_name, args={}): + + embedding = adata.obsm[obsm_key] + colors = _adata_to_color_dict(adata=adata, cluster_use=cluster_column_name) + + for cluster, color in colors.items(): + idx = np.where(adata.obs[cluster_column_name] == cluster)[0] + plt.scatter(embedding[idx, 0], embedding[idx, 1], c=color, label=cluster, **args) + + + + + +def figures_for_trajectories301(self, save_folder, scale_for_pseudotime=30, scale_for_simulated=30, quiver_scale=30, s=10, s_grid=30, vmin=-1, vmax=1, figsize=[5, 5], fontsize=15): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + inner_product_stats = self.oracle_dev.inner_product_stats + inner_product_stats_grouped = self.oracle_dev.inner_product_stats_grouped + + alpha = 1 + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + + + fig = plt.figure(figsize=figsize) + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key="X_umap", + cluster_column_name=cluster_column_name, + args={"s": s}) + + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s) + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"scatter_{cluster_column_name}_{cluster}.png"), transparent=True) + + + ## + fig = plt.figure(figsize=figsize) + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + #plt.title("Pseudotime") + plt.axis("off") + + plt.savefig(os.path.join(save_folder, f"pseudotime_{cluster_column_name}_{cluster}.png"), transparent=True) + + + #### + ### + fig = plt.figure(figsize=figsize) + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + #plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + plt.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + #plt.title("Differentiation") + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"differentiation_{cluster_column_name}_{cluster}.png"), transparent=True) + + + + fig = plt.figure(figsize=figsize) + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key="X_umap", + cluster_column_name=cluster_column_name, + args={"s": s}) + else: + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + #plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + plt.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + #plt.title("Gradient of pseudotime \n(=Development flow)") + plt.axis("off") + + plt.savefig(os.path.join(save_folder, f"differentiation_with_cells_of_interest_{cluster_column_name}_{cluster}.png"), transparent=True) + + + +def figures_for_perturb_analysis_301(self, save_folder, scale_for_pseudotime=30, scale_for_simulated=30, quiver_scale=30, s=10, s_grid=30, vmin=-1, vmax=1, figsize=[5, 5], fontsize=15): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + inner_product_stats = self.oracle_dev.inner_product_stats + inner_product_stats_grouped = self.oracle_dev.inner_product_stats_grouped + + alpha = 1 + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + if cluster == "whole": + pass + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + + + ##### + fig = plt.figure(figsize=figsize) + #plt.title(f"Perturb simulation \n color: {cluster_column_name}") + if cluster == "whole": + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=None, + cluster=cluster, quiver_scale=30, args={"s": s}) + else: + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=cluster_color, + cluster=cluster, quiver_scale=30, args={"s": s}) + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"quiver_full_on_cellls_of_interest_{cluster_column_name}_{cluster}.png"), transparent=True) + + + ####### + fig = plt.figure(figsize=figsize) + #plt.title("Perturb simulation \n color: cluster") + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + cluster=cluster, quiver_scale=30, args={"s": s}) + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"quiver_ful_on_cluster_{cluster_column_name}_{cluster}.png"), transparent=True) + + + + ######## + fig = plt.figure(figsize=figsize) + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + plt.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + #plt.title("Perturb simulation result on grid") + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"quiver_grid_{cluster_column_name}_{cluster}.png"), transparent=True) + + + ########## + fig = plt.figure(figsize=figsize) + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key="X_umap", + cluster_column_name=cluster_column_name, + args={"s": s}) + else: + plt.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + plt.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + plt.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + #plt.title("Perturb simulation result on grid") + plt.axis("off") + plt.savefig(os.path.join(save_folder, f"quiver_grid_on_cells_of_interest_{cluster_column_name}_{cluster}.png"), transparent=True) + + + + ######### + + fig = plt.figure(figsize=figsize) + plt.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + plt.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + plt.axis("off") + #plt.title("Inner product of \n Perturb simulation * Development flow") + plt.savefig(os.path.join(save_folder, f"inner_product_score_{cluster_column_name}_{cluster}.png"), transparent=True) + + + + fig = plt.figure(figsize=figsize) + plt.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + plt.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + plt.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + plt.axis("off") + #plt.title("Inner product of \n Perturb simulation * Development flow \n + Perturb simulation") + plt.savefig(os.path.join(save_folder, f"inner_product_score_and_perturb_quiver_grid_{cluster_column_name}_{cluster}.png"), transparent=True) + + + ##### + #fig = plt.figure(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) + pcm = ax.scatter(value_on_grid[~mass_filter], inner_product[~mass_filter], + c=inner_product[~mass_filter], cmap="coolwarm", + vmin=vmin, vmax=vmax, s=s_grid) + + plt.ylim([vmin*1.1, vmax*1.1]) + plt.axhline(0, color="lightgray") + pp = fig.colorbar(pcm, ax=ax, orientation="vertical") + sns.despine() + plt.xlabel("pseudotime") + plt.ylabel("inner product score") + plt.savefig(os.path.join(save_folder, f"inner_product_score_distribution_{cluster_column_name}_{cluster}.png"), transparent=True) + + + fig = plt.figure(figsize=figsize) + #fig, ax = plt.subplots(figsize=figsize) + sns.boxplot(data=inner_product_stats, x="pseudotime_id", y="score", color="white") + plt.xlabel("Digitized_pseudotime") + plt.ylabel("inner product score") + plt.axhline(0, color="gray") + plt.ylim([vmin*1.1, vmax*1.1]) + plt.tick_params( + labelleft=False) + plt.show() + plt.savefig(os.path.join(save_folder, f"inner_product_score_distribution_box_plot_{cluster_column_name}_{cluster}.png"), transparent=True) diff --git a/celloracle/oracle_utility/scatter_to_grid.py b/celloracle/oracle_utility/scatter_to_grid.py index 5e53714..52e18e8 100644 --- a/celloracle/oracle_utility/scatter_to_grid.py +++ b/celloracle/oracle_utility/scatter_to_grid.py @@ -10,7 +10,8 @@ import numpy as np from sklearn.linear_model import Ridge -from sklearn.neighbors import KNeighborsRegressor +from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier +from sklearn.preprocessing import PolynomialFeatures def scatter_value_to_grid_value(embedding, grid, value, method="knn", n_knn=30, n_poly=3): @@ -33,15 +34,35 @@ def scatter_value_to_grid_value(embedding, grid, value, method="knn", n_knn=30, x_new, y_new = grid[:, 0], grid[:, 1] if method == "poly": - value_on_grid = _polynomial_regression(x, y, x_new, y_new, value, n_degree=n_poly) + value_on_grid = _polynomial_regression_old_ver(x, y, x_new, y_new, value, n_degree=n_poly) + if method == "polynomial": + value_on_grid = _polynomial_regression_sklearn(x, y, x_new, y_new, value, n_degree=n_poly) elif method == "knn": value_on_grid = _knn_regression(x, y, x_new, y_new, value, n_knn=n_knn) + elif method == "knn_class": + value_on_grid = _knn_classification(x, y, x_new, y_new, value, n_knn=n_knn) + return value_on_grid +def _polynomial_regression_sklearn(x, y, x_new, y_new, value, n_degree=3): + + # Make polynomial features + data = np.stack([x, y], axis=1) + data_new = np.stack([x_new, y_new], axis=1) + + pol = PolynomialFeatures(degree=n_degree, include_bias=False) + data = pol.fit_transform(data) + data_new = pol.transform(data_new) + + + model = Ridge(random_state=123) + model.fit(data, value) + + return model.predict(data_new) -def _polynomial_regression(x, y, x_new, y_new, value, n_degree=3): +def _polynomial_regression_old_ver(x, y, x_new, y_new, value, n_degree=3): def __conv(x, y, n_degree=3): # Make polynomial data for polynomial ridge regression dic = {} @@ -69,3 +90,14 @@ def _knn_regression(x, y, x_new, y_new, value, n_knn=30): data_new = np.stack([x_new, y_new], axis=1) return model.predict(data_new) + +def _knn_classification(x, y, x_new, y_new, value, n_knn=30): + + data = np.stack([x, y], axis=1) + + model = KNeighborsClassifier(n_neighbors=n_knn) + model.fit(data, value) + + data_new = np.stack([x_new, y_new], axis=1) + + return model.predict(data_new) diff --git a/celloracle/oracle_utility/utility.py b/celloracle/oracle_utility/utility.py index 98e5743..eb95f29 100644 --- a/celloracle/oracle_utility/utility.py +++ b/celloracle/oracle_utility/utility.py @@ -63,7 +63,11 @@ def save_data(self, oracle, place, attributes): name_ = f"{place}/{j}" if name_ in self.names: del f[name_] - f[name_] = getattr(oracle, j) + att = getattr(oracle, j) + try: + f[name_] = att + except: + f[name_] = att.astype(h5py.string_dtype(encoding='utf-8')) self.names.append(name_) self.names = list(set(self.names)) diff --git a/celloracle/oracle_utility/visualization.py b/celloracle/oracle_utility/visualization.py index 18b89a5..e6dd06e 100644 --- a/celloracle/oracle_utility/visualization.py +++ b/celloracle/oracle_utility/visualization.py @@ -114,7 +114,7 @@ def visualize_developmental_analysis_ver2(self, scale_for_pseudotime=30, scale_f fig, ax = plt.subplots(2, 4, figsize=[20,10]) if cluster == "whole": - plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key="X_umap", + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, cluster_column_name=cluster_column_name, ax=ax[0, 0], args={"s": s}) ax[0, 0].set_title(f"Cluster of interest: all clusters") @@ -182,3 +182,742 @@ def visualize_developmental_analysis_ver2(self, scale_for_pseudotime=30, scale_f ax_.set_ylim([vmin*1.1, vmax*1.1]) ax_.tick_params( labelleft=False) + + + + +def visualize_developmental_analysis_ver101(self, scale_for_pseudotime=30, scale_for_simulated=30, s=10, s_grid=30, vmin=-1, vmax=1): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + + fig, ax = plt.subplots(2, 4, figsize=[20,10]) + + ax_ = ax[0, 0] + + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax[0, 0], args={"s": s}) + ax_.set_title(f"Cluster of interest: all clusters") + + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s) + if cluster == "True" : + ax_.set_title(f"Cells of interest: \n{cluster_column_name}") + + ax_.axis("off") + + ## + ax_ = ax[0, 1] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.set_title("Pseudotime") + ax_.axis("off") + + #### + ax_ = ax[0, 2] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=value_on_grid[~mass_filter], cmap="rainbow", s=s_grid) + ax_.set_title("Pseudotime on grid") + ax_.axis("off") + + + ### + ax_ = ax[0, 3] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + + #### + ax_ = ax[1, 0] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Pseudotime + \nDevelopment flow") + ax_.axis("off") + + #### + ax_ = ax[1, 1] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation") + ax_.axis("off") + + ax_ = ax[1, 2] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow \n + Perturb simulation") + + ##### + ax_ = ax[1, 3] + pcm = ax_.scatter(value_on_grid[~mass_filter], inner_product[~mass_filter], + c=inner_product[~mass_filter], cmap="coolwarm", + vmin=vmin, vmax=vmax, s=s_grid) + + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.axhline(0, color="lightgray") + pp = fig.colorbar(pcm, ax=ax_, orientation="vertical") + sns.despine() + ax_.set_xlabel("pseudotime") + ax_.set_ylabel("inner product score") + + + +def visualize_developmental_analysis_ver201(self, scale_for_pseudotime=30, scale_for_simulated=30, s=10, s_grid=30, vmin=-1, vmax=1): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + inner_product_stats = self.oracle_dev.inner_product_stats + inner_product_stats_grouped = self.oracle_dev.inner_product_stats_grouped + + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + + fig, ax = plt.subplots(2, 4, figsize=[20,10]) + + ax_ = ax[0, 0] + + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + ax_.set_title(f"Cluster of interest: all clusters") + + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s) + if cluster == True : + ax_.set_title(f"Cells of interest: \n{cluster_column_name}") + + ax_.axis("off") + + ## + ax_ = ax[0, 1] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.set_title("Pseudotime") + ax_.axis("off") + + #### + + ### + ax_ = ax[0, 2] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + ##### + + ax_ = ax[0, 3] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation") + ax_.axis("off") + + + #### + ax_ = ax[1, 0] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow") + + + ax_ = ax[1, 1] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow \n + Perturb simulation") + + + ##### + ax_ = ax[1, 2] + pcm = ax_.scatter(value_on_grid[~mass_filter], inner_product[~mass_filter], + c=inner_product[~mass_filter], cmap="coolwarm", + vmin=vmin, vmax=vmax, s=s_grid) + + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.axhline(0, color="lightgray") + pp = fig.colorbar(pcm, ax=ax_, orientation="vertical") + sns.despine() + ax_.set_xlabel("pseudotime") + ax_.set_ylabel("inner product score") + + ax_ = ax[1, 3] + sns.boxplot(data=inner_product_stats, x="pseudotime_id", y="score", color="white", ax=ax_) + ax_.set_xlabel("Digitized_pseudotime") + ax_.set_ylabel("inner product score") + ax_.axhline(0, color="gray") + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.tick_params( + labelleft=False) + + + +##### +def _get_ix_for_a_cluster(oracle, cluster_column_name, cluster): + ix = np.arange(oracle.adata.shape[0])[oracle.adata.obs[cluster_column_name] == cluster] + return ix + +def _plot_quiver_for_a_cluster(oracle, cluster_column_name, cluster, quiver_scale, ax, color=None, plot_whole_cells=True, args={}): + + if cluster == "whole": + ix_choice = ix = np.arange(oracle.adata.shape[0]) + else: + ix_choice = _get_ix_for_a_cluster(oracle, cluster_column_name, cluster) + + + if plot_whole_cells: + + ax.scatter(oracle.embedding[:, 0], oracle.embedding[:, 1], + c="lightgray", alpha=1, lw=0.3, rasterized=True, **args) + + ax.scatter(oracle.embedding[ix_choice, 0], oracle.embedding[ix_choice, 1], + c="lightgray", alpha=0.2, edgecolor=(0,0,0,1), lw=0.3, rasterized=True, **args) + + + + if color is None: + color=oracle.colorandum[ix_choice] + + quiver_kwargs=dict(headaxislength=7, headlength=11, headwidth=8, + linewidths=0.25, width=0.0045,edgecolors="k", + color=color, alpha=1) + + ax.quiver(oracle.embedding[ix_choice, 0], oracle.embedding[ix_choice, 1], + oracle.delta_embedding[ix_choice, 0], + oracle.delta_embedding[ix_choice, 1], + scale=quiver_scale, **quiver_kwargs) + + plt.axis("off") + + + + +def visualize_developmental_analysis_ver301(self, scale_for_pseudotime=30, scale_for_simulated=30, quiver_scale=30, s=10, s_grid=30, vmin=-1, vmax=1): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + inner_product_stats = self.oracle_dev.inner_product_stats + inner_product_stats_grouped = self.oracle_dev.inner_product_stats_grouped + + alpha = 1 + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + + ''' fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + ax_ = ax[0] + ax_.set_title(f"Clustering results") + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key,, + cluster_column_name=self.oracle.cluster_column_name, + ax=ax_, args={"s": s}) + ax_.axis("off")''' + + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + ax_ = ax[0] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + ax_.set_title(f"Cluster of interest: all clusters") + + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s) + if cluster == "True" : + ax_.set_title(f"Cells of interest: \n{cluster_column_name}") + ax_.axis("off") + + ## + ax_ = ax[1] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.set_title("Pseudotime") + ax_.axis("off") + + #### + ### + ax_ = ax[2] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + #ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + + ax_ = ax[3] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + else: + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + #ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + plt.show() + + + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + + ##### + ax_ = ax[0] + ax_.set_title(f"Perturb simulation \n color: {cluster_column_name}") + if cluster == "whole": + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=None, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + else: + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=cluster_color, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + ax_.axis("off") + + ax_ = ax[1] + ax_.set_title("Perturb simulation \n color: cluster") + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + ax_.axis("off") + + ax_ = ax[2] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation result on grid") + ax_.axis("off") + + + ax_ = ax[3] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + else: + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation result on grid") + ax_.axis("off") + + plt.show() + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + #### + ax_ = ax[0] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow") + + + ax_ = ax[1] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow \n + Perturb simulation") + + + ##### + ax_ = ax[2] + pcm = ax_.scatter(value_on_grid[~mass_filter], inner_product[~mass_filter], + c=inner_product[~mass_filter], cmap="coolwarm", + vmin=vmin, vmax=vmax, s=s_grid) + + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.axhline(0, color="lightgray") + pp = fig.colorbar(pcm, ax=ax_, orientation="vertical") + sns.despine() + ax_.set_xlabel("pseudotime") + ax_.set_ylabel("inner product score") + + ax_ = ax[3] + sns.boxplot(data=inner_product_stats, x="pseudotime_id", y="score", color="white", ax=ax_) + ax_.set_xlabel("Digitized_pseudotime") + ax_.set_ylabel("inner product score") + ax_.axhline(0, color="gray") + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.tick_params( + labelleft=False) + plt.show() + + + +def visualize_developmental_analysis_ver401(self, scale_for_pseudotime=30, scale_for_simulated=30, quiver_scale=30, s=10, s_grid=30, vmin=-1, vmax=1): + + whole_embedding = self.oracle.embedding + original_embedding=self.oracle_dev.embedding + original_value=self.oracle_dev.pseudotime + mass_filter=self.oracle_dev.mass_filter + grid=self.oracle_dev.flow_grid + value_on_grid=self.oracle_dev.new_pseudotime + gradient_pseudotime=self.oracle_dev.gradient + gradient_simulated=self.oracle_dev.flow + inner_product=self.oracle_dev.inner_product + + inner_product_stats = self.oracle_dev.inner_product_stats + inner_product_stats_grouped = self.oracle_dev.inner_product_stats_grouped + + alpha = 1 + + cluster = self.oracle_dev.cluster_loaded + #if cluster == "True": + # cluster = True + + cluster_column_name = self.oracle_dev.cluster_column_name_loaded + + + ''' fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + ax_ = ax[0] + ax_.set_title(f"Clustering results") + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key,, + cluster_column_name=self.oracle.cluster_column_name, + ax=ax_, args={"s": s}) + ax_.axis("off")''' + + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + ax_ = ax[0] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + ax_.set_title(f"Cluster of interest: all clusters") + + else: + cluster_color = _adata_to_color_dict(self.oracle.adata, cluster_column_name)[cluster] + + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s) + if cluster == "True" : + ax_.set_title(f"Cells of interest: \n{cluster_column_name}") + ax_.axis("off") + + ## + ax_ = ax[1] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.set_title("Pseudotime") + ax_.axis("off") + + #### + ### + ax_ = ax[2] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + #ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + + ax_ = ax[3] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + else: + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + #ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=original_value, cmap="rainbow", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_pseudotime[~mass_filter, 0], gradient_pseudotime[~mass_filter, 1], + scale=scale_for_pseudotime) + ax_.set_title("Gradient of pseudotime \n(=Development flow)") + ax_.axis("off") + + plt.show() + + + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + + ##### + ax_ = ax[0] + ax_.set_title(f"Perturb simulation \n color: {cluster_column_name}") + if cluster == "whole": + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=None, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + else: + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + color=cluster_color, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + ax_.axis("off") + + ax_ = ax[1] + ax_.set_title("Perturb simulation \n color: cluster") + _plot_quiver_for_a_cluster(oracle=self.oracle, + cluster_column_name=cluster_column_name, + cluster=cluster, quiver_scale=30, ax=ax_, args={"s": s}) + ax_.axis("off") + + ax_ = ax[2] + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation result on grid") + ax_.axis("off") + + + ax_ = ax[3] + if cluster == "whole": + plot_scatter_with_anndata(adata=self.oracle.adata, obsm_key=self.obsm_key, + cluster_column_name=cluster_column_name, + ax=ax_, args={"s": s}) + else: + ax_.scatter(whole_embedding[:, 0], whole_embedding[:, 1], c="lightgray", s=s) + ax_.scatter(original_embedding[:, 0], original_embedding[:, 1], c=cluster_color, s=s, alpha=alpha) + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.set_title("Perturb simulation result on grid") + ax_.axis("off") + + plt.show() + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + #### + ax_ = ax[0] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow") + + + ax_ = ax[1] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + ax_.scatter(grid[~mass_filter, 0], grid[~mass_filter, 1], c=inner_product[~mass_filter], + cmap="coolwarm", s=s_grid, vmin=vmin, vmax=vmax) + + ax_.quiver(grid[~mass_filter, 0], grid[~mass_filter, 1], + gradient_simulated[~mass_filter, 0], gradient_simulated[~mass_filter, 1], + scale=scale_for_simulated, zorder=20000) + ax_.axis("off") + ax_.set_title("Inner product of \n Perturb simulation * Development flow \n + Perturb simulation") + + + ##### + ax_ = ax[2] + pcm = ax_.scatter(value_on_grid[~mass_filter], inner_product[~mass_filter], + c=inner_product[~mass_filter], cmap="coolwarm", + vmin=vmin, vmax=vmax, s=s_grid) + + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.axhline(0, color="lightgray") + #pp = fig.colorbar(pcm, ax=ax_, orientation="vertical") + sns.despine() + ax_.set_xlabel("pseudotime") + ax_.set_ylabel("inner product score") + + ax_ = ax[3] + sns.boxplot(data=inner_product_stats, x="pseudotime_id", y="score", color="white", ax=ax_) + ax_.set_xlabel("Digitized_pseudotime") + ax_.set_ylabel("inner product score") + ax_.axhline(0, color="gray") + ax_.set_ylim([vmin*1.1, vmax*1.1]) + ax_.tick_params( + labelleft=False) + plt.show() + + + + fig, ax = plt.subplots(1, 4, figsize=[20,5]) + + stage_grid = self.oracle_dev.stage_grid + colors_dict = _adata_to_color_dict(adata=self.oracle.adata, cluster_use="Stage") + stage_order = return_order(stage_list=self.oracle_dev.inner_product_stats.stage) + + + #### + ax_ = ax[0] + ax_.scatter(grid[mass_filter, 0], grid[mass_filter, 1], s=0) + plot_grid_with_categprocal_color(grid=grid, mass_filter=mass_filter, color_array=stage_grid, + colors_dict=colors_dict, + ax=ax_, args={}) + + ax_.axis("off") + ax_.set_title("Stage (hpf) on grid") + + + ax_ = ax[1] + plot_legend(labels=stage_order, palette=colors_dict, ax_=ax_) + + ##### + ax_ = ax[2] + sns.violinplot(data=self.oracle_dev.inner_product_stats, + x="pseudotime", y="stage", + palette=colors_dict, order=stage_order[::-1], + ax=ax_) + sns.despine() + #ax_.set_xlabel("pseudotime") + ax_.set_ylabel("stage (hpf)") + + ax_ = ax[3] + plot_stackedvar(pd.crosstab(self.oracle_dev.inner_product_stats['stage'], + self.oracle_dev.inner_product_stats['pseudotime_id']).loc[stage_order], + ax=ax_, palette=colors_dict) + ax_.set_xlabel("Digitized_pseudotime") + ax_.set_ylabel("Grid point count") + sns.despine() + plt.show() + +def plot_grid_with_categprocal_color(grid, mass_filter, color_array, colors_dict, ax, args={}): + x = grid[~mass_filter, 0] + y = grid[~mass_filter, 1] + color_array_filtered = color_array[~mass_filter] + + for cluster, color in colors_dict.items(): + idx = np.where(color_array_filtered == cluster)[0] + ax.scatter(x[idx], y[idx], c=color, label=cluster, **args) + +def return_order(stage_list): + stage_unique = np.unique(stage_list) + hpf_unique = [float(i[:4]) for i in stage_unique] + stage_order = list(stage_unique[np.argsort(hpf_unique)]) + return stage_order + +def plot_legend(labels, palette, ax_): + + for i, label in enumerate(labels): + ax_.scatter([0], [i], s=100, c=palette[label]) + ax_.text(1, i-len(labels)*0.015, s=label) + ax_.set_ylim([-1, len(labels)]) + ax_.set_xlim([-1, 10]) + ax_.axis("off") + +def plot_stackedvar(df, ax, palette=None): + + bottom_feats=[] + if palette is None: + for i, j in enumerate(df.index.values): + if i==0: + ax.bar(df.columns.values, df.loc[j].values, edgecolor='white', label=j) + else: + ax.bar(df.columns.values, df.loc[j].values, label=j, + bottom=df.loc[bottom_feats].sum(axis=0).values, + edgecolor='white') + bottom_feats.append(j) + else: + for i, j in enumerate(df.index.values): + if i==0: + ax.bar(df.columns.values, df.loc[j].values, + edgecolor='white', color=palette[j], label=j) + else: + ax.bar(df.columns.values, df.loc[j].values, label=j, color=palette[j], + bottom=df.loc[bottom_feats].sum(axis=0).values, + edgecolor='white') + bottom_feats.append(j) + #plt.legend() + ax.set_xticks(df.columns) diff --git a/celloracle/trajectory/modified_VelocytoLoom_class.py b/celloracle/trajectory/modified_VelocytoLoom_class.py index e8f1cf9..e930c1a 100644 --- a/celloracle/trajectory/modified_VelocytoLoom_class.py +++ b/celloracle/trajectory/modified_VelocytoLoom_class.py @@ -431,7 +431,7 @@ def calculate_embedding_shift(self, sigma_corr: float=0.05) -> None: def calculate_grid_arrows(self, smooth: float=0.5, steps: Tuple=(40, 40), - n_neighbors: int=100, n_jobs: int=4) -> None: + n_neighbors: int=100, n_jobs: int=4, xylim: Tuple=((None, None), (None, None))) -> None: """Calculate the velocity using a points on a regular grid and a gaussian kernel Note: the function should work also for n-dimensional grid @@ -452,6 +452,8 @@ def calculate_grid_arrows(self, smooth: float=0.5, steps: Tuple=(40, 40), Higher value correspond to slower execution time n_jobs: number of processes for parallel computing + xymin: + ((xmin, xmax), (ymin, ymax)) Returns ------- @@ -478,6 +480,12 @@ def calculate_grid_arrows(self, smooth: float=0.5, steps: Tuple=(40, 40), grs = [] for dim_i in range(embedding.shape[1]): m, M = np.min(embedding[:, dim_i]), np.max(embedding[:, dim_i]) + + if xylim[dim_i][0] is not None: + m = xylim[dim_i][0] + if xylim[dim_i][1] is not None: + M = xylim[dim_i][1] + m = m - 0.025 * np.abs(M - m) M = M + 0.025 * np.abs(M - m) gr = np.linspace(m, M, steps[dim_i]) @@ -718,7 +726,7 @@ def plot_grid_arrows(self, quiver_scale: Union[str, float]="auto", scale_type: s XY = XY[~(mass_filter | (self.flow_norm_magnitude_rndm < min_magnitude)), :] else: UV_rndm[mass_filter | (self.flow_norm_magnitude_rndm < min_magnitude), :] = 0 - + if plot_random: plt.subplot(122) plt.title("Randomized") diff --git a/celloracle/trajectory/oracle_core.py b/celloracle/trajectory/oracle_core.py index b6b7d29..2bce3d5 100644 --- a/celloracle/trajectory/oracle_core.py +++ b/celloracle/trajectory/oracle_core.py @@ -283,7 +283,7 @@ def import_anndata_as_raw_count(self, adata, cluster_column_name=None, embedding self.adata.var["isin_top1000_var_mean_genes"] = self.adata.var.symbol.isin(self.high_var_genes) - def import_anndata_as_normalized_count(self, adata, cluster_column_name=None, embedding_name=None): + def import_anndata_as_normalized_count(self, adata, cluster_column_name=None, embedding_name=None, test_mode=False): """ Load scRNA-seq data. scRNA-seq data should be prepared as an anndata object. Preprocessing (cell and gene filtering, dimensional reduction, clustering, etc.) should be done before loading data. @@ -320,24 +320,26 @@ def import_anndata_as_normalized_count(self, adata, cluster_column_name=None, em self.adata.layers["normalized_count"] = self.adata.X.copy() # update color information - col_dict = _get_clustercolor_from_anndata(adata=self.adata, - cluster_name=self.cluster_column_name, - return_as="dict") - self.colorandum = np.array([col_dict[i] for i in self.adata.obs[self.cluster_column_name]]) - - # variable gene detection for the QC of simulation - """N = adata.shape[1] - if N >= 3000: - N = 3000 - n = int(N/3)-1 - """ - n = 1000 - self.score_cv_vs_mean(n, plot=False, max_expr_avg=35) - self.high_var_genes = self.cv_mean_selected_genes.copy() - self.cv_mean_selected_genes = None - - self.adata.var["symbol"] = self.adata.var.index.values - self.adata.var["isin_top1000_var_mean_genes"] = self.adata.var.symbol.isin(self.high_var_genes) + if not test_mode: + + col_dict = _get_clustercolor_from_anndata(adata=self.adata, + cluster_name=self.cluster_column_name, + return_as="dict") + self.colorandum = np.array([col_dict[i] for i in self.adata.obs[self.cluster_column_name]]) + + # variable gene detection for the QC of simulation + """N = adata.shape[1] + if N >= 3000: + N = 3000 + n = int(N/3)-1 + """ + n = 1000 + self.score_cv_vs_mean(n, plot=False, max_expr_avg=35) + self.high_var_genes = self.cv_mean_selected_genes.copy() + self.cv_mean_selected_genes = None + + self.adata.var["symbol"] = self.adata.var.index.values + self.adata.var["isin_top1000_var_mean_genes"] = self.adata.var.symbol.isin(self.high_var_genes)