Skip to content

Commit

Permalink
Merge branch 'rebase-txt2kg' into improve-system-prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Dec 12, 2024
2 parents 7708674 + 872547f commit 5807004
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| Example | Description |
| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |

| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/g_retriever_utils/rag_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ def check_retrieval_recall(subg: Data, ground_truth: Data):

pd.DataFrame.from_dict(retrieval_stats).to_csv(
args.out_file.split('.')[0] + '_metadata.csv')
torch.save(subgs, args.out_file)
torch.save(subgs, args.out_file)
2 changes: 1 addition & 1 deletion torch_geometric/utils/rag/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,4 @@ def apply_retrieval_via_pcst(
out_graph["triples"] = parsed_trips
return out_graph

return apply_retrieval_via_pcst
return apply_retrieval_via_pcst
2 changes: 1 addition & 1 deletion torch_geometric/utils/rag/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,4 @@ class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore):
def __init__(self, *args, **kwargs):
kwargs['model_name'] = kwargs.get(
'model_name', 'sentence-transformers/all-roberta-large-v1')
super().__init__(SentenceTransformer, *args, **kwargs)
super().__init__(SentenceTransformer, *args, **kwargs)
2 changes: 1 addition & 1 deletion torch_geometric/utils/rag/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def sample_subgraph(
node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
out = self.sampler.sample_from_nodes(node_sample_input)

return out
return out

0 comments on commit 5807004

Please sign in to comment.