Skip to content

Commit

Permalink
Update all plots, and use unique data caches with DynamicBind
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Aug 17, 2024
1 parent 93ea040 commit 662d147
Show file tree
Hide file tree
Showing 35 changed files with 235 additions and 214 deletions.
1 change: 1 addition & 0 deletions configs/model/dynamicbind_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ samples_per_complex: 40 # the number of samples to generate per complex
savings_per_complex: 1 # the (top-N) number of sample visualizations to save per complex
inference_steps: 20 # the number of inference steps to run for each complex
batch_size: 5 # the batch size to use for inference
cache_path: ${oc.env:PROJECT_ROOT}/data/dynamicbind_cache/cache # the cache directory to use for storing intermediate data files
header: ${dataset} # name of the results directory to create
num_workers: 1 # the number of workers to use for native relaxation during inference
skip_existing: true # whether to skip existing predictions
Expand Down
Binary file modified docs/source/_static/PoseBench.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions forks/DynamicBind/run_single_protein_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
parser.add_argument('--savings_per_complex', type=int, default=1, help='num of samples data saved for movie generation.')
parser.add_argument('--inference_steps', type=int, default=20, help='num of coordinate updates. (movie frames)')
parser.add_argument('--batch_size', type=int, default=5, help='chunk size for inference batches.')
parser.add_argument('--cache_path', type=str, default='data/cache', help='Folder from where to load/restore cached dataset')
parser.add_argument('--header', type=str, default='test', help='informative name used to name result folder')
parser.add_argument('--results', type=str, default='results', help='result folder.')
parser.add_argument('--device', type=int, default=0, help='CUDA_VISIBLE_DEVICES')
Expand Down Expand Up @@ -302,7 +303,7 @@ def ref_filename_sort_key(filepath):
do(cmd)
cmd = f"CUDA_VISIBLE_DEVICES={args.device} {python} {script_folder}/esm/scripts/extract.py esm2_t33_650M_UR50D {os.path.join(outputs_dir, f'prepared_for_esm_{header}.fasta')} {os.path.join(outputs_dir, 'esm2_output' + unique_id)} --repr_layers 33 --include per_tok --truncation_seq_length 10000 --model_dir {script_folder}/esm_models"
do(cmd)
cmd = f"{python} {script_folder}/inference.py --seed {args.seed} --ckpt {ckpt} {protein_dynamic}"
cmd = f"{python} {script_folder}/inference.py --cache_path {args.cache_path} --seed {args.seed} --ckpt {ckpt} {protein_dynamic}"
cmd += f" --save_visualisation --model_dir {model_workdir} --protein_ligand_csv {ligandFile_with_protein_path} "
cmd += f" --esm_embeddings_path {os.path.join(outputs_dir, 'esm2_output' + unique_id)} --out_dir {args.results}/{header} --inference_steps {args.inference_steps} --samples_per_complex {args.samples_per_complex} --savings_per_complex {args.savings_per_complex} --batch_size {args.batch_size} --actual_steps {args.inference_steps} --no_final_step_noise"
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
Expand Down Expand Up @@ -391,7 +392,7 @@ def ref_filename_sort_key(filepath):
do(cmd)
cmd = f"CUDA_VISIBLE_DEVICES={args.device} {python} {script_folder}/esm/scripts/extract.py esm2_t33_650M_UR50D {os.path.join(outputs_dir, f'prepared_for_esm_{header}.fasta')} {os.path.join(outputs_dir, 'esm2_output' + unique_id)} --repr_layers 33 --include per_tok --truncation_seq_length 10000 --model_dir {script_folder}/esm_models"
do(cmd)
cmd = f"CUDA_VISIBLE_DEVICES={args.device} {python} {script_folder}/inference.py --seed {args.seed} --ckpt {ckpt} {protein_dynamic}"
cmd = f"CUDA_VISIBLE_DEVICES={args.device} {python} {script_folder}/inference.py --cache_path {args.cache_path} --seed {args.seed} --ckpt {ckpt} {protein_dynamic}"
cmd += f" --save_visualisation --model_dir {model_workdir} --protein_ligand_csv {ligandFile_with_protein_path} "
cmd += f" --esm_embeddings_path {os.path.join(outputs_dir, 'esm2_output' + unique_id)} --out_dir {args.results}/{header} --inference_steps {args.inference_steps} --samples_per_complex {args.samples_per_complex} --savings_per_complex {args.savings_per_complex} --batch_size {args.batch_size} --actual_steps {args.inference_steps} --no_final_step_noise"
do(cmd)
Expand Down
Binary file modified img/PoseBench.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/casp15_all_multi_ligand_relaxed_rmsd_violin_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/casp15_all_single_ligand_relaxed_rmsd_violin_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 29 additions & 18 deletions notebooks/casp15_inference_results_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
" \"dynamicbind\": \"DynamicBind\",\n",
" \"neuralplexer\": \"NeuralPLexer\",\n",
" \"neuralplexer_no_ilcl\": \"NeuralPLexer w/o ILCL\",\n",
" \"rfaa\": \"RoseTTAFold-All-Atom\",\n",
" \"rfaa\": \"RoseTTAFold-AA\",\n",
" \"tulip\": \"TULIP\",\n",
" \"vina_diffdock\": \"DiffDock-L-Vina\",\n",
" \"vina_p2rank\": \"P2Rank-Vina\",\n",
Expand Down Expand Up @@ -133,11 +133,11 @@
" \"casp15\",\n",
" f\"top_{method}{'' if 'ensemble' in method else '_ensemble'}_predictions_{repeat_index}\",\n",
" )\n",
" globals()[\n",
" f\"{method}{config}_scoring_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[f\"{method}_output_dir_{repeat_index}\"] + config,\n",
" \"scoring_results.csv\",\n",
" globals()[f\"{method}{config}_scoring_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[f\"{method}_output_dir_{repeat_index}\"] + config,\n",
" \"scoring_results.csv\",\n",
" )\n",
" )\n",
" globals()[f\"{method}{config}_bust_results_csv_filepath_{repeat_index}\"] = os.path.join(\n",
" globals()[f\"{method}_output_dir_{repeat_index}\"] + config,\n",
Expand All @@ -160,12 +160,6 @@
" .groupby([\"target\", \"mdl\"])[\"pose\"]\n",
" .transform(\"count\")\n",
" )\n",
" grouped_num_target_ligands = (\n",
" globals()[f\"{method}{config}_scoring_results_table_{repeat_index}\"]\n",
" .groupby([\"target\", \"mdl\"])[\"num_target_ligands\"]\n",
" .first()\n",
" )\n",
" num_ligands_per_complex = grouped_num_target_ligands.loc[(slice(None), 1)].tolist()\n",
" globals()[f\"{method}{config}_bust_results_table_{repeat_index}\"] = (\n",
" pd.read_csv(\n",
" globals()[f\"{method}{config}_bust_results_csv_filepath_{repeat_index}\"]\n",
Expand Down Expand Up @@ -204,6 +198,13 @@
" <= 2\n",
" )\n",
"\n",
" grouped_num_target_ligands = (\n",
" globals()[f\"{method}{config}_scoring_results_table_{repeat_index}\"]\n",
" .groupby([\"target\", \"mdl\"])[\"num_target_ligands\"]\n",
" .first()\n",
" )\n",
" num_ligands_per_complex = grouped_num_target_ligands.loc[(slice(None), 1)].tolist()\n",
"\n",
" print(\n",
" f\"{method_title}{config}_{repeat_index} CASP15 set average `lddt_pli`: {globals()[f'{method}{config}_scoring_results_table_{repeat_index}']['lddt_pli'].mean()}\"\n",
" )\n",
Expand All @@ -229,6 +230,16 @@
" globals()[f\"{method}{config}_bust_results_table_{repeat_index}\"].loc[\n",
" :, \"dataset\"\n",
" ] = \"casp15\"\n",
" # filter bust results to only those for targets that were scoreable using the CASP scoring pipeline\n",
" globals()[f\"{method}{config}_bust_results_table_{repeat_index}\"] = globals()[\n",
" f\"{method}{config}_bust_results_table_{repeat_index}\"\n",
" ][\n",
" globals()[f\"{method}{config}_bust_results_table_{repeat_index}\"].target.isin(\n",
" globals()[\n",
" f\"{method}{config}_scoring_results_table_{repeat_index}\"\n",
" ].target.unique()\n",
" )\n",
" ]\n",
" globals()[f\"{method}{config}_bust_results_table_{repeat_index}\"].loc[\n",
" :, \"num_target_ligands\"\n",
" ] = num_ligands_per_complex\n",
Expand Down Expand Up @@ -357,7 +368,7 @@
" )\n",
" ]\n",
" combined_data_list.append(pd.concat([casp15_results_table, casp15_relaxed_results_table]))\n",
"combined_data = pd.concat(combined_data_list)\n",
"combined_data = pd.concat(combined_data_list).sort_values(\"method_assignment_index\")\n",
"\n",
"for complex_type in [\"single\", \"multi\"]:\n",
" for complex_license in [\"all\", \"public\"]:\n",
Expand Down Expand Up @@ -439,7 +450,7 @@
" )\n",
" ]\n",
" combined_data_list.append(pd.concat([casp15_results_table, casp15_relaxed_results_table]))\n",
"combined_data = pd.concat(combined_data_list)\n",
"combined_data = pd.concat(combined_data_list).sort_values(\"method_assignment_index\")\n",
"\n",
"for complex_type in [\"single\", \"multi\"]:\n",
" for complex_license in [\"all\", \"public\"]:\n",
Expand All @@ -453,7 +464,7 @@
" hue=\"post-processing\",\n",
" data=combined_data[\n",
" # ignore outliers\n",
" (combined_data[\"rmsd\"] < 50)\n",
" (combined_data[\"rmsd\"] < 150)\n",
" & (\n",
" # filter the data based on the complex type and license\n",
" combined_data[\"target\"].isin(\n",
Expand Down Expand Up @@ -724,7 +735,7 @@
" \"NeuralPLexer\",\n",
" \"DL-based blind\",\n",
" \"NeuralPLexer w/o ILCL\",\n",
" \"RoseTTAFold-All-Atom\",\n",
" \"RoseTTAFold-AA\",\n",
" \"TULIP\",\n",
" \"DiffDock-L-Vina\",\n",
" \"Conventional blind\",\n",
Expand Down Expand Up @@ -1042,7 +1053,7 @@
" \"NeuralPLexer\",\n",
" \"DL-based blind\",\n",
" \"NeuralPLexer w/o ILCL\",\n",
" \"RoseTTAFold-All-Atom\",\n",
" \"RoseTTAFold-AA\",\n",
" \"TULIP\",\n",
" \"DiffDock-L-Vina\",\n",
" \"Conventional blind\",\n",
Expand Down Expand Up @@ -1108,7 +1119,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Binary file modified notebooks/casp15_method_interaction_analysis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions notebooks/casp15_method_interaction_analysis_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
" \"diffdock\": \"DiffDock-L\",\n",
" \"dynamicbind\": \"DynamicBind\",\n",
" \"neuralplexer\": \"NeuralPLexer\",\n",
" \"rfaa\": \"RoseTTAFold-All-Atom\",\n",
" \"rfaa\": \"RoseTTAFold-AA\",\n",
" \"tulip\": \"TULIP\",\n",
" \"vina_diffdock\": \"DiffDock-L-Vina\",\n",
" \"vina_p2rank\": \"P2Rank-Vina\",\n",
Expand Down Expand Up @@ -482,7 +482,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 48 additions & 48 deletions notebooks/dockgen_inference_results_plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@
" )\n",
"\n",
" # DiffDock (relaxed-protein) results\n",
" globals()[\n",
" f\"diffdock_relaxed_protein_dockgen_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[\"diffdock_output_dir\"],\n",
" f\"diffdock_dockgen_output_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"diffdock_relaxed_protein_dockgen_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[\"diffdock_output_dir\"],\n",
" f\"diffdock_dockgen_output_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
" globals()[\n",
" f\"diffdock_relaxed_protein_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
Expand Down Expand Up @@ -129,12 +129,12 @@
" f\"dockgen_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" globals()[\n",
" f\"dynamicbind_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[\"dynamicbind_output_dir\"],\n",
" f\"dockgen_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"dynamicbind_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[\"dynamicbind_output_dir\"],\n",
" f\"dockgen_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
"\n",
" # NeuralPLexer results\n",
Expand All @@ -143,12 +143,12 @@
" f\"neuralplexer_dockgen_outputs_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" globals()[\n",
" f\"neuralplexer_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[\"neuralplexer_output_dir\"],\n",
" f\"neuralplexer_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"neuralplexer_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[\"neuralplexer_output_dir\"],\n",
" f\"neuralplexer_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
"\n",
" # RoseTTAFold-All-Atom results\n",
Expand All @@ -169,12 +169,12 @@
" f\"vina_diffdock_dockgen_outputs_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" globals()[\n",
" f\"vina_diffdock_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[\"vina_output_dir\"],\n",
" f\"vina_diffdock_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"vina_diffdock_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[\"vina_output_dir\"],\n",
" f\"vina_diffdock_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
"\n",
" # P2Rank-Vina results\n",
Expand All @@ -183,38 +183,38 @@
" f\"vina_p2rank_dockgen_outputs_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" globals()[\n",
" f\"vina_p2rank_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" globals()[\"vina_output_dir\"],\n",
" f\"vina_p2rank_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"vina_p2rank_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" globals()[\"vina_output_dir\"],\n",
" f\"vina_p2rank_dockgen_outputs_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
"\n",
" # Consensus ensemble results\n",
" globals()[\n",
" f\"consensus_ensemble_dockgen_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" os.path.join(\"..\", \"data\", \"test_cases\", \"dockgen\"),\n",
" f\"top_consensus_ensemble_predictions_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"consensus_ensemble_dockgen_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" os.path.join(\"..\", \"data\", \"test_cases\", \"dockgen\"),\n",
" f\"top_consensus_ensemble_predictions_{repeat_index}\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
" globals()[\n",
" f\"consensus_ensemble_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"\n",
" ] = os.path.join(\n",
" os.path.join(\"..\", \"data\", \"test_cases\", \"dockgen\"),\n",
" f\"top_consensus_ensemble_predictions_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" globals()[f\"consensus_ensemble_dockgen_relaxed_bust_results_csv_filepath_{repeat_index}\"] = (\n",
" os.path.join(\n",
" os.path.join(\"..\", \"data\", \"test_cases\", \"dockgen\"),\n",
" f\"top_consensus_ensemble_predictions_{repeat_index}_relaxed\",\n",
" \"bust_results.csv\",\n",
" )\n",
" )\n",
"\n",
"# Mappings\n",
"method_mapping = {\n",
" \"diffdock\": \"DiffDock-L\",\n",
" \"diffdock_relaxed_protein\": \"DiffDock-L (Relaxed-Protein)\",\n",
" \"diffdock_relaxed_protein\": \"DiffDock-L-Relax-Prot\",\n",
" \"fabind\": \"FABind\",\n",
" \"dynamicbind\": \"DynamicBind\",\n",
" \"neuralplexer\": \"NeuralPLexer\",\n",
" \"rfaa\": \"RoseTTAFold-All-Atom\",\n",
" \"rfaa\": \"RoseTTAFold-AA\",\n",
" \"vina_diffdock\": \"DiffDock-L-Vina\",\n",
" \"vina_p2rank\": \"P2Rank-Vina\",\n",
" \"consensus_ensemble\": \"Ensemble (Con)\",\n",
Expand Down Expand Up @@ -429,7 +429,7 @@
" x=\"method\",\n",
" y=\"rmsd\",\n",
" hue=\"post-processing\",\n",
" data=combined_relaxed_data[combined_relaxed_data[\"rmsd\"] < 50],\n",
" data=combined_relaxed_data[combined_relaxed_data[\"rmsd\"] < 150], # ignore outliers\n",
" split=True,\n",
" inner=\"quartile\",\n",
" palette=colors,\n",
Expand Down Expand Up @@ -724,12 +724,12 @@
"axis.set_xticklabels(\n",
" [\n",
" \"DiffDock-L\",\n",
" \"DiffDock-L (Relax-P)\",\n",
" \"DiffDock-L-Relax-Prot\",\n",
" \"FABind\",\n",
" \"DL-based blind\",\n",
" \"DynamicBind\",\n",
" \"NeuralPLexer\",\n",
" \"RoseTTAFold-All-Atom\",\n",
" \"RoseTTAFold-AA\",\n",
" \"DiffDock-L-Vina\",\n",
" \"Conventional blind\",\n",
" \"P2Rank-Vina\",\n",
Expand Down Expand Up @@ -791,7 +791,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Binary file modified notebooks/dockgen_single_ligand_relaxed_bar_chart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/dockgen_single_ligand_relaxed_rmsd_violin_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 662d147

Please sign in to comment.