Skip to content

[DNM] prelim fixes on ns2b3 misalignment #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions choppa/IO/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def phylo_json_to_df(json_file, gene=None):

`gene` can be specified to only export a specific gene into the dataframe.
"""
fitness_df = pd.DataFrame(json.load(open(json_file))["data"])
fitness_df = pd.DataFrame(
json.load(open(json_file))["ZIKV NS2B-NS3 (Open)"]["mut_metric_df"]
)

if gene:
print(f"Available genes: {set(fitness_df['gene'].values)}")
Expand All @@ -37,9 +39,22 @@ def phylo_json_to_df(json_file, gene=None):
return fitness_df


def nextstrain_to_csv(nextstrain_tsv):
""" """
def ns2b3_reset_residcs(df):
"""ns2b3 has the same indices between the two chains. Super annoying, resetting that here."""
new_idcs_col = []
for idx in df["reference_site"].values:
if "(NS2B) " in idx:
new_idcs_col.append(idx.replace("(NS2B) ", ""))
elif "(NS3) " in idx:
new_idcs_col.append(130 + int(idx.replace("(NS3) ", "")))
df["reference_site"] = new_idcs_col

return df


if __name__ == "__main__":
phylo_json_to_df(TOY_PHYLO_DATA, "N").to_csv("sars2_fitness.csv", index=False)
fitness_df = phylo_json_to_df(sys.argv[1])

fitness_df = ns2b3_reset_residcs(fitness_df)

fitness_df.to_csv(sys.argv[2], index=False)
37 changes: 30 additions & 7 deletions choppa/IO/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def check_validity(self, fitness_df):
self.fitness_colname: "fitness",
}
)

for resi, res_data in fitness_df.groupby(by="residue_index"):
if len(res_data) > 25:
raise ValueError(
f"Found residue indices in input fitness CSV (at index {resi}) with more mutants than expected ({len(res_data)})! Does your fitness data have "
f"multiple chains in it with overlapping residue indices? Please resolve the input data so that there is no overlap in "
f"residue indices between chains."
)

self.fitness_df = fitness_df
return True

Expand Down Expand Up @@ -218,21 +227,35 @@ def extract_ligands(system):
"""[Placeholder] Returns a system's ligands"""
return system

def check_validity(self, complex):
def reset_complex_sequence(self, complex):
"""Adjusts protein sequence indexing to run from 1 to n, rather than whatever wonky indexing
the crystallographer may have come up with. BioPython can do this but there are some
protections built in against hard re-indexing multi-chain proteins. By first setting
the indexing super high (starting at 100,000) and then re-indexing starting from 1 we can
circumvent these protections. For more details see https://github.com/biopython/biopython/pull/4623
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can install from git+ source if you want this functionality in the aforementioned biopython PR. Only issue is we would then be blocked on release upstream until biopython merges and does a release.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps its just better to use this hack.

"""
[Placeholder] Does some quick checks to make sure the imported PDB structure is valid. We're
not doing any kind of protein prep, just whether biopython _is able to_ read the PDB
file and we try to figure out what entry names the solvent/ligands have (if there are any)
"""
return complex
# first set indexing to an unphysically high number
original_index = []
residue_N = 100000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this number max(residue_numbers + 1) to avoid some kind of cutoff artefact in the extremely unlikely event someone passes in a megaprotein, probably unnecessary but good practice.

for residue in complex.get_residues():
original_index.append(residue.id[1])
residue.id = (residue.id[0], residue_N, residue.id[2])
residue_N += 1

# now renumber residue in complex starting from 1
residue_N = 1
for residue in complex.get_residues():
residue.id = (residue.id[0], residue_N, residue.id[2])
residue_N += 1
return complex, original_index

def load_pdb(self):
"""
Loads an input PDB file
"""
complex = PDBParser(QUIET=False).get_structure("COMPLEX", self.path_to_pdb_file)

self.check_validity(complex)
self.reset_complex_sequence(complex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.reset_complex_sequence(complex)
complex, _ = self.reset_complex_sequence(complex)

return complex

def load_pdb_rdkit(self):
Expand Down
121 changes: 94 additions & 27 deletions choppa/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,56 @@ def get_fitness_alignment_shift_dict(self, alignment):
"""
Given an input complex sequence with residue indices (may not start at 0) and the fitness-complex
alignment, creates a dictionary with indices that should be used for the fitness data of the form
{fitness_idx : aligned_idx}
{fitness_idx : aligned_idx}.

This is complicated for multiple reasons, so we're iterating over each individual index. For example
in the following alignment:
CSV 50 DMYIERAGDITWEKDAEVTGNSPRLDVALDESGDFSLVEEDGPPMREIILKVVLMAICGM
0 |||||||||||||||||||||||||||||||||||||||---------------------
PDB 0 DMYIERAGDITWEKDAEVTGNSPRLDVALDESGDFSLVE---------------------

