Skip to content

Draft: TXT2KG w/ hotpot_qa.py and tech_qa.py examples #9846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 327 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
327 commits
Select commit Hold shift + click to select a range
703c83d
speedup indexing
puririshi98 Jan 10, 2025
538eeba
speedup indexing
puririshi98 Jan 10, 2025
f1a3e7a
speedup indexing
puririshi98 Jan 10, 2025
9be12dc
speedup indexing
puririshi98 Jan 10, 2025
3a0bee2
speedup indexing
puririshi98 Jan 10, 2025
1cc5091
speedup indexing
puririshi98 Jan 10, 2025
845f186
debug
puririshi98 Jan 10, 2025
40caac7
debug
puririshi98 Jan 10, 2025
f4a5edc
debug
puririshi98 Jan 10, 2025
218f418
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
741a750
debug
puririshi98 Jan 10, 2025
0f46bec
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 10, 2025
14e962d
debug
puririshi98 Jan 10, 2025
8dd0b70
debug
puririshi98 Jan 10, 2025
3b6ecbe
Update hotpot_qa.py
puririshi98 Jan 10, 2025
b0075cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0716668
debug
puririshi98 Jan 10, 2025
1ecdaa7
debug
puririshi98 Jan 10, 2025
21c5875
debug
puririshi98 Jan 10, 2025
88ce2b8
debug
puririshi98 Jan 10, 2025
1d3030e
debug
puririshi98 Jan 10, 2025
454bf0a
debug
puririshi98 Jan 10, 2025
e6057a8
debug
puririshi98 Jan 10, 2025
3841de2
debug
puririshi98 Jan 10, 2025
9c2850b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
ff63faf
debug
puririshi98 Jan 10, 2025
3acba5e
debug
puririshi98 Jan 10, 2025
4b1b06b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c207009
debug
puririshi98 Jan 10, 2025
4d71797
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 10, 2025
c2670ed
debug
puririshi98 Jan 10, 2025
19b91b3
debug
puririshi98 Jan 10, 2025
68833b2
debug
puririshi98 Jan 10, 2025
8c9a041
debug
puririshi98 Jan 10, 2025
0d3318b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
f0d3710
debug
puririshi98 Jan 10, 2025
f959254
debug
puririshi98 Jan 10, 2025
fa90418
debug
puririshi98 Jan 10, 2025
bf2e13e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
34fb3da
debug
puririshi98 Jan 10, 2025
803c29b
debug
puririshi98 Jan 10, 2025
00d1e7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
fd9f099
debug
puririshi98 Jan 10, 2025
518c5b5
debug
puririshi98 Jan 10, 2025
3632cec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
3a7d7b2
debug
puririshi98 Jan 10, 2025
1ca06cd
debug
puririshi98 Jan 10, 2025
231e9f4
debug
puririshi98 Jan 10, 2025
17933a0
debug
puririshi98 Jan 10, 2025
7db3248
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
7b2d2f6
debug
puririshi98 Jan 10, 2025
673f8fb
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 10, 2025
fc18b59
debug
puririshi98 Jan 10, 2025
9a09c98
debug
puririshi98 Jan 10, 2025
955304e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
bb6512a
debug
puririshi98 Jan 11, 2025
976d175
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 11, 2025
b6d80b0
debug
puririshi98 Jan 11, 2025
eba0173
debug
puririshi98 Jan 11, 2025
eeb2a43
debug
puririshi98 Jan 11, 2025
954490e
debug
puririshi98 Jan 11, 2025
8ac4974
debug
puririshi98 Jan 11, 2025
d1c9e13
debug
puririshi98 Jan 11, 2025
c3dc6f5
debug
puririshi98 Jan 11, 2025
8579159
debug
puririshi98 Jan 11, 2025
2670959
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2025
961ff44
debug
puririshi98 Jan 11, 2025
44e676a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 11, 2025
644bcc3
debug
puririshi98 Jan 11, 2025
44a17f3
debug
puririshi98 Jan 11, 2025
105783c
debug
puririshi98 Jan 11, 2025
c8415c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2025
628324d
debug
puririshi98 Jan 11, 2025
a815db3
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 11, 2025
1c75a75
debug
puririshi98 Jan 12, 2025
217ec15
debug
puririshi98 Jan 12, 2025
b6aa2c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2025
1f15921
debug
puririshi98 Jan 13, 2025
bb5d4b1
debug
puririshi98 Jan 13, 2025
f077e26
debug
puririshi98 Jan 13, 2025
7a68eb6
debug
puririshi98 Jan 13, 2025
777c1d1
debug
puririshi98 Jan 13, 2025
571e09c
debug
puririshi98 Jan 13, 2025
5aabf3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
b1c1dc4
debug
puririshi98 Jan 13, 2025
2eae672
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 13, 2025
d7dab0e
debug
puririshi98 Jan 13, 2025
5f24624
debug
puririshi98 Jan 13, 2025
ab187ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
6c2cf17
debug
puririshi98 Jan 13, 2025
4a83a6a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 13, 2025
335d628
debug
puririshi98 Jan 13, 2025
9a3967a
debug
puririshi98 Jan 14, 2025
9747aa2
debug
puririshi98 Jan 14, 2025
5a32171
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
8185846
debug
puririshi98 Jan 14, 2025
4e702da
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 14, 2025
0fffa39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
5b2f845
debug
puririshi98 Jan 14, 2025
8ac7adc
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 14, 2025
a00cc1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
d7a78db
debug
puririshi98 Jan 14, 2025
fe7e3e5
debug
puririshi98 Jan 14, 2025
9d873f4
debug
puririshi98 Jan 14, 2025
571b0e7
debug
puririshi98 Jan 14, 2025
e28c88f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
76db078
debug
puririshi98 Jan 14, 2025
f1815ee
debug
puririshi98 Jan 14, 2025
8359d78
debug
puririshi98 Jan 14, 2025
ba911b8
debug
puririshi98 Jan 14, 2025
dcf2cc0
debug
puririshi98 Jan 14, 2025
767526e
debug
puririshi98 Jan 14, 2025
1418cdc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
a5bd9d7
debug
puririshi98 Jan 14, 2025
bb7d0f9
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 14, 2025
bb26487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
04519b0
debug
puririshi98 Jan 14, 2025
ae84691
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 14, 2025
a760b5a
debug
puririshi98 Jan 14, 2025
f9e73bf
debug
puririshi98 Jan 14, 2025
e1f369a
debug
puririshi98 Jan 14, 2025
ef925a8
debug
puririshi98 Jan 14, 2025
6350622
debug
puririshi98 Jan 14, 2025
2020bd0
debug
puririshi98 Jan 14, 2025
675e744
system works pretty well now, adding approx recall
puririshi98 Jan 14, 2025
497e65b
debug
puririshi98 Jan 14, 2025
8371afb
debug
puririshi98 Jan 14, 2025
0a0a090
debug
puririshi98 Jan 14, 2025
2d00ed2
debug
puririshi98 Jan 14, 2025
cdc86cb
debug
puririshi98 Jan 14, 2025
57b61c7
debug
puririshi98 Jan 14, 2025
f0f8ae9
debug
puririshi98 Jan 14, 2025
9f16ec4
debug
puririshi98 Jan 14, 2025
fbc8243
final fix
puririshi98 Jan 14, 2025
c96e9d3
removed prints, done debugging. will tune hyperparams for retriever p…
puririshi98 Jan 14, 2025
c76686e
removed prints, done debugging. will tune hyperparams for retriever p…
puririshi98 Jan 14, 2025
fa445a5
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 15, 2025
3ecccda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
63ea8b2
removing unneeded comments
puririshi98 Jan 15, 2025
2c8ec46
cleaning
puririshi98 Jan 15, 2025
7954786
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
5ab3299
Update graph_store.py
puririshi98 Jan 16, 2025
f9c0e9c
clean
puririshi98 Jan 16, 2025
423891d
improving defaults for hotpotQA retrieval, will continue to tune afte…
puririshi98 Jan 16, 2025
bb18b77
updating txt2kg system prompt, last sentence caused wierd behavior
puririshi98 Jan 16, 2025
5addd93
Update hotpot_qa.py
puririshi98 Jan 16, 2025
a5faf9b
Update hotpot_qa.py
puririshi98 Jan 16, 2025
a7e044b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2025
8f984d1
Update txt2kg.py
puririshi98 Jan 17, 2025
d44f280
fixing typo
puririshi98 Jan 20, 2025
6e3aefc
Update backend_utils.py
puririshi98 Jan 21, 2025
9e8d5cb
fix for bellow issue
puririshi98 Jan 21, 2025
4eb0ed8
Update backend_utils.py
puririshi98 Jan 21, 2025
fa5c659
Update txt2kg.py
puririshi98 Jan 21, 2025
0969d67
commenting and usability
puririshi98 Jan 22, 2025
d5e49a6
commenting and usability
puririshi98 Jan 22, 2025
0cc3ef5
commenting and usability
puririshi98 Jan 22, 2025
0f19922
skip pcst for bad graphs
puririshi98 Jan 22, 2025
4f28162
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 22, 2025
e7529a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
ec702db
commenting and usability
puririshi98 Jan 22, 2025
3a7cc3f
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 22, 2025
29523fe
commenting and usability
puririshi98 Jan 22, 2025
3904f51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
1ed1c18
commenting and usability
puririshi98 Jan 22, 2025
718d82f
commenting and usability
puririshi98 Jan 22, 2025
ff9c3fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
5521aea
commenting and usability
puririshi98 Jan 22, 2025
6de0788
commenting and usability
puririshi98 Jan 22, 2025
7309bc6
commenting and usability
puririshi98 Jan 22, 2025
3c93c41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
88a7986
drafting techqa
puririshi98 Jan 22, 2025
8f7a7dd
drafting techqa
puririshi98 Jan 22, 2025
6c9a0db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
c1aad3f
drafting techqa
puririshi98 Jan 22, 2025
dd026fb
drafting techqa
puririshi98 Jan 22, 2025
1deb4f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
e1581e1
drafting techqa
puririshi98 Jan 22, 2025
400eea0
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 22, 2025
0cbd525
drafting techqa
puririshi98 Jan 22, 2025
2484d2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
afb0ac5
drafting techqa
puririshi98 Jan 22, 2025
741bf2e
drafting techqa
puririshi98 Jan 22, 2025
f8c0ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
c7b5462
drafting techqa
puririshi98 Jan 22, 2025
b491aae
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 22, 2025
75157bd
drafting techqa
puririshi98 Jan 22, 2025
04c80a0
drafting techqa
puririshi98 Jan 22, 2025
2826852
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
0b081fe
drafting techqa
puririshi98 Jan 22, 2025
3bba220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2025
f0aec29
drafting techqa
puririshi98 Jan 22, 2025
db6cee7
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 22, 2025
6bbb6ae
drafting techqa
puririshi98 Jan 22, 2025
9cf3a89
drafting techqa
puririshi98 Jan 22, 2025
451ae4d
add techqa to Readme
puririshi98 Jan 22, 2025
2111b02
ai cleanup (#9970)
puririshi98 Jan 23, 2025
2ff6a85
ai cleanup for utils (#9971)
puririshi98 Jan 23, 2025
afe44d6
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 23, 2025
c7a5e4d
fix for "Note:"
puririshi98 Jan 24, 2025
024cabe
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 24, 2025
cdc492f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2025
4693e53
Update txt2kg.py
puririshi98 Jan 27, 2025
cadaa9d
Add eval (#9976)
puririshi98 Jan 27, 2025
e92c0e7
Merge branch 'master' into rebase-txt2kg
puririshi98 Jan 27, 2025
8249ba1
Update examples/llm/hotpot_qa.py
puririshi98 Jan 27, 2025
f134106
address reviews
puririshi98 Jan 27, 2025
b9a55cf
Update hotpot_qa.py
puririshi98 Jan 27, 2025
3dbcf8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
2519e67
drafting
puririshi98 Jan 27, 2025
45214cd
drafting
puririshi98 Jan 27, 2025
8b77eaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
7747b12
drafting
puririshi98 Jan 27, 2025
5addc14
drafting
puririshi98 Jan 27, 2025
a38d489
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
45133b7
drafting
puririshi98 Jan 27, 2025
cae97e7
drafting
puririshi98 Jan 27, 2025
e109dc5
drafting
puririshi98 Jan 27, 2025
29109a4
drafting
puririshi98 Jan 27, 2025
3b92cf1
drafting
puririshi98 Jan 27, 2025
8712e8b
drafting
puririshi98 Jan 27, 2025
a1c0cc4
drafting
puririshi98 Jan 27, 2025
0f744a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
98240cd
drafting
puririshi98 Jan 27, 2025
fb9e378
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 27, 2025
15fe2bd
drafting
puririshi98 Jan 27, 2025
941b745
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
473577d
drafting
puririshi98 Jan 27, 2025
0c664f1
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 27, 2025
055527b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2025
77d5951
drafting
puririshi98 Jan 27, 2025
db4a24a
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 27, 2025
d99a23c
drafting
puririshi98 Jan 27, 2025
c4400f1
drafting
puririshi98 Jan 28, 2025
10da58b
drafting
puririshi98 Jan 28, 2025
b977e9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2025
5240ca0
cleaning
puririshi98 Jan 28, 2025
4a2e86b
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 28, 2025
c3654c8
cleaning
puririshi98 Jan 28, 2025
8b8f00f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2025
7bc4bcb
cleaning
puririshi98 Jan 28, 2025
f6c81f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2025
906e82c
cleaning
puririshi98 Jan 28, 2025
cf6162c
Merge branch 'rebase-txt2kg' of https://github.com/pyg-team/pytorch_g…
puririshi98 Jan 28, 2025
a2ea3cd
cleaning
puririshi98 Jan 28, 2025
8883238
cleaning
puririshi98 Jan 28, 2025
11194ee
cleaning
puririshi98 Jan 28, 2025
ea15dc9
cleaning
puririshi98 Jan 28, 2025
87d7b70
cleaning
puririshi98 Jan 28, 2025
5d9b34e
Nim sent transformer and better comments for senttrans (#9990)
puririshi98 Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| Example | Description |
| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |

| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
| [`hotpot_qa.py`](./hotpot_qa.py) | Example for converting adapting the retrieval step of conventional Retrieval-Augmented Generation (RAG) for use with G-retriever, and how to approximate the precision/recall of a subgraph retrieval method. Uses the HotPotQA dataset as an example since it is multihop in nature. |
55 changes: 13 additions & 42 deletions examples/llm/g_retriever_utils/rag_generate.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
# %%
import argparse
from itertools import chain
from typing import Tuple

import pandas as pd
import torch
import tqdm
from rag_backend_utils import create_remote_backend_from_triplets
from rag_feature_store import SentenceTransformerFeatureStore
from rag_graph_store import NeighborSamplingRAGGraphStore

from torch_geometric.data import Data
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.datasets.web_qsp_dataset import (
preprocess_triplet,
retrieval_via_pcst,
)
from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet
from torch_geometric.loader import RAGQueryLoader
from torch_geometric.nn.nlp import SentenceTransformer
from torch_geometric.utils.rag.backend_utils import (
create_remote_backend_from_triplets,
make_pcst_filter,
)
from torch_geometric.utils.rag.feature_store import (
SentenceTransformerFeatureStore,
)
from torch_geometric.utils.rag.graph_store import NeighborSamplingRAGGraphStore

# %%
parser = argparse.ArgumentParser(
description="""Generate new WebQSP subgraphs\n""" +
"""NOTE: Evaluating with smaller samples may result in""" +
""" poorer performance for the trained models compared""" +
""" to untrained models.""")
parser = argparse.ArgumentParser(description="""Generate new WebQSP subgraphs
NOTE: Evaluating with smaller samples may result in poorer performance for the trained models compared to untrained models."""
)
# TODO: Add more arguments for configuring rag params
parser.add_argument("--use_pcst", action="store_true")
parser.add_argument("--num_samples", type=int, default=4700)
parser.add_argument("--out_file", default="subg_results.pt")
args = parser.parse_args()
Expand Down Expand Up @@ -56,37 +54,10 @@
feature_db=SentenceTransformerFeatureStore).load()

# %%


def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3,
topk_e: int = 3,
cost_e: float = 0.5) -> Tuple[Data, str]:
q_emb = model.encode(query)
textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes,
textual_edges, topk, topk_e, cost_e)
out_graph["desc"] = desc
return out_graph


def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]:
textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index()
textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index()
desc = (
textual_nodes.to_csv(index=False) + "\n" +
textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"]))
graph["desc"] = desc
return graph


transform = apply_retrieval_via_pcst \
if args.use_pcst else apply_retrieval_with_text

query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5},
seed_edges_kwargs={"k_edges": 5},
sampler_kwargs={"num_neighbors": [50] * 2},
local_filter=transform)
local_filter=make_pcst_filter(triplets, model))


