Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Gemini to Ragnarok #11

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Most LLMs supported by VLLM/FastChat should additionally be supported by Ragnar
| command-r | `command-r` |
| Llama-3 8B Instruct | `meta-llama/Meta-Llama-3-8B-Instruct` |
| Llama3-ChatQA-1.5 | `nvidia/Llama3-ChatQA-1.5` |
| Gemini | `gemini-2.0-flash-thinking-exp` |


## ✨ References
Expand Down
3 changes: 3 additions & 0 deletions src/ragnarok/generate/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def get_cohere_api_key() -> str:
load_dotenv(dotenv_path=f".env.local")
return os.getenv("CO_API_KEY")

def get_gemini_api_key() -> str:
load_dotenv(dotenv_path=f".env.local")
return os.getenv("GEMINI_API_KEY")

def get_anyscale_api_key() -> str:
load_dotenv(dotenv_path=f".env.local")
Expand Down
1 change: 1 addition & 0 deletions src/ragnarok/generate/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def run_llm(
self,
prompt: Union[str, List[Dict[str, Any]]],
logging: bool = False,
topk: int = 20,
) -> Tuple[Any, RAGExecInfo]:
query, top_k_docs = prompt[0]["query"], prompt[0]["context"]
if logging:
Expand Down
138 changes: 138 additions & 0 deletions src/ragnarok/generate/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import time
from enum import Enum
from typing import Any, Dict, List, Tuple, Union
import os
import google.generativeai as genai
import google.api_core.exceptions as exceptions
from ragnarok.data import RAGExecInfo, Request
from ragnarok.generate.llm import LLM, PromptMode
from ragnarok.generate.post_processor import gemini_post_processor
from ragnarok.generate.templates.ragnarok_templates import RagnarokTemplates
from ragnarok.data import CitedSentence
class Gemini(LLM):
def __init__(
self,
model: str,
context_size: int,
prompt_mode: PromptMode = PromptMode.GEMINI,
max_output_tokens: int = 1500,
num_few_shot_examples: int = 0,
key: str = None,
) -> None:
"""
Creates instance of the Gemini class, used to make Gemini models perform generation in RAG pipelines.
The Gemini 1.5, 1.5-flash, and 2.0-flash-experimental models are the only ones implemented so far.
Parameters:
- model (str): The model identifier for the LLM.
- context_size (int): The maximum number of tokens that the model can handle in a single request.
- prompt_mode (PromptMode, optional): Specifies the mode of prompt generation, with the default set to GEMINI.
- max_output_tokens (int, optional): Maximum number of tokens that can be generated in a single response. Defaults to 1500.
- num_few_shot_examples (int, optional): Number of few-shot learning examples to include in the prompt, allowing for
the integration of example-based learning to improve model performance. Defaults to 0, indicating no few-shot examples
by default.
- key (str, optional): A single Gemini API key.
Raises:
- ValueError: If an unsupported prompt mode is provided or if no API key / invalid API key is supplied.
"""
# Initialize values and check for errors in entered values
super().__init__(
model, context_size, prompt_mode, max_output_tokens, num_few_shot_examples
)
self.key = str(key)
if not (key and isinstance(self.key, str)):
raise ValueError(f"Gemini api key not provided or in an invalid format. The key provided (if any) is {key}. Assign the appropriate key to the GEMINI_API_KEY env variable.")
if prompt_mode not in [
PromptMode.CHATQA,
PromptMode.RAGNAROK_V2,
PromptMode.RAGNAROK_V3,
PromptMode.RAGNAROK_V4,
PromptMode.RAGNAROK_V4_BIOGEN,
PromptMode.RAGNAROK_V5_BIOGEN,
PromptMode.RAGNAROK_V5_BIOGEN_NO_CITE,
PromptMode.RAGNAROK_V4_NO_CITE,
]:
raise ValueError(
f"unsupported prompt mode for GPT models: {prompt_mode}, expected one of {PromptMode.CHATQA}, {PromptMode.RAGNAROK_V2}, {PromptMode.RAGNAROK_V3}, {PromptMode.RAGNAROK_V4}, {PromptMode.RAGNAROK_V4_NO_CITE}."
)
# Configure model parameters
genai.configure(api_key=self.key)
generation_config = {
"temperature": 0.1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": max_output_tokens,
"response_mime_type": "text/plain",
}
safety_settings = [
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
]
system_instruction = "This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful and detailed answers to the user's question based on the context references. The assistant should also indicate when the answer cannot be found in the context references."
# Initialize model
self.gen_model = genai.GenerativeModel(
model_name=model,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction
)
def run_llm(
self,
prompt: Union[str, List[Dict[str, str]]],
logging: bool = False,
topk: int = 1,
) -> Tuple[List[CitedSentence], RAGExecInfo]:
chat_session = self.gen_model.start_chat(
history=[
]
)
while True:
try:
prompt_text = prompt[-1]["content"]
response = chat_session.send_message(prompt_text)
response = response.text
break
except exceptions.ResourceExhausted:
print("rate limit error encountered, waiting 2 seconds...")
time.sleep(2)
except:
print("unknown error encountered, waiting 2 seconds...")
time.sleep(2)
answers, rag_exec_response = gemini_post_processor(response, topk)
rag_exec_info = RAGExecInfo(
prompt=prompt,
response=rag_exec_response,
input_token_count=None,
output_token_count=None,
candidates=[],
)
if logging:
print(f"Prompt: {prompt}")
print(f"Response: {response}")
print(f"Answers: {answers}")
print(f"RAG Exec Info: {rag_exec_info}")
time.sleep(0.5)
return answers, rag_exec_info
def create_prompt(
self, request: Request, topk: int
) -> Tuple[List[Dict[str, str]], int]:
query = request.query.text
max_length = (self._context_size - 200) // topk
rank = 0
context = []
for cand in request.candidates[:topk]:
rank += 1
content = self.convert_doc_to_prompt_content(cand.doc, max_length)
context.append(
f"[{rank}] {self._replace_number(content)}",
)
ragnarok_template = RagnarokTemplates(self._prompt_mode)
messages = ragnarok_template(query, context, "gemini")
return messages
def get_num_tokens(self, prompt: Union[str, List[Dict[str, str]]]) -> int:
"""Placeholder function. Returns 1."""
return 1
def cost_per_1k_token(self, input_token: bool) -> float:
"""Placeholder function. Returns 1"""
return 1
3 changes: 2 additions & 1 deletion src/ragnarok/generate/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def run_llm(
self,
prompt: Union[str, List[Dict[str, str]]],
logging: bool = False,
topk: int = 20,
) -> Tuple[str, RAGExecInfo]:
model_key = "model"
if logging:
Expand All @@ -155,7 +156,7 @@ def run_llm(
encoding = tiktoken.get_encoding("cl100k_base")
if logging:
print(f"Response: {response}")
answers, rag_exec_response = self._post_processor(response)
answers, rag_exec_response = self._post_processor(response, topk)
if logging:
print(f"Answers: {answers}")
rag_exec_info = RAGExecInfo(
Expand Down
5 changes: 3 additions & 2 deletions src/ragnarok/generate/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PromptMode(Enum):
RAGNAROK_V5_BIOGEN = "ragnarok_v5_biogen"
RAGNAROK_V5_BIOGEN_NO_CITE = "ragnarok_v5_biogen_no_cite"
RAGNAROK_V4_NO_CITE = "ragnarok_v4_no_cite"
GEMINI = "gemini"

def __str__(self):
return self.value
Expand Down Expand Up @@ -193,8 +194,8 @@ def answer_batch(
initial_results.append(result)
else:
for request in requests:
prompt, input_token_count = self.create_prompt(request, topk)
answer, rag_exec_summary = self.run_llm(prompt, logging)
prompt = self.create_prompt(request, topk)
answer, rag_exec_summary = self.run_llm(prompt, logging, topk)
rag_exec_summary.candidates = [
candidate.__dict__ for candidate in request.candidates[:topk]
]
Expand Down
2 changes: 1 addition & 1 deletion src/ragnarok/generate/os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run_llm_batched(
assert False, "Failed run_llm_batched"

def run_llm(
self, prompt: str, logging: bool = False, vllm: bool = True
self, prompt: str, logging: bool = False, vllm: bool = True, topk: int = 20,
) -> Tuple[str, int]:
if logging:
print(f"Prompt: {prompt}")
Expand Down
73 changes: 71 additions & 2 deletions src/ragnarok/generate/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ def _find_sentence_citations(
sentence = sentence[:-2] + sentence[-1]
return sentence, citations

def __call__(self, response) -> List[Dict[str, Any]]:
def __call__(self, response, topk) -> List[Dict[str, Any]]:
text_output = response
# Remove all \nNote: and \nReferences: from the text
text_output = re.sub(r"\nNote:.*", "", text_output)
text_output = re.sub(r"\nReferences:.*", "", text_output)
sentences = self.tokenizer.tokenize(text_output)
answers = []
citation_range = list(range(20))
citation_range = list(range(1, topk))
for sentence in sentences:
sentence_parsed, citations = self._find_sentence_citations(
sentence, citation_range
Expand All @@ -198,3 +198,72 @@ def __call__(self, response) -> List[Dict[str, Any]]:
rag_exec_response = {"text": response, "citations": citation_range}

return answers, rag_exec_response

def gemini_find_citations(sentence: str, citation_range: List[int] = list(range(20))
) -> tuple[str, List[int]]:
# Regex pattern to find citations
pattern = re.compile(r"\[\d+\](?:,? ?)")

# Find all citations
citations = pattern.findall(sentence)
if citations:
# Remove citations from text
sentence = pattern.sub("", sentence).strip()

# Extract indices from citations
indices = [
int(re.search(r"\d+", citation).group()) - 1 for citation in citations
]
citations = [index for index in indices if index in citation_range]
else:
matches = re.findall(r"\[[^\]]*\]", sentence)
if not matches:
return sentence, []
citations = []
for match in matches:
citation = match[1:-1]
try:
if "," in citation:
flag = False
for cit in citation.split(","):
cit = int(cit) - 1
if cit in citation_range:
flag = True
citations.append(int(cit))
if flag:
sentence = sentence.replace(match, "")
else:
citation = int(citation) - 1
if citation in citation_range:
citations.append(citation)
sentence = sentence.replace(match, "")
except:
print(f"Not a valid citation: {match}")

sentence = re.sub(" +", " ", sentence)
if len(sentence) > 3:
if sentence[-2] == " ":
sentence = sentence[:-2] + sentence[-1]
return sentence, citations

def gemini_post_processor(response: str, topk: int) -> List[Dict[str, Any]]:
# Remove all \n then split text into sentences.
text_output = response.replace("\n", "")
sentences = text_output.split(".")

# Avoids last entry being empty, which can cause errors later.
if not sentences[-1]:
sentences.pop()

citation_range = list(range(1, topk))

rag_exec_response = {"text": response, "citations": citation_range}

answers = []
for sentence in sentences:
sentence_parsed, citations = gemini_find_citations(sentence, citation_range)
answers.append(CitedSentence(text=sentence_parsed, citations=citations))

return answers, rag_exec_response


8 changes: 5 additions & 3 deletions src/ragnarok/generate/templates/ragnarok_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]:
+ self.sep
+ f"Instruction: {self.get_instruction()}"
)
elif "gpt" in model:
elif "gpt" in model.lower() or "gemini" in model.lower():
user_input_context = (
f"Instruction: {self.get_instruction()}"
+ self.sep
Expand All @@ -153,7 +153,7 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]:
+ f"Instruction: {self.get_instruction()}"
)

if "gpt" in model:
if "gpt" in model.lower() or "gemini" in model.lower():
messages = []
system_message = (
self.system_message_gpt_no_cite
Expand All @@ -179,6 +179,8 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]:
return messages
elif "chatqa" in model.lower():
prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User: {user_input_context}"
# elif "gemini" in model.lower():
# prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User input context: {user_input_context}"
else:
conv = get_conversation_template(model)
system_message = (
Expand Down Expand Up @@ -213,4 +215,4 @@ def get_instruction(self) -> str:
elif self.prompt_mode == PromptMode.RAGNAROK_V5_BIOGEN:
return self.instruction_ragnarok_v5_biogen
else:
return self.instruction_ragnarok
return self.instruction_ragnarok
Loading