CSV 110 NPIAIPFAAGAWYVYVKTGKRSGALWDVPAPKEVKKGETTDGVYRVMTRRLLGSTQVGVG
60 ------------------------------------||||||||||||||||||||||||
PDB 39 ------------------------------------GETTDGVYRVMTRRLLGSTQVGVG

we need to 1) keep track of the starting indices of fitness ('CSV') and crystal structure ('PDB') (50 and 0, resp.)
and 2) we need to be able to skip the gap in alignment.
"""

alignment_shift_dict = {}
for fitness_res, fitness_resid, pdb_res, pdb_resid in zip(
alignment[0],
self.fitness_get_seqidcs(),
alignment[1],
self.complex_get_seqidcs(),
):
# do some checks before adding to the alignment dict. we do these checks at multiple layers to be 100% sure we're not mismatching the two sequences.
if fitness_res == "-" and fitness_res != pdb_res:
# the fitness data does not contain this residue in the PDB and alignment has created a gap -> good
alignment_shift_dict[fitness_resid] = pdb_resid
elif fitness_res == pdb_res:
# the fitness data does contain this residue in the PDB and alignment has matched it -> good
alignment_shift_dict[fitness_resid] = pdb_resid
to_remove = []

# first we grab the starting indices for fitness ('CSV') and complex ('PDB') by slicing
# the alignment object view. Hacky but robust, no method implemented in BioPython for this.
start_idx_fitness = int(alignment.format().splitlines()[0].split()[1])
start_idx_complex = int(alignment.format().splitlines()[2].split()[1])
# print(f"Start fitness: {start_idx_fitness}, start complex: {start_idx_complex}") # DEBUG

# now we will loop over the alignment. We need both the fitness and complex residues
# and the original indices.

# this might break if the PDB is longer than the fitness data?
for fitness_res, complex_res in zip(alignment[0], alignment[1]):
# print(start_idx_fitness, fitness_res, complex_res, start_idx_complex) # DEBUG
if fitness_res == complex_res:
# good match. Can add this fitness data to the dict. Can bump both.
alignment_shift_dict[start_idx_fitness] = start_idx_complex
start_idx_fitness += 1
start_idx_complex += 1
else:
raise ValueError(
f"Unable to match fitness residue {fitness_res} ({fitness_resid}) to PDB residue {pdb_res} ({pdb_resid})"
)

return alignment_shift_dict
# bad match.
to_remove.append(start_idx_complex)
if complex_res == "-":
# there is a gap in the fitness data.
# skip over this fitness datapoint only
start_idx_fitness += 1
else:
# the alignment matched a fitness residue to the wrong residue type, can happen in e.g. point mutations
# in this case we also need to skip over the protein residue
start_idx_complex += 1
start_idx_fitness += 1

return alignment_shift_dict, to_remove

def fitness_reset_keys(self, alignment):
"""
Expand All @@ -98,16 +125,17 @@ def fitness_reset_keys(self, alignment):
represented as 'empty' dict entries. This way the fitness HTML view will have 'empty' fitness data
for those residues.
"""
alignment_dict = self.get_fitness_alignment_shift_dict(alignment)
alignment_dict, to_remove = self.get_fitness_alignment_shift_dict(alignment)
reset_dict = {}

for _, fitness_data in self.fitness_input.items():
# we build a new dict where keys are the aligned index, then the aligned/unaligned indices (provenance),
# then wildtype data and then per-mutant fitness data

if not fitness_data["fitness_csv_index"] in alignment_dict.keys():
logger.warn(
f"Fitness data found to have a residue (index {fitness_data['fitness_csv_index']}) not in the PDB - skipping."
)
# logger.warn( # disabled for now, can spam a lot
# f"Fitness data found to have a residue (index {fitness_data['fitness_csv_index']}) not in the PDB - skipping."
# )
continue

reset_dict[alignment_dict[fitness_data["fitness_csv_index"]]] = {
Expand All @@ -117,6 +145,19 @@ def fitness_reset_keys(self, alignment):
**fitness_data,
}

"""
are we just setting the wrong indexing somewhere?
looks like we might be taking the fitness residue
instead of complex?
"""
print(to_remove)
print(reset_dict.keys())
for resi_to_remove in set(to_remove):
# remove complex indices that we have no fitness data for. We'll fill these with empty fitness data
# later on.
if resi_to_remove in reset_dict:
reset_dict.pop(resi_to_remove)

return reset_dict

def fill_aligned_fitness(self, aligned_fitness_dict):
Expand All @@ -125,14 +166,29 @@ def fill_aligned_fitness(self, aligned_fitness_dict):
for easier parsing during visualization.
"""
filled_aligned_fitness_dict = {}

