Skip to content

Commit

Permalink
update tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven51516 committed Aug 16, 2024
1 parent 8661530 commit eaace97
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 55 deletions.
147 changes: 98 additions & 49 deletions tutorials/FlexMol_102_Dual_Encoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,178 +91,192 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 1/1 [00:02<00:00, 2.00s/batch, loss=0.672]\n"
"Epoch 0: 0%| | 0/1 [00:00<?, ?batch/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 1/1 [00:01<00:00, 1.11s/batch, loss=0.664]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tTraining Loss: 0.671772\n",
"Epoch: 0 \tValidation Loss: 2.133007\n",
"Epoch: 0 \tValidation roc-auc: 0.7872\n"
"Epoch: 0 \tTraining Loss: 0.663608\n",
"Epoch: 0 \tValidation Loss: 0.653195\n",
"Epoch: 0 \tValidation roc-auc: 0.5957\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 1/1 [00:01<00:00, 1.59s/batch, loss=0.652]\n"
"Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 11.10batch/s, loss=0.658]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1 \tTraining Loss: 0.652205\n",
"Epoch: 1 \tValidation Loss: 1.683086\n",
"Epoch: 1 \tValidation roc-auc: 0.7801\n"
"Epoch: 1 \tTraining Loss: 0.658224\n",
"Epoch: 1 \tValidation Loss: 0.624172\n",
"Epoch: 1 \tValidation roc-auc: 0.4965\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 2: 100%|██████████| 1/1 [00:01<00:00, 1.69s/batch, loss=0.643]\n"
"Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 11.93batch/s, loss=0.641]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 2 \tTraining Loss: 0.642842\n",
"Epoch: 2 \tValidation Loss: 1.383644\n",
"Epoch: 2 \tValidation roc-auc: 0.7660\n"
"Epoch: 2 \tTraining Loss: 0.640507\n",
"Epoch: 2 \tValidation Loss: 0.613853\n",
"Epoch: 2 \tValidation roc-auc: 0.3759\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 3: 100%|██████████| 1/1 [00:01<00:00, 1.19s/batch, loss=0.641]\n"
"Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 10.70batch/s, loss=0.639]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 3 \tTraining Loss: 0.641445\n",
"Epoch: 3 \tValidation Loss: 1.163450\n",
"Epoch: 3 \tValidation roc-auc: 0.7305\n"
"Epoch: 3 \tTraining Loss: 0.638680\n",
"Epoch: 3 \tValidation Loss: 0.617383\n",
"Epoch: 3 \tValidation roc-auc: 0.3546\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 4: 100%|██████████| 1/1 [00:01<00:00, 1.69s/batch, loss=0.641]\n"
"Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 12.06batch/s, loss=0.638]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 4 \tTraining Loss: 0.641228\n",
"Epoch: 4 \tValidation Loss: 1.007522\n",
"Epoch: 4 \tValidation roc-auc: 0.6738\n"
"Epoch: 4 \tTraining Loss: 0.638139\n",
"Epoch: 4 \tValidation Loss: 0.625347\n",
"Epoch: 4 \tValidation roc-auc: 0.3191\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 100%|██████████| 1/1 [00:01<00:00, 1.74s/batch, loss=0.641]\n"
"Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 12.16batch/s, loss=0.638]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 5 \tTraining Loss: 0.641124\n",
"Epoch: 5 \tValidation Loss: 0.912583\n",
"Epoch: 5 \tValidation roc-auc: 0.5816\n"
"Epoch: 5 \tTraining Loss: 0.637545\n",
"Epoch: 5 \tValidation Loss: 0.641414\n",
"Epoch: 5 \tValidation roc-auc: 0.3333\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 6: 100%|██████████| 1/1 [00:01<00:00, 1.86s/batch, loss=0.641]\n"
"Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 11.63batch/s, loss=0.636]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 6 \tTraining Loss: 0.640793\n",
"Epoch: 6 \tValidation Loss: 0.864285\n",
"Epoch: 6 \tValidation roc-auc: 0.5745\n"
"Epoch: 6 \tTraining Loss: 0.636284\n",
"Epoch: 6 \tValidation Loss: 0.669903\n",
"Epoch: 6 \tValidation roc-auc: 0.3546\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 7: 100%|██████████| 1/1 [00:01<00:00, 1.51s/batch, loss=0.64]\n"
"Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 11.79batch/s, loss=0.636]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 7 \tTraining Loss: 0.640439\n",
"Epoch: 7 \tValidation Loss: 0.829402\n",
"Epoch: 7 \tValidation roc-auc: 0.5816\n"
"Epoch: 7 \tTraining Loss: 0.635997\n",
"Epoch: 7 \tValidation Loss: 0.713974\n",
"Epoch: 7 \tValidation roc-auc: 0.3546\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 8: 100%|██████████| 1/1 [00:01<00:00, 1.74s/batch, loss=0.64]\n"
"Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 11.87batch/s, loss=0.636]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 8 \tTraining Loss: 0.640099\n",
"Epoch: 8 \tValidation Loss: 0.805460\n",
"Epoch: 8 \tValidation roc-auc: 0.6099\n"
"Epoch: 8 \tTraining Loss: 0.635855\n",
"Epoch: 8 \tValidation Loss: 0.764850\n",
"Epoch: 8 \tValidation roc-auc: 0.3759\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 9: 100%|██████████| 1/1 [00:01<00:00, 1.63s/batch, loss=0.64]\n"
"Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 12.12batch/s, loss=0.635]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 9 \tTraining Loss: 0.639802\n",
"Epoch: 9 \tValidation Loss: 0.795091\n",
"Epoch: 9 \tValidation roc-auc: 0.4894\n"
"Epoch: 9 \tTraining Loss: 0.635422\n",
"Epoch: 9 \tValidation Loss: 0.826206\n",
"Epoch: 9 \tValidation roc-auc: 0.3617\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 10: 100%|██████████| 1/1 [00:01<00:00, 1.39s/batch, loss=0.64]\n"
"Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 12.23batch/s, loss=0.635]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 10 \tTraining Loss: 0.639539\n",
"Epoch: 10 \tValidation Loss: 0.793096\n",
"Epoch: 10 \tValidation roc-auc: 0.3404\n",
"Epoch: 10 \tTraining Loss: 0.634938\n",
"Epoch: 10 \tValidation Loss: 0.897588\n",
"Epoch: 10 \tValidation roc-auc: 0.3546\n",
"Early stopping triggered after 10 epochs.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
Expand All @@ -272,25 +286,60 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start testing...\n",
"Test Loss: 0.800313\n",
"accuracy: 0.080000\n",
"precision: 0.041667\n",
"Test Loss: 0.978239\n",
"accuracy: 0.060000\n",
"precision: 0.040816\n",
"recall: 1.000000\n",
"f1: 0.080000\n"
"f1: 0.078431\n"
]
}
],
"source": [
"trainer.test(test_data)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Perform inference on the test data using the trained model\n",
"# This returns the total loss, all true labels, and all model predictions\n",
"# Note: You can now compute custom metrics using 'all_labels' and 'all_predictions'\n",
"total_loss, all_labels, all_predictions = trainer.inference(trainer.create_loader(test_data))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ground Truth Lables:\n",
"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
"All_Predictions:\n",
"[0.6218882, 0.6287671, 0.6197413, 0.5495148, 0.643523, 0.62342805, 0.5295856, 0.4806529, 0.5538527, 0.63042957, 0.5538419, 0.57894975, 0.72568244, 0.6414241, 0.6840686, 0.56393874, 0.73658043, 0.64223135, 0.61184186, 0.5978523, 0.5393544, 0.74992156, 0.62433743, 0.6433823, 0.52405965, 0.6406623, 0.5763875, 0.57649946, 0.59156555, 0.5964469, 0.57781327, 0.5427243, 0.6857456, 0.60090387, 0.5776537, 0.59034, 0.70655006, 0.7228107, 0.6243686, 0.5771154, 0.70688933, 0.57815194, 0.6854482, 0.70658773, 0.54950184, 0.72565895, 0.59068215, 0.6951169, 0.71074164, 0.7069677]\n"
]
}
],
"source": [
"print(\"Ground Truth Lables:\")\n",
"print(all_labels)\n",
"print(\"All_Predictions:\")\n",
"print(all_predictions)"
]
}
],
"metadata": {
Expand Down
45 changes: 40 additions & 5 deletions tutorials/FlexMol_Dataloading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
"from FlexMol.dataset import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Drug-Target Interaction (DTI) Data:** \n",
" You can load DTI data using the `load_DTI` function. Ensure that your file contains a header with at least the following columns:\n",
" - `Drug`\n",
" - `Protein`\n",
" - `Y` (Interaction label) \n",
" Optionally, include a `Protein_ID` column if you plan to use a protein structure encoder."
]
},
{
"cell_type": "code",
"execution_count": 18,
Expand Down Expand Up @@ -50,6 +62,17 @@
"print(DTI.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Drug-Drug Interaction (DDI) Data:** \n",
" You can load DDI data using the `load_DDI` function. This function supports any file format readable by `pd.read_csv`. Ensure that the first line of your file contains a header with at least the following columns:\n",
" - `Drug1`\n",
" - `Drug2`\n",
" - `Y` (Interaction label)"
]
},
{
"cell_type": "code",
"execution_count": 19,
Expand Down Expand Up @@ -85,6 +108,18 @@
"print(DDI.head())\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Protein-Protein Interaction (PPI) Data:** \n",
" You can load PPI data using the `load_PPI` function. Ensure that your file contains a header with at least the following columns:\n",
" - `Protein1`\n",
" - `Protein2`\n",
" - `Y` (Interaction label) \n",
" Optionally, include `Protein1_ID` and `Protein2_ID` columns if you plan to use protein structure encoders."
]
},
{
"cell_type": "code",
"execution_count": 20,
Expand Down Expand Up @@ -207,11 +242,11 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": []
"source": [
"Alternatively, you can load your custom data directly into a pandas DataFrame using your preferred method. As long as the DataFrame contains the required columns (`Drug1`, `Drug2`, `Protein`, `Y`, etc.), you can proceed with the FlexMol pipeline without any issues."
]
}
],
"metadata": {
Expand All @@ -230,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/FlexMol_TDC_Interface_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"# We load the DrugBank dataset from TDC and obtain the data splits (train, validation, and test).\n",
"\n",
"data = DDI(name='DrugBank')\n",
"split = data.get_split()\n",
"split = data.get_split() # choose other split methods if needed eg. scaffold, cold_protein, cold_drug, ...\n",
"\n",
"# Print the columns of the train split to understand the data structure.\n",
"print(\"Columns in the train split:\")\n",
Expand Down

0 comments on commit eaace97

Please sign in to comment.