Skip to content

Commit

Permalink
added standard alphafold database coloring to notebook displays
Browse files Browse the repository at this point in the history
  • Loading branch information
gchojnowski committed Mar 1, 2024
1 parent 0be420d commit a9cba49
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions alphapulldown/analysis_pipeline/af2_3dmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import os, sys, re
import glob
import iotbx
import cctbx
import iotbx.pdb
import re
import py3Dmol
from scitbx import matrix
from scitbx.math import superpose
import numpy as np

PLDDT_BANDS = [(50, '#FF7D45'), (70, '#FFDB13'), (90, '#65CBF3'), (100, '#0053D6')]

def parse_pdbstring(pdb_string):

Expand Down Expand Up @@ -92,6 +96,9 @@ def parse_results(output, color=None, models=5, multimer=False):
)

viewer = (0, 1)

set_b_to_plddtbands(ph_array[0])

view.addModel(ph_array[0].as_pdb_string(), "pdb", viewer=viewer)
view.zoomTo(viewer=viewer)
set_3dmol_styles(
Expand All @@ -108,6 +115,9 @@ def parse_results(output, color=None, models=5, multimer=False):

for idx, _ph in enumerate(ph_array):
viewer = (0, idx)

if color=="lDDT": set_b_to_plddtbands(_ph)

view.addModel(_ph.as_pdb_string(), "pdb", viewer=viewer)
view.zoomTo(viewer=viewer)

Expand Down Expand Up @@ -175,6 +185,9 @@ def parse_results_colour_chains(output, color=None, models=5, multimer=False):

for idx, _ph in enumerate(ph_array):
viewer = (0, idx)

if color=="lDDT": set_b_to_plddtbands(_ph)

view.addModel(_ph.as_pdb_string(), "pdb", viewer=viewer)
view.zoomTo(viewer=viewer)

Expand All @@ -186,6 +199,14 @@ def parse_results_colour_chains(output, color=None, models=5, multimer=False):
# ------------------------------------------------------


def set_b_to_plddtbands(ph):

plddt_lims = np.array([_[0] for _ in PLDDT_BANDS])
for resi in protomer_ph.residue_groups():
resi.atoms().set_b(new_b=cctbx.array_family.flex.double(resi.atoms().size(), float(np.argmax(plddt_lims>resi.atoms()[0].b))))

# ------------------------------------------------------

def set_3dmol_styles(
view,
viewer,
Expand All @@ -201,14 +222,15 @@ def set_3dmol_styles(
"""

if color == "lDDT":

color_map = {i: band[1] for i, band in enumerate(PLDDT_BANDS)}

view.setStyle(
{
"cartoon": {
"colorscheme": {
"prop": "b",
"gradient": "roygb",
"min": 0,
"max": 100,
'map': color_map
}
}
},
Expand Down

0 comments on commit a9cba49

Please sign in to comment.