Skip to content

Commit

Permalink
update binary model
Browse files Browse the repository at this point in the history
  • Loading branch information
zouter committed Nov 5, 2024
1 parent 1a9da42 commit b353564
Show file tree
Hide file tree
Showing 14 changed files with 966 additions and 193 deletions.
2 changes: 1 addition & 1 deletion src/chromatinhd/biomart/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get(self, attributes=[], filters=[], use_cache=True, timeout = 20) -> pd.Dat
try:
response = requests.get(url, timeout=timeout)
except requests.exceptions.Timeout:
raise ValueError("Ensembl web service timed out")
raise ValueError("Ensembl web service timed out: ", url)
# check response status
if response.status_code != 200:
raise ValueError(f"Response status code is {response.status_code} and not 200. Response text: {response.text}")
Expand Down
2 changes: 1 addition & 1 deletion src/chromatinhd/data/folds/folds.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class Folds(Flow):
"""
Folds of multiple cell and reion combinations
Folds of multiple cell and region combinations
"""

folds: dict = Stored()
Expand Down
2 changes: 1 addition & 1 deletion src/chromatinhd/data/motifscan/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_hocomoco(path, organism="hs", variant="CORE", overwrite=False):
path.mkdir(parents=True, exist_ok=True)

# download cutoffs, pwms and annotations
if overwrite or (not (path / "pwm_cutoffs.txt").exists()):
if overwrite or (not (path / "pwms.tar.gz").exists()):
urllib.request.urlretrieve(
f"https://hocomoco12.autosome.org/final_bundle/hocomoco12/H12{variant}/H12{variant}_annotation.jsonl",
path / "annotation.jsonl",
Expand Down
246 changes: 177 additions & 69 deletions src/chromatinhd/data/motifscan/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def __init__(

super().__init__()

motifs_oi, group_info, motifdata = _process_grouped_motifs(
motifs_oi, group_info = _process_grouped_motifs_oi(
motifs_oi, motifscan, group_info=group_info
)

motifdata = _process_grouped_motifs(
gene, motifs_oi, motifscan, group_info=group_info, window=window
)

Expand Down Expand Up @@ -206,26 +210,9 @@ def blend_with_white(color, alpha):

return blended_color


def _process_grouped_motifs(
gene, motifs_oi, motifscan, window=None, group_info=None, slices_oi=None
def _process_grouped_motifs_oi(
motifs_oi, motifscan, group_info=None
):
region_ix = motifscan.regions.coordinates.index.get_loc(gene)

# get motif data
if window is not None:
positions, indices = motifscan.get_slice(
region_ix=region_ix,
start=window[0],
end=window[1],
return_scores=False,
return_strands=False,
)
else:
positions, indices = motifscan.get_slice(
region_ix=region_ix, return_scores=False, return_strands=False
)

# check motifs oi
if not isinstance(motifs_oi, pd.DataFrame):
raise ValueError("motifs_oi should be a dataframe")
Expand All @@ -235,10 +222,6 @@ def _process_grouped_motifs(
)
elif "group" not in motifs_oi.columns:
motifs_oi["group"] = motifs_oi.index
# raise ValueError("motifs_oi should have a 'group' column")
# ensure all rows are unique
elif not motifs_oi.index.is_unique:
raise ValueError("motifs_oi should have unique rows", motifs_oi)

# get group info
if group_info is None:
Expand All @@ -256,17 +239,51 @@ def _process_grouped_motifs(
for x in mpl.cm.tab20(np.arange(group_info.shape[0]) % 20)
]
motifs_oi["color"] = motifs_oi["group"].map(group_info["color"].to_dict())
return motifs_oi, group_info

def _process_grouped_motifs(
gene, motifs_oi, motifscan, window=None, group_info=None, slices_oi=None, return_strands = False, prune = 20,
):
region_ix = motifscan.regions.coordinates.index.get_loc(gene)

if prune is not False:
return_strands = True

# get motif data
if window is not None:
positions, indices, scores, strands = motifscan.get_slice(
region_ix=region_ix,
start=window[0],
end=window[1],
return_scores=True,
return_strands=True,
)
else:
positions, indices, scores, strands = motifscan.get_slice(
region_ix=region_ix, return_scores=True, return_strands=True,
)

motifscan.motifs["ix"] = np.arange(motifscan.motifs.shape[0])
motifdata = []
for motif in motifs_oi.index:
group = motifs_oi.loc[motif, "group"]
for motif, group in zip(motifs_oi.index, motifs_oi["group"]):
motif_ix = motifscan.motifs.index.get_loc(motif)
positions_oi = positions[indices == motif_ix]
motifdata.extend(
[{"position": pos, "motif": motif, "group": group} for pos in positions_oi]
)
motifdata = pd.DataFrame(motifdata, columns=["position", "motif", "group"])
if return_strands:
strands_oi = strands[indices == motif_ix]
scores_oi = scores[indices == motif_ix]
motifdata.extend(
[
{"position": pos, "motif": motif, "group": group, "strand": strand, "score":score}
for pos, strand, score in zip(positions_oi, strands_oi, scores_oi)
]
)
else:
motifdata.extend(
[{"position": pos, "motif": motif, "group": group,} for pos in positions_oi]
)

motifdata = pd.DataFrame(motifdata, columns=["position", "motif", "group", "strand", "score"]) if return_strands else pd.DataFrame(motifdata, columns=["position", "motif", "group"])
motifdata = motifdata.sort_values("position", ascending = True)

# check slices oi
if slices_oi is not None:
Expand All @@ -290,7 +307,22 @@ def _process_grouped_motifs(
)
).any(axis=0)

return motifs_oi, group_info, motifdata
# prune
if (prune is not False) and len(motifdata):
if "score" not in motifdata.columns:
raise ValueError("motifdata should have a 'score' column to prune, set return_strands = True")

# go over each group and delete motifs that are too close
motifdata_pruned = []
for _, group in motifdata.groupby("group"):
group["distance"] = group["position"].diff().fillna(0)

group["section"] = (group["distance"] > prune).cumsum()
group = group.sort_values("score", ascending = False).groupby("section", as_index = False).first()
motifdata_pruned.append(group)
motifdata = pd.concat(motifdata_pruned)

return motifdata


def intersect_positions_slices(positions, slices_start, slices_end):
Expand All @@ -317,17 +349,45 @@ def __init__(
group_info=None,
panel_height=0.1,
slices_oi=None,
show_triangle:bool=True,
show_bar:bool=True,
):
"""
Plot the location of motifs in a region.
Parameters
----------
motifscan : MotifScan
MotifScan object
gene : str
Gene name
motifs_oi : pd.DataFrame
Dataframe with motifs to plot. Should have a 'group' column.
breaking : Breaking
Breaking object
group_info : pd.DataFrame
Dataframe with group information. Should have a 'label' column.
panel_height : float
Height in inches of each motif group line
slices_oi : pd.DataFrame
Dataframe with slices information. Should have 'start', 'end', 'cluster', and 'region' columns.
show_triangle : bool
Show triangle markers
show_bar : bool
Show bar markers
"""

super().__init__()

motifs_oi, group_info, motifdata = _process_grouped_motifs(
gene, motifs_oi, motifscan, group_info=group_info, slices_oi=slices_oi
motifs_oi, group_info = _process_grouped_motifs_oi(
motifs_oi, motifscan, group_info=group_info
)

motifdatas = [_process_grouped_motifs(
gene, motifs_oi, motifscan, group_info=group_info, window=[region["start"], region["end"]]
) for _, region in breaking.regions.iterrows()]

for group, group_info_oi in group_info.iterrows():
broken = self.add_under(
Broken(
Expand Down Expand Up @@ -356,63 +416,112 @@ def __init__(
# create a very narrow triangle
marker = mpl.path.Path(
[
[-0.4, 0.0],
[0.4, 0.],
[0.0, -1],
[-0.4, 0.0],
[-0.5, 1.0],
[0.5, 1.],
[0.0, 0],
[-0.5, 1.],
]
)
marker = mpl.path.Path(
[
[-0.4, -1],
[0.4, -1],
[0.0, 0.05],
[-0.4, -1],
[-0.5, 0],
[0.5, 0],
[0.0, 1.],
[-0.5, 0],
]
)

# plot the motifs
for (region, region_info), (panel, ax) in zip(
breaking.regions.iterrows(), broken
for (region, region_info), (panel, ax), motifdata in zip(
breaking.regions.iterrows(), broken, motifdatas
):
motifdata_region = motifdata.loc[
(motifdata["position"] >= region_info["start"])
& (motifdata["position"] <= region_info["end"])
& (motifdata["motif"].isin(group_motifs.index))
(motifdata["motif"].isin(group_motifs.index))
]

# remove duplicates from the same group
motifdata_region = motifdata_region.drop_duplicates(
keep="first", subset=["position"]
)

for motif in group_motifs.itertuples():
# add motifs
motif_id = motif.Index
plotdata = motifdata_region

plotdata = motifdata_region.loc[motifdata["motif"] == motif_id].copy()

if len(plotdata) > 0:
if "oi" in plotdata.columns:
plotdata_oi = plotdata.loc[plotdata["oi"]]
plotdata_not_oi = plotdata.loc[~plotdata["oi"]]
if len(plotdata) > 0:
if "oi" not in plotdata.columns:
plotdata["oi"] = True
if "oi" in plotdata.columns:
plotdata_oi = plotdata.loc[plotdata["oi"]]
plotdata_not_oi = plotdata.loc[~plotdata["oi"]]

# join very close
plotdata_oi = plotdata_oi.sort_values("position")
plotdata_oi["distance_to_next"] = plotdata_oi["position"].diff().fillna(0.)
plotdata_oi["group"] = np.cumsum(plotdata_oi["distance_to_next"] > 50)
plotdata_oi["n"] = 1
plotdata_oi = plotdata_oi.groupby("group").agg(
{"position": "mean", "n": "sum", "oi": "first"}
)

# oi
if show_triangle:
ax.scatter(
plotdata_oi["position"],
[1.] * len(plotdata_oi),
[0.] * len(plotdata_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
# marker="v",
marker = marker,
color=color,
s=250,
s=100,
zorder=20,
lw=0.5,
edgecolor="white",
lw = 0.
)
# background white
ax.scatter(
plotdata_oi["position"],
[1.] * len(plotdata_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
# marker="v",
marker = "|",
color="white",
s=400,
zorder=19,
lw = 1.
)
ax.scatter(
plotdata_not_oi["position"],
[1.] * len(plotdata_not_oi),
[0.] * len(plotdata_not_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
# marker="v",
marker = marker,
color=blend_with_white(color, 0.3),
s=100,
zorder=18,
lw = 0.
)
if show_bar:
ax.scatter(
plotdata_oi["position"],
[1.] * len(plotdata_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
# marker="v",
marker = "|",
color=color,
s=200,
zorder=22,
lw = 1.5
)

ax.scatter(
plotdata_oi["position"],
[1.] * len(plotdata_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
Expand All @@ -424,23 +533,22 @@ def __init__(
lw=0.0,
edgecolor="white",
)
else:
ax.scatter(
plotdata["position"],
[1.] * len(plotdata),
plotdata_not_oi["position"],
[1.] * len(plotdata_not_oi),
transform=mpl.transforms.blended_transform_factory(
ax.transData, ax.transAxes
),
# marker="v",
marker = marker,
color=motif.color,
s=250,
zorder=20,
lw=0.5,
edgecolor="white",
marker = "|",
color=blend_with_white(color, 0.3),
s=200,
zorder=18,
lw = 1.
)



def _setup_group(ax, group_info_oi, group_motifs):
ax.set_xticks([])
ax.set_yticks([])
Expand Down
Loading

0 comments on commit b353564

Please sign in to comment.