Skip to content

Commit

Permalink
fix numpy warning
Browse files Browse the repository at this point in the history
  • Loading branch information
zouter committed May 31, 2024
1 parent d12354c commit b154652
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<img src="https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/main/docs/source/static/logo.png" width="300" />
</a>
<a href="https://chromatinhd.eu">
<img src="https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/main/docs/source/static/anim.gif" />
<img src="https://raw.githubusercontent.com/DeplanckeLab/ChromatinHD/main/docs/source/static/comparison.gif" />
</a>
</p>

Expand Down
Binary file removed docs/source/quickstart/anim.gif
Binary file not shown.
2 changes: 1 addition & 1 deletion src/chromatinhd/data/motifscan/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_hocomoco(path, organism="hs", variant="CORE", overwrite=False):
motifs = pd.DataFrame(motifs).set_index("name")
motifs.index.name = "motif"

for thresh in motifs["standard_thresholds"][0].keys():
for thresh in motifs["standard_thresholds"].iloc[0].keys():
motifs["cutoff_" + thresh] = [thresholds[thresh] for _, thresholds in motifs["standard_thresholds"].items()]
for species in ["HUMAN", "MOUSE"]:
motifs[species + "_gene_symbol"] = [
Expand Down
12 changes: 8 additions & 4 deletions src/chromatinhd/data/motifscan/motifscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def from_pwms(

# do the actual counting by looping over the batches, extract the sequences and scanning
progress = tqdm.tqdm(region_coordinates.groupby("batch"))
cur_region_index = 0
for batch, region_coordinates_batch in progress:
# extract onehot
if fasta is None:
Expand Down Expand Up @@ -245,9 +246,9 @@ def from_pwms(
strands,
) = scan(onehot, pwm2, cutoff=cutoff)

coordinates = positions.astype(np.int32) % onehot.shape[1]
coordinates = positions.astype(np.int32) % onehot.shape[-1]

region_indices = positions // onehot.shape[1]
region_indices = positions // onehot.shape[-1] + cur_region_index

coordinates = (
coordinates
Expand Down Expand Up @@ -282,6 +283,9 @@ def from_pwms(
self.coordinates.extend(coordinates)
self.region_indices.extend(region_indices)

# update current region index
cur_region_index += len(region_coordinates_batch)

return self

def create_region_indptr(self, overwrite=False):
Expand Down Expand Up @@ -569,14 +573,14 @@ def scan(onehot, pwm, cutoff=0.0):
found_positive = positive >= cutoff
scores_positive = positive[found_positive]
positions_positive = torch.stack(torch.where(found_positive)).to(torch.int64)
positions_positive = positions_positive[0] * (onehot.shape[1]) + positions_positive[1]
positions_positive = positions_positive[0] * (onehot.shape[-1]) + positions_positive[1]

negative = torch.nn.functional.conv1d(onehot_comp, pwm_rev.unsqueeze(0))[:, 0]

found_negative = negative >= cutoff
scores_negative = negative[found_negative]
positions_negative = torch.stack(torch.where(found_negative)).to(torch.int64)
positions_negative = (positions_negative[0]) * (onehot.shape[1]) + positions_negative[1]
positions_negative = (positions_negative[0]) * (onehot.shape[-1]) + positions_negative[1]

return (
torch.cat([scores_positive, scores_negative]).cpu().numpy(),
Expand Down
12 changes: 8 additions & 4 deletions src/chromatinhd/data/motifscan/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def __init__(
ax.set_yticks([])

ax.axis("off")
ax.axhspan(0, 1, color=group_motifs["color"].iloc[0], zorder=0, alpha=0.1, transform=ax.transAxes, lw=0)
color = group_motifs["color"].iloc[0] if group_motifs.shape[0] == 1 else "black"
ax.axhspan(0, 1, color=color, zorder=0, alpha=0.1, transform=ax.transAxes, lw=0)

if label_motifs:
if label_motifs_side == "right":
Expand All @@ -129,7 +130,7 @@ def __init__(
xycoords="axes fraction",
xytext=xytext,
textcoords="offset points",
color=group_motifs["color"][0],
color=group_motifs["color"].iloc[0] if group_motifs.shape[0] == 1 else "black",
va="center",
fontsize=9,
ha=ha,
Expand Down Expand Up @@ -228,9 +229,11 @@ def __init__(self, motifscan, gene, motifs_oi, breaking, group_info=None, panel_
panel, ax = broken[0, -1]
_setup_group(ax, group_info_oi, group_motifs)

color = group_motifs["color"].iloc[0] if group_motifs.shape[0] == 1 else "black"

for panel, ax in broken:
ax.axis("off")
ax.axhspan(0, 1, color=group_motifs["color"].iloc[0], zorder=0, alpha=0.1, transform=ax.transAxes, lw=0)
ax.axhspan(0, 1, color=color, zorder=0, alpha=0.1, transform=ax.transAxes, lw=0)

# plot the motifs
for motif in group_motifs.itertuples():
Expand Down Expand Up @@ -260,7 +263,8 @@ def _setup_group(ax, group_info_oi, group_motifs):
ax.set_yticks([])

if "label" in group_info_oi.keys():
ax.text(s=group_info_oi["label"], color=group_motifs["color"].tolist()[0], x=1.0, y=0.0, transform=ax.transAxes)
color = group_motifs["color"].tolist()[0] if group_motifs.shape[0] == 1 else "black"
ax.text(s=group_info_oi["label"], color=color, x=1.0, y=0.0, transform=ax.transAxes)
else:
rainbow_text(
ax=ax,
Expand Down
2 changes: 1 addition & 1 deletion src/chromatinhd/models/diff/interpret/regionpositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def get_plotdata(self, region: str, clusters=None, relative_to=None) -> (pd.Data
plotdata["prob"] = plotdata["prob"]

if relative_to is not None:
plotdata_mean = plotdata[["prob"]].query("cluster in @relative_to").groupby("coord").mean()
plotdata_mean = plotdata[["prob"]].query("cluster in @relative_to").groupby("coord", observed=False).mean()
else:
plotdata_mean = plotdata[["prob"]].groupby("coord", observed=True).mean()

Expand Down
5 changes: 3 additions & 2 deletions src/chromatinhd/models/diff/plot/differential.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,9 @@ def _process_plotdata(plotdata, plotdata_mean, cluster_info, order, relative_to,
plotdata.index.get_level_values("coord").min(),
plotdata.index.get_level_values("coord").max(),
)

if window is not None:
else:
if isinstance(window, pd.Series):
window = window.values.tolist()
plotdata = plotdata.loc[
(plotdata.index.get_level_values("coord") >= window[0])
& (plotdata.index.get_level_values("coord") <= window[1])
Expand Down

0 comments on commit b154652

Please sign in to comment.