diff --git a/.gitignore b/.gitignore index 3a646d9..f425c41 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,12 @@ venv.tar.gz .idea .vscode +# TensorBoard +tb_logs/ + +# Feature Processing +*work_filenames*.csv + # DIPS project/datasets/DIPS/complexes/** project/datasets/DIPS/interim/** @@ -119,6 +125,7 @@ project/datasets/DIPS/parsed/** project/datasets/DIPS/raw/** project/datasets/DIPS/final/raw/** project/datasets/DIPS/final/final_raw_dips.tar.gz* +project/datasets/DIPS/final/processed/** # DB5 project/datasets/DB5/processed/** @@ -126,6 +133,7 @@ project/datasets/DB5/raw/** project/datasets/DB5/interim/** project/datasets/DB5/final/raw/** project/datasets/DB5/final/final_raw_db5.tar.gz* +project/datasets/DB5/final/processed/** # EVCoupling project/datasets/EVCoupling/raw/** @@ -137,4 +145,7 @@ project/datasets/EVCoupling/final/processed/** project/datasets/CASP-CAPRI/raw/** project/datasets/CASP-CAPRI/interim/** project/datasets/CASP-CAPRI/final/raw/** -project/datasets/CASP-CAPRI/final/processed/** \ No newline at end of file +project/datasets/CASP-CAPRI/final/processed/** + +# Input +project/datasets/Input/** \ No newline at end of file diff --git a/README.md b/README.md index 50fdbda..72d8d3f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ The Enhanced Database of Interacting Protein Structures for Interface Prediction -[![Paper](http://img.shields.io/badge/paper-arxiv.2106.04362-B31B1B.svg)](https://arxiv.org/abs/2106.04362) [![CC BY 4.0][cc-by-shield]][cc-by] [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5134732.svg)](https://doi.org/10.5281/zenodo.5134732) +[![Paper](http://img.shields.io/badge/paper-arxiv.2106.04362-B31B1B.svg)](https://arxiv.org/abs/2106.04362) [![CC BY 4.0][cc-by-shield]][cc-by] [![Primary Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5134732.svg)](https://doi.org/10.5281/zenodo.5134732) [![Supplementary Data DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.8071136.svg)](https://doi.org/10.5281/zenodo.8071136) [cc-by]: http://creativecommons.org/licenses/by/4.0/ [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png @@ -25,8 +25,9 @@ The Enhanced Database of Interacting Protein Structures for Interface Prediction * DB5-Plus' final 'raw' tar archive now also includes a corrected (i.e. de-duplicated) list of filenames for its 55 test complexes * Benchmark results included in our paper were run after this issue was resolved * However, if you ran experiments using DB5-Plus' filename list for its test complexes, please re-run them using the latest list +* Version 1.2.0: Minor additions to DIPS-Plus tar archives, including new residue-level intrinsic disorder region annotations and raw Jackhmmer-small BFD MSAs (Supplementary Data DOI: 10.5281/zenodo.8071136) -## How to run creation tools +## How to set up First, download Mamba (if not already downloaded): ```bash @@ -51,66 +52,135 @@ conda activate DIPS-Plus # Note: One still needs to use `conda` to (de)activate pip3 install -e . ``` -## Default DIPS-Plus directory structure +To install PSAIA for feature generation, install GCC 10 for PSAIA: + +```bash +# Install GCC 10 for Ubuntu 20.04: +sudo apt install software-properties-common +sudo add-apt-repository ppa:ubuntu-toolchain-r/ppa +sudo apt update +sudo apt install gcc-10 g++-10 + +# Or install GCC 10 for Arch Linux/Manjaro: +yay -S gcc10 +``` + +Then install QT4 for PSAIA: + +```bash +# Install QT4 for Ubuntu 20.04: +sudo add-apt-repository ppa:rock-core/qt4 +sudo apt update +sudo apt install libqt4* libqtcore4 libqtgui4 libqtwebkit4 qt4* libxext-dev + +# Or install QT4 for Arch Linux/Manjaro: +yay -S qt4 +``` + +Conclude by compiling PSAIA from source: + +```bash +# Select the location to install the software: +MY_LOCAL=~/Programs + +# Download and extract PSAIA's source code: +mkdir "$MY_LOCAL" +cd "$MY_LOCAL" +wget http://complex.zesoi.fer.hr/data/PSAIA-1.0-source.tar.gz +tar -xvzf PSAIA-1.0-source.tar.gz + +# Compile PSAIA (i.e., a GUI for PSA): +cd PSAIA_1.0_source/make/linux/psaia/ +qmake-qt4 psaia.pro +make + +# Compile PSA (i.e., the protein structure analysis (PSA) program): +cd ../psa/ +qmake-qt4 psa.pro +make + +# Compile PIA (i.e., the protein interaction analysis (PIA) program): +cd ../pia/ +qmake-qt4 pia.pro +make + +# Test run any of the above-compiled programs: +cd "$MY_LOCAL"/PSAIA_1.0_source/bin/linux +# Test run PSA inside a GUI: +./psaia/psaia +# Test run PIA through a terminal: +./pia/pia +# Test run PSA through a terminal: +./psa/psa +``` + +Lastly, install Docker following the instructions from https://docs.docker.com/engine/install/ + +## How to generate protein feature inputs +In our [feature generation notebook](notebooks/feature_generation.ipynb), we provide examples of how users can generate the protein features described in our [accompanying manuscript](https://arxiv.org/abs/2106.04362) for individual protein inputs. + +## How to use data +In our [data usage notebook](notebooks/data_usage.ipynb), we provide examples of how users might use DIPS-Plus (or DB5-Plus) for downstream analysis or prediction tasks. For example, to train a new NeiA model with DB5-Plus as its cross-validation dataset, first download DB5-Plus' raw files and process them via the `data_usage` notebook: + +```bash +mkdir -p project/datasets/DB5/final +wget https://zenodo.org/record/5134732/files/final_raw_db5.tar.gz -O project/datasets/DB5/final/final_raw_db5.tar.gz +tar -xzf project/datasets/DB5/final/final_raw_db5.tar.gz -C project/datasets/DB5/final/ + +# To process these raw files for training and subsequently train a model: +python3 notebooks/data_usage.py +``` + +## Standard DIPS-Plus directory structure ``` DIPS-Plus │ └───project -│ │ -│ └───datasets -│ │ │ -│ │ └───builder -│ │ │ -│ │ └───DB5 -│ │ │ │ -│ │ │ └───final -│ │ │ │ │ -│ │ │ │ └───raw -│ │ │ │ -│ │ │ └───interim -│ │ │ │ │ -│ │ │ │ └───complexes -│ │ │ │ │ -│ │ │ │ └───external_feats -│ │ │ │ │ -│ │ │ │ └───pairs -│ │ │ │ -│ │ │ └───raw -│ │ │ │ -│ │ │ README -│ │ │ -│ │ └───DIPS -│ │ │ -│ │ └───filters -│ │ │ -│ │ └───final -│ │ │ │ -│ │ │ └───raw -│ │ │ -│ │ └───interim -│ │ │ │ -│ │ │ └───complexes -│ │ │ │ -│ │ │ └───external_feats -│ │ │ │ -│ │ │ └───pairs-pruned -│ │ │ -│ │ └───raw -│ │ │ -│ │ └───pdb -│ │ -│ └───utils -│ constants.py -│ utils.py -│ -.gitignore -environment.yml -LICENSE -README.md -requirements.txt -setup.cfg -setup.py + │ + └───datasets + │ + └───DB5 + │ │ + │ └───final + │ │ │ + │ │ └───processed # task-ready features for each dataset example + │ │ │ + │ │ └───raw # generic features for each dataset example + │ │ + │ └───interim + │ │ │ + │ │ └───complexes # metadata for each dataset example + │ │ │ + │ │ └───external_feats # features curated for each dataset example using external tools + │ │ │ + │ │ └───pairs # pair-wise features for each dataset example + │ │ + │ └───raw # raw PDB data downloads for each dataset example + │ + └───DIPS + │ + └───filters # filters to apply to each (un-pruned) dataset example + │ + └───final + │ │ + │ └───processed # task-ready features for each dataset example + │ │ + │ └───raw # generic features for each dataset example + │ + └───interim + │ │ + │ └───complexes # metadata for each dataset example + │ │ + │ └───external_feats # features curated for each dataset example using external tools + │ │ + │ └───pairs-pruned # filtered pair-wise features for each dataset example + │ │ + │ └───parsed # pair-wise features for each dataset example after initial parsing + │ + └───raw + │ + └───pdb # raw PDB data downloads for each dataset example ``` ## How to compile DIPS-Plus from scratch @@ -122,7 +192,7 @@ Retrieve protein complexes from the RCSB PDB and build out directory structure: rm project/datasets/DIPS/final/raw/pairs-postprocessed.txt project/datasets/DIPS/final/raw/pairs-postprocessed-train.txt project/datasets/DIPS/final/raw/pairs-postprocessed-val.txt project/datasets/DIPS/final/raw/pairs-postprocessed-test.txt # Create data directories (if not already created): -mkdir project/datasets/DIPS/raw project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim project/datasets/DIPS/interim/external_feats project/datasets/DIPS/final project/datasets/DIPS/final/raw project/datasets/DIPS/final/processed +mkdir project/datasets/DIPS/raw project/datasets/DIPS/raw/pdb project/datasets/DIPS/interim project/datasets/DIPS/interim/pairs-pruned project/datasets/DIPS/interim/external_feats project/datasets/DIPS/final project/datasets/DIPS/final/raw project/datasets/DIPS/final/processed # Download the raw PDB files: rsync -rlpt -v -z --delete --port=33444 --include='*.gz' --include='*.xz' --include='*/' --exclude '*' \ @@ -139,7 +209,17 @@ python3 project/datasets/builder/prune_pairs.py project/datasets/DIPS/interim/pa # Generate externally-sourced features: python3 project/datasets/builder/generate_psaia_features.py "$PSAIADIR" "$PROJDIR"/project/datasets/builder/psaia_config_file_dips.txt "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$PROJDIR"/project/datasets/DIPS/interim/external_feats --source_type rcsb -python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file +python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB" "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --write_file # Note: After this, one needs to re-run this command with `--read_file` instead + +# Generate multiple sequence alignments (MSAs) using a smaller sequence database (if not already created using the standard BFD): +DOWNLOAD_DIR="$HHSUITE_DB_DIR" && ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" && SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" && BASENAME=$(basename "${SOURCE_URL}") && mkdir --parents "${ROOT_DIR}" && aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" && pushd "${ROOT_DIR}" && gunzip "${ROOT_DIR}/${BASENAME}" && popd # e.g., Download the small BFD +python3 project/datasets/builder/generate_hhsuite_features.py "$PROJDIR"/project/datasets/DIPS/interim/parsed "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned "$HHSUITE_DB_DIR"/small_bfd "$PROJDIR"/project/datasets/DIPS/interim/external_feats --num_cpu_jobs 4 --num_cpus_per_job 8 --num_iter 2 --source_type rcsb --generate_msa_only --write_file # Note: After this, one needs to re-run this command with `--read_file` instead + +# Identify interfaces within intrinsically disordered regions (IDRs) # +# (1) Pull down the Docker image for `flDPnn` +docker pull docker.io/sinaghadermarzi/fldpnn +# (2) For all sequences in the dataset, predict which interface residues reside within IDRs +python3 project/datasets/builder/annotate_idr_interfaces.py "$PROJDIR"/project/datasets/DIPS/final/raw # Add new features to the filtered pairs, ensuring that the pruned pairs' original PDB files are stored locally for DSSP: python3 project/datasets/builder/download_missing_pruned_pair_pdbs.py "$PROJDIR"/project/datasets/DIPS/raw/pdb "$PROJDIR"/project/datasets/DIPS/interim/pairs-pruned --num_cpus 32 --rank "$1" --size "$2" @@ -198,7 +278,7 @@ python3 project/datasets/builder/convert_complexes_to_graphs.py "$PROJDIR"/proje We split the (tar.gz) archive into eight separate parts with 'split -b 4096M interim_external_feats_dips.tar.gz "interim_external_feats_dips.tar.gz.part"' -to upload it to Zenodo, so to recover the original archive: +to upload it to the dataset's primary Zenodo record, so to recover the original archive: ```bash # Reassemble external features archive with 'cat' diff --git a/environment.yml b/environment.yml index 99d5000..e792c78 100644 --- a/environment.yml +++ b/environment.yml @@ -11,6 +11,11 @@ dependencies: - _libgcc_mutex=0.1=conda_forge - _openmp_mutex=4.5=2_kmp_llvm - appdirs=1.4.4=pyhd3eb1b0_0 + - asttokens=2.2.1=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=pyhd8ed1ab_3 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - binutils_impl_linux-64=2.36.1=h193b22a_2 - biopython=1.78=py38h7f8727e_0 - blas=1.0=openblas - bottleneck=1.3.5=py38h7deecbd_0 @@ -21,6 +26,7 @@ dependencies: - cffi=1.15.1=py38h5eee18b_3 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - colorama=0.4.6=pyhd8ed1ab_0 + - comm=0.1.3=pyhd8ed1ab_0 - cryptography=40.0.2=py38h3d167d9_0 - cuda=11.6.1=0 - cuda-cccl=11.6.55=hf6102b2_0 @@ -56,21 +62,37 @@ dependencies: - cuda-visual-tools=11.6.1=0 - cudatoolkit=11.7.0=hd8887f6_10 - cudnn=8.8.0.121=h0800d71_0 + - debugpy=1.6.7=py38h8dc9893_0 + - decorator=5.1.1=pyhd8ed1ab_0 - dgl=1.1.0.cu116=py38_0 - dssp=3.0.0=h3fd9d12_4 + - executing=1.2.0=pyhd8ed1ab_0 - ffmpeg=4.3=hf484d3e_0 - freetype=2.12.1=hca18f0e_1 + - gcc=10.3.0=he2824d0_10 + - gcc_impl_linux-64=10.3.0=hf2f2afa_16 - gds-tools=1.6.1.9=0 - gmp=6.2.1=h58526e2_0 - gnutls=3.6.13=h85f3911_1 + - gxx=10.3.0=he2824d0_10 + - gxx_impl_linux-64=10.3.0=hf2f2afa_16 - hhsuite=3.3.0=py38pl5321hcbe9525_8 + - hmmer=3.3.2=hdbdd923_4 - icu=58.2=he6710b0_3 - idna=3.4=py38h06a4308_0 + - importlib-metadata=6.6.0=pyha770c72_0 + - importlib_metadata=6.6.0=hd8ed1ab_0 + - ipykernel=6.23.1=pyh210e3f2_0 + - ipython=8.4.0=py38h578d9bd_0 + - jedi=0.18.2=pyhd8ed1ab_0 - joblib=1.1.1=py38h06a4308_0 - jpeg=9e=h0b41bf4_3 + - jupyter_client=8.2.0=pyhd8ed1ab_0 + - jupyter_core=4.12.0=py38h578d9bd_0 + - kernel-headers_linux-64=2.6.32=he073ed8_15 - lame=3.100=h166bdaf_1003 - lcms2=2.15=hfd0df8a_0 - - ld_impl_linux-64=2.38=h1181459_1 + - ld_impl_linux-64=2.36.1=hea4e1c9_2 - lerc=4.0.0=h27087fc_0 - libblas=3.9.0=16_linux64_openblas - libboost=1.73.0=h28710b8_12 @@ -88,9 +110,10 @@ dependencies: - libcusparse-dev=11.7.2.124=hbbe9722_0 - libdeflate=1.17=h0b41bf4_0 - libffi=3.4.4=h6a678d5_0 + - libgcc-devel_linux-64=10.3.0=he6cfe16_16 - libgcc-ng=12.2.0=h65d4601_19 - - libgfortran-ng=11.2.0=h00389a5_1 - - libgfortran5=11.2.0=h1234567_1 + - libgfortran-ng=12.2.0=h69a702a_19 + - libgfortran5=12.2.0=h337968e_19 - libgomp=12.2.0=h65d4601_19 - libiconv=1.17=h166bdaf_0 - liblapack=3.9.0=16_linux64_openblas @@ -102,7 +125,10 @@ dependencies: - libopenblas=0.3.21=h043d6bf_0 - libpng=1.6.39=h753d276_0 - libprotobuf=3.21.12=h3eb15da_0 + - libsanitizer=10.3.0=h26c7422_16 + - libsodium=1.0.18=h36c2ea0_1 - libsqlite=3.42.0=h2797004_0 + - libstdcxx-devel_linux-64=10.3.0=he6cfe16_16 - libstdcxx-ng=12.2.0=h46fd767_19 - libtiff=4.5.0=h6adf6a1_2 - libuuid=2.38.1=h0b41bf4_0 @@ -112,10 +138,13 @@ dependencies: - llvm-openmp=16.0.4=h4dfa4b3_0 - lz4-c=1.9.4=h6a678d5_0 - magma=2.6.2=hc72dce7_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 - mkl=2022.2.1=h84fe81f_16997 + - mpi=1.0=openmpi - msms=2.6.1=h516909a_0 - nccl=2.15.5.1=h0800d71_0 - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.5.6=pyhd8ed1ab_0 - nettle=3.6=he412f7d_0 - networkx=3.1=pyhd8ed1ab_0 - ninja=1.11.1=h924138e_0 @@ -125,16 +154,23 @@ dependencies: - numpy-base=1.24.3=py38h1e6e340_0 - openh264=2.1.1=h780b84a_0 - openjpeg=2.5.0=hfec8fc6_2 - - openssl=3.1.0=hd590300_3 + - openmpi=4.1.5=h414af15_101 + - openssl=3.1.1=hd590300_1 - packaging=23.0=py38h06a4308_0 - pandas=1.5.3=py38h417a72b_0 + - parso=0.8.3=pyhd8ed1ab_0 - perl=5.32.1=0_h5eee18b_perl5 + - pexpect=4.8.0=pyh1a96a4e_2 + - pickleshare=0.7.5=py_1003 - pillow=9.4.0=py38hde6dc18_1 - - pip=23.0.1=py38h06a4308_0 - pooch=1.4.0=pyhd3eb1b0_0 + - prompt-toolkit=3.0.38=pyha770c72_0 - psutil=5.9.5=py38h1de0b5d_0 - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.15.1=pyhd8ed1ab_0 - pyopenssl=23.1.1=pyhd8ed1ab_0 - pysocks=1.7.1=py38h06a4308_0 - python=3.8.16=he550d4f_1_cpython @@ -144,6 +180,7 @@ dependencies: - pytorch-cuda=11.6=h867d48c_1 - pytorch-mutex=1.0=cuda - pytz=2022.7=py38h06a4308_0 + - pyzmq=25.0.2=py38he24dcef_0 - readline=8.2=h5eee18b_0 - requests=2.29.0=py38h06a4308_0 - scikit-learn=1.2.2=py38h6a678d5_0 @@ -151,37 +188,78 @@ dependencies: - six=1.16.0=pyhd3eb1b0_1 - sleef=3.5.1=h9b69904_2 - sqlite=3.41.2=h5eee18b_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - sysroot_linux-64=2.12=he073ed8_15 - tbb=2021.7.0=h924138e_0 - threadpoolctl=2.2.0=pyh0d69192_0 - tk=8.6.12=h1ccaba5_0 - torchaudio=0.13.1=py38_cu116 - torchvision=0.14.1=py38_cu116 + - tornado=6.3.2=py38h01eb140_0 + - traitlets=5.9.0=pyhd8ed1ab_0 - typing_extensions=4.6.0=pyha770c72_0 - urllib3=1.26.15=py38h06a4308_0 + - wcwidth=0.2.6=pyhd8ed1ab_0 - wheel=0.38.4=py38h06a4308_0 - xorg-libxau=1.0.11=hd590300_0 - xorg-libxdmcp=1.1.3=h7f98852_0 - xz=5.4.2=h5eee18b_0 + - zeromq=4.3.4=h9c3ff4c_1 + - zipp=3.15.0=pyhd8ed1ab_0 - zlib=1.2.13=h166bdaf_4 - zstd=1.5.5=hc292b87_0 - pip: + - absl-py==1.4.0 + - aiohttp==3.8.4 + - aiosignal==1.3.1 - alabaster==0.7.13 - - atom3-py3==0.1.9.9 + - async-timeout==4.0.2 + - git+https://github.com/amorehead/atom3.git@83987404ceed38a1f5a5abd517aa38128d0a4f2c + - attrs==23.1.0 - babel==2.12.1 + - cachetools==5.3.1 - click==7.0 + - configparser==5.3.0 - dill==0.3.3 + - docker-pycreds==0.4.0 - docutils==0.17.1 - easy-parallel-py3==0.1.6.4 + - fairscale==0.4.0 + - frozenlist==1.3.3 + - fsspec==2023.5.0 + - future==0.18.3 + - gitdb==4.0.10 + - gitpython==3.1.31 + - google-auth==2.19.0 + - google-auth-oauthlib==1.0.0 + - grpcio==1.54.2 - h5py==3.8.0 + - hickle==5.0.2 - imagesize==1.4.1 + - install==1.3.5 - jinja2==2.11.3 + - markdown==3.4.3 - markupsafe==1.1.1 + - mpi4py==3.0.3 + - multidict==6.0.4 - multiprocess==0.70.11.1 + - oauthlib==3.2.2 - pathos==0.2.7 + - pathtools==0.1.2 - pox==0.3.2 - ppft==1.7.6.6 - - pygments==2.15.1 - - setuptools==56.2.0 + - promise==2.3 + - protobuf==3.20.3 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pydeprecate==0.3.1 + - pytorch-lightning==1.4.8 + - pyyaml==6.0 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - sentry-sdk==1.24.0 + - shortuuid==1.0.11 + - smmap==5.0.0 - snowballstemmer==2.2.0 - sphinx==4.0.1 - sphinxcontrib-applehelp==1.0.4 @@ -190,4 +268,12 @@ dependencies: - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - - tqdm==4.49.0 + - subprocess32==3.5.4 + - tensorboard==2.13.0 + - tensorboard-data-server==0.7.0 + - termcolor==2.3.0 + - torchmetrics==0.5.1 + - wandb==0.12.2 + - werkzeug==2.3.4 + - yarl==1.9.2 + - yaspin==2.3.0 diff --git a/notebooks/data_usage.ipynb b/notebooks/data_usage.ipynb new file mode 100644 index 0000000..6c8e59f --- /dev/null +++ b/notebooks/data_usage.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example of data usage" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Neural network model training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# -------------------------------------------------------------------------------------------------------------------------------------\n", + "# Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch):\n", + "# -------------------------------------------------------------------------------------------------------------------------------------\n", + "\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "import pytorch_lightning as pl\n", + "import torch.nn as nn\n", + "from pytorch_lightning.plugins import DDPPlugin\n", + "\n", + "from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule\n", + "from project.utils.modules import LitNeiA\n", + "from project.utils.training_utils import collect_args, process_args, construct_pl_logger" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main(args):\n", + " # -----------\n", + " # Data\n", + " # -----------\n", + " # Load Docking Benchmark 5 (DB5) data module\n", + " db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir,\n", + " batch_size=args.batch_size,\n", + " num_dataloader_workers=args.num_workers,\n", + " knn=args.knn,\n", + " self_loops=args.self_loops,\n", + " percent_to_use=args.db5_percent_to_use,\n", + " process_complexes=args.process_complexes,\n", + " input_indep=args.input_indep)\n", + " db5_data_module.setup()\n", + "\n", + " # ------------\n", + " # Model\n", + " # ------------\n", + " # Assemble a dictionary of model arguments\n", + " dict_args = vars(args)\n", + " use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB\n", + "\n", + " # Pick model and supply it with a dictionary of arguments\n", + " if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA)\n", + " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n", + " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n", + " gnn_activ_fn=nn.Tanh(),\n", + " interact_activ_fn=nn.ReLU(),\n", + " num_classes=db5_data_module.db5_test.num_classes,\n", + " weighted_avg=True, # Use the neighborhood weighted average variant of NeiA\n", + " num_gnn_layers=dict_args['num_gnn_layers'],\n", + " num_interact_layers=dict_args['num_interact_layers'],\n", + " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n", + " num_epochs=dict_args['num_epochs'],\n", + " pn_ratio=dict_args['pn_ratio'],\n", + " knn=dict_args['knn'],\n", + " dropout_rate=dict_args['dropout_rate'],\n", + " metric_to_track=dict_args['metric_to_track'],\n", + " weight_decay=dict_args['weight_decay'],\n", + " batch_size=dict_args['batch_size'],\n", + " lr=dict_args['lr'],\n", + " multi_gpu_backend=dict_args[\"accelerator\"])\n", + " args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n", + " f'-n{db5_data_module.db5_test.num_node_features}' \\\n", + " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n", + " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n", + " if not args.experiment_name \\\n", + " else args.experiment_name\n", + " template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}'\n", + "\n", + " else: # Default Model - Neighborhood Average (NeiA)\n", + " model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features,\n", + " num_edge_input_feats=db5_data_module.db5_test.num_edge_features,\n", + " gnn_activ_fn=nn.Tanh(),\n", + " interact_activ_fn=nn.ReLU(),\n", + " num_classes=db5_data_module.db5_test.num_classes,\n", + " weighted_avg=False,\n", + " num_gnn_layers=dict_args['num_gnn_layers'],\n", + " num_interact_layers=dict_args['num_interact_layers'],\n", + " num_interact_hidden_channels=dict_args['num_interact_hidden_channels'],\n", + " num_epochs=dict_args['num_epochs'],\n", + " pn_ratio=dict_args['pn_ratio'],\n", + " knn=dict_args['knn'],\n", + " dropout_rate=dict_args['dropout_rate'],\n", + " metric_to_track=dict_args['metric_to_track'],\n", + " weight_decay=dict_args['weight_decay'],\n", + " batch_size=dict_args['batch_size'],\n", + " lr=dict_args['lr'],\n", + " multi_gpu_backend=dict_args[\"accelerator\"])\n", + " args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \\\n", + " f'-n{db5_data_module.db5_test.num_node_features}' \\\n", + " f'-e{db5_data_module.db5_test.num_edge_features}' \\\n", + " f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \\\n", + " if not args.experiment_name \\\n", + " else args.experiment_name\n", + " template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}'\n", + "\n", + " # ------------\n", + " # Checkpoint\n", + " # ------------\n", + " ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)\n", + " ckpt_path_exists = os.path.exists(ckpt_path)\n", + " ckpt_provided = args.ckpt_name != '' and ckpt_path_exists\n", + " model = model.load_from_checkpoint(ckpt_path,\n", + " use_wandb_logger=use_wandb_logger,\n", + " batch_size=args.batch_size,\n", + " lr=args.lr,\n", + " weight_decay=args.weight_decay,\n", + " dropout_rate=args.dropout_rate) if ckpt_provided else model\n", + "\n", + " # ------------\n", + " # Trainer\n", + " # ------------\n", + " trainer = pl.Trainer.from_argparse_args(args)\n", + "\n", + " # -------------\n", + " # Learning Rate\n", + " # -------------\n", + " if args.find_lr:\n", + " lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder\n", + " fig = lr_finder.plot(suggest=True) # Plot learning rates\n", + " fig.savefig('optimal_lr.pdf')\n", + " fig.show()\n", + " model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate\n", + " print(f'Optimal learning rate found: {model.hparams.lr}')\n", + "\n", + " # ------------\n", + " # Logger\n", + " # ------------\n", + " pl_logger = construct_pl_logger(args) # Log everything to an external logger\n", + " trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance\n", + "\n", + " # -----------\n", + " # Callbacks\n", + " # -----------\n", + " # Create and use callbacks\n", + " mode = 'min' if 'ce' in args.metric_to_track else 'max'\n", + " early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,\n", + " mode=mode,\n", + " min_delta=args.min_delta,\n", + " patience=args.patience)\n", + " ckpt_callback = pl.callbacks.ModelCheckpoint(\n", + " monitor=args.metric_to_track,\n", + " mode=mode,\n", + " verbose=True,\n", + " save_last=True,\n", + " save_top_k=3,\n", + " filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs\n", + " )\n", + " lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True)\n", + " trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback]\n", + "\n", + " # ------------\n", + " # Restore\n", + " # ------------\n", + " # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally\n", + " if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path):\n", + " checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best'\n", + " artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model')\n", + " artifact_dir = artifact.download()\n", + " model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt',\n", + " use_wandb_logger=use_wandb_logger,\n", + " batch_size=args.batch_size,\n", + " lr=args.lr,\n", + " weight_decay=args.weight_decay)\n", + "\n", + " # -------------\n", + " # Training\n", + " # -------------\n", + " # Train with the provided model and DataModule\n", + " trainer.fit(model=model, datamodule=db5_data_module)\n", + "\n", + " # -------------\n", + " # Testing\n", + " # -------------\n", + " trainer.test()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# -----------\n", + "# Jupyter\n", + "# -----------\n", + "sys.argv = ['']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# -----------\n", + "# Arguments\n", + "# -----------\n", + "# Collect all arguments\n", + "parser = collect_args()\n", + "\n", + "# Parse all known and unknown arguments\n", + "args, unparsed_argv = parser.parse_known_args()\n", + "\n", + "# Let the model add what it wants\n", + "parser = LitNeiA.add_model_specific_args(parser)\n", + "\n", + "# Re-parse all known and unknown arguments after adding those that are model specific\n", + "args, unparsed_argv = parser.parse_known_args()\n", + "\n", + "# TODO: Manually set arguments within a Jupyter notebook from here\n", + "args.model_name = \"neia\"\n", + "args.multi_gpu_backend = \"dp\"\n", + "args.db5_data_dir = \"../project/datasets/DB5/final/raw\"\n", + "args.process_complexes = True\n", + "args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation\n", + "\n", + "# Set Lightning-specific parameter values before constructing Trainer instance\n", + "args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes}\n", + "args.max_epochs = args.num_epochs\n", + "args.profiler = args.profiler_method\n", + "args.accelerator = args.multi_gpu_backend\n", + "args.auto_select_gpus = args.auto_choose_gpus\n", + "args.gpus = args.num_gpus\n", + "args.num_nodes = args.num_compute_nodes\n", + "args.precision = args.gpu_precision\n", + "args.accumulate_grad_batches = args.accum_grad_batches\n", + "args.gradient_clip_val = args.grad_clip_val\n", + "args.gradient_clip_algo = args.grad_clip_algo\n", + "args.stochastic_weight_avg = args.stc_weight_avg\n", + "args.deterministic = True # Make LightningModule's training reproducible\n", + "\n", + "# Set plugins for Lightning\n", + "args.plugins = [\n", + " # 'ddp_sharded', # For sharded model training (to reduce GPU requirements)\n", + " # DDPPlugin(find_unused_parameters=False),\n", + "]\n", + "\n", + "# Finalize all arguments as necessary\n", + "args = process_args(args)\n", + "\n", + "# Begin execution of model training with given args above\n", + "main(args)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DIPS-Plus", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/data_usage.py b/notebooks/data_usage.py new file mode 100644 index 0000000..2ed9296 --- /dev/null +++ b/notebooks/data_usage.py @@ -0,0 +1,239 @@ +# %% [markdown] +# # Example of data usage + +# %% [markdown] +# ### Neural network model training + +# %% +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): +# ------------------------------------------------------------------------------------------------------------------------------------- + +import os +import sys +from pathlib import Path + +import pytorch_lightning as pl +import torch.nn as nn +from pytorch_lightning.plugins import DDPPlugin + +from project.datasets.DB5.db5_dgl_data_module import DB5DGLDataModule +from project.utils.modules import LitNeiA +from project.utils.training_utils import collect_args, process_args, construct_pl_logger + +# %% +def main(args): + # ----------- + # Data + # ----------- + # Load Docking Benchmark 5 (DB5) data module + db5_data_module = DB5DGLDataModule(data_dir=args.db5_data_dir, + batch_size=args.batch_size, + num_dataloader_workers=args.num_workers, + knn=args.knn, + self_loops=args.self_loops, + percent_to_use=args.db5_percent_to_use, + process_complexes=args.process_complexes, + input_indep=args.input_indep) + db5_data_module.setup() + + # ------------ + # Model + # ------------ + # Assemble a dictionary of model arguments + dict_args = vars(args) + use_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB + + # Pick model and supply it with a dictionary of arguments + if args.model_name.lower() == 'neiwa': # Neighborhood Weighted Average (NeiWA) + model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features, + num_edge_input_feats=db5_data_module.db5_test.num_edge_features, + gnn_activ_fn=nn.Tanh(), + interact_activ_fn=nn.ReLU(), + num_classes=db5_data_module.db5_test.num_classes, + weighted_avg=True, # Use the neighborhood weighted average variant of NeiA + num_gnn_layers=dict_args['num_gnn_layers'], + num_interact_layers=dict_args['num_interact_layers'], + num_interact_hidden_channels=dict_args['num_interact_hidden_channels'], + num_epochs=dict_args['num_epochs'], + pn_ratio=dict_args['pn_ratio'], + knn=dict_args['knn'], + dropout_rate=dict_args['dropout_rate'], + metric_to_track=dict_args['metric_to_track'], + weight_decay=dict_args['weight_decay'], + batch_size=dict_args['batch_size'], + lr=dict_args['lr'], + multi_gpu_backend=dict_args["accelerator"]) + args.experiment_name = f'LitNeiWA-b{args.batch_size}-gl{args.num_gnn_layers}' \ + f'-n{db5_data_module.db5_test.num_node_features}' \ + f'-e{db5_data_module.db5_test.num_edge_features}' \ + f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \ + if not args.experiment_name \ + else args.experiment_name + template_ckpt_filename = 'LitNeiWA-{epoch:02d}-{val_ce:.2f}' + + else: # Default Model - Neighborhood Average (NeiA) + model = LitNeiA(num_node_input_feats=db5_data_module.db5_test.num_node_features, + num_edge_input_feats=db5_data_module.db5_test.num_edge_features, + gnn_activ_fn=nn.Tanh(), + interact_activ_fn=nn.ReLU(), + num_classes=db5_data_module.db5_test.num_classes, + weighted_avg=False, + num_gnn_layers=dict_args['num_gnn_layers'], + num_interact_layers=dict_args['num_interact_layers'], + num_interact_hidden_channels=dict_args['num_interact_hidden_channels'], + num_epochs=dict_args['num_epochs'], + pn_ratio=dict_args['pn_ratio'], + knn=dict_args['knn'], + dropout_rate=dict_args['dropout_rate'], + metric_to_track=dict_args['metric_to_track'], + weight_decay=dict_args['weight_decay'], + batch_size=dict_args['batch_size'], + lr=dict_args['lr'], + multi_gpu_backend=dict_args["accelerator"]) + args.experiment_name = f'LitNeiA-b{args.batch_size}-gl{args.num_gnn_layers}' \ + f'-n{db5_data_module.db5_test.num_node_features}' \ + f'-e{db5_data_module.db5_test.num_edge_features}' \ + f'-il{args.num_interact_layers}-i{args.num_interact_hidden_channels}' \ + if not args.experiment_name \ + else args.experiment_name + template_ckpt_filename = 'LitNeiA-{epoch:02d}-{val_ce:.2f}' + + # ------------ + # Checkpoint + # ------------ + ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name) + ckpt_path_exists = os.path.exists(ckpt_path) + ckpt_provided = args.ckpt_name != '' and ckpt_path_exists + model = model.load_from_checkpoint(ckpt_path, + use_wandb_logger=use_wandb_logger, + batch_size=args.batch_size, + lr=args.lr, + weight_decay=args.weight_decay, + dropout_rate=args.dropout_rate) if ckpt_provided else model + + # ------------ + # Trainer + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + + # ------------- + # Learning Rate + # ------------- + if args.find_lr: + lr_finder = trainer.tuner.lr_find(model, datamodule=db5_data_module) # Run learning rate finder + fig = lr_finder.plot(suggest=True) # Plot learning rates + fig.savefig('optimal_lr.pdf') + fig.show() + model.hparams.lr = lr_finder.suggestion() # Save optimal learning rate + print(f'Optimal learning rate found: {model.hparams.lr}') + + # ------------ + # Logger + # ------------ + pl_logger = construct_pl_logger(args) # Log everything to an external logger + trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance + + # ----------- + # Callbacks + # ----------- + # Create and use callbacks + mode = 'min' if 'ce' in args.metric_to_track else 'max' + early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track, + mode=mode, + min_delta=args.min_delta, + patience=args.patience) + ckpt_callback = pl.callbacks.ModelCheckpoint( + monitor=args.metric_to_track, + mode=mode, + verbose=True, + save_last=True, + save_top_k=3, + filename=template_ckpt_filename # Warning: May cause a race condition if calling trainer.test() with many GPUs + ) + lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval='step', log_momentum=True) + trainer.callbacks = [early_stop_callback, ckpt_callback, lr_monitor_callback] + + # ------------ + # Restore + # ------------ + # If using WandB, download checkpoint artifact from their servers if the checkpoint is not already stored locally + if use_wandb_logger and args.ckpt_name != '' and not os.path.exists(ckpt_path): + checkpoint_reference = f'{args.entity}/{args.project_name}/model-{args.run_id}:best' + artifact = trainer.logger.experiment.use_artifact(checkpoint_reference, type='model') + artifact_dir = artifact.download() + model = model.load_from_checkpoint(Path(artifact_dir) / 'model.ckpt', + use_wandb_logger=use_wandb_logger, + batch_size=args.batch_size, + lr=args.lr, + weight_decay=args.weight_decay) + + # ------------- + # Training + # ------------- + # Train with the provided model and DataModule + trainer.fit(model=model, datamodule=db5_data_module) + + # ------------- + # Testing + # ------------- + trainer.test() + + +# %% +# ----------- +# Jupyter +# ----------- +# sys.argv = [''] + +# %% +# ----------- +# Arguments +# ----------- +# Collect all arguments +parser = collect_args() + +# Parse all known and unknown arguments +args, unparsed_argv = parser.parse_known_args() + +# Let the model add what it wants +parser = LitNeiA.add_model_specific_args(parser) + +# Re-parse all known and unknown arguments after adding those that are model specific +args, unparsed_argv = parser.parse_known_args() + +# TODO: Manually set arguments within a Jupyter notebook from here +args.model_name = "neia" +args.multi_gpu_backend = "dp" +args.db5_data_dir = "project/datasets/DB5/final/raw" +args.process_complexes = True +args.batch_size = 1 # Note: `batch_size` must be `1` for compatibility with the current model implementation + +# Set Lightning-specific parameter values before constructing Trainer instance +args.max_time = {'hours': args.max_hours, 'minutes': args.max_minutes} +args.max_epochs = args.num_epochs +args.profiler = args.profiler_method +args.accelerator = args.multi_gpu_backend +args.auto_select_gpus = args.auto_choose_gpus +args.gpus = args.num_gpus +args.num_nodes = args.num_compute_nodes +args.precision = args.gpu_precision +args.accumulate_grad_batches = args.accum_grad_batches +args.gradient_clip_val = args.grad_clip_val +args.gradient_clip_algo = args.grad_clip_algo +args.stochastic_weight_avg = args.stc_weight_avg +args.deterministic = True # Make LightningModule's training reproducible + +# Set plugins for Lightning +args.plugins = [ + # 'ddp_sharded', # For sharded model training (to reduce GPU requirements) + # DDPPlugin(find_unused_parameters=False), +] + +# Finalize all arguments as necessary +args = process_args(args) + +# Begin execution of model training with given args above +main(args) + + diff --git a/notebooks/feature_generation.ipynb b/notebooks/feature_generation.ipynb new file mode 100644 index 0000000..6550802 --- /dev/null +++ b/notebooks/feature_generation.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feature generation for PDB file inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import atom3.complex as comp\n", + "import atom3.conservation as con\n", + "import atom3.neighbors as nb\n", + "import atom3.pair as pair\n", + "import atom3.parse as parse\n", + "import dill as pickle\n", + "\n", + "from pathlib import Path\n", + "\n", + "from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Parse PDB file input to pair-wise features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdb_filename = \"../project/datasets/Input/raw/pdb/2g/12gs.pdb1\" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip`\n", + "output_pkl = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n", + "complexes_dill = \"../project/datasets/Input/interim/complexes/complexes.dill\"\n", + "pairs_dir = \"../project/datasets/Input/interim/pairs\"\n", + "pkl_filenames = [output_pkl]\n", + "source_type = \"rcsb\" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs)\n", + "neighbor_def = \"non_heavy_res\"\n", + "cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here\n", + "unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False`\n", + "\n", + "for item in [\n", + " Path(pdb_filename).parent,\n", + " Path(output_pkl).parent,\n", + " Path(complexes_dill).parent,\n", + " pairs_dir,\n", + "]:\n", + " os.makedirs(item, exist_ok=True)\n", + "\n", + "# note: the following replicates the logic within `make_dataset.py` for a single PDB file input\n", + "parse.parse(\n", + " # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed\n", + " pdb_filename=pdb_filename,\n", + " output_pkl=output_pkl\n", + ")\n", + "complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type)\n", + "comp.write_complexes(complexes=complexes, output_dill=complexes_dill)\n", + "get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff)\n", + "get_pairs = pair.build_get_pairs(\n", + " neighbor_def=neighbor_def,\n", + " type=source_type,\n", + " unbound=unbound,\n", + " nb_fn=get_neighbors,\n", + " full=False\n", + ")\n", + "complexes = comp.read_complexes(input_dill=complexes_dill)\n", + "pair.complex_to_pairs(\n", + " complex=list(complexes['data'].values())[0],\n", + " source_type=source_type,\n", + " get_pairs=get_pairs,\n", + " output_dir=pairs_dir\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Compute sequence-based features using external tools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "psaia_dir = \"~/Programs/PSAIA-1.0/bin/linux/psa\" # note: replace this with the path to your local installation of PSAIA\n", + "psaia_config_file = \"../project/datasets/builder/psaia_config_file_dips.txt\" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above\n", + "file_list_file = os.path.join(\"../project/datasets/Input/interim/external_feats/\", 'PSAIA', source_type.upper(), 'pdb_list.fls')\n", + "num_cpus = 8\n", + "pkl_filename = \"../project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl\"\n", + "output_filename = \"../project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl\"\n", + "hhsuite_db = \"~/Data/Databases/pfamA_35.0/pfam\" # note: substitute the path to your local HHsuite3 database here\n", + "num_iter = 2\n", + "msa_only = False\n", + "\n", + "for item in [\n", + " Path(file_list_file).parent,\n", + " Path(output_filename).parent,\n", + "]:\n", + " os.makedirs(item, exist_ok=True)\n", + "\n", + "# note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input\n", + "with open(file_list_file, 'w') as file:\n", + " file.write(f'{pdb_filename}\\n') # note: references the `pdb_filename` as defined previously\n", + "con.gen_protrusion_index(\n", + " psaia_dir=psaia_dir,\n", + " psaia_config_file=psaia_config_file,\n", + " file_list_file=file_list_file,\n", + ")\n", + "con.map_profile_hmms(\n", + " num_cpus=num_cpus,\n", + " pkl_filename=pkl_filename,\n", + " output_filename=output_filename,\n", + " hhsuite_db=hhsuite_db,\n", + " source_type=source_type,\n", + " num_iter=num_iter,\n", + " msa_only=msa_only,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Compute structure-based features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from project.utils.utils import __should_keep_postprocessed\n", + "\n", + "\n", + "raw_pdb_dir = \"../project/datasets/Input/raw/pdb\"\n", + "pair_filename = \"../project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill\"\n", + "source_type = \"rcsb\"\n", + "external_feats_dir = \"../project/datasets/Input/interim/external_feats/parsed\"\n", + "output_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", + "\n", + "unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type)\n", + "if should_keep:\n", + " # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired\n", + " postprocessed_pair = postprocess_pruned_pair(\n", + " raw_pdb_filenames=raw_pdb_filenames,\n", + " external_feats_dir=external_feats_dir,\n", + " original_pair=unprocessed_pair,\n", + " source_type=source_type,\n", + " )\n", + " # write into output_filenames if not exist\n", + " os.makedirs(Path(output_filename).parent, exist_ok=True)\n", + " with open(output_filename, 'wb') as f:\n", + " pickle.dump(postprocessed_pair, f)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Embed deep learning-based IDR features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model\n", + "!docker pull docker.io/sinaghadermarzi/fldpnn\n", + "\n", + "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", + "pickle_filepaths = [input_pair_filename]\n", + "\n", + "annotate_idr_residues(\n", + " pickle_filepaths=pickle_filepaths\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Impute missing feature values (optional)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill\"\n", + "output_pair_filename = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n", + "impute_atom_features = False\n", + "advanced_logging = False\n", + "\n", + "impute_missing_feature_values(\n", + " input_pair_filename=input_pair_filename,\n", + " output_pair_filename=output_pair_filename,\n", + " impute_atom_features=impute_atom_features,\n", + " advanced_logging=advanced_logging,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Convert pair-wise features into graph inputs (optional)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "raw_filepath = \"../project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill\"\n", + "new_graph_dir = \"../project/datasets/Input/final/processed/2g\"\n", + "processed_filepath = \"../project/datasets/Input/final/processed/2g/12gs.pdb1.pt\"\n", + "edge_dist_cutoff = 15.0\n", + "edge_limit = 5000\n", + "self_loops = True\n", + "\n", + "os.makedirs(new_graph_dir, exist_ok=True)\n", + "\n", + "process_raw_file_into_dgl_graphs(\n", + " raw_filepath=raw_filepath,\n", + " new_graph_dir=new_graph_dir,\n", + " processed_filepath=processed_filepath,\n", + " edge_dist_cutoff=edge_dist_cutoff,\n", + " edge_limit=edge_limit,\n", + " self_loops=self_loops,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DIPS-Plus", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/feature_generation.py b/notebooks/feature_generation.py new file mode 100644 index 0000000..3d832fc --- /dev/null +++ b/notebooks/feature_generation.py @@ -0,0 +1,181 @@ +# %% [markdown] +# # Feature generation for PDB file inputs + +# %% +import os + +import atom3.complex as comp +import atom3.conservation as con +import atom3.neighbors as nb +import atom3.pair as pair +import atom3.parse as parse +import dill as pickle + +from pathlib import Path + +from project.utils.utils import annotate_idr_residues, impute_missing_feature_values, postprocess_pruned_pair, process_raw_file_into_dgl_graphs + +# %% [markdown] +# ### 1. Parse PDB file input to pair-wise features + +# %% +pdb_filename = "project/datasets/Input/raw/pdb/2g/12gs.pdb1" # note: an input PDB must be uncompressed (e.g., not in `.gz` archive format) using e.g., `gunzip` +output_pkl = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl" +complexes_dill = "project/datasets/Input/interim/complexes/complexes.dill" +pairs_dir = "project/datasets/Input/interim/pairs" +pkl_filenames = [output_pkl] +source_type = "rcsb" # note: this default value will likely work for common use cases (i.e., those concerning bound-state PDB protein complex structure inputs) +neighbor_def = "non_heavy_res" +cutoff = 6 # note: distance threshold (in Angstrom) for classifying inter-chain interactions can be customized here +unbound = False # note: if `source_type` is set to `rcsb`, this value should likely be `False` + +for item in [ + Path(pdb_filename).parent, + Path(output_pkl).parent, + Path(complexes_dill).parent, + pairs_dir, +]: + os.makedirs(item, exist_ok=True) + +# note: the following replicates the logic within `make_dataset.py` for a single PDB file input +parse.parse( + # note: assumes the PDB file input (i.e., `pdb_filename`) is not compressed + pdb_filename=pdb_filename, + output_pkl=output_pkl +) +complexes = comp.get_complexes(filenames=pkl_filenames, type=source_type) +comp.write_complexes(complexes=complexes, output_dill=complexes_dill) +get_neighbors = nb.build_get_neighbors(criteria=neighbor_def, cutoff=cutoff) +get_pairs = pair.build_get_pairs( + neighbor_def=neighbor_def, + type=source_type, + unbound=unbound, + nb_fn=get_neighbors, + full=False +) +complexes = comp.read_complexes(input_dill=complexes_dill) +pair.complex_to_pairs( + complex=list(complexes['data'].values())[0], + source_type=source_type, + get_pairs=get_pairs, + output_dir=pairs_dir +) + +# %% [markdown] +# ### 2. Compute sequence-based features using external tools + +# %% +psaia_dir = "~/Programs/PSAIA-1.0/bin/linux/psa" # note: replace this with the path to your local installation of PSAIA +psaia_config_file = "project/datasets/builder/psaia_config_file_dips.txt" # note: choose `psaia_config_file_dips.txt` according to the `source_type` selected above +file_list_file = os.path.join("project/datasets/Input/interim/external_feats/", 'PSAIA', source_type.upper(), 'pdb_list.fls') +num_cpus = 8 +pkl_filename = "project/datasets/Input/interim/parsed/2g/12gs.pdb1.pkl" +output_filename = "project/datasets/Input/interim/external_feats/parsed/2g/12gs.pdb1.pkl" +hhsuite_db = "~/Data/Databases/pfamA_35.0/pfam" # note: substitute the path to your local HHsuite3 database here +num_iter = 2 +msa_only = False + +for item in [ + Path(file_list_file).parent, + Path(output_filename).parent, +]: + os.makedirs(item, exist_ok=True) + +# note: the following replicates the logic within `generate_psaia_features.py` and `generate_hhsuite_features.py` for a single PDB file input +with open(file_list_file, 'w') as file: + file.write(f'{pdb_filename}\n') # note: references the `pdb_filename` as defined previously +con.gen_protrusion_index( + psaia_dir=psaia_dir, + psaia_config_file=psaia_config_file, + file_list_file=file_list_file, +) +con.map_profile_hmms( + num_cpus=num_cpus, + pkl_filename=pkl_filename, + output_filename=output_filename, + hhsuite_db=hhsuite_db, + source_type=source_type, + num_iter=num_iter, + msa_only=msa_only, +) + +# %% [markdown] +# ### 3. Compute structure-based features + +# %% +from project.utils.utils import __should_keep_postprocessed + + +raw_pdb_dir = "project/datasets/Input/raw/pdb" +pair_filename = "project/datasets/Input/interim/pairs/2g/12gs.pdb1_0.dill" +source_type = "rcsb" +external_feats_dir = "project/datasets/Input/interim/external_feats/parsed" +output_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" + +unprocessed_pair, raw_pdb_filenames, should_keep = __should_keep_postprocessed(raw_pdb_dir, pair_filename, source_type) +if should_keep: + # note: save `postprocessed_pair` to local storage within `project/datasets/Input/final/raw` for future reference as desired + postprocessed_pair = postprocess_pruned_pair( + raw_pdb_filenames=raw_pdb_filenames, + external_feats_dir=external_feats_dir, + original_pair=unprocessed_pair, + source_type=source_type, + ) + # write into output_filenames if not exist + os.makedirs(Path(output_filename).parent, exist_ok=True) + with open(output_filename, 'wb') as f: + pickle.dump(postprocessed_pair, f) + +# %% [markdown] +# ### 4. Embed deep learning-based IDR features + +# %% +# note: ensures the Docker image for `flDPnn` is available locally before trying to run inference with the model +# !docker pull docker.io/sinaghadermarzi/fldpnn + +input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" +pickle_filepaths = [input_pair_filename] + +annotate_idr_residues( + pickle_filepaths=pickle_filepaths +) + +# %% [markdown] +# ### 5. Impute missing feature values (optional) + +# %% +input_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0.dill" +output_pair_filename = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill" +impute_atom_features = False +advanced_logging = False + +impute_missing_feature_values( + input_pair_filename=input_pair_filename, + output_pair_filename=output_pair_filename, + impute_atom_features=impute_atom_features, + advanced_logging=advanced_logging, +) + +# %% [markdown] +# ### 6. Convert pair-wise features into graph inputs (optional) + +# %% +raw_filepath = "project/datasets/Input/final/raw/2g/12gs.pdb1_0_imputed.dill" +new_graph_dir = "project/datasets/Input/final/processed/2g" +processed_filepath = "project/datasets/Input/final/processed/2g/12gs.pdb1.pt" +edge_dist_cutoff = 15.0 +edge_limit = 5000 +self_loops = True + +os.makedirs(new_graph_dir, exist_ok=True) + +process_raw_file_into_dgl_graphs( + raw_filepath=raw_filepath, + new_graph_dir=new_graph_dir, + processed_filepath=processed_filepath, + edge_dist_cutoff=edge_dist_cutoff, + edge_limit=edge_limit, + self_loops=self_loops, +) + + diff --git a/project/datasets/DB5/db5_dgl_data_module.py b/project/datasets/DB5/db5_dgl_data_module.py new file mode 100644 index 0000000..8e5b3e7 --- /dev/null +++ b/project/datasets/DB5/db5_dgl_data_module.py @@ -0,0 +1,52 @@ +from typing import Optional + +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +from project.datasets.DB5.db5_dgl_dataset import DB5DGLDataset + + +class DB5DGLDataModule(LightningDataModule): + """Unbound protein complex data module for DGL with PyTorch.""" + + # Dataset partition instantiation + db5_train = None + db5_val = None + db5_test = None + + def __init__(self, data_dir: str, batch_size: int, num_dataloader_workers: int, knn: int, + self_loops: bool, percent_to_use: float, process_complexes: bool, input_indep: bool): + super().__init__() + + self.data_dir = data_dir + self.batch_size = batch_size + self.num_dataloader_workers = num_dataloader_workers + self.knn = knn + self.self_loops = self_loops + self.percent_to_use = percent_to_use # Fraction of DB5 dataset to use + self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training + self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model + + def setup(self, stage: Optional[str] = None): + # Assign training/validation/testing data set for use in DataLoaders - called on every GPU + self.db5_train = DB5DGLDataset(mode='train', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, + percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, + input_indep=self.input_indep) + self.db5_val = DB5DGLDataset(mode='val', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, + percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, + input_indep=self.input_indep) + self.db5_test = DB5DGLDataset(mode='test', raw_dir=self.data_dir, knn=self.knn, self_loops=self.self_loops, + percent_to_use=self.percent_to_use, process_complexes=self.process_complexes, + input_indep=self.input_indep) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.db5_train, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.db5_val, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.db5_test, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_dataloader_workers, collate_fn=lambda x: x, pin_memory=True) diff --git a/project/datasets/DB5/db5_dgl_dataset.py b/project/datasets/DB5/db5_dgl_dataset.py new file mode 100644 index 0000000..cd5cd4e --- /dev/null +++ b/project/datasets/DB5/db5_dgl_dataset.py @@ -0,0 +1,258 @@ +import os +import pickle + +import pandas as pd +from dgl.data import DGLDataset, download, check_sha1 + +from project.utils.training_utils import construct_filenames_frame_txt_filenames, build_filenames_frame_error_message, process_complex_into_dict, zero_out_complex_features + + +class DB5DGLDataset(DGLDataset): + r"""Unbound protein complex dataset for DGL with PyTorch. + + Statistics: + + - Train examples: 140 + - Validation examples: 35 + - Test examples: 55 + - Number of structures per complex: 2 + ---------------------- + - Total examples: 230 + ---------------------- + + Parameters + ---------- + mode: str, optional + Should be one of ['train', 'val', 'test']. Default: 'train'. + raw_dir: str + Raw file directory to download/contains the input data directory. Default: 'final/raw'. + knn: int + How many nearest neighbors to which to connect a given node. Default: 20. + self_loops: bool + Whether to connect a given node to itself. Default: True. + percent_to_use: float + How much of the dataset to load. Default: 1.0. + process_complexes: bool + Whether to process each unprocessed complex as we load in the dataset. Default: True. + input_indep: bool + Whether to zero-out each input node and edge feature for an input-independent baseline. Default: False. + force_reload: bool + Whether to reload the dataset. Default: False. + verbose: bool + Whether to print out progress information. Default: False. + + Notes + ----- + All the samples will be loaded and preprocessed in the memory first. + + Examples + -------- + >>> # Get dataset + >>> train_data = DB5DGLDataset() + >>> val_data = DB5DGLDataset(mode='val') + >>> test_data = DB5DGLDataset(mode='test') + >>> + >>> len(test_data) + 55 + >>> test_data.num_chains + 2 + """ + + def __init__(self, + mode='test', + raw_dir=f'final{os.sep}raw', + knn=20, + self_loops=True, + percent_to_use=1.0, + process_complexes=True, + input_indep=False, + force_reload=False, + verbose=False): + assert mode in ['train', 'val', 'test'] + assert 0.0 < percent_to_use <= 1.0 + self.mode = mode + self.root = raw_dir + self.knn = knn + self.self_loops = self_loops + self.percent_to_use = percent_to_use # How much of the DB5 dataset to use + self.process_complexes = process_complexes # Whether to process any unprocessed complexes before training + self.input_indep = input_indep # Whether to use an input-independent pipeline to train the model + self.final_dir = os.path.join(*self.root.split(os.sep)[:-1]) + self.processed_dir = os.path.join(self.final_dir, 'processed') + + self.filename_sampling = 0.0 < self.percent_to_use < 1.0 + self.base_txt_filename, self.filenames_frame_txt_filename, self.filenames_frame_txt_filepath = \ + construct_filenames_frame_txt_filenames(self.mode, self.percent_to_use, self.filename_sampling, self.root) + + # Try to load the text file containing all DB5 filenames, and alert the user if it is missing or corrupted + filenames_frame_to_be_written = not os.path.exists(self.filenames_frame_txt_filepath) + + # Randomly sample DataFrame of filenames with requested cross validation ratio + if self.filename_sampling: + if filenames_frame_to_be_written: + try: + self.filenames_frame = pd.read_csv( + os.path.join(self.root, self.base_txt_filename + '.txt'), header=None) + except Exception: + raise FileNotFoundError( + build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath)) + self.filenames_frame = self.filenames_frame.sample(frac=self.percent_to_use).reset_index() + try: + self.filenames_frame[0].to_csv(self.filenames_frame_txt_filepath, header=None, index=None) + except Exception: + raise Exception( + build_filenames_frame_error_message('DB5-Plus', 'write', self.filenames_frame_txt_filepath)) + + # Load in existing DataFrame of filenames as requested (or if a sampled DataFrame .txt has already been written) + if not filenames_frame_to_be_written: + try: + self.filenames_frame = pd.read_csv(self.filenames_frame_txt_filepath, header=None) + except Exception: + raise FileNotFoundError( + build_filenames_frame_error_message('DB5-Plus', 'load', self.filenames_frame_txt_filepath)) + + # Process any unprocessed examples prior to using the dataset + self.process() + + super(DB5DGLDataset, self).__init__(name='DB5-Plus', + raw_dir=raw_dir, + force_reload=force_reload, + verbose=verbose) + print(f"Loaded DB5-Plus {mode}-set, source: {self.processed_dir}, length: {len(self)}") + + def download(self): + """Download and extract a pre-packaged version of the raw pairs if 'self.raw_dir' is not already populated.""" + # Path to store the file + gz_file_path = os.path.join(os.path.join(*self.raw_dir.split(os.sep)[:-1]), 'final_raw_db5.tar.gz') + + # Download file + download(self.url, path=gz_file_path) + + # Check SHA-1 + if not check_sha1(gz_file_path, self._sha1_str): + raise UserWarning('File {} is downloaded but the content hash does not match.' + 'The repo may be outdated or download may be incomplete. ' + 'Otherwise you can create an issue for it.'.format(gz_file_path)) + + # Remove existing raw directory to make way for the new archive to be extracted + if os.path.exists(self.raw_dir): + os.removedirs(self.raw_dir) + + # Extract archive to parent directory of `self.raw_dir` + self._extract_gz(gz_file_path, os.path.join(*self.raw_dir.split(os.sep)[:-1])) + + def process(self): + """Process each protein complex into a testing-ready dictionary representing both structures.""" + if self.process_complexes: + # Ensure the directory of processed complexes is already created + os.makedirs(self.processed_dir, exist_ok=True) + # Process each unprocessed protein complex + for (i, raw_path) in self.filenames_frame.iterrows(): + raw_filepath = os.path.join(self.root, f'{os.path.splitext(raw_path[0])[0]}.dill') + processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill') + if not os.path.exists(processed_filepath): + processed_parent_dir_to_make = os.path.join(self.processed_dir, os.path.split(raw_path[0])[0]) + os.makedirs(processed_parent_dir_to_make, exist_ok=True) + process_complex_into_dict(raw_filepath, processed_filepath, self.knn, self.self_loops, False) + + def has_cache(self): + """Check if each complex is downloaded and available for training, validation, or testing.""" + for (i, raw_path) in self.filenames_frame.iterrows(): + processed_filepath = os.path.join(self.processed_dir, f'{os.path.splitext(raw_path[0])[0]}.dill') + if not os.path.exists(processed_filepath): + print( + f'Unable to load at least one processed DB5 pair. ' + f'Please make sure all processed pairs have been successfully downloaded and are not corrupted.') + raise FileNotFoundError + print('DB5 cache found') # Otherwise, a cache was found! + + def __getitem__(self, idx): + r""" Get feature dictionary by index of complex. + + Parameters + ---------- + idx : int + + Returns + ------- + :class:`dict` + + - ``complex['graph1_node_feats']:`` PyTorch Tensor containing each of the first graph's encoded node features + - ``complex['graph2_node_feats']``: PyTorch Tensor containing each of the second graph's encoded node features + - ``complex['graph1_node_coords']:`` PyTorch Tensor containing each of the first graph's node coordinates + - ``complex['graph2_node_coords']``: PyTorch Tensor containing each of the second graph's node coordinates + - ``complex['graph1_edge_feats']:`` PyTorch Tensor containing each of the first graph's edge features for each node + - ``complex['graph2_edge_feats']:`` PyTorch Tensor containing each of the second graph's edge features for each node + - ``complex['graph1_nbrhd_indices']:`` PyTorch Tensor containing each of the first graph's neighboring node indices + - ``complex['graph2_nbrhd_indices']:`` PyTorch Tensor containing each of the second graph's neighboring node indices + - ``complex['examples']:`` PyTorch Tensor containing the labels for inter-graph node pairs + - ``complex['complex']:`` Python string describing the complex's code and original pdb filename + """ + # Assemble filepath of processed protein complex + complex_filepath = f'{os.path.splitext(self.filenames_frame[0][idx])[0]}.dill' + processed_filepath = os.path.join(self.processed_dir, complex_filepath) + + # Load in processed complex + with open(processed_filepath, 'rb') as f: + processed_complex = pickle.load(f) + processed_complex['filepath'] = complex_filepath # Add filepath to each complex dictionary + + # Optionally zero-out input data for an input-independent pipeline (per Karpathy's suggestion) + if self.input_indep: + processed_complex = zero_out_complex_features(processed_complex) + + # Manually filter for desired node and edge features + # n_feat_idx_1, n_feat_idx_2 = 43, 85 # HSAAC + # processed_complex['graph1'].ndata['f'] = processed_complex['graph1'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2] + # processed_complex['graph2'].ndata['f'] = processed_complex['graph2'].ndata['f'][:, n_feat_idx_1: n_feat_idx_2] + + # g1_rsa = processed_complex['graph1'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA + # g1_psaia = processed_complex['graph1'].ndata['f'][:, 37: 43] # PSAIA + # g1_hsaac = processed_complex['graph1'].ndata['f'][:, 43: 85] # HSAAC + # processed_complex['graph1'].ndata['f'] = torch.cat((g1_rsa, g1_psaia, g1_hsaac), dim=1) + # + # g2_rsa = processed_complex['graph2'].ndata['f'][:, 35: 36].reshape(-1, 1) # RSA + # g2_psaia = processed_complex['graph2'].ndata['f'][:, 37: 43] # PSAIA + # g2_hsaac = processed_complex['graph2'].ndata['f'][:, 43: 85] # HSAAC + # processed_complex['graph2'].ndata['f'] = torch.cat((g2_rsa, g2_psaia, g2_hsaac), dim=1) + + # processed_complex['graph1'].edata['f'] = processed_complex['graph1'].edata['f'][:, 1].reshape(-1, 1) + # processed_complex['graph2'].edata['f'] = processed_complex['graph2'].edata['f'][:, 1].reshape(-1, 1) + + # Return requested complex to DataLoader + return processed_complex + + def __len__(self) -> int: + r"""Number of graph batches in the dataset.""" + return len(self.filenames_frame) + + @property + def num_chains(self) -> int: + """Number of protein chains in each complex.""" + return 2 + + @property + def num_classes(self) -> int: + """Number of classes for each pair of inter-protein residues.""" + return 2 + + @property + def num_node_features(self) -> int: + """Number of node feature values after encoding them.""" + return 107 + + @property + def num_edge_features(self) -> int: + """Number of edge feature values after encoding them.""" + return 3 + + @property + def raw_path(self) -> str: + """Directory in which to locate raw pairs.""" + return self.raw_dir + + @property + def url(self) -> str: + """URL with which to download TAR archive of preprocessed pairs.""" + # TODO: Update URL + return 'https://zenodo.org/record/4815267/files/final_raw_db5.tar.gz?download=1' diff --git a/project/datasets/builder/annotate_idr_residues.py b/project/datasets/builder/annotate_idr_residues.py new file mode 100644 index 0000000..3c2576c --- /dev/null +++ b/project/datasets/builder/annotate_idr_residues.py @@ -0,0 +1,51 @@ +import click +import logging +import multiprocessing +import os + +from pathlib import Path + +from project.utils.utils import annotate_idr_residues + + +@click.command() +@click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) +@click.option('--num_cpus', '-c', default=1) +def main(raw_data_dir: str, num_cpus: int): + # Collect paths of files to analyze + raw_data_dir = Path(raw_data_dir) + raw_data_pickle_filepaths = [] + for root, dirs, files in os.walk(raw_data_dir): + for dir in dirs: + for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir): + for file in subfiles: + if file.endswith('.dill'): + raw_data_pickle_filepaths.append(raw_data_dir / dir / file) + + # Annotate whether each residue resides in an IDR, using multiprocessing # + # Define the number of processes to use + num_processes = min(num_cpus, multiprocessing.cpu_count()) + + # Split the list of file paths into chunks + chunk_size = len(raw_data_pickle_filepaths) // num_processes + file_path_chunks = [ + raw_data_pickle_filepaths[i:i+chunk_size] + for i in range(0, len(raw_data_pickle_filepaths), chunk_size) + ] + + # Create a pool of worker processes + pool = multiprocessing.Pool(processes=num_processes) + + # Process each chunk of file paths in parallel + pool.map(annotate_idr_residues, file_path_chunks) + + # Close the pool and wait for all processes to finish + pool.close() + pool.join() + + +if __name__ == "__main__": + log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' + logging.basicConfig(level=logging.INFO, format=log_fmt) + + main() diff --git a/project/datasets/builder/create_hdf5_dataset.py b/project/datasets/builder/create_hdf5_dataset.py new file mode 100644 index 0000000..b367072 --- /dev/null +++ b/project/datasets/builder/create_hdf5_dataset.py @@ -0,0 +1,42 @@ +import click +import logging +import os + +from pathlib import Path +from tqdm import tqdm + +from project.utils.utils import convert_pair_pickle_to_hdf5 + + +@click.command() +@click.argument('raw_data_dir', default='../DIPS/final/raw', type=click.Path(exists=True)) +def main(raw_data_dir: str): + raw_data_dir = Path(raw_data_dir) + raw_data_pickle_filepaths = [] + for root, dirs, files in os.walk(raw_data_dir): + for dir in dirs: + for subroot, subdirs, subfiles in os.walk(raw_data_dir / dir): + for file in subfiles: + if file.endswith('.dill'): + raw_data_pickle_filepaths.append(raw_data_dir / dir / file) + for pickle_filepath in tqdm(raw_data_pickle_filepaths): + convert_pair_pickle_to_hdf5( + pickle_filepath=pickle_filepath, + hdf5_filepath=Path(pickle_filepath).with_suffix(".hdf5") + ) + # filepath = Path("project/datasets/DIPS/final/raw/0g/10gs.pdb1_0.dill") + # pickle_example = convert_pair_hdf5_to_pickle( + # hdf5_filepath=Path(filepath).with_suffix(".hdf5") + # ) + # hdf5_file_example = convert_pair_hdf5_to_hdf5_file( + # hdf5_filepath=Path(filepath).with_suffix(".hdf5") + # ) + # print(f"pickle_example: {pickle_example}") + # print(f"hdf5_file_example: {hdf5_file_example}") + + +if __name__ == "__main__": + log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s' + logging.basicConfig(level=logging.INFO, format=log_fmt) + + main() diff --git a/project/datasets/builder/generate_hhsuite_features.py b/project/datasets/builder/generate_hhsuite_features.py index 2895a01..64a5f60 100644 --- a/project/datasets/builder/generate_hhsuite_features.py +++ b/project/datasets/builder/generate_hhsuite_features.py @@ -2,6 +2,7 @@ Source code (MIT-Licensed) inspired by Atom3 (https://github.com/drorlab/atom3 & https://github.com/amorehead/atom3) """ +import glob import logging import os from os import cpu_count @@ -24,15 +25,18 @@ @click.option('--num_cpus_per_job', '-c', default=2, type=int) @click.option('--num_iter', '-i', default=2, type=int) @click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5', 'evcoupling', 'casp_capri'])) +@click.option('--generate_hmm_profile/--generate_msa_only', '-p', default=True) @click.option('--write_file/--read_file', '-w', default=True) def main(pkl_dataset: str, pruned_dataset: str, hhsuite_db: str, output_dir: str, rank: int, - size: int, num_cpu_jobs: int, num_cpus_per_job: int, num_iter: int, source_type: str, write_file: bool): + size: int, num_cpu_jobs: int, num_cpus_per_job: int, num_iter: int, source_type: str, + generate_hmm_profile: bool, write_file: bool): """Run external programs for feature generation to turn raw PDB files from (../raw) into sequence or structure-based residue features (saved in ../interim/external_feats by default).""" logger = logging.getLogger(__name__) logger.info(f'Generating external features from PDB files in {pkl_dataset}') # Reestablish global rank rank = get_global_node_rank(rank, size) + logger.info(f"Assigned global rank {rank} of world size {size}") # Determine true rank and size for a given node bfd_copy_ids = ["_1", "_2", "_3", "_4", "_5", "_6", "_7", "_8", @@ -40,8 +44,9 @@ def main(pkl_dataset: str, pruned_dataset: str, hhsuite_db: str, output_dir: str bfd_copy_id = bfd_copy_ids[rank] # Assemble true ID of the BFD copy to use for generating profile HMMs - hhsuite_db = os.path.join(hhsuite_db + bfd_copy_id, - 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') + hhsuite_dbs = glob.glob(os.path.join(hhsuite_db + bfd_copy_id, "*bfd*")) + assert len(hhsuite_dbs) == 1, "Only a single BFD database must be present in the given database directory." + hhsuite_db = hhsuite_dbs[0] logger.info(f'Starting HH-suite for node {rank + 1} out of a global MPI world of size {size},' f' with a local MPI world size of {MPI.COMM_WORLD.Get_size()}.' f' This node\'s copy of the BFD is {hhsuite_db}') @@ -50,7 +55,7 @@ def main(pkl_dataset: str, pruned_dataset: str, hhsuite_db: str, output_dir: str # Run with --write_file=True using one node # Then run with --read_file=True using multiple nodes to distribute workload across nodes and their CPU cores map_all_profile_hmms(pkl_dataset, pruned_dataset, output_dir, hhsuite_db, num_cpu_jobs, - num_cpus_per_job, source_type, num_iter, rank, size, write_file) + num_cpus_per_job, source_type, num_iter, not generate_hmm_profile, rank, size, write_file) if __name__ == '__main__': diff --git a/project/utils/constants.py b/project/utils/constants.py index c3fd5f6..83afa1f 100644 --- a/project/utils/constants.py +++ b/project/utils/constants.py @@ -72,7 +72,7 @@ # Features to be one-hot encoded during graph processing and what their values could be FEAT_COLS = [ - # 'resname', # By default, leave out one-hot encoding of residues' type to decrease feature redundancy + 'resname', # By default, leave out one-hot encoding of residues' type to decrease feature redundancy 'ss_value', 'rsa_value', 'rd_value' @@ -88,8 +88,8 @@ ALLOWABLE_FEATS = [ # By default, leave out one-hot encoding of residues' type to decrease feature redundancy - # ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR", - # "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"], + ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR", + "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"], ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values [], # Empty list means take scalar value as is [], diff --git a/project/utils/modules.py b/project/utils/modules.py new file mode 100644 index 0000000..f0d2bca --- /dev/null +++ b/project/utils/modules.py @@ -0,0 +1,554 @@ +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code adapted from NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): +# ------------------------------------------------------------------------------------------------------------------------------------- + +from argparse import ArgumentParser + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics as tm +from torch.optim import Adam + +from project.utils.training_utils import construct_interact_tensor + + +# ------------------ +# PyTorch Modules +# ------------------ +class NeiAGraphConv(nn.Module): + """A neighborhood-averaging graph neural network layer as a PyTorch module. + + NeiAGraphConv stands for a Graph Convolution neighborhood-averaging layer. It is the + equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph conv layer in a GCN. + """ + + def __init__( + self, + num_node_feats: int, + num_edge_feats: int, + nbrhd_size: int, + activ_fn=nn.Tanh(), + dropout=0.1, + **kwargs + ): + """Neighborhood-Averaging Graph Conv Layer + + Parameters + ---------- + num_node_feats : int + Input node feature size. + num_edge_feats : int + Input edge feature size. + nbrhd_size : int + The size of each residue's receptive field for feature updates. + activ_fn : Module + Activation function to apply in MLPs. + dropout : float + How much dropout (forget rate) to apply before activation functions. + """ + super().__init__() + + # Initialize shared layer variables + self.num_node_feats = num_node_feats + self.num_edge_feats = num_edge_feats + self.nbrhd_size = nbrhd_size + self.activ_fn = activ_fn + self.dropout = dropout + + # Define weight matrix for neighboring node feature matrix (i.e. W^N for H_i) + self.W_N = nn.Linear(self.num_node_feats, self.num_node_feats, bias=False) + + # Define weight matrix for neighboring edge feature matrix (i.e. W^E for E_i) + self.W_E = nn.Linear(self.num_edge_feats, self.num_node_feats, bias=False) + + def forward(self, X: torch.Tensor, H: torch.Tensor, E: torch.Tensor, device: torch.device): + """Forward pass of the network + + Parameters + ---------- + X : Tensor + Tensor of node features to update with graph convolutions. + H : Tensor + Tensor of neighboring node features with which to convolve. + E : Tensor + Tensor of neighboring edge features with which to convolve. + device : torch.device + Computation device (e.g. CPU or GPU) on which to collect tensors. + """ + + # Create neighbor node signals + H_sig = self.W_N.weight.matmul(H.transpose(1, 2)) # (n_node_feats, n_nodes) + assert not H_sig.isnan().any() + + # Create neighbor in-edge signals + E_sig = self.W_E.weight.matmul(E.transpose(1, 2)) # (n_nodes, n_node_feats, n_nbrs) + assert not E_sig.isnan().any() + + # Create combined neighbor node + in-edge signals + Z = self.activ_fn(H_sig + E_sig) + assert not Z.isnan().any() + + # Average each learned feature vector in each sub-tensor Z_i corresponding to node i + Z_avg = Z / Z.shape[-1] + Z_sig = Z_avg.matmul(torch.ones(Z.shape[-1], device=device)) + Z_sig = F.dropout(Z_sig, p=self.dropout, training=self.training).squeeze() # Remove "1" leftover from 'q' + assert not Z_sig.isnan().any() + + # Apply a residual node feature update with the learned matrix Z after performing neighborhood averaging + X += Z_sig + assert not X.isnan().any() + + # Update node features of original graph via an updated subgraph + return X + + def __repr__(self): + return f'NeiAGraphConv(structure=h_in{self.num_node_feats}-h_hid{self.num_node_feats}' \ + f'-h_out{self.num_node_feats}-h_e{self.num_edge_feats})' + + +class NeiWAGraphConv(nn.Module): + """A neighborhood weighted-averaging graph neural network layer as a PyTorch module. + + NeiWAGraphConv stands for a Graph Convolution neighborhood weighted-averaging layer. It is the + equivalent of a linear layer in an MLP, a conv layer in a CNN, or a graph conv layer in a GCN. + """ + + def __init__( + self, + num_node_feats: int, + num_edge_feats: int, + nbrhd_size: int, + activ_fn=nn.Tanh(), + dropout=0.3, + **kwargs + ): + """Neighborhood Weighted-Averaging Graph Conv Layer + + Parameters + ---------- + num_node_feats : int + Input node feature size. + num_edge_feats : int + Input edge feature size. + nbrhd_size : int + The size of each residue's receptive field for feature updates. + activ_fn : Module + Activation function to apply in MLPs. + dropout : float + How much dropout (forget rate) to apply before activation functions. + """ + super().__init__() + + # Initialize shared layer variables + self.num_node_feats = num_node_feats + self.num_edge_feats = num_edge_feats + self.nbrhd_size = nbrhd_size + self.activ_fn = activ_fn + self.dropout = dropout + + # Define weight matrix for neighboring node feature matrix (i.e. W^N for H_i) + self.W_N = nn.Linear(self.num_node_feats, self.num_node_feats, bias=False) + + # Define weight matrix for neighboring edge feature matrix (i.e. W^E for E_i) + self.W_E = nn.Linear(self.num_edge_feats, self.num_node_feats, bias=False) + + # Define weight vector for neighboring node-edge matrix (i.e. the q in 'a = softmax(Z^T matmul q)') + self.q = nn.Linear(1, self.num_node_feats, bias=False) + + def forward(self, X: torch.Tensor, H: torch.Tensor, E: torch.Tensor, device: torch.device): + """Forward pass of the network + + Parameters + ---------- + X : Tensor + Tensor of node features to update with graph convolutions. + H : Tensor + Tensor of neighboring node features with which to convolve. + E : Tensor + Tensor of neighboring edge features with which to convolve. + device : torch.device + Computation device (e.g. CPU or GPU) on which to collect tensors. + """ + + # Create neighbor node signals + H_sig = self.W_N.weight.matmul(H.transpose(1, 2)) # (n_node_feats, n_nodes) + + # Create neighbor in-edge signals + E_sig = self.W_E.weight.matmul(E.transpose(1, 2)) # (n_nodes, n_node_feats, n_nbrs) + + # Create combined neighbor node + in-edge signals + Z = self.activ_fn(H_sig + E_sig) + + # Calculate weight vector for neighboring node-edge features (i.e. the a in 'a = softmax(Z^T matmul q)') + a = torch.softmax(Z.transpose(1, 2).matmul(self.q.weight), dim=0) # Element-wise softmax each row + + # Average each learned feature vector in each sub-tensor Z_i corresponding to node i + Z_avg = Z / Z.shape[-1] + Z_sig = Z_avg.matmul(a) + Z_sig = F.dropout(Z_sig, p=self.dropout, training=self.training).squeeze() # Remove "1" leftover from 'q' + + # Apply a residual node feature update with the learned matrix Z after performing neighborhood averaging + X += Z_sig.squeeze() # Remove trivial "1" dimension leftover from the vector q's definition + + # Update node features of original graph via an updated subgraph + return X + + def __repr__(self): + return f'NeiWAGraphConv(structure=h_in{self.num_node_feats}-h_hid{self.num_node_feats}' \ + f'-h_out{self.num_node_feats}-h_e{self.num_edge_feats})' + + +# ------------------ +# Lightning Modules +# ------------------ +class LitNeiA(pl.LightningModule): + """A siamese neighborhood-averaging (NeiA) module.""" + + def __init__(self, num_node_input_feats: int, num_edge_input_feats: int, gnn_activ_fn=nn.Tanh(), + interact_activ_fn=nn.ReLU(), num_classes=2, weighted_avg=False, num_gnn_layers=1, + num_interact_layers=3, num_interact_hidden_channels=214, num_epochs=50, pn_ratio=0.1, knn=20, + dropout_rate=0.3, metric_to_track='val_bce', weight_decay=1e-7, batch_size=32, lr=1e-5, + multi_gpu_backend="ddp"): + """Initialize all the parameters for a LitNeiA module.""" + super().__init__() + + # Build the network + self.num_node_input_feats = num_node_input_feats + self.num_edge_input_feats = num_edge_input_feats + self.gnn_activ_fn = gnn_activ_fn + self.interact_activ_fn = interact_activ_fn + self.num_classes = num_classes + self.weighted_avg = weighted_avg + + # GNN module's keyword arguments provided via the command line + self.num_gnn_layers = num_gnn_layers + + # Interaction module's keyword arguments provided via the command line + self.num_interact_layers = num_interact_layers + self.num_interact_hidden_channels = num_interact_hidden_channels + + # Model hyperparameter keyword arguments provided via the command line + self.num_epochs = num_epochs + self.pn_ratio = pn_ratio + self.nbrhd_size = knn + self.dropout_rate = dropout_rate + self.metric_to_track = metric_to_track + self.weight_decay = weight_decay + self.batch_size = batch_size + self.lr = lr + self.multi_gpu_backend = multi_gpu_backend + + # Assemble the layers of the network + self.gnn_block = self.build_gnn_block() + self.init_res_block, self.interim_res_blocks, self.final_res_block, self.final_conv_layer = self.build_i_block() + + # Declare loss functions and metrics for training, validation, and testing + self.loss_fn = nn.BCEWithLogitsLoss() + self.test_auroc = tm.AUROC(average='weighted', pos_label=1) + self.test_auprc = tm.AveragePrecision(pos_label=1) + self.test_acc = tm.Accuracy(average='weighted', num_classes=self.num_classes, multiclass=True) + self.test_f1 = tm.F1(average='weighted', num_classes=self.num_classes, multiclass=True) + + # Log hyperparameters + self.save_hyperparameters() + + def build_gnn_block(self): + """Define the layers for all NeiA GNN modules.""" + # Marshal all GNN layers, allowing the user to choose which kind of neighborhood averaging they would like + if self.weighted_avg: + gnn_layer = (NeiWAGraphConv(num_node_feats=self.num_node_input_feats, + num_edge_feats=self.num_edge_input_feats, + nbrhd_size=self.nbrhd_size, + activ_fn=self.gnn_activ_fn, + dropout=self.dropout_rate)) + else: + gnn_layer = (NeiAGraphConv(num_node_feats=self.num_node_input_feats, + num_edge_feats=self.num_edge_input_feats, + nbrhd_size=self.nbrhd_size, + activ_fn=self.gnn_activ_fn, + dropout=self.dropout_rate)) + gnn_layers = [gnn_layer for _ in range(self.num_gnn_layers)] + return nn.ModuleList(gnn_layers) + + def get_res_block(self): + """Retrieve a residual block of a specific type (e.g. ResNet).""" + res_block = nn.ModuleList([ + nn.Conv2d(in_channels=self.num_interact_hidden_channels, + out_channels=self.num_interact_hidden_channels, + kernel_size=(3, 3), + padding=(1, 1)), + self.interact_activ_fn, + nn.Conv2d(in_channels=self.num_interact_hidden_channels, + out_channels=self.num_interact_hidden_channels, + kernel_size=(3, 3), + padding=(1, 1)), + ]) + return res_block + + def build_i_block(self): + """Define the layers of the interaction block for an interaction tensor.""" + # Marshal all interaction layers, beginning with the initial residual block + init_res_block = nn.ModuleList([ + nn.Conv2d(self.num_node_input_feats * 2, + self.num_interact_hidden_channels, + kernel_size=(1, 1), + padding=(0, 0)), + self.interact_activ_fn, + nn.Conv2d(self.num_interact_hidden_channels, + self.num_interact_hidden_channels, + kernel_size=(1, 1), + padding=(0, 0)), + ]) + # Unroll requested number of intermediate residual blocks + interim_res_blocks = [] + for _ in range(self.num_interact_layers - 2): + interim_res_block = self.get_res_block() + interim_res_blocks.append(interim_res_block) + interim_res_blocks = nn.ModuleList(interim_res_blocks) + # Attach final residual block to project channel dimensionality down to original size + final_res_block = nn.ModuleList([ + nn.Conv2d(self.num_interact_hidden_channels, + self.num_interact_hidden_channels, + kernel_size=(1, 1), + padding=(0, 0)), + self.interact_activ_fn, + nn.Conv2d(self.num_interact_hidden_channels, + self.num_node_input_feats * 2, + kernel_size=(1, 1), + padding=(0, 0)) + ]) + # Craft final convolution layer to project channel dimensionality down to 1, the target number of channels + final_conv_layer = nn.Conv2d(self.num_node_input_feats * 2, 1, kernel_size=(1, 1), padding=(0, 0)) + return init_res_block, interim_res_blocks, final_res_block, final_conv_layer + + # --------------------- + # Training + # --------------------- + def gnn_forward(self, node_feats: torch.Tensor, nbrhd_node_feats: torch.Tensor, + nbrhd_edge_feats: torch.Tensor, gnn_layer_id: int): + """Make a forward pass through a single GNN layer.""" + # Convolve over graph nodes and edges + new_node_feats = self.gnn_block[gnn_layer_id](node_feats, nbrhd_node_feats, nbrhd_edge_feats, self.device) + return new_node_feats + + def interact_forward(self, interact_tensor: torch.Tensor): + """Make a forward pass through the interaction blocks.""" + residual, logits = interact_tensor, interact_tensor + # Convolve over the 3D "interaction" tensor given using the initial residual block + for layer in self.init_res_block: + logits = layer(logits) + logits += residual + residual = logits + logits = self.interact_activ_fn(logits) + # Convolve over the 3D "interaction" logits using the interim residual blocks + for interim_res_block in self.interim_res_blocks: + for layer in interim_res_block: + logits = layer(logits) + logits += residual + residual = logits + logits = self.interact_activ_fn(logits) + # Convolve over the 3D "interaction" tensor given using the final residual block + for layer in self.final_res_block: + logits = layer(logits) + logits += residual + logits = self.interact_activ_fn(logits) + # Project number of channels down to target size, 1 + logits = self.final_conv_layer(logits) + return logits + + def forward(self, cmplx: dict, labels: torch.Tensor): + """Make a forward pass through the entire siamese network.""" + # Make a copy of the complex's feature and index tensors to prevent feature overflow between epochs + graph1_node_feats = cmplx['graph1_node_feats'].clone() + graph2_node_feats = cmplx['graph2_node_feats'].clone() + graph1_nbrhd_indices = cmplx['graph1_nbrhd_indices'].clone() + graph2_nbrhd_indices = cmplx['graph2_nbrhd_indices'].clone() + graph1_nbrhd_node_feats = cmplx['graph1_node_feats'].clone() + graph2_nbrhd_node_feats = cmplx['graph2_node_feats'].clone() + graph1_nbrhd_edge_feats = cmplx['graph1_edge_feats'].clone() + graph2_nbrhd_edge_feats = cmplx['graph2_edge_feats'].clone() + + # Replace any leftover NaN values in edge features with zero + if True in torch.isnan(cmplx['graph1_edge_feats']): + graph1_nbrhd_edge_feats = torch.tensor( + np.nan_to_num(graph1_nbrhd_edge_feats.cpu().numpy()), device=self.device + ) + if True in torch.isnan(cmplx['graph2_edge_feats']): + graph2_nbrhd_edge_feats = torch.tensor( + np.nan_to_num(graph2_nbrhd_edge_feats.cpu().numpy()), device=self.device + ) + + # Secure layer-specific copy of node features to restrict each node's receptive field to the current hop + graph1_layer_node_feats = cmplx['graph1_node_feats'].clone() + graph2_layer_node_feats = cmplx['graph2_node_feats'].clone() + # Convolve node features using a specified number of GNN layers + for gnn_layer_id in range(len(self.gnn_block)): + # Update node features in batches of residue-residue pairs in a node-unique manner + unique_examples = max(len(torch.unique(labels[:, 0])), len(torch.unique(labels[:, 1]))) + for i in range(int(unique_examples / self.batch_size)): + index = int(i * self.batch_size) + # Get a batch of unique node IDs + batch = labels[index: index + self.batch_size] + graph1_batch_n_ids, graph2_batch_n_ids = batch[:, 0], batch[:, 1] + g1_nbrhd_indices = graph1_nbrhd_indices[graph1_batch_n_ids].squeeze() + g2_nbrhd_indices = graph2_nbrhd_indices[graph2_batch_n_ids].squeeze() + # Get unique features selected for the batch + g1_node_feats = graph1_node_feats[graph1_batch_n_ids] + g2_node_feats = graph2_node_feats[graph2_batch_n_ids] + g1_nbrhd_node_feats = graph1_nbrhd_node_feats[g1_nbrhd_indices].reshape( + -1, g1_nbrhd_indices.shape[-1], graph1_nbrhd_node_feats.shape[-1] + ) + g2_nbrhd_node_feats = graph2_nbrhd_node_feats[g2_nbrhd_indices].reshape( + -1, g2_nbrhd_indices.shape[-1], graph2_nbrhd_node_feats.shape[-1] + ) + g1_nbrhd_edge_feats = graph1_nbrhd_edge_feats[graph1_batch_n_ids] + g2_nbrhd_edge_feats = graph2_nbrhd_edge_feats[graph2_batch_n_ids] + # Forward propagate with weight-shared GNN layers using batch of residues + updated_node_feats1 = self.gnn_forward( + g1_node_feats, g1_nbrhd_node_feats, g1_nbrhd_edge_feats, gnn_layer_id + ) + updated_node_feats2 = self.gnn_forward( + g2_node_feats, g2_nbrhd_node_feats, g2_nbrhd_edge_feats, gnn_layer_id + ) + # Update original node features according to updated node feature batch + graph1_layer_node_feats[graph1_batch_n_ids] = updated_node_feats1 + graph2_layer_node_feats[graph2_batch_n_ids] = updated_node_feats2 + # Update original clone of node features for next hop + graph1_node_feats = graph1_layer_node_feats.clone() + graph2_node_feats = graph2_layer_node_feats.clone() + graph1_nbrhd_node_feats = graph1_layer_node_feats.clone() + graph2_nbrhd_node_feats = graph2_layer_node_feats.clone() + + # Interleave node features from both graphs to achieve the desired interaction tensor + interact_tensor = construct_interact_tensor(graph1_node_feats, graph2_node_feats) + + # Predict residue-residue pair interactions using a convolution block (i.e. series of residual CNN blocks) + logits = self.interact_forward(interact_tensor) + + # Return network prediction + return logits.squeeze() # Remove any trivial dimensions from logits + + def downsample_examples(self, examples: torch.tensor): + """Randomly sample enough negative pairs to achieve requested positive-negative class ratio (via shuffling).""" + examples = examples[torch.randperm(len(examples))] # Randomly shuffle examples (during training) + pos_examples = examples[examples[:, 2] == 1] # Find out how many interacting pairs there are + num_neg_pairs_to_sample = int(len(pos_examples) / self.pn_ratio) # Determine negative sample size + neg_examples = examples[examples[:, 2] == 0][:num_neg_pairs_to_sample] # Sample negative pairs + downsampled_examples = torch.cat((pos_examples, neg_examples)) + return downsampled_examples + + def training_step(self, batch, batch_idx): + """Lightning calls this inside the training loop.""" + # Make a forward pass through the network for a batch of protein complexes + cmplx = batch[0] + examples = cmplx['examples'] + examples = self.downsample_examples(examples) + logits = self(cmplx, examples) + sampled_indices = examples[:, :2][:, 1] + logits.shape[1] * examples[:, :2][:, 0] # 1d_idx = x + width * y + flattened_logits = torch.flatten(logits) + downsampled_logits = flattened_logits[sampled_indices] + + # Down-weight negative pairs to achieve desired PN weight, leaving positive pairs with a weight of one + sample_weights = examples[:, 2].float() + sample_weights[sample_weights == 0] = self.pn_ratio + loss_fn = nn.BCEWithLogitsLoss(weight=sample_weights) # Weight each class separately for a given complex + loss = loss_fn(downsampled_logits, examples[:, 2].float()) # Calculate loss of a single complex + + # Log training step metric(s) + self.log('train_bce', loss) + + return {'loss': loss} + + def validation_step(self, batch, batch_idx): + """Lightning calls this inside the validation loop.""" + # Make a forward pass through the network for a batch of protein complexes + cmplx = batch[0] + examples = cmplx['examples'] + logits = self(cmplx, examples) + sampled_indices = examples[:, :2][:, 1] + logits.shape[1] * examples[:, :2][:, 0] # 1d_idx = x + width * y + flattened_logits = torch.flatten(logits) + sampled_logits = flattened_logits[sampled_indices] + + # Calculate the complex loss and metrics + loss = self.loss_fn(sampled_logits, examples[:, 2].float()) # Calculate loss of a single complex + + # Log validation step metric(s) + self.log('val_bce', loss, sync_dist=True) + + return {'loss': loss} + + def test_step(self, batch, batch_idx): + """Lightning calls this inside the testing loop.""" + # Make a forward pass through the network for a batch of protein complexes + cmplx = batch[0] + examples = cmplx['examples'] + logits = self(cmplx, examples) + sampled_indices = examples[:, :2][:, 1] + logits.shape[1] * examples[:, :2][:, 0] # 1d_idx = x + width * y + flattened_logits = torch.flatten(logits) + sampled_logits = flattened_logits[sampled_indices] + + # Make predictions + preds = torch.softmax(sampled_logits, dim=0) + preds_rounded = torch.round(preds) + int_labels = examples[:, 2].int() + + # Calculate the complex loss and metrics + loss = self.loss_fn(sampled_logits, examples[:, 2].float()) # Calculate loss of a single complex + test_acc = self.test_acc(preds_rounded, int_labels) # Calculate Accuracy of a single complex + test_f1 = self.test_f1(preds_rounded, int_labels) # Calculate F1 score of a single complex + test_auroc = self.test_auroc(preds, int_labels) # Calculate AUROC of a complex + test_auprc = self.test_auprc(preds, int_labels) # Calculate AveragePrecision (i.e. AUPRC) of a complex + + # Log test step metric(s) + self.log('test_bce', loss, sync_dist=True) + + return { + 'loss': loss, 'test_acc': test_acc, 'test_f1': test_f1, 'test_auroc': test_auroc, 'test_auprc': test_auprc + } + + def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT): + """Lightning calls this at the end of every test epoch.""" + # Tuplize scores for the current device (e.g. Rank 0) + test_accs = torch.cat([output_dict['test_acc'].unsqueeze(0) for output_dict in outputs]) + test_f1s = torch.cat([output_dict['test_f1'].unsqueeze(0) for output_dict in outputs]) + test_aurocs = torch.cat([output_dict['test_auroc'].unsqueeze(0) for output_dict in outputs]) + test_auprcs = torch.cat([output_dict['test_auprc'].unsqueeze(0) for output_dict in outputs]) + # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive + test_accs = test_accs if self.multi_gpu_backend in ["dp"] else torch.cat([test_acc for test_acc in self.all_gather(test_accs)]) + test_f1s = test_f1s if self.multi_gpu_backend in ["dp"] else torch.cat([test_f1 for test_f1 in self.all_gather(test_f1s)]) + test_aurocs = test_aurocs if self.multi_gpu_backend in ["dp"] else torch.cat([test_auroc for test_auroc in self.all_gather(test_aurocs)]) + test_auprcs = test_auprcs if self.multi_gpu_backend in ["dp"] else torch.cat([test_auprc for test_auprc in self.all_gather(test_auprcs)]) + + # Reset test TorchMetrics for all devices + self.test_acc.reset() + self.test_f1.reset() + self.test_auroc.reset() + self.test_auprc.reset() + + # When logging only on rank 0, add 'rank_zero_only=True' to avoid deadlocks on synchronization + if self.trainer.is_global_zero: + self.log('med_test_acc', torch.median(test_accs), rank_zero_only=True) # Log MedAccuracy of an epoch + self.log('med_test_f1', torch.median(test_f1s), rank_zero_only=True) # Log MedF1 of an epoch + self.log('med_test_auroc', torch.median(test_aurocs), rank_zero_only=True) # Log MedAUROC of an epoch + self.log('med_test_auprc', torch.median(test_auprcs), rank_zero_only=True) # Log epoch MedAveragePrecision + + # --------------------- + # Training Setup + # --------------------- + def configure_optimizers(self): + """Called to configure the trainer's optimizer(s).""" + optimizer = Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + return [optimizer] + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + # ----------------- + # Model arguments + # ----------------- + parser.add_argument('--num_interact_hidden_channels', type=int, default=214, + help='Dimensionality of interaction module filters') + return parser diff --git a/project/utils/training_constants.py b/project/utils/training_constants.py new file mode 100644 index 0000000..957b503 --- /dev/null +++ b/project/utils/training_constants.py @@ -0,0 +1,99 @@ +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code curated for NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): +# ------------------------------------------------------------------------------------------------------------------------------------- +import numpy as np + +# Dataset-global node count limits to restrict computational learning complexity +ATOM_COUNT_LIMIT = 2048 # Default atom count filter for DIPS-Plus when encoding complexes at an atom-based level +RESIDUE_COUNT_LIMIT = 256 # Default residue count limit for DIPS-Plus (empirically determined for smoother training) +NODE_COUNT_LIMIT = 2304 # An upper-bound on the node count limit for Geometric Transformers - equal to 9-sized batch +KNN = 20 # Default number of nearest neighbors to query for during graph message passing + +# The PDB codes of structures added between DB4 and DB5 (to be used for testing dataset) +DB5_TEST_PDB_CODES = ['3R9A', '4GAM', '3AAA', '4H03', '1EXB', + '2GAF', '2GTP', '3RVW', '3SZK', '4IZ7', + '4GXU', '3BX7', '2YVJ', '3V6Z', '1M27', + '4FQI', '4G6J', '3BIW', '3PC8', '3HI6', + '2X9A', '3HMX', '2W9E', '4G6M', '3LVK', + '1JTD', '3H2V', '4DN4', 'BP57', '3L5W', + '3A4S', 'CP57', '3DAW', '3VLB', '3K75', + '2VXT', '3G6D', '3EO1', '4JCV', '4HX3', + '3F1P', '3AAD', '3EOA', '3MXW', '3L89', + '4M76', 'BAAD', '4FZA', '4LW4', '1RKE', + '3FN1', '3S9D', '3H11', '2A1A', '3P57'] + +# Default fill values for missing features +HSAAC_DIM = 42 # We have 2 + (2 * 20) HSAAC values from the two instances of the unknown residue symbol '-' +DEFAULT_MISSING_FEAT_VALUE = np.nan +DEFAULT_MISSING_SS = '-' +DEFAULT_MISSING_RSA = DEFAULT_MISSING_FEAT_VALUE +DEFAULT_MISSING_RD = DEFAULT_MISSING_FEAT_VALUE +DEFAULT_MISSING_PROTRUSION_INDEX = [DEFAULT_MISSING_FEAT_VALUE for _ in range(6)] +DEFAULT_MISSING_HSAAC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(HSAAC_DIM)] +DEFAULT_MISSING_CN = DEFAULT_MISSING_FEAT_VALUE +DEFAULT_MISSING_SEQUENCE_FEATS = np.array([DEFAULT_MISSING_FEAT_VALUE for _ in range(27)]) +DEFAULT_MISSING_NORM_VEC = [DEFAULT_MISSING_FEAT_VALUE for _ in range(3)] + +# Dict for converting three letter codes to one letter codes +D3TO1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K', + 'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', + 'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', + 'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'} + +# PSAIA features to encode as DataFrame columns +PSAIA_COLUMNS = ['avg_cx', 's_avg_cx', 's_ch_avg_cx', 's_ch_s_avg_cx', 'max_cx', 'min_cx'] + +# Features to be one-hot encoded during graph processing and what their values could be +FEAT_COLS = [ + 'resname', # [7:27] + 'ss_value', # [27:35] + 'rsa_value', # [35:36] + 'rd_value' # [36:37] +] +FEAT_COLS.extend( + PSAIA_COLUMNS + # [37:43] + ['hsaac', # [43:85] + 'cn_value', # [85:86] + 'sequence_feats', # [86:113] + 'amide_norm_vec', # [Stored separately] + # 'element' # For atom-level learning only + ]) + +ALLOWABLE_FEATS = [ + ["TRP", "PHE", "LYS", "PRO", "ASP", "ALA", "ARG", "CYS", "VAL", "THR", + "GLY", "SER", "HIS", "LEU", "GLU", "TYR", "ILE", "ASN", "MET", "GLN"], + ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-'], # Populated 1D list means restrict column feature values by list values + [], # Empty list means take scalar value as is + [], + [], + [], + [], + [], + [], + [], + [[]], # Doubly-nested, empty list means take first-level nested list as is + [], + [[]], + [[]], + # ['C', 'O', 'N', 'S'] # For atom-level learning only +] + +# A schematic of which tensor indices correspond to which node and edge features +FEATURE_INDICES = { + # Node feature indices + 'node_pos_enc': 0, + 'node_geo_feats_start': 1, + 'node_geo_feats_end': 7, + 'node_dips_plus_feats_start': 7, + 'node_dips_plus_feats_end': 113, + # Edge feature indices + 'edge_pos_enc': 0, + 'edge_weights': 1, + 'edge_dist_feats_start': 2, + 'edge_dist_feats_end': 20, + 'edge_dir_feats_start': 20, + 'edge_dir_feats_end': 23, + 'edge_orient_feats_start': 23, + 'edge_orient_feats_end': 27, + 'edge_amide_angles': 27 +} diff --git a/project/utils/training_utils.py b/project/utils/training_utils.py new file mode 100644 index 0000000..b684eee --- /dev/null +++ b/project/utils/training_utils.py @@ -0,0 +1,463 @@ +import logging +import os +import pickle +from argparse import ArgumentParser +from typing import List + +import dgl +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from dgl.nn import pairwise_squared_distance +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger + +from project.utils.constants import D3TO1, FEAT_COLS, ALLOWABLE_FEATS, DEFAULT_MISSING_HSAAC, HSAAC_DIM + + +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code adapted from Atom3D (https://github.com/drorlab/atom3d/blob/master/benchmarking/pytorch_geometric/ppi_dataloader.py): +# ------------------------------------------------------------------------------------------------------------------------------------- +def prot_df_to_dgl_graph_feats(df: pd.DataFrame, feat_cols: List, allowable_feats: List[List], knn: int): + r"""Convert protein in dataframe representation to a graph compatible with DGL, where each node is a residue. + + :param df: Protein structure in dataframe format. + :type df: pandas.DataFrame + :param feat_cols: Columns of dataframe in which to find node feature values. For example, for residues use ``feat_cols=["element", ...]`` and for residues use ``feat_cols=["resname", ...], or both!`` + :type feat_cols: list[list[Any]] + :param allowable_feats: List of lists containing all possible values of node type, to be converted into 1-hot node features. + Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`). + :param knn: Maximum number of nearest neighbors (i.e. edges) to allow for a given node. + :type knn: int + + :return: tuple containing + - knn_graph (dgl.DGLGraph): K-nearest neighbor graph for the structure DataFrame given. + + - pairwise_dists (torch.FloatTensor): Pairwise squared distances for the K-nearest neighbor graph's coordinates. + + - node_coords (torch.FloatTensor): Cartesian coordinates of each node. + + - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``. + :rtype: Tuple + """ + # Exit early if feat_cols or allowable_feats do not align in dimensionality + if len(feat_cols) != len(allowable_feats): + raise Exception('feat_cols does not match the length of allowable_feats') + + # Aggregate structure-based node features + node_feats = torch.FloatTensor([]) + for i in range(len(feat_cols)): + # Search through embedded 2D list for allowable values + feat_vecs = [one_of_k_encoding_unk(feat, allowable_feats[i], feat_cols[i]) for feat in df[feat_cols[i]]] + one_hot_feat_vecs = torch.FloatTensor(feat_vecs) + node_feats = torch.cat((node_feats, one_hot_feat_vecs), 1) + + # Organize residue coordinates into a FloatTensor + node_coords = torch.tensor(df[['x', 'y', 'z']].values, dtype=torch.float32) + + # Define edges - KNN argument determines whether a residue-residue edge gets created in the resulting graph + knn_graph = dgl.knn_graph(node_coords, knn) + pairwise_dists = torch.topk(pairwise_squared_distance(node_coords), knn, 1, largest=False).values + + return knn_graph, pairwise_dists, node_coords, node_feats + + +def one_of_k_encoding_unk(feat, allowable_set, feat_col): + """Converts input to 1-hot encoding given a set of (or sets of) allowable values. Additionally maps inputs not in the allowable set to the last element.""" + if len(allowable_set) == 0: # e.g. RSA values + return [feat] + elif len(allowable_set) == 1 and type(allowable_set[0]) == list and len(allowable_set[0]) == 0: # e.g. HSAAC values + if len(feat) == 0: + return DEFAULT_MISSING_HSAAC if feat_col == 'hsaac' else [] # Else means skip encoding amide_norm_vec + if feat_col == 'hsaac' and len(feat) > HSAAC_DIM: # Handle for edge case from postprocessing + return np.array(DEFAULT_MISSING_HSAAC) + return feat if feat_col == 'hsaac' or feat_col == 'sequence_feats' else [] # Else means skip encoding amide_norm_vec as a node feature + else: # e.g. Residue element type values + if feat not in allowable_set: + feat = allowable_set[-1] + return list(map(lambda s: feat == s, allowable_set)) + + +# ------------------------------------------------------------------------------------------------------------------------------------- +# Following code curated for NeiA-PyTorch (https://github.com/amorehead/NeiA-PyTorch): +# ------------------------------------------------------------------------------------------------------------------------------------- +def construct_filenames_frame_txt_filenames(mode: str, percent_to_use: float, filename_sampling: bool, root: str): + """Build the file path of the requested filename DataFrame text file.""" + base_txt_filename = f'pairs-postprocessed' if mode == 'full' else f'pairs-postprocessed-{mode}' + filenames_frame_txt_filename = base_txt_filename + f'-{int(percent_to_use * 100)}%-sampled.txt' \ + if filename_sampling else base_txt_filename + '.txt' + filenames_frame_txt_filepath = os.path.join(root, filenames_frame_txt_filename) + return base_txt_filename, filenames_frame_txt_filename, filenames_frame_txt_filepath + + +def build_filenames_frame_error_message(dataset: str, task: str, filenames_frame_txt_filepath: str): + """Assemble the standard error message for a corrupt or missing filenames DataFrame text file.""" + return f'Unable to {task} {dataset} filenames text file' \ + f' (i.e. {filenames_frame_txt_filepath}).' \ + f' Please make sure it is downloaded and not corrupted.' + + +def min_max_normalize_tensor(tensor: torch.Tensor, device=None): + """Normalize provided tensor to have values be in range [0, 1].""" + min_value = min(tensor) + max_value = max(tensor) + tensor = torch.tensor([(value - min_value) / (max_value - min_value) for value in tensor], device=device) + return tensor + + +def convert_df_to_dgl_graph(df: pd.DataFrame, input_file: str, knn: int, self_loops: bool) -> dgl.DGLGraph: + r""" Transform a given DataFrame of residues into a corresponding DGL graph. + + Parameters + ---------- + df : pandas.DataFrame + input_file : str + knn : int + self_loops : bool + + Returns + ------- + :class:`dgl.DGLGraph` + + Graph structure, feature tensors for each node and edge. + +... node_feats = graph.ndata['f'] +... node_coords = graph.ndata['x'] +... edge_weights = graph.edata['w'] +... residue_residue_angles = graph.edata['a'] + + - ``ndata['f']``: feature tensors of the nodes + - ``ndata['x']:`` Cartesian coordinate tensors of the nodes + - ``ndata['f']``: feature tensors of the edges + """ + # Derive node features, with edges being defined via a k-nearest neighbors approach and a maximum distance threshold + struct_df = df[df['atom_name'] == 'CA'] + graph, _, node_coords, node_feats = prot_df_to_dgl_graph_feats( + struct_df, # Only use CA atoms when constructing the initial graph + FEAT_COLS, + ALLOWABLE_FEATS, + knn + ) + + # Retrieve src and destination node IDs + srcs = graph.edges()[0] + dsts = graph.edges()[1] + + # Remove self-loops (if requested) + if not self_loops: + graph = dgl.remove_self_loop(graph) + srcs = graph.edges()[0] + dsts = graph.edges()[1] + + # Manually add isolated nodes (i.e. those with no connected edges) to the graph + if len(node_feats) > graph.number_of_nodes(): + num_of_isolated_nodes = len(node_feats) - graph.number_of_nodes() + raise Exception(f'{num_of_isolated_nodes} isolated node(s) detected in {input_file}') + + """Encode node features and labels in graph""" + # Positional encoding for each node (used for Transformer-like GNNs) + graph.ndata['f'] = min_max_normalize_tensor(graph.nodes()).reshape(-1, 1) # [num_res_in_struct_df, 1] + # One-hot features for each residue + graph.ndata['f'] = torch.cat((graph.ndata['f'], node_feats), dim=1) # [num_res_in_struct_df, num_node_feats] + # Cartesian coordinates for each residue + graph.ndata['x'] = node_coords # [num_res_in_struct_df, 3] + + """Encode edge features and labels in graph""" + # Positional encoding for each edge (used for sequentially-ordered inputs like proteins) + graph.edata['f'] = torch.sin((graph.edges()[0] - graph.edges()[1]).float()).reshape(-1, 1) # [num_edges, 1] + # Normalized edge weights (according to Euclidean distance) + edge_weights = min_max_normalize_tensor(torch.sum(node_coords[srcs] - node_coords[dsts] ** 2, 1)).reshape(-1, 1) + graph.edata['f'] = torch.cat((graph.edata['f'], edge_weights), dim=1) # [num_edges, 1] + + # Angle between the two amide normal vectors for a pair of residues, for all edge-connected residue pairs + plane1 = struct_df[['amide_norm_vec']].iloc[dsts] + plane2 = struct_df[['amide_norm_vec']].iloc[srcs] + plane1.columns = ['amide_norm_vec'] + plane2.columns = ['amide_norm_vec'] + plane1 = torch.from_numpy(np.stack(plane1['amide_norm_vec'].values).astype('float32')) + plane2 = torch.from_numpy(np.stack(plane2['amide_norm_vec'].values).astype('float32')) + angles = np.array([ + torch.acos(torch.dot(vec1, vec2) / (torch.linalg.norm(vec1) * torch.linalg.norm(vec2))) + for vec1, vec2 in zip(plane1, plane2) + ]) + # Ensure amide plane normal vector angles on each edge are zeroed out rather than being left as NaN (in some cases) + np.nan_to_num(angles, copy=False, nan=0.0, posinf=None, neginf=None) + amide_angles = torch.from_numpy(np.nan_to_num( + min_max_normalize_tensor(torch.from_numpy(angles)).cpu().numpy(), + copy=True, nan=0.0, posinf=None, neginf=None + )).reshape(-1, 1) # [num_edges, 1] + graph.edata['f'] = torch.cat((graph.edata['f'], amide_angles), dim=1) # Amide-amide angles: [num_edges, 1] + + return graph + + +def build_complex_labels(bound_complex: any, df0: pd.DataFrame, df1: pd.DataFrame, + df0_index_to_node_id: dict, df1_index_to_node_id: dict, shuffle: bool): + """ Construct the labels matrix for a given protein complex and mode (e.g. train, val, or test).""" + # Get Cartesian product of CA atom row indices for both structures, making an array copy for future view calls + index_pairs = np.transpose([ + np.tile(df0.index.values, len(df1.index.values)), np.repeat(df1.index.values, len(df0.index.values)) + ]).copy() + + # Get an array copy of pos_idx and neg_idx for (row) view calls + pos_idx = bound_complex.pos_idx.copy() + + # Derive inter-protein node-node (i.e. residue-residue) interaction array (Interacting = 1, 0 otherwise) + pos_labels = np.hstack((pos_idx, np.ones((len(pos_idx), 1), dtype=np.int64))) + labels = pos_labels + + # Find residue-residue pairs not already included in pos_idx + index_pair_rows = index_pairs.view([('', index_pairs.dtype)] * index_pairs.shape[1]) + pos_idx_rows = pos_idx.view([('', pos_idx.dtype)] * pos_idx.shape[1]) + unique_index_pairs = np.setdiff1d(index_pair_rows, pos_idx_rows) \ + .view(index_pairs.dtype).reshape(-1, index_pairs.shape[1]).copy() + + new_labels = np.hstack( + (unique_index_pairs, np.zeros((len(unique_index_pairs), 1), dtype=np.int64))) + + # Derive inter-protein node-node (i.e. residue-residue) interaction matrix (Interacting = 1, 0 otherwise) + labels = np.concatenate((labels, new_labels)) + + # Shuffle rows corresponding to residue-residue pairs + if shuffle: + np.random.shuffle(labels) + + # Map DataFrame indices to graph node IDs for each structure + labels[:, 0] = np.vectorize(df0_index_to_node_id.get)(labels[:, 0]) + labels[:, 1] = np.vectorize(df1_index_to_node_id.get)(labels[:, 1]) + + # Return new labels matrix + return torch.from_numpy(labels) + + +def process_complex_into_dict(raw_filepath: str, processed_filepath: str, + knn: int, self_loops: bool, check_sequence: bool): + """Process protein complex into a dictionary representing both structures and ready for a given mode (e.g. val).""" + # Retrieve specified complex + bound_complex = pd.read_pickle(raw_filepath) + + # Isolate CA atoms in each structure's DataFrame + df0 = bound_complex.df0[bound_complex.df0['atom_name'] == 'CA'] + df1 = bound_complex.df1[bound_complex.df1['atom_name'] == 'CA'] + + # Ensure that the sequence of each DataFrame's residues matches its original FASTA sequence, character-by-character + if check_sequence: + df0_sequence = bound_complex.sequences['l_b'] + for i, (df_res_name, orig_res) in enumerate(zip(df0['resname'].values, df0_sequence)): + if D3TO1[df_res_name] != orig_res: + raise Exception(f'DataFrame 0 residue sequence does not match original FASTA sequence at position {i}') + df1_sequence = bound_complex.sequences['r_b'] + for i, (df_res_name, orig_res) in enumerate(zip(df1['resname'].values, df1_sequence)): + if D3TO1[df_res_name] != orig_res: + raise Exception(f'DataFrame 1 residue sequence does not match original FASTA sequence at position {i}') + + # Convert each DataFrame into its DGLGraph representation, using all atoms to generate geometric features + all_atom_df0, all_atom_df1 = bound_complex.df0, bound_complex.df1 + graph1 = convert_df_to_dgl_graph(all_atom_df0, raw_filepath, knn, self_loops) + graph2 = convert_df_to_dgl_graph(all_atom_df1, raw_filepath, knn, self_loops) + + # Assemble the examples (containing labels) for the complex + df0_index_to_node_id = {df0_index: idx for idx, df0_index in enumerate(df0.index.values)} + df1_index_to_node_id = {df1_index: idx for idx, df1_index in enumerate(df1.index.values)} + examples = build_complex_labels(bound_complex, df0, df1, df0_index_to_node_id, df1_index_to_node_id, shuffle=False) + + # Use pure PyTorch tensors to represent a given complex - Assemble tensors for storage in complex's dictionary + graph1_node_feats = graph1.ndata['f'] # (n_nodes, n_node_feats) + graph2_node_feats = graph2.ndata['f'] + + graph1_node_coords = graph1.ndata['x'] # (n_nodes, 3) + graph2_node_coords = graph2.ndata['x'] + + # Collect the neighboring node and in-edge features for each of the first graph's nodes (in a consistent order) + graph1_edge_feats = [] + graph1_nbrhd_indices = [] + for h_i in graph1.nodes(): + in_edge_ids_for_h_i = graph1.in_edges(h_i) + in_edges_for_h_i = graph1.edges[in_edge_ids_for_h_i] + graph1_edge_feats.append(in_edges_for_h_i.data['f']) + dst_node_ids_for_h_i = in_edge_ids_for_h_i[0].reshape(-1, 1) + graph1_nbrhd_indices.append(dst_node_ids_for_h_i) + graph1_edge_feats = torch.stack(graph1_edge_feats) # (n_nodes, nbrhd_size, n_edge_feats) + graph1_nbrhd_indices = torch.stack(graph1_nbrhd_indices) # (n_nodes, nbrhd_size, 1) + + # Collect the neighboring node and in-edge features for each of the second graph's nodes (in a consistent order) + graph2_edge_feats = [] + graph2_nbrhd_indices = [] + for h_i in graph2.nodes(): + in_edge_ids_for_h_i = graph2.in_edges(h_i) + in_edges_for_h_i = graph2.edges[in_edge_ids_for_h_i] + graph2_edge_feats.append(in_edges_for_h_i.data['f']) + dst_node_ids_for_h_i = in_edge_ids_for_h_i[0].reshape(-1, 1) + graph2_nbrhd_indices.append(dst_node_ids_for_h_i) + graph2_edge_feats = torch.stack(graph2_edge_feats) + graph2_nbrhd_indices = torch.stack(graph2_nbrhd_indices) + + # Initialize the complex's new representation as a dictionary + processed_complex = { + 'graph1_node_feats': torch.nan_to_num(graph1_node_feats), + 'graph2_node_feats': torch.nan_to_num(graph2_node_feats), + 'graph1_node_coords': torch.nan_to_num(graph1_node_coords), + 'graph2_node_coords': torch.nan_to_num(graph2_node_coords), + 'graph1_edge_feats': torch.nan_to_num(graph1_edge_feats), + 'graph2_edge_feats': torch.nan_to_num(graph2_edge_feats), + 'graph1_nbrhd_indices': graph1_nbrhd_indices, + 'graph2_nbrhd_indices': graph2_nbrhd_indices, + 'examples': examples, + 'complex': bound_complex.complex + } + + # Write into processed_filepath + processed_file_dir = os.path.join(*processed_filepath.split(os.sep)[: -1]) + os.makedirs(processed_file_dir, exist_ok=True) + with open(processed_filepath, 'wb') as f: + pickle.dump(processed_complex, f) + + +def zero_out_complex_features(cmplx: dict): + """Zero-out the input features for a given protein complex dictionary (for an input-independent baseline).""" + cmplx['graph1_node_feats'] = torch.zeros_like(cmplx['graph1_node_feats']) + cmplx['graph2_node_feats'] = torch.zeros_like(cmplx['graph2_node_feats']) + cmplx['graph1_edge_feats'] = torch.zeros_like(cmplx['graph1_edge_feats']) + cmplx['graph2_edge_feats'] = torch.zeros_like(cmplx['graph2_edge_feats']) + return cmplx + + +def construct_interact_tensor(graph1_feats: torch.Tensor, graph2_feats: torch.Tensor, pad=False, max_len=256): + """Build the interaction tensor for given node representations, optionally padding up to the node count limit.""" + # Get descriptors and reshaped versions of the input feature matrices + len_1, len_2 = graph1_feats.shape[0], graph2_feats.shape[0] + x_a, x_b = graph1_feats.permute(1, 0).unsqueeze(0), graph2_feats.permute(1, 0).unsqueeze(0) + if pad: + x_a_num_zeros = max_len - x_a.shape[2] + x_b_num_zeros = max_len - x_b.shape[2] + x_a = F.pad(x_a, (0, x_a_num_zeros, 0, 0, 0, 0)) # Pad the start of 3D tensors + x_b = F.pad(x_b, (0, x_b_num_zeros, 0, 0, 0, 0)) # Pad the end of 3D tensors + len_1, len_2 = max_len, max_len + # Interleave 2D input matrices into a 3D interaction tensor + interact_tensor = torch.cat((torch.repeat_interleave(x_a.unsqueeze(3), repeats=len_2, dim=3), + torch.repeat_interleave(x_b.unsqueeze(2), repeats=len_1, dim=2)), dim=1) + return interact_tensor + + +def collect_args(): + """Collect all arguments required for training/testing.""" + parser = ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + + # ----------------- + # Model arguments + # ----------------- + parser.add_argument('--model_name', type=str, default='NeiA', help='Options are NeiA or NeiWA') + parser.add_argument('--num_gnn_layers', type=int, default=1, help='Number of GNN layers') + parser.add_argument('--num_interact_layers', type=int, default=3, help='Number of layers in interaction module') + parser.add_argument('--metric_to_track', type=str, default='val_bce', help='Scheduling and early stop') + + # ----------------- + # Data arguments + # ----------------- + parser.add_argument('--knn', type=int, default=20, help='Number of nearest neighbor edges for each node') + parser.add_argument('--self_loops', action='store_true', dest='self_loops', help='Allow node self-loops') + parser.add_argument('--no_self_loops', action='store_false', dest='self_loops', help='Disable self-loops') + parser.add_argument('--pn_ratio', type=float, default=0.1, + help='Positive-negative class ratio to instate during training with DIPS-Plus') + parser.add_argument('--dips_data_dir', type=str, default='datasets/DIPS/final/raw', help='Path to DIPS-Plus') + parser.add_argument('--dips_percent_to_use', type=float, default=1.00, + help='Fraction of DIPS-Plus dataset splits to use') + parser.add_argument('--db5_data_dir', type=str, default='datasets/DB5/final/raw', help='Path to DB5-Plus') + parser.add_argument('--db5_percent_to_use', type=float, default=1.00, + help='Fraction of DB5-Plus dataset splits to use') + parser.add_argument('--process_complexes', action='store_true', dest='process_complexes', + help='Check if all complexes for a dataset are processed and, if not, process those remaining') + + # ----------------- + # Logging arguments + # ----------------- + parser.add_argument('--logger_name', type=str, default='TensorBoard', help='Which logger to use for experiments') + parser.add_argument('--experiment_name', type=str, default=None, help='Logger experiment name') + parser.add_argument('--project_name', type=str, default='NeiA-PyTorch', help='Logger project name') + parser.add_argument('--entity', type=str, default='PyTorch', help='Logger entity (i.e. team) name') + parser.add_argument('--run_id', type=str, default='', help='Logger run ID') + parser.add_argument('--offline', action='store_true', dest='offline', help='Whether to log locally or remotely') + parser.add_argument('--online', action='store_false', dest='offline', help='Whether to log locally or remotely') + parser.add_argument('--tb_log_dir', type=str, default='tb_logs', help='Where to store TensorBoard log files') + parser.set_defaults(offline=False) # Default to using online logging mode + + # ----------------- + # Seed arguments + # ----------------- + parser.add_argument('--seed', type=int, default=None, help='Seed for NumPy and PyTorch') + + # ----------------- + # Meta-arguments + # ----------------- + parser.add_argument('--batch_size', type=int, default=1, help='Number of samples included in each data batch') + parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') + parser.add_argument('--weight_decay', type=float, default=1e-7, help='Decay rate of optimizer weight') + parser.add_argument('--num_epochs', type=int, default=50, help='Maximum number of epochs to run for training') + parser.add_argument('--dropout_rate', type=float, default=0.3, help='Dropout (forget) rate') + parser.add_argument('--patience', type=int, default=5, help='Number of epochs to wait until early stopping') + parser.add_argument('--pad', action='store_true', dest='pad', help='Whether to zero pad interaction tensors') + + # ----------------- + # Miscellaneous + # ----------------- + parser.add_argument('--max_hours', type=int, default=1, help='Maximum number of hours to allot for training') + parser.add_argument('--max_minutes', type=int, default=55, help='Maximum number of minutes to allot for training') + parser.add_argument('--multi_gpu_backend', type=str, default='ddp', help='Multi-GPU backend for training') + parser.add_argument('--num_gpus', type=int, default=1, help='Number of GPUs to use (e.g. -1 = all available GPUs)') + parser.add_argument('--auto_choose_gpus', action='store_true', dest='auto_choose_gpus', help='Auto-select GPUs') + parser.add_argument('--num_compute_nodes', type=int, default=1, help='Number of compute nodes to use') + parser.add_argument('--gpu_precision', type=int, default=32, help='Bit size used during training (e.g. 16-bit)') + parser.add_argument('--num_workers', type=int, default=4, help='Number of CPU threads for loading data') + parser.add_argument('--profiler_method', type=str, default=None, help='PL profiler to use (e.g. simple)') + parser.add_argument('--ckpt_dir', type=str, default=f'{os.path.join(os.getcwd(), "checkpoints")}', + help='Directory in which to save checkpoints') + parser.add_argument('--ckpt_name', type=str, default='', help='Filename of best checkpoint') + parser.add_argument('--min_delta', type=float, default=5e-6, help='Minimum percentage of change required to' + ' "metric_to_track" before early stopping' + ' after surpassing patience') + parser.add_argument('--accum_grad_batches', type=int, default=1, help='Norm over which to clip gradients') + parser.add_argument('--grad_clip_val', type=float, default=0.5, help='Norm over which to clip gradients') + parser.add_argument('--grad_clip_algo', type=str, default='norm', help='Algorithm with which to clip gradients') + parser.add_argument('--stc_weight_avg', action='store_true', dest='stc_weight_avg', help='Smooth loss landscape') + parser.add_argument('--find_lr', action='store_true', dest='find_lr', help='Find an optimal learning rate a priori') + parser.add_argument('--input_indep', action='store_true', dest='input_indep', help='Whether to zero input for test') + + return parser + + +def process_args(args): + """Process all arguments required for training/testing.""" + # --------------------------------------- + # Seed fixing for random numbers + # --------------------------------------- + if not args.seed: + args.seed = 42 # np.random.randint(100000) + logging.info(f'Seeding everything with random seed {args.seed}') + pl.seed_everything(args.seed) + + return args + + +def construct_pl_logger(args): + """Return a specific Logger instance requested by the user.""" + if args.logger_name.lower() == 'wandb': + return construct_wandb_pl_logger(args) + else: # Default to using TensorBoard + return construct_tensorboard_pl_logger(args) + + +def construct_wandb_pl_logger(args): + """Return an instance of WandbLogger with corresponding project and name strings.""" + return WandbLogger(name=args.experiment_name, + offline=args.offline, + project=args.project_name, + log_model=True, + entity=args.entity) + + +def construct_tensorboard_pl_logger(args): + """Return an instance of TensorBoardLogger with corresponding project and experiment name strings.""" + return TensorBoardLogger(save_dir=args.tb_log_dir, + name=args.experiment_name) diff --git a/project/utils/utils.py b/project/utils/utils.py index 7028f05..8c8219c 100644 --- a/project/utils/utils.py +++ b/project/utils/utils.py @@ -1,19 +1,25 @@ import collections as col +import copy +import dill import gzip +import h5py import logging import os import re import shutil +import subprocess +import tempfile import urllib.request as request from contextlib import closing from pathlib import Path -from typing import List, Tuple +from typing import List import atom3.database as db import atom3.neighbors as nb import atom3.pair as pa import dgl import dill as pickle +import hickle as hkl import numpy as np import pandas as pd import torch @@ -639,7 +645,10 @@ def postprocess_pruned_pair(raw_pdb_filenames: List[str], external_feats_dir: st # Get protrusion indices using PSAIA psaia_filepath = os.path.relpath(os.path.splitext(os.path.split(raw_pdb_filename)[-1])[0]) - psaia_filename = [path for path in Path(external_feats_dir).rglob(f'{psaia_filepath}*.tbl')][0] # 1st path + psaia_filenames = [path for path in Path(external_feats_dir).rglob(f'{psaia_filepath}*.tbl')] + if len(psaia_filenames) == 0: + psaia_filenames = [path for path in Path(external_feats_dir).parent.rglob(f'{psaia_filepath}*.tbl')] + psaia_filename = psaia_filenames[0] psaia_df = get_df_from_psaia_tbl_file(psaia_filename) # Extract half-sphere exposure (HSE) statistics for each PDB model (including HSAAC and CN values) @@ -1252,6 +1261,70 @@ def log_dataset_statistics(logger, dataset_statistics: dict): f' valid amide normal vectors found for df1 structures in total') +def convert_pair_pickle_to_hdf5(pickle_filepath: Path, hdf5_filepath: Path): + # Load pickle file + with open(str(pickle_filepath), 'rb') as f: + pair_data = dill.load(f) + + # Save data to HDF5 file + hkl_data = list(pair_data) + hkl.dump(hkl_data, str(hdf5_filepath)) + + +def convert_pair_hdf5_to_pickle(hdf5_filepath: Path) -> pa.Pair: + # Load HDF5 file as pickle object + hkl_data = hkl.load(str(hdf5_filepath)) + pair_data = pa.Pair(*hkl_data) + return pair_data + + +def convert_pair_hdf5_to_hdf5_file(hdf5_filepath: Path) -> h5py.File: + # Load HDF5 file as HDF5 file + data = h5py.File(str(hdf5_filepath), 'r') + return data + + +def annotate_idr_residues(pickle_filepaths: List[Path]): + # Process each pickle file input + for pickle_filepath in pickle_filepaths: + # Load pickle file + with open(str(pickle_filepath), 'rb') as f: + pair_data = dill.load(f) + # Find IDR interface residues + annotating_new_residues = False + annotated_residue_sequences = copy.deepcopy(pair_data.sequences) + for key, sequence in pair_data.sequences.items(): + if "idr_annotations" not in key and f"{key}_idr_annotations" not in pair_data.sequences: + annotating_new_residues = True + fasta_filepath = os.path.join(tempfile.mkdtemp(), f'{key}.fasta') + with open(fasta_filepath, 'w') as file: + file.write(f'>sequence\n{sequence}\n') + # Run `flDPnn` to predict which residues reside in an IDR + output_filepath = os.path.join(tempfile.mkdtemp(), "fldpnn_results.csv") + cmd = [ + 'docker', 'run', '-i', 'sinaghadermarzi/fldpnn', + 'fldpnn/dockerinout_nofunc', '<', fasta_filepath, '>', output_filepath + ] + try: + subprocess.run(' '.join(cmd), shell=True, check=True) + print("Docker command executed successfully.") + except subprocess.CalledProcessError as e: + print(f"Error executing Docker command: {e}") + # Parse output from `flDPnn` + output_df = pd.read_csv(output_filepath, skiprows=1) + annotated_residue_sequences[f"{key}_idr_annotations"] = output_df['Binary Prediction for Disorder'].tolist() + assert len(annotated_residue_sequences[f"{key}_idr_annotations"]) == len(pair_data.sequences[key]), "IDR annotations must match length of input sequence." + # Record IDR residue annotations within pickle file + if annotating_new_residues: + pair_data.sequences.update({ + key: value + for key, value in annotated_residue_sequences.items() + if key not in pair_data.sequences + }) + with open(str(pickle_filepath), 'wb') as f: + dill.dump(pair_data, f) + + def process_raw_file_into_dgl_graphs(raw_filepath: str, new_graph_dir: str, processed_filepath: str, edge_dist_cutoff: float, edge_limit: int, self_loops: bool): """Process each postprocessed pair into a graph pair.""" diff --git a/setup.py b/setup.py index 622dd79..548a1b9 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name='DIPS-Plus', - version='1.1.0', + version='1.2.0', description='The Enhanced Database of Interacting Protein Structures for Interface Prediction', author='Alex Morehead', author_email='acmwhb@umsystem.edu', @@ -15,9 +15,7 @@ 'tqdm==4.49.0', 'Sphinx==4.0.1', 'easy-parallel-py3==0.1.6.4', - 'atom3-py3==0.1.9.9', 'click==7.0.0', - # mpi4py==3.0.3 # On Andes, do 'source venv/bin/activate', 'module load gcc/10.3.0', and 'pip install mpi4py --no-cache-dir --no-binary :all:' ], packages=find_packages(), )