Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot of MSA coverage #24

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from chai_lab.ranking.frames import get_frames_and_mask
from chai_lab.ranking.rank import SampleRanking, get_scores, rank
from chai_lab.utils.paths import chai1_component
from chai_lab.utils.plot import plot_msa
from chai_lab.utils.tensor_utils import move_data_to_device, set_seed, und_self
from chai_lab.utils.typing import Float, typecheck

Expand Down Expand Up @@ -270,7 +271,7 @@ def run_inference(
constraint_context=constraint_context,
)

output_pdb_paths, _, _ = run_folding_on_context(
output_pdb_paths, _, _, _ = run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
Expand Down Expand Up @@ -308,10 +309,16 @@ def run_folding_on_context(
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
) -> tuple[list[Path], ConfidenceScores, list[SampleRanking]]:
) -> tuple[list[Path], ConfidenceScores, list[SampleRanking], Path]:
"""
Function for in-depth explorations.
User completely controls folding inputs.

Returns:
- list of Path corresponding to folding outputs
- ConfidenceScores object
- SampleRanking data
- Path to plot of MSA coverage
"""
# Set seed
if seed is not None:
Expand Down Expand Up @@ -609,6 +616,14 @@ def avg_1d(x):
## Write the outputs
##

# Write a MSA plot
output_dir.mkdir(parents=True, exist_ok=True)
msa_plot_path = plot_msa(
input_tokens=feature_context.structure_context.token_residue_type,
msa_tokens=feature_context.msa_context.tokens,
out_fname=output_dir / "msa_depth.pdf",
)

output_paths: list[Path] = []
ranking_data: list[SampleRanking] = []

Expand Down Expand Up @@ -676,4 +691,4 @@ def avg_1d(x):
**scores,
)

return output_paths, confidence_scores, ranking_data
return output_paths, confidence_scores, ranking_data, msa_plot_path
68 changes: 68 additions & 0 deletions chai_lab/utils/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
""""""

import logging
from pathlib import Path

import torch
from einops import reduce
from matplotlib import pyplot as plt
from torch import Tensor

from chai_lab.data import residue_constants as rc
from chai_lab.utils.typing import Int, UInt8, typecheck


@typecheck
def plot_msa(
input_tokens: Int[Tensor, "n_tokens"],
msa_tokens: UInt8[Tensor, "msa_depth n_tokens"],
out_fname: Path,
gap: str = "-",
mask: str = ":",
sort_by_identity: bool = True,
) -> Path:
gap_idx = rc.residue_types_with_nucleotides.index(gap)
mask_idx = rc.residue_types_with_nucleotides.index(mask)

# Trim padding tokens (= pad in all alignments)
token_is_pad = torch.all(msa_tokens == mask_idx, dim=0)
msa_tokens = msa_tokens[:, ~token_is_pad]
input_tokens = input_tokens[~token_is_pad]

# Calculate sequence identity for each MSA sequence
msa_seq_ident = (msa_tokens == input_tokens).float().mean(dim=-1)
sort_idx = (
torch.argsort(msa_seq_ident, descending=True)
if sort_by_identity
else torch.arange(msa_tokens.shape[0])
)

# Valid tokens are not padding and not a gap; we plot the valid tokens
msa_tokens_is_valid = (msa_tokens != gap_idx) & (msa_tokens != mask_idx)
msa_coverage = reduce(msa_tokens_is_valid.float(), "m t -> t", "mean")

# Scale each of the MSA entries by its sequence identity for plotting
msa_by_identity = msa_tokens_is_valid.float() * msa_seq_ident.unsqueeze(-1)
msa_by_identity[~msa_tokens_is_valid] = torch.nan

# Plotting
fig, ax = plt.subplots(dpi=150)
patch = ax.imshow(
msa_by_identity[sort_idx],
cmap="rainbow_r",
vmin=0,
vmax=1,
interpolation="nearest",
)
ax.set_aspect("auto")
ax.set(ylabel="Sequences", xlabel="Positions")

ax2 = ax.twinx()
ax2.plot(msa_coverage, color="black")
ax2.set(ylim=[0, 1], yticks=[])

fig.colorbar(patch)
fig.savefig(out_fname, bbox_inches="tight")
logging.info(f"Saved MSA plot to {out_fname}")
plt.close(fig)
return out_fname
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ typer~=0.12 # CLI generator
# notebooks, plotting
ipykernel~=6.27 # needed by vs code to run notebooks in devcontainer
# seaborn
# matplotlib
matplotlib

# misc
tqdm~=4.66
Expand Down
Loading