print(aligned_fitness_dict.keys())
for complex_idx, complex_res in zip(
self.complex_get_seqidcs(), self.complex_get_seq()
):
if complex_res == "X":
continue # this is a ligand, we can skip because we don't show fitness for this anyway
continue # this is a ligand or water, we can skip because we don't show fitness for this anyway

if complex_idx in aligned_fitness_dict:
if complex_idx not in aligned_fitness_dict:
# no fitness data for this residue in the complex, need to make an empty one
filled_aligned_fitness_dict[complex_idx] = {"wildtype": {complex_res}}
elif complex_idx in aligned_fitness_dict:
print(
f"{complex_idx}:{self.complex_get_seq()[complex_idx]} should be {aligned_fitness_dict[complex_idx]['wildtype']['aa']}:{aligned_fitness_dict[complex_idx]['fitness_csv_index']}"
) # DEBUG
# check that the fitness wildtype equals the protein PDB residue type
if (
not self.complex_get_seq()[complex_idx]
== aligned_fitness_dict[complex_idx]["wildtype"]["aa"]
):
# hard stop - this is a critical alignment mismatch
raise ValueError(
f"Alignment mismatch between wildtype and PDB!\n\nFitness: {aligned_fitness_dict[complex_idx]}\n\nProtein: {complex_idx}{complex_res}"
)
# this fitness-complex data matches, can just copy the data across
filled_aligned_fitness_dict[complex_idx] = aligned_fitness_dict[
complex_idx
Expand All @@ -141,6 +197,16 @@ def fill_aligned_fitness(self, aligned_fitness_dict):
# no fitness data for this residue in the complex, need to make an empty one
filled_aligned_fitness_dict[complex_idx] = {"wildtype": {complex_res}}

for i, j in filled_aligned_fitness_dict.items():
print()
print(i, j)
break
# so alignment indices are correct now, but for some reason the wrong logoplots are showing up?
# do the wildtypes in the filled_aligned_fitness_dict correspond? why
# are the surface colors wrong??

# for some reason the indices in filled_aligned_fitness_dict are fucked, fitness_csv_index should start at 50

return filled_aligned_fitness_dict, len(filled_aligned_fitness_dict) - len(
aligned_fitness_dict
)
Expand All @@ -152,8 +218,9 @@ def get_alignment(self, fitness_seq, complex_seq):
"""

aligner = Align.PairwiseAligner()
aligner.mode = "local"
aligner.open_gap_score = (
-10
-20
) # set these to make gaps happen less. With fitness data we know there shouldn't
aligner.extend_gap_score = -0.5 # really be any gaps.
aligner.substitution_matrix = substitution_matrices.load(
Expand Down
8 changes: 7 additions & 1 deletion choppa/logoplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def render_logoplot(
global_max_confidence=False,
lhs=True,
wildtype=False,
index=False,
):
"""
Creates a logoplot as a base64 string. Also annotes with confidence values if present.
Expand Down Expand Up @@ -210,6 +211,8 @@ def render_logoplot(
plt.yticks([])
# plt.savefig("debug_logoplot.png", dpi=70, bbox_inches="tight") # uncomment for testing
# plt object directly to base64 string instead of tmpfile
if index:
plt.annotate(str(index), xy=(0.8, 0.05), xycoords="axes fraction", size=20)
lp_bytes = io.BytesIO()
plt.savefig(
lp_bytes,
Expand All @@ -223,7 +226,9 @@ def render_logoplot(

return lp_base64

def build_logoplot(self, global_min_confidence=False, global_max_confidence=False):
def build_logoplot(
self, global_min_confidence=False, global_max_confidence=False, index=False
):
# determine the wildtype, unfit and fit mutants for this input
wildtype, unfit_mutants, fit_mutants = self.divide_fitness_types()
# generate the logoplot base64 for wildtype (LHS, top), fit (LHS, bottom) and unfit (RHS; with colorbar)
Expand All @@ -232,6 +237,7 @@ def build_logoplot(self, global_min_confidence=False, global_max_confidence=Fals
global_min_confidence=global_min_confidence,
global_max_confidence=global_max_confidence,
wildtype=True,
index=index,
)
fit_base64 = self.render_logoplot(
fit_mutants,
Expand Down
2 changes: 1 addition & 1 deletion choppa/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def get_confidence_limits(self):
if "mutants" in res: # skips over PDB residues that don't have fitness data
mut_conf_values = [mut["confidence"] for mut in res["mutants"]]
wildtype_conf_value = res["wildtype"]["confidence"]

for conf_val in mut_conf_values + [wildtype_conf_value]:
confidence_values.append(conf_val)
if math.isnan(confidence_values[0]):
Expand Down Expand Up @@ -413,6 +412,7 @@ def _make_logoplot_residue(self, idx, residue_fitness_dict, confidence_lims):
).build_logoplot(
global_min_confidence=confidence_lims[0],
global_max_confidence=confidence_lims[1],
index=idx,
)

return (
Expand Down
Loading