diff --git a/src/chromatinhd/biomart/dataset.py b/src/chromatinhd/biomart/dataset.py index 2312549..2d7d9d8 100755 --- a/src/chromatinhd/biomart/dataset.py +++ b/src/chromatinhd/biomart/dataset.py @@ -178,7 +178,7 @@ def get(self, attributes=[], filters=[], use_cache=True, timeout = 20) -> pd.Dat try: response = requests.get(url, timeout=timeout) except requests.exceptions.Timeout: - raise ValueError("Ensembl web service timed out") + raise ValueError("Ensembl web service timed out: ", url) # check response status if response.status_code != 200: raise ValueError(f"Response status code is {response.status_code} and not 200. Response text: {response.text}") diff --git a/src/chromatinhd/data/folds/folds.py b/src/chromatinhd/data/folds/folds.py index 028aa6c..c654a25 100755 --- a/src/chromatinhd/data/folds/folds.py +++ b/src/chromatinhd/data/folds/folds.py @@ -11,7 +11,7 @@ class Folds(Flow): """ - Folds of multiple cell and reion combinations + Folds of multiple cell and region combinations """ folds: dict = Stored() diff --git a/src/chromatinhd/data/motifscan/download.py b/src/chromatinhd/data/motifscan/download.py index 68ff29f..fd13f49 100755 --- a/src/chromatinhd/data/motifscan/download.py +++ b/src/chromatinhd/data/motifscan/download.py @@ -81,7 +81,7 @@ def get_hocomoco(path, organism="hs", variant="CORE", overwrite=False): path.mkdir(parents=True, exist_ok=True) # download cutoffs, pwms and annotations - if overwrite or (not (path / "pwm_cutoffs.txt").exists()): + if overwrite or (not (path / "pwms.tar.gz").exists()): urllib.request.urlretrieve( f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_annotation.jsonl", path / "annotation.jsonl", diff --git a/src/chromatinhd/data/motifscan/plot.py b/src/chromatinhd/data/motifscan/plot.py index 7254e60..dcc89b7 100755 --- a/src/chromatinhd/data/motifscan/plot.py +++ b/src/chromatinhd/data/motifscan/plot.py @@ -115,7 +115,11 @@ def __init__( super().__init__() - motifs_oi, group_info, motifdata = _process_grouped_motifs( + motifs_oi, group_info = _process_grouped_motifs_oi( + motifs_oi, motifscan, group_info=group_info + ) + + motifdata = _process_grouped_motifs( gene, motifs_oi, motifscan, group_info=group_info, window=window ) @@ -206,26 +210,9 @@ def blend_with_white(color, alpha): return blended_color - -def _process_grouped_motifs( - gene, motifs_oi, motifscan, window=None, group_info=None, slices_oi=None +def _process_grouped_motifs_oi( + motifs_oi, motifscan, group_info=None ): - region_ix = motifscan.regions.coordinates.index.get_loc(gene) - - # get motif data - if window is not None: - positions, indices = motifscan.get_slice( - region_ix=region_ix, - start=window[0], - end=window[1], - return_scores=False, - return_strands=False, - ) - else: - positions, indices = motifscan.get_slice( - region_ix=region_ix, return_scores=False, return_strands=False - ) - # check motifs oi if not isinstance(motifs_oi, pd.DataFrame): raise ValueError("motifs_oi should be a dataframe") @@ -235,10 +222,6 @@ def _process_grouped_motifs( ) elif "group" not in motifs_oi.columns: motifs_oi["group"] = motifs_oi.index - # raise ValueError("motifs_oi should have a 'group' column") - # ensure all rows are unique - elif not motifs_oi.index.is_unique: - raise ValueError("motifs_oi should have unique rows", motifs_oi) # get group info if group_info is None: @@ -256,17 +239,51 @@ def _process_grouped_motifs( for x in mpl.cm.tab20(np.arange(group_info.shape[0]) % 20) ] motifs_oi["color"] = motifs_oi["group"].map(group_info["color"].to_dict()) + return motifs_oi, group_info + +def _process_grouped_motifs( + gene, motifs_oi, motifscan, window=None, group_info=None, slices_oi=None, return_strands = False, prune = 20, +): + region_ix = motifscan.regions.coordinates.index.get_loc(gene) + + if prune is not False: + return_strands = True + + # get motif data + if window is not None: + positions, indices, scores, strands = motifscan.get_slice( + region_ix=region_ix, + start=window[0], + end=window[1], + return_scores=True, + return_strands=True, + ) + else: + positions, indices, scores, strands = motifscan.get_slice( + region_ix=region_ix, return_scores=True, return_strands=True, + ) motifscan.motifs["ix"] = np.arange(motifscan.motifs.shape[0]) motifdata = [] - for motif in motifs_oi.index: - group = motifs_oi.loc[motif, "group"] + for motif, group in zip(motifs_oi.index, motifs_oi["group"]): motif_ix = motifscan.motifs.index.get_loc(motif) positions_oi = positions[indices == motif_ix] - motifdata.extend( - [{"position": pos, "motif": motif, "group": group} for pos in positions_oi] - ) - motifdata = pd.DataFrame(motifdata, columns=["position", "motif", "group"]) + if return_strands: + strands_oi = strands[indices == motif_ix] + scores_oi = scores[indices == motif_ix] + motifdata.extend( + [ + {"position": pos, "motif": motif, "group": group, "strand": strand, "score":score} + for pos, strand, score in zip(positions_oi, strands_oi, scores_oi) + ] + ) + else: + motifdata.extend( + [{"position": pos, "motif": motif, "group": group,} for pos in positions_oi] + ) + + motifdata = pd.DataFrame(motifdata, columns=["position", "motif", "group", "strand", "score"]) if return_strands else pd.DataFrame(motifdata, columns=["position", "motif", "group"]) + motifdata = motifdata.sort_values("position", ascending = True) # check slices oi if slices_oi is not None: @@ -290,7 +307,22 @@ def _process_grouped_motifs( ) ).any(axis=0) - return motifs_oi, group_info, motifdata + # prune + if (prune is not False) and len(motifdata): + if "score" not in motifdata.columns: + raise ValueError("motifdata should have a 'score' column to prune, set return_strands = True") + + # go over each group and delete motifs that are too close + motifdata_pruned = [] + for _, group in motifdata.groupby("group"): + group["distance"] = group["position"].diff().fillna(0) + + group["section"] = (group["distance"] > prune).cumsum() + group = group.sort_values("score", ascending = False).groupby("section", as_index = False).first() + motifdata_pruned.append(group) + motifdata = pd.concat(motifdata_pruned) + + return motifdata def intersect_positions_slices(positions, slices_start, slices_end): @@ -317,17 +349,45 @@ def __init__( group_info=None, panel_height=0.1, slices_oi=None, + show_triangle:bool=True, + show_bar:bool=True, ): """ Plot the location of motifs in a region. + + Parameters + ---------- + motifscan : MotifScan + MotifScan object + gene : str + Gene name + motifs_oi : pd.DataFrame + Dataframe with motifs to plot. Should have a 'group' column. + breaking : Breaking + Breaking object + group_info : pd.DataFrame + Dataframe with group information. Should have a 'label' column. + panel_height : float + Height in inches of each motif group line + slices_oi : pd.DataFrame + Dataframe with slices information. Should have 'start', 'end', 'cluster', and 'region' columns. + show_triangle : bool + Show triangle markers + show_bar : bool + Show bar markers + """ super().__init__() - motifs_oi, group_info, motifdata = _process_grouped_motifs( - gene, motifs_oi, motifscan, group_info=group_info, slices_oi=slices_oi + motifs_oi, group_info = _process_grouped_motifs_oi( + motifs_oi, motifscan, group_info=group_info ) + motifdatas = [_process_grouped_motifs( + gene, motifs_oi, motifscan, group_info=group_info, window=[region["start"], region["end"]] + ) for _, region in breaking.regions.iterrows()] + for group, group_info_oi in group_info.iterrows(): broken = self.add_under( Broken( @@ -356,29 +416,27 @@ def __init__( # create a very narrow triangle marker = mpl.path.Path( [ - [-0.4, 0.0], - [0.4, 0.], - [0.0, -1], - [-0.4, 0.0], + [-0.5, 1.0], + [0.5, 1.], + [0.0, 0], + [-0.5, 1.], ] ) marker = mpl.path.Path( [ - [-0.4, -1], - [0.4, -1], - [0.0, 0.05], - [-0.4, -1], + [-0.5, 0], + [0.5, 0], + [0.0, 1.], + [-0.5, 0], ] ) # plot the motifs - for (region, region_info), (panel, ax) in zip( - breaking.regions.iterrows(), broken + for (region, region_info), (panel, ax), motifdata in zip( + breaking.regions.iterrows(), broken, motifdatas ): motifdata_region = motifdata.loc[ - (motifdata["position"] >= region_info["start"]) - & (motifdata["position"] <= region_info["end"]) - & (motifdata["motif"].isin(group_motifs.index)) + (motifdata["motif"].isin(group_motifs.index)) ] # remove duplicates from the same group @@ -386,33 +444,84 @@ def __init__( keep="first", subset=["position"] ) - for motif in group_motifs.itertuples(): - # add motifs - motif_id = motif.Index + plotdata = motifdata_region - plotdata = motifdata_region.loc[motifdata["motif"] == motif_id].copy() - - if len(plotdata) > 0: - if "oi" in plotdata.columns: - plotdata_oi = plotdata.loc[plotdata["oi"]] - plotdata_not_oi = plotdata.loc[~plotdata["oi"]] + if len(plotdata) > 0: + if "oi" not in plotdata.columns: + plotdata["oi"] = True + if "oi" in plotdata.columns: + plotdata_oi = plotdata.loc[plotdata["oi"]] + plotdata_not_oi = plotdata.loc[~plotdata["oi"]] + + # join very close + plotdata_oi = plotdata_oi.sort_values("position") + plotdata_oi["distance_to_next"] = plotdata_oi["position"].diff().fillna(0.) + plotdata_oi["group"] = np.cumsum(plotdata_oi["distance_to_next"] > 50) + plotdata_oi["n"] = 1 + plotdata_oi = plotdata_oi.groupby("group").agg( + {"position": "mean", "n": "sum", "oi": "first"} + ) + + # oi + if show_triangle: ax.scatter( plotdata_oi["position"], - [1.] * len(plotdata_oi), + [0.] * len(plotdata_oi), transform=mpl.transforms.blended_transform_factory( ax.transData, ax.transAxes ), # marker="v", marker = marker, color=color, - s=250, + s=100, zorder=20, - lw=0.5, - edgecolor="white", + lw = 0. + ) + # background white + ax.scatter( + plotdata_oi["position"], + [1.] * len(plotdata_oi), + transform=mpl.transforms.blended_transform_factory( + ax.transData, ax.transAxes + ), + # marker="v", + marker = "|", + color="white", + s=400, + zorder=19, + lw = 1. ) ax.scatter( plotdata_not_oi["position"], - [1.] * len(plotdata_not_oi), + [0.] * len(plotdata_not_oi), + transform=mpl.transforms.blended_transform_factory( + ax.transData, ax.transAxes + ), + # marker="v", + marker = marker, + color=blend_with_white(color, 0.3), + s=100, + zorder=18, + lw = 0. + ) + if show_bar: + ax.scatter( + plotdata_oi["position"], + [1.] * len(plotdata_oi), + transform=mpl.transforms.blended_transform_factory( + ax.transData, ax.transAxes + ), + # marker="v", + marker = "|", + color=color, + s=200, + zorder=22, + lw = 1.5 + ) + + ax.scatter( + plotdata_oi["position"], + [1.] * len(plotdata_oi), transform=mpl.transforms.blended_transform_factory( ax.transData, ax.transAxes ), @@ -424,23 +533,22 @@ def __init__( lw=0.0, edgecolor="white", ) - else: ax.scatter( - plotdata["position"], - [1.] * len(plotdata), + plotdata_not_oi["position"], + [1.] * len(plotdata_not_oi), transform=mpl.transforms.blended_transform_factory( ax.transData, ax.transAxes ), # marker="v", - marker = marker, - color=motif.color, - s=250, - zorder=20, - lw=0.5, - edgecolor="white", + marker = "|", + color=blend_with_white(color, 0.3), + s=200, + zorder=18, + lw = 1. ) + def _setup_group(ax, group_info_oi, group_motifs): ax.set_xticks([]) ax.set_yticks([]) diff --git a/src/chromatinhd/data/motifscan/plot_genome.py b/src/chromatinhd/data/motifscan/plot_genome.py new file mode 100644 index 0000000..ee0a25a --- /dev/null +++ b/src/chromatinhd/data/motifscan/plot_genome.py @@ -0,0 +1,331 @@ +from matplotlib.path import Path +import matplotlib as mpl +import numpy as np +import copy +import pandas as pd + +from polyptich.grid import Grid, Panel, Broken + +path_a = Path(np.array([[0.65971097, 0.78427118], + [0.33386837, 0.78427118], + [0.27528096, 1. ], + [0. , 1. ], + [0.34028903, 0. ], + [0.65971097, 0. ], + [1. , 1. ], + [0.71829837, 1. ], + [0.65971097, 0.78427118], + [0.65971097, 0.78427118], + [0.37560198, 0.61255405], + [0.61637226, 0.61255405], + [0.49598698, 0.17027416], + [0.37560198, 0.61255405], + [0.37560198, 0.61255405]]), np.array([ 1, 2, 2, 2, 2, 2, 2, 2, 2, 79, 1, 2, 2, 2, 79], + dtype=np.uint8)) + + +path_t = Path(np.array([[0.6588729 , 0.18109665], + [0.6688729 , 1. ], + [0.34746375, 1. ], + [0.34746375, 0.18109665], + [0. , 0.18109665], + [0. , 0. ], + [1. , 0. ], + [1. , 0.18109665], + [0.6688729 , 0.18109665], + [0.6688729 , 0.18109665]]), np.array([ 1, 2, 2, 2, 2, 2, 2, 2, 2, 79], dtype=np.uint8)) + + +path_c = Path(np.array([[5.83258302e-01, 0.00000000e+00], + [7.17371840e-01, 0.00000000e+00], + [8.10981077e-01, 2.74914295e-02], + [9.05490696e-01, 5.42953515e-02], + [9.83798468e-01, 1.03780069e-01], + [8.38884117e-01, 2.37113370e-01], + [7.87578689e-01, 2.04810826e-01], + [7.25472763e-01, 1.86254117e-01], + [6.63366523e-01, 1.67010620e-01], + [5.90459157e-01, 1.67010620e-01], + [5.12151071e-01, 1.67010620e-01], + [4.47344628e-01, 2.01374967e-01], + [3.82538185e-01, 2.35051807e-01], + [3.43834332e-01, 3.08591375e-01], + [3.05130480e-01, 3.81443675e-01], + [3.05130480e-01, 4.98282071e-01], + [3.05130480e-01, 6.70103325e-01], + [3.86138455e-01, 7.48453768e-01], + [4.68046812e-01, 8.26116942e-01], + [5.95859563e-01, 8.26116942e-01], + [6.89469114e-01, 8.26116942e-01], + [7.51575354e-01, 7.99312780e-01], + [8.13681594e-01, 7.72508858e-01], + [8.65886776e-01, 7.40893582e-01], + [1.00000000e+00, 8.71477813e-01], + [9.29793151e-01, 9.24398629e-01], + [8.27182923e-01, 9.62199314e-01], + [7.24572381e-01, 1.00000000e+00], + [5.79658031e-01, 1.00000000e+00], + [4.10440911e-01, 1.00000000e+00], + [2.79027889e-01, 9.42955338e-01], + [1.48514621e-01, 8.85223408e-01], + [7.38071197e-02, 7.73883394e-01], + [0.00000000e+00, 6.61855872e-01], + [0.00000000e+00, 4.98282071e-01], + [0.00000000e+00, 3.38831875e-01], + [7.65076369e-02, 2.27491861e-01], + [1.53915341e-01, 1.16151608e-01], + [2.86228431e-01, 5.84196776e-02], + [4.18541834e-01, 4.79768063e-07], + [5.83258302e-01, 4.79768063e-07], + [5.83258302e-01, 0.00000000e+00], + [5.83258302e-01, 0.00000000e+00]]), np.array([ 1, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 2, 79], dtype=np.uint8)) + + +path_g = Path(np.array([[0.55545272, 1. ], + [0.28505905, 1. ], + [0.14207124, 0.87216502], + [0. , 0.74364277], + [0. , 0.49828183], + [0. , 0.33608249], + [0.07882659, 0.22542945], + [0.15765318, 0.11408938], + [0.29147501, 0.05704469], + [0.42529716, 0. ], + [0.59028319, 0. ], + [0.72960473, 0. ], + [0.82492983, 0.03230232], + [0.92025524, 0.06460464], + [0.994499 , 0.11752572], + [0.8368455 , 0.23986256], + [0.77910057, 0.20137458], + [0.72593846, 0.18281787], + [0.67277604, 0.16426043], + [0.60494856, 0.16426043], + [0.51970569, 0.16426043], + [0.45279477, 0.19862455], + [0.38680043, 0.23230165], + [0.3483037 , 0.30652852], + [0.30980697, 0.38006788], + [0.30980697, 0.49965565], + [0.30980697, 0.62405429], + [0.33638898, 0.69759389], + [0.36388596, 0.77113325], + [0.41888183, 0.80343581], + [0.47387739, 0.83505086], + [0.55820338, 0.83505086], + [0.60311639, 0.83505086], + [0.64252968, 0.82748972], + [0.68194298, 0.8192425 ], + [0.71585688, 0.80618321], + [0.71585688, 0.58831611], + [0.55637025, 0.58831611], + [0.52704014, 0.43024061], + [1. , 0.43024061], + [1. , 0.90378007], + [0.9046749 , 0.94776643], + [0.79468411, 0.97388333], + [0.68560988, 1. ], + [0.55545432, 1. ], + [0.55545272, 1. ], + [0.55545272, 1. ]]), np.array([ 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 79], dtype=np.uint8)) + + +path_c = path_g = path_a = path_t = Path(np.array([[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]), np.array([1, 2, 2, 2, 79], dtype=np.uint8)) +polygon_c = mpl.patches.PathPatch( + path_c +) +polygon_a = mpl.patches.PathPatch( + path_a +) +polygon_t = mpl.patches.PathPatch( + path_t +) +polygon_g = mpl.patches.PathPatch( + path_g +) + +polygons = { + "A": polygon_a, + "T": polygon_t, + "G": polygon_g, + "C": polygon_c, +} +polygon_colors = { + "A": "#2ECC40", + "T": "#FF4136", + "G": "#FFDC00", + "C": "#0074D9", +} + + + +def plot_sequence(sequence, x): + patches = [] + colors = [] + import copy + for i, char in enumerate(sequence): + polygon = copy.copy(polygons[char.upper()]) + polygon.set_transform(mpl.transforms.Affine2D().translate(x + i, -1).scale(1, -1)) + + patches.append(polygon) + colors.append(polygon_colors[char.upper()]) + collection = mpl.collections.PatchCollection( + patches, + lw=0, + facecolor=colors, + ) + return collection + +def plot_motif(pwm, x, y): + patches = [] + colors = [] + for row in range(pwm.shape[0]): + pos = 0 + for col, char in zip(range(pwm.shape[1]), ["A", "C", "G", "T"]): + score = pwm[row, col] / np.sqrt(2) + if score > 0: + patch = copy.copy(polygons[char.upper()]) + patches.append(patch) + patch.set_transform(mpl.transforms.Affine2D().scale(1, -score).translate(x+row, y+pos+score)) + pos += score + colors.append(polygon_colors[char.upper()]) + + collection = mpl.collections.PatchCollection( + patches, + lw=0, + facecolor=colors, + ) + + return collection + +def plot_motifs(ax, motifdata, pwms): + # do the actual plotting of motifs + prev_max_xs = [] # keeps track of previously plotted x and y to avoid overlap + y = 0 + full_max_y = 0 # keeps track of the maximum y to know dimensions of plot + + motifdata = motifdata.sort_values("position") + + for _, row in motifdata.iterrows(): + pwm = pwms[row["motif"]].numpy() + length = pwm.shape[0] + x = row["position"] + max_x = row["position"] + length + color = row["color"] + label = row["label"] + prev_max_xs = [(prev_max_x, y) for prev_max_x, y in prev_max_xs if x < prev_max_x] + + if len(prev_max_xs) > 0: + ys = [y for _, y in prev_max_xs] + for i in range(0, -10, -1): + if i not in ys: + y = i + break + else: + y = 0 + + full_max_y = max(y, full_max_y) + + rect = mpl.patches.Rectangle( + (x, y - 1), + length, + 1, + fc = color, + alpha = 0.1, + zorder = -5, + ) + ax.add_patch(rect) + + if row["strand"] == -1: + pwm = pwm[::-1, ::-1] + + # plot motif + collection = plot_motif(pwm, x, y-1) + ax.add_collection(collection) + + # plot motif name + text = ax.text( + x, + # x + length / 2, + y-0.5, + label, + # ha="center", + ha="right", + va="center", + fontsize=6, + color = color, + # color = "white", + fontweight = "bold", + ) + text.set_path_effects( + [ + mpl.patheffects.withStroke(linewidth=1, foreground="white"), + ] + ) + + prev_max_xs.append((max_x, y)) + + return full_max_y-2 + +from .plot import _process_grouped_motifs + + + +class GroupedMotifsGenomeBroken(Broken): + def __init__( + self, + motifscan, + gene, + motifs_oi, + breaking, + genome, + # group_info, + pwms, + panel_height=0.2, + ): + """ + Plot the location of motifs in a region. + + Parameters + ---------- + + """ + + super().__init__(breaking = breaking, height = 0.2) + + # plot the motifs + for (window, window_info), (panel, ax) in zip( + breaking.regions.iterrows(), self + ): + ax.set_xlim(window_info["start"], window_info["end"]) + ax.axis("off") + + sequence = genome.fetch( + window_info["chrom"], window_info["start_chrom"], window_info["end_chrom"] + ) + + collection = plot_sequence(sequence, window_info["start"]) + ax.add_collection(collection) + + motifdata = _process_grouped_motifs( + gene, + motifs_oi, + motifscan, + return_strands = True, + window = [window_info["start"], window_info["end"]], + ) + + motifdata["label"] = motifscan.motifs.loc[motifdata["motif"]]["tf"].values + motifdata["color"] = motifs_oi.reset_index().set_index(["group", "motif"])["color"].loc[pd.MultiIndex.from_frame(motifdata[["group", "motif"]])].values + + full_max_y = plot_motifs(ax, motifdata, pwms) + + ax.set_ylim(full_max_y, 1) + + ax.dim = (ax.width, (-full_max_y)*panel_height) + + diff --git a/src/chromatinhd/data/peakcounts/plot.py b/src/chromatinhd/data/peakcounts/plot.py index 4d44dfd..9bf43f2 100755 --- a/src/chromatinhd/data/peakcounts/plot.py +++ b/src/chromatinhd/data/peakcounts/plot.py @@ -1,4 +1,3 @@ -import pybedtools import pandas as pd import numpy as np @@ -24,6 +23,11 @@ def get_usecols_and_names(peakcaller): def extract_peaks(peaks_bed, promoter, peakcaller): if peaks_bed is None: return pd.DataFrame({"start": [], "end": [], "method": [], "peak": []}) + + try: + import pybedtools + except ImportError: + raise ImportError("pybedtools is required to plot peaks, install using `pip install pybedtools` or `conda install -c bioconda pybedtools`. You may also need to install bedtools, e.g. using `conda install -c bioconda bedtools`.") promoter_bed = pybedtools.BedTool.from_dataframe(pd.DataFrame(promoter).T[["chrom", "start", "end"]]) @@ -266,7 +270,10 @@ def _get_peaks(region, peakcallers): peaks = [] - import pybedtools + try: + import pybedtools + except ImportError: + raise ImportError("pybedtools is required to plot peaks, install using `pip install pybedtools` or `conda install -c bioconda pybedtools`. You may also need to install bedtools, e.g. using `conda install -c bioconda bedtools`.") for peakcaller, peakcaller_info in peakcallers.iterrows(): if not pathlib.Path(peakcaller_info["path"]).exists(): diff --git a/src/chromatinhd/data/transcriptome/transcriptome.py b/src/chromatinhd/data/transcriptome/transcriptome.py index 6a92bc3..7fc777a 100755 --- a/src/chromatinhd/data/transcriptome/transcriptome.py +++ b/src/chromatinhd/data/transcriptome/transcriptome.py @@ -166,12 +166,15 @@ def filter_cells(self, cells, path=None): adata = None return Transcriptome.create(var=self.var, obs=self.obs.loc[cells], X=X, layers=layers, path=path, adata=adata) - def get_X(self, gene_ids, layer=None): """ Get the counts for a given set of genes. """ - gene_ixs = self.var.index.get_loc(gene_ids) + + if isinstance(gene_ids, str): + gene_ixs = self.var.index.get_loc(gene_ids) + else: + gene_ixs = self.var.index.get_indexer(gene_ids) if layer is None: value = self.X[:, gene_ixs] diff --git a/src/chromatinhd/models/diff/interpret/regionpositional.py b/src/chromatinhd/models/diff/interpret/regionpositional.py index eaaf366..2677753 100755 --- a/src/chromatinhd/models/diff/interpret/regionpositional.py +++ b/src/chromatinhd/models/diff/interpret/regionpositional.py @@ -369,6 +369,8 @@ def get_plotdata(self, region: str, clusters=None, relative_to=None, scale = 1.) Parameters: region: the region + relative_to: + the clusters to normalize to Returns: Two dataframes, one with the probabilities per cluster, one with the mean @@ -381,7 +383,11 @@ def get_plotdata(self, region: str, clusters=None, relative_to=None, scale = 1.) plotdata = probs.to_dataframe("prob") # plotdata["prob"] = plotdata["prob"] * scale - plotdata["prob"].mean() * scale - if relative_to is not None: + if relative_to == "previous": + plotdata_mean = plotdata[["prob"]].groupby("coord", observed=True).mean() + elif relative_to is not None: + if relative_to not in plotdata.index.get_level_values("cluster"): + raise ValueError(f"Cluster {relative_to} not in clusters") plotdata_mean = plotdata[["prob"]].query("cluster in @relative_to").groupby("coord", observed=False).mean() else: plotdata_mean = plotdata[["prob"]].groupby("coord", observed=True).mean() @@ -621,7 +627,7 @@ def spread_true(arr, width=5): .reset_index(drop=True) ) - if differential_prob_cutoff is not None: + if len(windows) and (differential_prob_cutoff is not None): windows["extra_selection"] = windows.apply( lambda x: (plotdata_diff.iloc[:, plotdata_diff.columns.get_loc(x["start"]) : plotdata_diff.columns.get_loc(x["end"])] > np.log(differential_prob_cutoff)).any().any(), axis=1 ) diff --git a/src/chromatinhd/models/diff/model/binary.py b/src/chromatinhd/models/diff/model/binary.py index f75243b..32f36f0 100755 --- a/src/chromatinhd/models/diff/model/binary.py +++ b/src/chromatinhd/models/diff/model/binary.py @@ -162,7 +162,7 @@ def train_model( n_epochs=30, lr=1e-2, pbar=True, - early_stopping=True, + early_stopping=False, fold=None, fragments: Fragments = None, clustering=None, @@ -229,7 +229,6 @@ def train_model( ) trainer = Trainer( - # trainer = TrainerPerFeature( self, loaders_train, loaders_validation, @@ -466,6 +465,10 @@ def evaluate_pseudo( class Models(Flow): + """ + Multiple ChromatinHD-diff models, based on different train/test/validation splits + """ + models = LinkedDict() clustering = Linked() @@ -480,6 +483,49 @@ class Models(Flow): model_params = Stored(default=dict) train_params = Stored(default=dict) + @classmethod + def create( + cls, + fragments=None, + clustering=None, + folds=None, + model_params:dict=None, + train_params:dict=None, + path=None, + reset=False, + ): + """ + Creates a new Models object + + Parameters: + fragments: + Fragments object + clustering: + Clustering object + folds: + List of folds + model_params: + Parameters for the model. See Model.create + train_params: + Parameters for training. See Model.train_model + """ + self = super(Models, cls).create(path=path, reset=reset) + + self.fragments = fragments + self.clustering = clustering + self.folds = folds + + if model_params is not None: + self.model_params = model_params + else: + self.model_params = dict() + if train_params is not None: + self.train_params = train_params + else: + self.train_params = dict + + return self + @property def models_path(self): path = self.path / "models" @@ -489,6 +535,9 @@ def models_path(self): def train_models( self, fragments=None, clustering=None, folds=None, device=None, pbar=True, regions_oi=None, **kwargs ): + """ + Create and train all models + """ if fragments is None: fragments = self.fragments if clustering is None: @@ -553,5 +602,4 @@ def get_prediction(self, fold_ix, **kwargs): @property def trained(self): - print(len(self)) return len(self) > 0 diff --git a/src/chromatinhd/models/diff/plot/differential.py b/src/chromatinhd/models/diff/plot/differential.py index a979828..7ae9ec4 100755 --- a/src/chromatinhd/models/diff/plot/differential.py +++ b/src/chromatinhd/models/diff/plot/differential.py @@ -5,9 +5,13 @@ import numpy as np import pandas as pd +cmap_atac_diff = mpl.colors.LinearSegmentedColormap.from_list( + "RdBu_r_cb", [mpl.cm.RdBu_r(x) for x in np.linspace(0.1, 0.9, 100)] +) + def get_cmap_atac_diff(): - return mpl.cm.RdBu_r + return cmap_atac_diff def get_norm_atac_diff(): @@ -25,7 +29,7 @@ def __init__( panel_height=0.5, plotdata_empirical=None, show_atac_diff=True, - cmap_atac_diff=mpl.cm.RdBu_r, + cmap_atac_diff=cmap_atac_diff, norm_atac_diff=mpl.colors.Normalize(np.log(1 / 4), np.log(4.0), clip=True), ymax=100, ylintresh=25, @@ -76,7 +80,7 @@ def __init__( self.window = window self.cluster_info = cluster_info - plotdata, plotdata_mean, order, window = _process_plotdata( + plotdata, order, window = _process_plotdata( plotdata, plotdata_mean, cluster_info, @@ -91,7 +95,7 @@ def __init__( panel, ax = polyptich.grid.Panel((width, panel_height)) self.add(panel) - _scale_differential(ax, ymax) + _scale_differential(ax, ymax, lintresh=ylintresh) _setup_differential( ax, ymax, @@ -104,10 +108,13 @@ def __init__( ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_color("grey") if plotdata_empirical is not None: # empirical distribution of atac-seq cuts - plotdata_empirical_cluster = plotdata_empirical.query("cluster == @cluster") + plotdata_empirical_cluster = plotdata_empirical.query( + "cluster == @cluster" + ) ax.fill_between( plotdata_empirical_cluster["coord"], np.exp(plotdata_empirical_cluster["prob"]), @@ -128,7 +135,6 @@ def draw(self, plotdata, plotdata_mean): self.artists = [] for ax, cluster in zip(self.elements, self.order): - ax = ax.ax if self.show_atac_diff: # posterior distribution of atac-seq cuts plotdata_cluster = plotdata.xs(cluster, level="cluster") @@ -183,8 +189,12 @@ def draw(self, plotdata, plotdata_mean): self.artists.extend([gradient, polygon, background, differential]) @classmethod - def from_regionpositional(cls, region_id, regionpositional, width, relative_to=None, **kwargs): - plotdata, plotdata_mean = regionpositional.get_plotdata(region_id, relative_to=relative_to) + def from_regionpositional( + cls, region_id, regionpositional, width, relative_to=None, **kwargs + ): + plotdata, plotdata_mean = regionpositional.get_plotdata( + region_id, relative_to=relative_to + ) self = cls( plotdata=plotdata, plotdata_mean=plotdata_mean, @@ -211,28 +221,9 @@ def get_artists(self): class DifferentialBroken(polyptich.grid.Wrap): - def __init__( - self, - plotdata, - plotdata_mean, - cluster_info, - breaking, - window=None, - panel_height=0.5, - show_atac_diff=True, - cmap_atac_diff=mpl.cm.RdBu_r, - norm_atac_diff=mpl.colors.Normalize(np.log(1 / 4), np.log(4.0), clip=True), - ymax=100, - ylintresh=25, - order=False, - relative_to=None, - label_accessibility=True, - label_cluster=True, - **kwargs, - ): - """ - Parameters - + """ + Parameters + --- plotdata: dataframe with columns "coord", "prob", "cluster" plotdata_mean: @@ -261,8 +252,30 @@ def __init__( order of the clusters relative_to: cluster or clusters to show the differential accessibility relative to - """ + label_accessibility: + label the accessibility + """ + def __init__( + self, + plotdata, + plotdata_mean, + cluster_info, + breaking, + window=None, + panel_height=0.5, + show_atac_diff=True, + cmap_atac_diff=cmap_atac_diff, + norm_atac_diff=mpl.colors.Normalize(np.log(1 / 4), np.log(4.0), clip=True), + ymax=100, + ylintresh=25, + order=False, + relative_to=None, + label_accessibility=True, + label_cluster=True, + show_scale=True, + **kwargs, + ): super().__init__(ncol=1, **{"padding_height": 0, **kwargs}) self.show_atac_diff = show_atac_diff self.cmap_atac_diff = cmap_atac_diff @@ -276,22 +289,28 @@ def __init__( plotdata = plotdata.query("cluster in @cluster_info.index") if order is True: - self.order = plotdata.groupby(level=0).mean().sort_values(ascending=False).index + self.order = ( + plotdata.groupby(level=0).mean().sort_values(ascending=False).index + ) elif order is False: self.order = cluster_info.index else: self.order = order - plotdata, plotdata_mean, order, _ = _process_plotdata(plotdata, plotdata_mean, cluster_info, order, relative_to) + plotdata, order, _ = _process_plotdata( + plotdata, plotdata_mean, cluster_info, order, relative_to + ) self.order = order for cluster, cluster_info_oi in self.cluster_info.loc[self.order].iterrows(): broken = self.add( - polyptich.grid.Broken(breaking, height=panel_height, margin_top=0.0, padding_height=0.0) + polyptich.grid.Broken( + breaking, height=panel_height, margin_top=0.0, padding_height=0.0 + ) ) - panel, ax = broken[0, 0] + ax = broken[0, 0] _setup_differential( ax, ymax, @@ -300,43 +319,130 @@ def __init__( label_cluster=label_cluster, ) - for panel, ax in broken: - _scale_differential(ax, ymax, lintresh = ylintresh) + for ax in broken: + _scale_differential(ax, ymax, lintresh=ylintresh) ax.set_yticks([]) ax.set_yticks([], minor=True) ax.axvline(0, dashes=(1, 1), color="#AAA", zorder=-1, lw=1) + ax.spines["bottom"].set_color("#33333333") + ax.spines["left"].set_visible(False) + ax.spines["right"].set_visible(False) + if self.show_atac_diff: - self.draw(plotdata, plotdata_mean) + self.draw(plotdata) + + if relative_to == "previous": + for i in range(len(cluster_info) - 1): + ax = self[i][0, 0] + ax.annotate( + "", + xy=(0, -0.5), + xycoords="axes fraction", + xytext=(0, 0.5), + textcoords="axes fraction", + arrowprops=dict( + arrowstyle="->", + connectionstyle="angle3,angleA=60,angleB=-60", + ec="#333333", + ), + zorder=-5, + ) - def draw(self, plotdata, plotdata_mean): + # scale + if show_scale is not False: + self.add_scale() + + def add_scale(self): + # ax = self[0][0, -1] + ax = self[0][0, 0] + pad = self.breaking.resolution * 0.05 + x1 = self.breaking.regions["start"].iloc[0] + pad + x2 = self.breaking.regions["start"].iloc[0] + pad + 500 + transform = mpl.transforms.blended_transform_factory(ax.transData, ax.transAxes) + ax.plot( + [x1, x2], + [1, 1], + transform=transform, + zorder=100, + clip_on=False, + lw=1, + color="grey", + ) + text = ax.annotate( + "500bp", + xy=(x1 + (x2 - x1) / 2, 1), + xycoords=transform, + xytext=(0, 2), + textcoords="offset points", + ha="center", + va="bottom", + zorder=100, + fontsize=6, + clip_on=False, + color="grey", + path_effects=[mpl.patheffects.withStroke(linewidth=1, foreground="white")], + ) + text.set_path_effects( + [mpl.patheffects.withStroke(linewidth=1, foreground="white")] + ) + + def draw(self, plotdata): artists = [] - for (cluster, cluster_info_oi), broken in zip(self.cluster_info.loc[self.order].iterrows(), self): - for (region, region_info), (panel, ax) in zip(self.breaking.regions.iterrows(), broken): + for (cluster, cluster_info_oi), broken in zip( + self.cluster_info.loc[self.order].iterrows(), self + ): + for (region, region_info), ax in zip( + self.breaking.regions.iterrows(), broken + ): plotdata_cluster = plotdata.xs(cluster, level="cluster") plotdata_cluster_break = plotdata_cluster.loc[ - (plotdata_cluster.index.get_level_values("coord") >= region_info["start"]) - & (plotdata_cluster.index.get_level_values("coord") <= region_info["end"]) - ] - plotdata_mean_break = plotdata_mean.loc[ - (plotdata_mean.index >= region_info["start"]) & (plotdata_mean.index <= region_info["end"]) + ( + plotdata_cluster.index.get_level_values("coord") + >= (region_info["start"] - 100) + ) + & ( + plotdata_cluster.index.get_level_values("coord") + <= (region_info["end"] + 100) + ) ] artists_cluster_region = _draw_differential( - ax, plotdata_cluster_break, plotdata_mean_break, self.cmap_atac_diff, self.norm_atac_diff + ax, + plotdata_cluster_break, + self.cmap_atac_diff, + self.norm_atac_diff, ) artists.extend(artists_cluster_region) - ax.axvline(region_info["start"], color="#AAA", zorder=-1, lw=1) - ax.axvline(region_info["end"], color="#AAA", zorder=-1, lw=1) + ax.axvspan( + region_info["start"], + region_info["end"], + color="#22222209", + zorder=0, + lw=0, + ) + # ax.axvline(region_info["start"], color="#AAA", zorder=-1, lw=1) + # ax.axvline(region_info["end"], color="#AAA", zorder=-1, lw=1) return artists @classmethod - def from_regionpositional(cls, region_id, regionpositional, breaking, cluster_info, relative_to=None, **kwargs): - plotdata, plotdata_mean = regionpositional.get_plotdata(region_id, relative_to=relative_to) + def from_regionpositional( + cls, + region_id, + regionpositional, + breaking, + cluster_info, + relative_to=None, + **kwargs, + ): + plotdata, plotdata_mean = regionpositional.get_plotdata( + region_id, relative_to=relative_to + ) self = cls( plotdata=plotdata, plotdata_mean=plotdata_mean, cluster_info=cluster_info, breaking=breaking, + relative_to=relative_to, **kwargs, ) self.region_id = region_id @@ -347,21 +453,33 @@ def add_differential_slices(self, differential_slices): slicescores = differential_slices.get_slice_scores() slicescores = slicescores.loc[slicescores["region_ix"] == self.region_ix] - for (cluster, cluster_info_oi), broken in zip(self.cluster_info.loc[self.order].iterrows(), self): - for (region, region_info), (panel, ax) in zip(self.breaking.regions.iterrows(), broken): + for (cluster, cluster_info_oi), broken in zip( + self.cluster_info.loc[self.order].iterrows(), self + ): + for (region, region_info), (panel, ax) in zip( + self.breaking.regions.iterrows(), broken + ): # find slicescores that (partially) overlap slicescores_oi = slicescores.loc[ - ~((slicescores["start"] >= region_info["end"]) & (slicescores["end"] <= region_info["start"])) + ~( + (slicescores["start"] >= region_info["end"]) + & (slicescores["end"] <= region_info["start"]) + ) ] # slicescores_oi = slicescores.loc[(slicescores["start"] >= region_info["start"]) & (slicescores["end"] <= region_info["end"])] for start, end in zip(slicescores_oi["start"], slicescores_oi["end"]): ax.axvspan(start, end, color="#33333333", zorder=0, lw=0) -def _draw_differential(ax, plotdata_cluster, plotdata_mean, cmap_atac_diff, norm_atac_diff): +def _draw_differential(ax, plotdata_cluster, cmap_atac_diff, norm_atac_diff): + if any(np.isnan(plotdata_cluster["prob"])): + raise ValueError("plotdata_cluster contains NaN values") + if any(np.isnan(plotdata_cluster["prob_diff"])): + raise ValueError("plotdata_cluster contains NaN values") + (background,) = ax.plot( - plotdata_mean.index, - np.exp(plotdata_mean["prob"]), + plotdata_cluster.index, + np.exp(plotdata_cluster["prob_reference"]), color="grey", lw=0.5, zorder=1, @@ -375,8 +493,8 @@ def _draw_differential(ax, plotdata_cluster, plotdata_mean, cmap_atac_diff, norm zorder=1, ) polygon = ax.fill_between( - plotdata_mean.index, - np.exp(plotdata_mean["prob"]), + plotdata_cluster.index, + np.exp(plotdata_cluster["prob_reference"]), np.exp(plotdata_cluster["prob"]), color="black", zorder=0, @@ -406,12 +524,14 @@ def _draw_differential(ax, plotdata_cluster, plotdata_mean, cmap_atac_diff, norm return gradient, polygon, background, differential -def _scale_differential(ax, ymax, lintresh = 25): +def _scale_differential(ax, ymax, lintresh=25): ax.set_yscale("symlog", linthresh=lintresh) ax.set_ylim(0, ymax) -def _setup_differential(ax, ymax, cluster_info_oi, label=False, label_cluster=True, show_tss=True): +def _setup_differential( + ax, ymax, cluster_info_oi, label=False, label_cluster=True, show_tss=True +): minor_ticks = np.array([2.5, 5, 7.5, 25, 50, 75, 250, 500]) minor_ticks = minor_ticks[minor_ticks <= ymax] @@ -425,19 +545,33 @@ def _setup_differential(ax, ymax, cluster_info_oi, label=False, label_cluster=Tr if show_tss: ax.axvline(0, dashes=(1, 1), color="#AAA", zorder=-1, lw=1) - if label_cluster: - text = ax.annotate( - text=f"{cluster_info_oi['label']}", - xy=(0, 1), - xytext=(2, -2), - textcoords="offset points", - xycoords="axes fraction", - ha="left", - va="top", - fontsize=10, - color="#333", - zorder=30, - ) + if label_cluster is not None: + if label_cluster == "front": + text = ax.annotate( + text=f"{cluster_info_oi['label']}", + xy=(0, 0.5), + xytext=(-5, 0), + textcoords="offset points", + xycoords="axes fraction", + ha="right", + va="center", + fontsize=10, + color="#333", + zorder=30, + ) + else: + text = ax.annotate( + text=f"{cluster_info_oi['label']}", + xy=(0, 1), + xytext=(2, -2), + textcoords="offset points", + xycoords="axes fraction", + ha="left", + va="top", + fontsize=10, + color="#333", + zorder=30, + ) text.set_path_effects( [ mpl.patheffects.Stroke(linewidth=2, foreground="white"), @@ -448,21 +582,35 @@ def _setup_differential(ax, ymax, cluster_info_oi, label=False, label_cluster=Tr ax.set_xticks([]) if label: - ax.set_ylabel("Accessibility\nper 100 cells\nper 100bp", rotation=0, ha="right", va="center") + ax.set_ylabel( + "Accessibility\nper 100 cells\nper 100bp", rotation=0, ha="right", va="center" + ) else: ax.set_yticklabels([]) ax.set_yticklabels([], minor=True) -def _process_plotdata(plotdata, plotdata_mean, cluster_info, order, relative_to, window=None): +def _process_plotdata( + plotdata, plotdata_mean, cluster_info, order, relative_to, window=None +): # check plotdata - plotdata = plotdata.reset_index().assign(coord=lambda x: x.coord.astype(int)).set_index(["cluster", "coord"]) - plotdata_mean = plotdata_mean.reset_index().assign(coord=lambda x: x.coord.astype(int)).set_index(["coord"]) + plotdata = ( + plotdata.reset_index() + .assign(coord=lambda x: x.coord.astype(int)) + .set_index(["cluster", "coord"]) + ) + plotdata_mean = ( + plotdata_mean.reset_index() + .assign(coord=lambda x: x.coord.astype(int)) + .set_index(["coord"]) + ) # determine relative to - # if "prob_diff" not in plotdata.loc[self.order].columns: if relative_to is None: plotdata["prob_diff"] = plotdata["prob"] - plotdata_mean["prob"] + plotdata["prob_reference"] = ( + plotdata_mean["prob"].loc[plotdata.index.get_level_values("coord")].values + ) elif isinstance(relative_to, (list, tuple, np.ndarray, pd.Series, pd.Index)): plotdata["prob_diff"] = ( plotdata["prob"] @@ -471,10 +619,39 @@ def _process_plotdata(plotdata, plotdata_mean, cluster_info, order, relative_to, .mean()["prob"][plotdata.index.get_level_values("coord")] .values ) - plotdata_mean = plotdata.loc[relative_to].groupby(level="coord").mean() + plotdata["prob_reference"] = ( + plotdata.loc[relative_to].groupby(level="coord").mean()["prob"] + ) + elif isinstance(relative_to, str) and relative_to == "previous": + reference = pd.Series( + { + cluster: cluster_info.index[i - 1] if i > 0 else cluster_info.index[0] + for i, cluster in enumerate(cluster_info.index) + } + ) + plotdata["prob_reference"] = plotdata.loc[ + pd.MultiIndex.from_frame( + pd.DataFrame( + { + "cluster": reference[plotdata.index.get_level_values("cluster")], + "coord": plotdata.index.get_level_values("coord"), + } + ) + ) + ]["prob"].values + plotdata["prob_diff"] = plotdata["prob"] - plotdata["prob_reference"] else: plotdata["prob_diff"] = plotdata["prob"] - plotdata.loc[relative_to]["prob"] - plotdata_mean = plotdata.loc[relative_to] + plotdata["prob_reference"] = ( + plotdata.loc[relative_to]["prob"] + .loc[plotdata.index.get_level_values("coord")] + .values + ) + + if np.isnan(plotdata["prob_diff"]).any(): + raise ValueError("plotdata contains NaN values") + if np.isnan(plotdata["prob_reference"]).any(): + raise ValueError("plotdata contains NaN values") # subset on requested clusters plotdata = plotdata.query("cluster in @cluster_info.index") @@ -499,6 +676,23 @@ def _process_plotdata(plotdata, plotdata_mean, cluster_info, order, relative_to, (plotdata.index.get_level_values("coord") >= window[0]) & (plotdata.index.get_level_values("coord") <= window[1]) ] - plotdata_mean = plotdata_mean.loc[(plotdata_mean.index >= window[0]) & (plotdata_mean.index <= window[1])] - return plotdata, plotdata_mean, order, window + return plotdata, order, window + + +def create_colorbar_horizontal(): + import matplotlib.pyplot as plt + + fig_colorbar = plt.figure(figsize=(3.0, 0.1)) + ax_colorbar = fig_colorbar.add_axes([0.05, 0.05, 0.5, 0.9]) + mappable = mpl.cm.ScalarMappable( + norm=get_norm_atac_diff(), + cmap=get_cmap_atac_diff(), + ) + colorbar = plt.colorbar( + mappable, cax=ax_colorbar, orientation="horizontal", extend="both" + ) + colorbar.set_label("Differential accessibility") + colorbar.set_ticks(np.log([0.25, 0.5, 1, 2, 4])) + colorbar.set_ticklabels(["¼", "½", "1", "2", "4"]) + return fig_colorbar diff --git a/src/chromatinhd/models/diff/plot/differential_expression.py b/src/chromatinhd/models/diff/plot/differential_expression.py index 4ab81da..35f5da4 100755 --- a/src/chromatinhd/models/diff/plot/differential_expression.py +++ b/src/chromatinhd/models/diff/plot/differential_expression.py @@ -2,6 +2,9 @@ import seaborn as sns import polyptich import numpy as np +import chromatinhd as chd +import pandas as pd + def get_cmap_rna(): return mpl.cm.BuGn @@ -10,6 +13,7 @@ def get_cmap_rna(): def get_cmap_rna_diff(): return mpl.cm.RdBu_r + class DifferentialExpression(polyptich.grid.Wrap): def __init__( self, @@ -22,8 +26,9 @@ def __init__( show_cluster=True, show_n_cells=True, order=False, - relative_to = None, + relative_to=None, annotate_expression=False, + gene_info = None, **kwargs, ): super().__init__(ncol=1, **{"padding_height": 0, **kwargs}) @@ -31,15 +36,22 @@ def __init__( if relative_to is None: cmap_expression = get_cmap_rna() if norm_expression is None: - norm_expression = mpl.colors.Normalize( - min(0.0, plotdata_expression_clusters.min()), plotdata_expression_clusters.max(), clip=True - ) + norm_expression = {gene:mpl.colors.Normalize( + min(0.0, plotdata_expression_clusters[gene].min()), + plotdata_expression_clusters[gene].max(), + clip=True, + ) for gene in plotdata_expression_clusters.columns} else: cmap_expression = get_cmap_rna_diff() - plotdata_expression_clusters = plotdata_expression_clusters - plotdata_expression_clusters.loc[relative_to] - norm_expression = mpl.colors.Normalize(np.log(0.25), np.log(4), clip=True) + plotdata_expression_clusters = ( + plotdata_expression_clusters + - plotdata_expression_clusters.loc[relative_to] + ) + norm_expression = {gene:mpl.colors.Normalize(np.log(0.25), np.log(4), clip=True) for gene in plotdata_expression_clusters.columns} - plotdata_expression_clusters = plotdata_expression_clusters.loc[cluster_info.index] + plotdata_expression_clusters = plotdata_expression_clusters.loc[ + cluster_info.index + ] if order is True: self.order = plotdata_expression_clusters.sort_values(ascending=False).index @@ -48,24 +60,34 @@ def __init__( else: self.order = plotdata_expression_clusters.index + if gene_info is None: + gene_info = pd.DataFrame( + {"gene": plotdata_expression_clusters.columns} + ).set_index("gene").assign(label = plotdata_expression_clusters.columns) + + gene_info["ix"] = np.arange(len(gene_info)) + for cluster_id in self.order: - panel, ax = polyptich.grid.Panel((width, panel_height)) + panel, ax = polyptich.grid.Panel((width * len(gene_info), panel_height)) self.add(panel) sns.despine(ax=ax, left=True, right=True, top=True, bottom=True) ax.set_yticks([]) ax.set_xticks([]) - circle = mpl.patches.Circle( - (0, 0), - norm_expression(plotdata_expression_clusters[cluster_id]) * 0.9 + 0.1, - fc=cmap_expression(norm_expression(plotdata_expression_clusters[cluster_id])), - lw=1, - ec="#333333", - ) - ax.add_patch(circle) - ax.set_xlim(-1.05, 1.05) - ax.set_ylim(-1.05, 1.05) + for gene, ix in zip(gene_info.index, gene_info["ix"]): + circle = mpl.patches.Circle( + (0, 0), + norm_expression[gene](plotdata_expression_clusters.loc[cluster_id, gene]) * 0.45 + 0.05, + fc=cmap_expression( + norm_expression[gene](plotdata_expression_clusters.loc[cluster_id, gene]) + ), + lw=1, + ec="#333333", + ) + ax.add_patch(circle) + ax.set_xlim(-0.5, 0.5+len(gene_info)) + ax.set_ylim(-0.5, 0.5) ax.set_aspect(1) if show_cluster: @@ -74,7 +96,7 @@ def __init__( if show_n_cells and "n_cells" in cluster_info.columns: label += f" ({cluster_info.loc[cluster_id, 'n_cells']})" ax.text( - 1.05, + 0.5, 0, label, ha="left", @@ -111,12 +133,51 @@ def __init__( @classmethod def from_transcriptome( - cls, transcriptome, clustering, gene, width=0.5, panel_height=0.5, cluster_info=None, layer="counts", **kwargs + cls, + transcriptome: chd.data.Transcriptome, + clustering: chd.data.Clustering, + gene: str, + width=0.5, + panel_height=0.5, + cluster_info=None, + layer="normalized", + **kwargs, ): + """ + Create a DifferentialExpression plot from a transcriptome. + + Parameters + ---------- + transcriptome + The transcriptome to plot. + clustering + The clustering object. + gene + The gene to plot. Can also be a list of genes. + + """ import pandas as pd - plotdata_expression_clusters = np.log( - (np.exp(pd.Series(transcriptome.get_X(gene, layer=layer), index=transcriptome.obs.index))-1) + if isinstance(gene, str): + gene = [gene] + gene_info = pd.DataFrame({"gene": gene}).set_index("gene").assign(label = gene) + elif isinstance(gene, pd.Series): + gene_info = gene.to_frame().assign(label = gene) + else: + gene_info = pd.DataFrame({"gene": gene}).set_index("gene").assign(label = gene) + + plotdata_expression_clusters = np.log1p( + ( + np.exp( + pd.DataFrame( + transcriptome.get_X(gene, layer=layer), + index=transcriptome.obs.index, + columns=gene, + ) + ) + - 1 + ) + # (np.exp(pd.Series(transcriptome.get_X(gene, layer=layer), index=transcriptome.obs.index))-1) .groupby(clustering.labels.values, observed=True) .mean() ) @@ -129,5 +190,6 @@ def from_transcriptome( cluster_info=cluster_info, width=width, panel_height=panel_height, + gene_info = gene_info, **kwargs, - ) + ) diff --git a/src/chromatinhd/plot/genome/genes.py b/src/chromatinhd/plot/genome/genes.py index 9ef8ece..19cb7cf 100755 --- a/src/chromatinhd/plot/genome/genes.py +++ b/src/chromatinhd/plot/genome/genes.py @@ -390,9 +390,9 @@ def __init__( ax = self[0, 0] ax.set_yticks(np.arange(len(plotdata_genes))) ax.set_yticklabels(plotdata_genes["symbol"], fontsize=6, style="italic") - for tick in ax.yaxis.get_major_ticks(): - if tick.label1.get_text() == plotdata_genes.loc[gene_id, "symbol"]: - tick.label1.set_weight("bold") + # for tick in ax.yaxis.get_major_ticks(): + # if tick.label1.get_text() == plotdata_genes.loc[gene_id, "symbol"]: + # tick.label1.set_weight("bold") ax.tick_params(axis="y", length=0, pad=2, width=0.5) @classmethod @@ -491,12 +491,12 @@ def __init__( xticks = np.array(xticks) ax.set_xticks(xticks) ax.xaxis.set_major_formatter(chromatinhd.plot.gene_ticker) + # ax.get_xticklabels()[0].set_horizontalalignment("left") + # ax.get_xticklabels()[-1].set_horizontalalignment("right") sns.despine(ax=ax, right=True, left=True, bottom=True, top=True) # top bar - ax.axhline(0.5, color = "#333", lw = 1., zorder = -10) - ax.set_xlim(*window) ax.set_yticks([]) @@ -533,6 +533,8 @@ def __init__( labelrotation=90, ) + ax.axhline(genes_max, color = "#333", lw = 1., zorder = -10) + for gene, gene_info in plotdata_genes.reset_index().set_index("transcript").iterrows(): y = gene_info["ix"] is_oi = (gene == gene_id) or ("gene" in gene_info and gene_info["gene"] == gene_id) diff --git a/src/chromatinhd/utils/numpy.py b/src/chromatinhd/utils/numpy.py index 8bbeae3..d63e965 100755 --- a/src/chromatinhd/utils/numpy.py +++ b/src/chromatinhd/utils/numpy.py @@ -29,3 +29,16 @@ def indices_to_indptr_chunked(x, n, dtype=np.int32, batch_size=10e3): cur_value = x_[-1] indptr = np.cumsum(counts, dtype=dtype) return indptr + + + +def interpolate_1d(x: np.ndarray, xp: np.ndarray, fp: np.ndarray) -> np.ndarray: + a = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) + b = fp[:-1] - (a * xp[:-1]) + + indices = np.searchsorted(xp, x, side="left") - 1 + indices = np.clip(indices, 0, a.shape[0] - 1) + + slope = a[indices] + intercept = b[indices] + return x * slope + intercept diff --git a/src/chromatinhd/utils/torch.py b/src/chromatinhd/utils/torch.py index 9d1878a..edf5fc1 100755 --- a/src/chromatinhd/utils/torch.py +++ b/src/chromatinhd/utils/torch.py @@ -13,7 +13,6 @@ def interpolate_1d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch intercept = b.index_select(a.ndim - 1, indices) return x * slope + intercept - def interpolate_0d(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor: a = (fp[..., 1:] - fp[..., :-1]) / (xp[..., 1:] - xp[..., :-1]) b = fp[..., :-1] - (a.mul(xp[..., :-1]))