Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: TXT2KG w/ hotpot_qa.py and tech_qa.py examples #9846

Open
wants to merge 279 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
279 commits
Select commit Hold shift + click to select a range
5775885
note about retriever bug for customers taking early look
puririshi98 Jan 7, 2025
032c2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c010843
adding tqdm to retrieval step
Jan 7, 2025
02298b4
comments and clean up (#9922)
puririshi98 Jan 7, 2025
09c6013
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 7, 2025
dc795a6
Update CHANGELOG.md
puririshi98 Jan 7, 2025
d6c8eeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
325ae68
cleanup (#9923)
puririshi98 Jan 8, 2025
8ffda39
cleanup
Jan 8, 2025
737ff16
removing rag generate example which is now subsumed by hotpot qa example
Jan 8, 2025
aec5398
reorg of readme
Jan 8, 2025
5c37dba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
1097ff5
updating hyperparams for knn+neighborsampling, necesarry to get the g…
Jan 8, 2025
857d35a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 8, 2025
8074555
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
b1ff4b7
updating hyperparams for knn+neighborsampling, necesarry to get the g…
Jan 9, 2025
6960715
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 9, 2025
5b8e0f5
commenting out part that makes it so that the selected src/dst nodes …
Jan 9, 2025
4a5ee11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
4c280ed
adding comments about retriever hyperparams being important, tuning m…
Jan 9, 2025
ec71020
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 9, 2025
53572aa
removing seed nodes from the ragquerylaoder and feature store since t…
Jan 9, 2025
0de2b2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
d4295c0
hyperparams for pcst
Jan 10, 2025
4097e45
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
5a8111e
relation should be lower to in preprocess triplets
Jan 10, 2025
c2d5d7e
retriever hyperparams
Jan 10, 2025
f9c522f
potential bugfix and some cleanups
Jan 10, 2025
f68b8f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
270cd2c
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 10, 2025
103aeb1
potential bugfix and some cleanups
Jan 10, 2025
b0a7049
typo fix
Jan 10, 2025
7cd5138
typo fix
Jan 10, 2025
5fabf44
typo fix
Jan 10, 2025
0999e53
typo fix
Jan 10, 2025
6b74594
typo fix
Jan 10, 2025
8964588
tqdm for triplets
Jan 10, 2025
02c53ed
tqdm for triplets
Jan 10, 2025
139e0f8
tqdm for triplets
Jan 10, 2025
9da04ba
not using seed edges
Jan 10, 2025
e8dd95d
not using seed edges
Jan 10, 2025
d092411
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
b7742c0
speedup indexing
Jan 10, 2025
ece8b61
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
6d95bbf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
4bd4e08
speedup indexing
Jan 10, 2025
3e0a0de
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
725cab0
speedup indexing
Jan 10, 2025
703c83d
speedup indexing
Jan 10, 2025
538eeba
speedup indexing
Jan 10, 2025
f1a3e7a
speedup indexing
Jan 10, 2025
9be12dc
speedup indexing
Jan 10, 2025
3a0bee2
speedup indexing
Jan 10, 2025
1cc5091
speedup indexing
Jan 10, 2025
845f186
debug
Jan 10, 2025
40caac7
debug
Jan 10, 2025
f4a5edc
debug
Jan 10, 2025
218f418
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
741a750
debug
Jan 10, 2025
0f46bec
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
14e962d
debug
Jan 10, 2025
8dd0b70
debug
Jan 10, 2025
3b6ecbe
Update hotpot_qa.py
puririshi98 Jan 10, 2025
b0075cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0716668
debug
Jan 10, 2025
1ecdaa7
debug
Jan 10, 2025
21c5875
debug
Jan 10, 2025
88ce2b8
debug
Jan 10, 2025
1d3030e
debug
Jan 10, 2025
454bf0a
debug
Jan 10, 2025
e6057a8
debug
Jan 10, 2025
3841de2
debug
Jan 10, 2025
9c2850b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
ff63faf
debug
Jan 10, 2025
3acba5e
debug
Jan 10, 2025
4b1b06b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c207009
debug
Jan 10, 2025
4d71797
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
c2670ed
debug
Jan 10, 2025
19b91b3
debug
Jan 10, 2025
68833b2
debug
Jan 10, 2025
8c9a041
debug
Jan 10, 2025
0d3318b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
f0d3710
debug
Jan 10, 2025
f959254
debug
Jan 10, 2025
fa90418
debug
Jan 10, 2025
bf2e13e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
34fb3da
debug
Jan 10, 2025
803c29b
debug
Jan 10, 2025
00d1e7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
fd9f099
debug
Jan 10, 2025
518c5b5
debug
Jan 10, 2025
3632cec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
3a7d7b2
debug
Jan 10, 2025
1ca06cd
debug
Jan 10, 2025
231e9f4
debug
Jan 10, 2025
17933a0
debug
Jan 10, 2025
7db3248
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
7b2d2f6
debug
Jan 10, 2025
673f8fb
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 10, 2025
fc18b59
debug
Jan 10, 2025
9a09c98
debug
Jan 10, 2025
955304e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
bb6512a
debug
Jan 11, 2025
976d175
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 11, 2025
b6d80b0
debug
Jan 11, 2025
eba0173
debug
Jan 11, 2025
eeb2a43
debug
Jan 11, 2025
954490e
debug
Jan 11, 2025
8ac4974
debug
Jan 11, 2025
d1c9e13
debug
Jan 11, 2025
c3dc6f5
debug
Jan 11, 2025
8579159
debug
Jan 11, 2025
2670959
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2025
961ff44
debug
Jan 11, 2025
44e676a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 11, 2025
644bcc3
debug
Jan 11, 2025
44a17f3
debug
Jan 11, 2025
105783c
debug
Jan 11, 2025
c8415c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2025
628324d
debug
Jan 11, 2025
a815db3
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 11, 2025
1c75a75
debug
Jan 12, 2025
217ec15
debug
Jan 12, 2025
b6aa2c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2025
1f15921
debug
Jan 13, 2025
bb5d4b1
debug
Jan 13, 2025
f077e26
debug
Jan 13, 2025
7a68eb6
debug
Jan 13, 2025
777c1d1
debug
Jan 13, 2025
571e09c
debug
Jan 13, 2025
5aabf3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
b1c1dc4
debug
Jan 13, 2025
2eae672
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 13, 2025
d7dab0e
debug
Jan 13, 2025
5f24624
debug
Jan 13, 2025
ab187ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
6c2cf17
debug
Jan 13, 2025
4a83a6a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 13, 2025
335d628
debug
Jan 13, 2025
9a3967a
debug
Jan 14, 2025
9747aa2
debug
Jan 14, 2025
5a32171
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
8185846
debug
Jan 14, 2025
4e702da
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 14, 2025
0fffa39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
5b2f845
debug
Jan 14, 2025
8ac7adc
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 14, 2025
a00cc1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
d7a78db
debug
Jan 14, 2025
fe7e3e5
debug
Jan 14, 2025
9d873f4
debug
Jan 14, 2025
571b0e7
debug
Jan 14, 2025
e28c88f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
76db078
debug
Jan 14, 2025
f1815ee
debug
Jan 14, 2025
8359d78
debug
Jan 14, 2025
ba911b8
debug
Jan 14, 2025
dcf2cc0
debug
Jan 14, 2025
767526e
debug
Jan 14, 2025
1418cdc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
a5bd9d7
debug
Jan 14, 2025
bb7d0f9
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 14, 2025
bb26487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
04519b0
debug
Jan 14, 2025
ae84691
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 14, 2025
a760b5a
debug
Jan 14, 2025
f9e73bf
debug
Jan 14, 2025
e1f369a
debug
Jan 14, 2025
ef925a8
debug
Jan 14, 2025
6350622
debug
Jan 14, 2025
2020bd0
debug
Jan 14, 2025
675e744
system works pretty well now, adding approx recall
Jan 14, 2025
497e65b
debug
Jan 14, 2025
8371afb
debug
Jan 14, 2025
0a0a090
debug
Jan 14, 2025
2d00ed2
debug
Jan 14, 2025
cdc86cb
debug
Jan 14, 2025
57b61c7
debug
Jan 14, 2025
f0f8ae9
debug
Jan 14, 2025
9f16ec4
debug
Jan 14, 2025
fbc8243
final fix
Jan 14, 2025
c96e9d3
removed prints, done debugging. will tune hyperparams for retriever p…
Jan 14, 2025
c76686e
removed prints, done debugging. will tune hyperparams for retriever p…
Jan 14, 2025
fa445a5
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 15, 2025
3ecccda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
63ea8b2
removing unneeded comments
puririshi98 Jan 15, 2025
2c8ec46
cleaning
puririshi98 Jan 15, 2025
7954786
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
5ab3299
Update graph_store.py
puririshi98 Jan 16, 2025
f9c0e9c
clean
puririshi98 Jan 16, 2025
423891d
improving defaults for hotpotQA retrieval, will continue to tune afte…
puririshi98 Jan 16, 2025
bb18b77
updating txt2kg system prompt, last sentence caused wierd behavior
puririshi98 Jan 16, 2025
5addd93
Update hotpot_qa.py
puririshi98 Jan 16, 2025
a5faf9b
Update hotpot_qa.py
puririshi98 Jan 16, 2025
a7e044b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2025
8f984d1
Update txt2kg.py
puririshi98 Jan 17, 2025
d44f280
fixing typo
puririshi98 Jan 20, 2025
6e3aefc
Update backend_utils.py
puririshi98 Jan 21, 2025
9e8d5cb
fix for bellow issue
puririshi98 Jan 21, 2025
4eb0ed8
Update backend_utils.py
puririshi98 Jan 21, 2025
fa5c659
Update txt2kg.py
puririshi98 Jan 21, 2025
0969d67
commenting and usability
Jan 22, 2025
d5e49a6
commenting and usability
Jan 22, 2025
0cc3ef5
commenting and usability
Jan 22, 2025
0f19922
skip pcst for bad graphs
Jan 22, 2025
4f28162
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 22, 2025
e7529a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
ec702db
commenting and usability
Jan 22, 2025
3a7cc3f
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 22, 2025
29523fe
commenting and usability
Jan 22, 2025
3904f51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
1ed1c18
commenting and usability
Jan 22, 2025
718d82f
commenting and usability
Jan 22, 2025
ff9c3fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
5521aea
commenting and usability
Jan 22, 2025
6de0788
commenting and usability
Jan 22, 2025
7309bc6
commenting and usability
Jan 22, 2025
3c93c41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
88a7986
drafting techqa
Jan 22, 2025
8f7a7dd
drafting techqa
Jan 22, 2025
6c9a0db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
c1aad3f
drafting techqa
Jan 22, 2025
dd026fb
drafting techqa
Jan 22, 2025
1deb4f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
e1581e1
drafting techqa
Jan 22, 2025
400eea0
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 22, 2025
0cbd525
drafting techqa
Jan 22, 2025
2484d2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
afb0ac5
drafting techqa
Jan 22, 2025
741bf2e
drafting techqa
Jan 22, 2025
f8c0ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
c7b5462
drafting techqa
Jan 22, 2025
b491aae
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 22, 2025
75157bd
drafting techqa
Jan 22, 2025
04c80a0
drafting techqa
Jan 22, 2025
2826852
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
0b081fe
drafting techqa
Jan 22, 2025
3bba220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
f0aec29
drafting techqa
Jan 22, 2025
db6cee7
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
Jan 22, 2025
6bbb6ae
drafting techqa
Jan 22, 2025
9cf3a89
drafting techqa
Jan 22, 2025
451ae4d
add techqa to Readme
Jan 22, 2025
2111b02
ai cleanup (#9970)
puririshi98 Jan 23, 2025
2ff6a85
ai cleanup for utils (#9971)
puririshi98 Jan 23, 2025
afe44d6
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 23, 2025
c7a5e4d
fix for "Note:"
puririshi98 Jan 24, 2025
024cabe
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 24, 2025
cdc492f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Adds TXT2KG class with example on HotPotQA ([#9846](https://github.com/pyg-team/pytorch_geometric/pull/9846))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
Expand Down
19 changes: 10 additions & 9 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
| Example | Description |
| -------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
| [`hotpot_qa.py`](./hotpot_qa.py) | Example for converting adapting the retrieval step of conventional Retrieval-Augmented Generation (RAG) for use with G-retriever, and how to approximate the precision/recall of a subgraph retrieval method. Uses the HotPotQA dataset from [Hugging Face](https://huggingface.co/datasets/hotpotqa/hotpot_qa). This is it is multihop in nature. |
| [`tech_qa.py`](./tech_qa.py) | Full end 2 end GraphRAG pipeline combining txt2kg and retrieval from `hotpot_qa.py` and training/testing from g_retriever.py. Uses the techQA dataset from [Hugging Face](https://huggingface.co/datasets/rojagtap/tech-qa) |
52 changes: 27 additions & 25 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def load_params_dict(model, save_path):
return model


def get_loss(model, batch, model_save_name: str) -> Tensor:
def get_loss(model, batch, model_save_name="gnn+llm") -> Tensor:
"""Compute the loss for a given model and batch of data.

Args:
Expand All @@ -158,7 +158,7 @@ def get_loss(model, batch, model_save_name: str) -> Tensor:
)


def inference_step(model, batch, model_save_name):
def inference_step(model, batch, model_save_name="gnn+llm"):
"""Performs inference on a given batch of data using the provided model.

Args:
Expand All @@ -184,6 +184,29 @@ def inference_step(model, batch, model_save_name):
)


def adjust_learning_rate(param_group, LR, epoch, num_epochs):
"""Decay learning rate with half-cycle cosine after warmup.

Args:
param_group (dict): Parameter group.
LR (float): Learning rate.
epoch (int): Current epoch.

Returns:
float: Adjusted learning rate.
"""
min_lr = 5e-6
warmup_epochs = 1
if epoch < warmup_epochs:
lr = LR
else:
lr = min_lr + (LR - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * (epoch - warmup_epochs) /
(num_epochs - warmup_epochs)))
param_group['lr'] = lr
return lr


def train(
num_epochs, # Total number of training epochs
hidden_channels, # Number of hidden channels in GNN
Expand Down Expand Up @@ -216,28 +239,6 @@ def train(
Returns:
None
"""
def adjust_learning_rate(param_group, LR, epoch):
"""Decay learning rate with half-cycle cosine after warmup.

Args:
param_group (dict): Parameter group.
LR (float): Learning rate.
epoch (int): Current epoch.

Returns:
float: Adjusted learning rate.
"""
min_lr = 5e-6
warmup_epochs = 1
if epoch < warmup_epochs:
lr = LR
else:
lr = min_lr + (LR - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * (epoch - warmup_epochs) /
(num_epochs - warmup_epochs)))
param_group['lr'] = lr
return lr

# Start training time
start_time = time.time()

Expand Down Expand Up @@ -322,7 +323,8 @@ def adjust_learning_rate(param_group, LR, epoch):

if (step + 1) % 2 == 0:
adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)
step / len(train_loader) + epoch,
num_epochs)

optimizer.step()
epoch_loss = epoch_loss + float(loss)
Expand Down
11 changes: 4 additions & 7 deletions examples/llm/g_retriever_utils/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
# Examples for LLM and GNN co-training

| Example | Description |
| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore |
| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) |
| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. |
| Example | Description |
| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ |
| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. |
| [`minimal_demo.py`](./minimal_demo.py) | Minimal demo for WebQSP dataset comparing GNN+LLM vs LLM |

NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes.
139 changes: 0 additions & 139 deletions examples/llm/g_retriever_utils/rag_generate.py

This file was deleted.

Loading
Loading