# %%
Expand Down
128 changes: 128 additions & 0 deletions examples/llm/hotpot_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import os
from itertools import chain

import datasets
import torch
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.loader import RAGQueryLoader
from torch_geometric.nn.nlp import TXT2KG, SentenceTransformer
from torch_geometric.utils.rag.backend_utils import (
create_remote_backend_from_triplets,
make_pcst_filter,
preprocess_triplet,
)
from torch_geometric.utils.rag.feature_store import (
SentenceTransformerFeatureStore,
)
from torch_geometric.utils.rag.graph_store import NeighborSamplingRAGGraphStore

if __name__ == '__main__':
seed_everything(50)
parser = argparse.ArgumentParser()
parser.add_argument('--NV_NIM_MODEL', type=str,
default="nvidia/llama-3.1-nemotron-70b-instruct")
parser.add_argument('--NV_NIM_KEY', type=str, default="")
parser.add_argument('--local_lm', action="store_true")
parser.add_argument('--percent_data', type=float, default=1.0)
parser.add_argument('--chunk_size', type=int, default=512)
parser.add_argument('--verbose', action="store_true")
args = parser.parse_args()
assert args.percent_data <= 100 and args.percent_data > 0
if args.local_lm:
kg_maker = TXT2KG(
local_LM=True,
chunk_size=args.chunk_size,
)
else:
kg_maker = TXT2KG(
NVIDIA_NIM_MODEL=args.NV_NIM_MODEL,
NVIDIA_API_KEY=args.NV_NIM_KEY,
chunk_size=args.chunk_size,
)
if os.path.exists("hotpot_kg.pt"):
print("Re-using existing KG...")
relevant_triples = torch.load("hotpot_kg.pt")
else:
# Use training set for simplicity since our retrieval method is nonparametric
raw_dataset = datasets.load_dataset('hotpotqa/hotpot_qa', 'fullwiki',
trust_remote_code=True)["train"]
# Build KG
num_data_pts = len(raw_dataset)
data_idxs = torch.randperm(num_data_pts)[0:int(num_data_pts *
args.percent_data /
100.0)]
for idx in tqdm(data_idxs, desc="Building KG"):
data_point = raw_dataset[int(idx)]
q = data_point["question"]
a = data_point["answer"]
context_doc = ''
for i in data_point["context"]["sentences"]:
for sentence in i:
context_doc += sentence

