From 3af6f093f052d29431ff46ab827fe62bfc75caf9 Mon Sep 17 00:00:00 2001
From: Keith Cheveralls <keith.chev@gmail.com>
Date: Thu, 25 Jan 2024 17:59:18 -0800
Subject: [PATCH] Ensure user-provided PDBs are used for clustering (#83)

* prioritize copy_pdbs

* exclude the input protids from the protids for which to download pdbs from alphafold
---
 ProteinCartography/filter_uniprot_hits.py | 20 +++++++++++++++++++-
 Snakefile                                 |  9 +++++++--
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/ProteinCartography/filter_uniprot_hits.py b/ProteinCartography/filter_uniprot_hits.py
index dca225c..1f4b3ce 100644
--- a/ProteinCartography/filter_uniprot_hits.py
+++ b/ProteinCartography/filter_uniprot_hits.py
@@ -30,6 +30,11 @@ def parse_args():
         default="0",
         help="maximum protein length of proteins to keep. If set to 0, no upper limit is applied.",
     )
+    parser.add_argument(
+        "--excluded-protids",
+        nargs="*",
+        help="a list of protids to exclude from the results",
+    )
     args = parser.parse_args()
 
     return args
@@ -42,6 +47,7 @@ def filter_results(
     filter_fragment=True,
     min_length=0,
     max_length=0,
+    excluded_protids=None,
 ):
     """
     Takes an input uniprot_features.tsv file and filters the results based on fragment status,
@@ -78,6 +84,11 @@ def filter_results(
     if max_length > 0:
         filtered_df = filtered_df[filtered_df["Length"].astype(int) < max_length]
 
+    if excluded_protids is None:
+        excluded_protids = []
+
+    filtered_df = filtered_df[~filtered_df["protid"].isin(excluded_protids)]
+
     with open(output_file, "w+") as f:
         f.writelines([protid + "\n" for protid in filtered_df["protid"]])
 
@@ -88,6 +99,7 @@ def main():
     args = parse_args()
     input_file = args.input
     output_file = args.output
+    excluded_protids = args.excluded_protids
 
     try:
         min_length = int(args.min_length)
@@ -98,7 +110,13 @@ def main():
     except (TypeError, ValueError):
         max_length = 0
 
-    filter_results(input_file, output_file, min_length=min_length, max_length=max_length)
+    filter_results(
+        input_file,
+        output_file,
+        min_length=min_length,
+        max_length=max_length,
+        excluded_protids=excluded_protids,
+    )
 
 
 if __name__ == "__main__":
diff --git a/Snakefile b/Snakefile
index 087267f..b1b492f 100644
--- a/Snakefile
+++ b/Snakefile
@@ -125,7 +125,7 @@ rule make_pdb:
     input:
         cds=input_dir / "{protid}.fasta",
     output:
-        pdb=input_dir / "{protid}.pdb",
+        pdb=pdb_download_dir / "{protid}.pdb",
     benchmark:
         output_dir / benchmarks_dir / "{protid}.make_pdb.txt"
     conda:
@@ -150,6 +150,11 @@ rule copy_pdb:
         """
 
 
+# first try to copy any user-provided PDB files from the input directory;
+# if they don't exist, generate them using make_pdb
+ruleorder: copy_pdb > make_pdb
+
+
 rule run_blast:
     """
     Using files located in the input directory, run `blastp` using the remote BLAST API.
@@ -348,7 +353,7 @@ rule filter_uniprot_hits:
         "envs/pandas.yml"
     shell:
         """
-        python ProteinCartography/filter_uniprot_hits.py -i {input} -o {output} -m {params.min_length} -M {params.max_length}
+        python ProteinCartography/filter_uniprot_hits.py -i {input} -o {output} -m {params.min_length} -M {params.max_length} --excluded-protids {PROTID}
         """