Skip to content

Commit

Permalink
Feat/rag (#26) (#27)
Browse files Browse the repository at this point in the history
RAG implementation:

LOTR (Merger Retriever)
Agentic Routing RAG
Agentic Tool Use RAG
  • Loading branch information
haruiz authored Jun 27, 2024
1 parent 91251a4 commit a3bc0cb
Show file tree
Hide file tree
Showing 42 changed files with 1,991 additions and 366 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ poetry.lock
lib64/
sdist/
var/
.python-version
wheels/
.cache/
share/python-wheels/
Expand All @@ -30,6 +31,7 @@ share/python-wheels/
.idea
.DS_Store
MANIFEST
data/vis-language-model.pdf


# PyInstaller
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ response = gemini_client.generate_response("models/gemini-1.5-pro-latest", multi
generation_config={"temperature": 0.0, "top_p": 1.0})
# Print the response
for candidate in response.candidates:
for part in candidate.content.parts:
for part in candidate._search_content_type.parts:
if part.text:
print(part.text)
```
Expand Down
Binary file modified data/.DS_Store
Binary file not shown.
892 changes: 892 additions & 0 deletions data/titanic.csv

Large diffs are not rendered by default.

Binary file added data/vis-language-model.pdf
Binary file not shown.
2 changes: 2 additions & 0 deletions examples/chat_wit_your_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from geminiplayground.core import GeminiClient
from geminiplayground.parts import GitRepo
from dotenv import load_dotenv, find_dotenv
from geminiplayground.catching import cache

load_dotenv(find_dotenv())

Expand All @@ -12,6 +13,7 @@ def chat_wit_your_code():
Get the content parts of a github repo and generate a request.
:return:
"""
cache.clear()
repo = GitRepo.from_url(
"https://github.com/karpathy/ng-video-lecture",
branch="master",
Expand Down
26 changes: 26 additions & 0 deletions examples/chat_with_multiple_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from rich import print

from dotenv import load_dotenv, find_dotenv
from pathlib import Path

from geminiplayground.parts import ImageFile
from geminiplayground.catching import cache

from geminiplayground.core import GeminiClient

load_dotenv(find_dotenv())

if __name__ == '__main__':
cache.clear()

gemini_client = GeminiClient()
images = [ImageFile(image_file, gemini_client=gemini_client) for image_file in Path("./../data").glob("*.jpg")]
prompt = ["Please describe the following images:"] + images

model_name = "models/gemini-1.5-pro-latest"
tokens_count = gemini_client.count_tokens(model_name, prompt)
print(f"Tokens count: {tokens_count}")
response = gemini_client.generate_response(model_name, prompt, stream=True)
for message_chunk in response:
if message_chunk.parts:
print(message_chunk.text, end="")
6 changes: 3 additions & 3 deletions examples/chat_with_your_audios.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def chat_wit_your_audios():
audio_file_path = "./../data/audio_example.mp3"
gemini_client = GeminiClient()
audio_file = AudioFile(audio_file_path, gemini_client=gemini_client)
# audio_file.delete()
# audio_file.clear_cache()
prompt = ["Listen this audio:", audio_file, "Describe what you heard"]
model_name = "models/gemini-1.5-pro-latest"
model_name = "models/gemini-1.5-flash-latest"
tokens_count = gemini_client.count_tokens(model_name, prompt)
print(f"Tokens count: {tokens_count}")
response = gemini_client.generate_response(model_name, prompt, stream=True)
for message_chunk in response:
if message_chunk.parts:
print(message_chunk.text)
print(message_chunk.text, end="")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/chat_with_your_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def chat_wit_your_images():
image_file_path = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
image_file = ImageFile(image_file_path, gemini_client=gemini_client)
prompt = ["what do you see in this image?", image_file]
model_name = "models/gemini-1.5-pro-latest"
model_name = "models/gemini-1.5-flash-latest"
tokens_count = gemini_client.count_tokens(model_name, prompt)
print(f"Tokens count: {tokens_count}")
response = gemini_client.generate_response(model_name, prompt, stream=True)
for message_chunk in response:
if message_chunk.parts:
print(message_chunk.text)
print(message_chunk.text, end="")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/chat_with_your_pdf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from rich import print

from geminiplayground.core import GeminiClient
from geminiplayground.parts import PdfFile
from dotenv import load_dotenv, find_dotenv
from geminiplayground.catching import cache

load_dotenv(find_dotenv())

Expand All @@ -12,10 +11,11 @@ def chat_wit_your_pdf():
Get the content parts of a pdf file and generate a request.
:return:
"""
cache.clear()
gemini_client = GeminiClient()
pdf_file_path = "https://www.tnstate.edu/faculty/fyao/COMP3050/Py-tutorial.pdf"
# pdf_file_path = "./../data/vis-language-model.pdf"
pdf_file = PdfFile(pdf_file_path, gemini_client=gemini_client)

prompt = ["Please create a summary of the pdf file:", pdf_file]
model_name = "models/gemini-1.5-pro-latest"
tokens_count = gemini_client.count_tokens(model_name, prompt)
Expand Down
4 changes: 2 additions & 2 deletions examples/chat_with_your_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def chat_wit_your_video():
:return:
"""
gemini_client = GeminiClient()
model_name = "models/gemini-1.5-pro-latest"
model_name = "models/gemini-1.5-flash-latest"

video_file_path = "./../data/transformers-explained.mp4"
video_file = VideoFile(video_file_path, gemini_client=gemini_client)
Expand All @@ -30,7 +30,7 @@ def chat_wit_your_video():
response = gemini_client.generate_response(model_name, prompt, stream=True)
for message_chunk in response:
if message_chunk.parts:
print(message_chunk.text)
print(message_chunk.text, end="")


if __name__ == "__main__":
Expand Down
42 changes: 42 additions & 0 deletions examples/code_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rich import print

from geminiplayground.core import GeminiClient
from geminiplayground.parts import GitRepo
from dotenv import load_dotenv, find_dotenv
from geminiplayground.catching import cache

load_dotenv(find_dotenv())

cache.clear()


def chat_wit_your_code():
"""
Get the content parts of a github repo and generate a request.
:return:
"""
repo = GitRepo.from_url(
"https://github.com/mhdawson/node-core-utils.git",
branch="main",
config={
"content": "code-files"
},
)
prompt = [
"Describe the following codebase:",
repo
]
model = "models/gemini-1.5-flash-latest"
gemini_client = GeminiClient()
tokens_count = gemini_client.count_tokens(model, prompt)
print("Tokens count: ", tokens_count)
response = gemini_client.generate_response(model, prompt, stream=True)

# Print the response
for message_chunk in response:
if message_chunk.parts:
print(message_chunk.text, "")


if __name__ == "__main__":
chat_wit_your_code()
3 changes: 2 additions & 1 deletion examples/gemini_client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
if __name__ == "__main__":
gemini_client = GeminiClient()
files = gemini_client.query_files(page_size=5)
print(files)
for file in files:
print(file)
2 changes: 1 addition & 1 deletion examples/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@

# Print the response
for candidate in response.candidates:
for part in candidate.content.parts:
for part in candidate._search_content_type.parts:
if part.text:
print(part.text)
10 changes: 1 addition & 9 deletions examples/playground_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dotenv import load_dotenv, find_dotenv

from geminiplayground.core import GeminiPlayground, Message, ToolCall
from geminiplayground.parts import ImageFile
from geminiplayground.core import GeminiPlayground, ToolCall

load_dotenv(find_dotenv())

Expand All @@ -24,13 +23,6 @@ def write_poem() -> str:


chat = playground.start_chat(history=[])

image = ImageFile("./data/dog.jpg")
ai_message = chat.send_message(["can you describe the following image: ", image], stream=True)
for response_chunk in ai_message:
if isinstance(response_chunk, Message):
print(response_chunk.text, end="")
print()
while True:
user_input = input("You: ")
if user_input == "exit":
Expand Down
153 changes: 153 additions & 0 deletions examples/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_weaviate import WeaviateVectorStore

from geminiplayground.parts import MultimodalPart, ImageFile, AudioFile, GitRepo, PdfFile, VideoFile
from langchain_core.retrievers import BaseRetriever
from dotenv import load_dotenv, find_dotenv
from geminiplayground.rag import SummarizationLoader, AgenticToolUseRAG
from rich.console import Console
from langchain_core.vectorstores import VectorStore
import typing
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
import weaviate

# from geminiplayground.catching import cache

console = Console()

load_dotenv(find_dotenv())


class MultiModalSummarizationRetriever(BaseRetriever):
summarization_model: str
docs: typing.List[MultimodalPart]
vectorstore: VectorStore
batch_docs_size = 50
"""List of documents to retrieve from."""
k: int
"""Number of top results to return"""

def index_docs(self):
"""
Index all the documents.
"""
loader = SummarizationLoader(self.summarization_model, *self.docs)
docs = loader.load()
self.vectorstore.add_documents(docs, batch_size=self.batch_docs_size)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> typing.List[Document]:
docs, scores = zip(
*self.vectorstore.similarity_search_with_score(query, k=self.k)
)
for doc, score in zip(docs, scores):
doc.metadata["score"] = score

return docs


def create_retriever_from_multimodal_data(docs_index_name: str,
docs: typing.List[MultimodalPart]):
"""
Create a retriever for a document
"""
return MultiModalSummarizationRetriever(
docs=docs,
summarization_model="models/gemini-1.5-flash-latest",
vectorstore=WeaviateVectorStore(
client=weaviate_client,
index_name=docs_index_name,
embedding=embeddings_model,
text_key="page_content"
),
k=5
)


if __name__ == '__main__':

weaviate_client = weaviate.connect_to_embedded()
embeddings_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", task_type="retrieval_document")
chat_model = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash-latest", temperature=0.0)

retrievers = [{
"name": "media_files",
"description": "This Retriever combine a various media files, including a picture of my dog",
"retriever": create_retriever_from_multimodal_data("media_files", [
ImageFile("./../data/dog.jpg"),

])
}, {
"name": "code_files",
"description": "This Retriever contains code from karpathy's ng-video-lecture repo about transformers",
"retriever": create_retriever_from_multimodal_data("code_files", [
GitRepo.from_url(
"https://github.com/karpathy/ng-video-lecture",
branch="master",
config={
"content": "code-files"
},
)])
}, {
"name": "transformer_files",
"description": "This Retriever contains various media files, relating to transformers and language models",
"retriever": create_retriever_from_multimodal_data("pdf_files", [
VideoFile("./../data/transformers-explained.mp4"),
PdfFile("./../data/vis-language-model.pdf"),
AudioFile("./../data/audio_example.mp3")
])
}]

# Index all the documents in the retrievers
for retriever in retrievers:
retriever["retriever"].index_docs()


# rag = LOTRRAG(
# chat_model=chat_model,
# retrievers_info=retrievers,
# chat_history=[]
# )

# rag = AgenticRoutingRAG(
# chat_model=chat_model,
# retrievers_info=retrievers,
# chat_history=[]
# )

@tool
def subtract(x: float, y: float) -> float:
"""Subtract 'x' from 'y'."""
return y - x


@tool
def sum(x: float, y: float) -> float:
"""Calculate the percentage difference between 'x' and 'y'."""
return x + y


rag = AgenticToolUseRAG(
chat_model=chat_model,
retrievers_info=retrievers,
custom_tools=[subtract, sum],
chat_history=[
HumanMessage(content="Hello, I am a Henry Ruiz")
])

while True:
question = input("Question: ")
if question.lower() == "exit":
print(rag.chat_history)
weaviate_client.close()
break
result = rag.invoke(question)
rag.chat_history.extend([HumanMessage(content=question), result.answer])
console.print(f"Answer: {result.answer}")
docs = result.docs
for doc in docs[:3]:
console.print(doc.page_content[:100], doc.metadata)
2 changes: 1 addition & 1 deletion examples/readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@

# Print the response
for candidate in response.candidates:
for part in candidate.content.parts:
for part in candidate._search_content_type.parts:
if part.text:
print(part.text)
Loading

0 comments on commit a3bc0cb

Please sign in to comment.