QA_pair = (q, a)
kg_maker.add_doc_2_KG(
txt=context_doc,
QA_pair=QA_pair,
)
kg_maker.save_kg("hotpot_kg.pt")
relevant_triples = kg_maker.relevant_triples
if args.local_lm:
print("Total number of context characters parsed by LLM",
kg_maker.total_chars_parsed)
print(
"Average number of context characters parsed by LLM per second=",
kg_maker.avg_chars_parsed_per_sec)

print("Size of KG (number of triples) =",
sum([len(rel_trips) for rel_trips in relevant_triples.values()]))

triples = list(
chain.from_iterable(triple_set
for triple_set in relevant_triples.values()))
# debug
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer(
model_name='sentence-transformers/all-roberta-large-v1').to(device)
fs, gs = create_remote_backend_from_triplets(
triplets=triples, node_embedding_model=model,
node_method_to_call="encode", path="backend",
pre_transform=preprocess_triplet, node_method_kwargs={
"batch_size": min(len(triples), 256)
}, graph_db=NeighborSamplingRAGGraphStore,
feature_db=SentenceTransformerFeatureStore).load()
query_loader = RAGQueryLoader(
data=(fs, gs), seed_nodes_kwargs={"k_nodes":
5}, seed_edges_kwargs={"k_edges": 5},
sampler_kwargs={"num_neighbors": [50] * 2},
local_filter=make_pcst_filter(triples, model))
"""
approx precision = num_relevant_out_of_retrieved/num_retrieved_triples
We will use precision as a proxy for recall. This is because for recall,
we must know how many relevant triples exist for each question,
but this is not known.
"""
precisions = []
for QA_pair in relevant_triples.keys():
golden_triples = relevant_triples[QA_pair]
q = QA_pair[0]
retrieved_subgraph = query_loader.query(q)
retrieved_triples = retrieved_subgraph.triples

if args.verbose:
print("Q=", q)
print("A=", QA_pair[1])
print("retrieved_triples =", retrieved_triples)

num_relevant_out_of_retrieved = float(
sum([
int(bool(retrieved_triple in golden_triples))
for retrieved_triple in retrieved_triples
]))
precisions.append(num_relevant_out_of_retrieved /
len(retrieved_triples))
approx_precision = sum(precisions) / len(precisions)
print("approx_precision =", approx_precision)
9 changes: 0 additions & 9 deletions examples/llm/multihop_rag/README.md

This file was deleted.

12 changes: 0 additions & 12 deletions examples/llm/multihop_rag/multihop_download.sh

This file was deleted.

Loading
Loading