From f8b23d156bb1020ed9aae17bda99132aab39eeb1 Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Tue, 31 Dec 2024 17:39:30 -0800 Subject: [PATCH 1/6] Initial Gemini implementation --- src/ragnarok/generate/api_keys.py | 3 + src/ragnarok/generate/gemini.py | 144 ++++++++++++++++++ src/ragnarok/generate/llm.py | 3 +- src/ragnarok/generate/post_processor.py | 45 ++++++ .../generate/templates/ragnarok_templates.py | 2 + src/ragnarok/retrieve_and_generate.py | 69 +++++---- 6 files changed, 237 insertions(+), 29 deletions(-) create mode 100644 src/ragnarok/generate/gemini.py diff --git a/src/ragnarok/generate/api_keys.py b/src/ragnarok/generate/api_keys.py index de43c81..e61af91 100644 --- a/src/ragnarok/generate/api_keys.py +++ b/src/ragnarok/generate/api_keys.py @@ -28,6 +28,9 @@ def get_cohere_api_key() -> str: load_dotenv(dotenv_path=f".env.local") return os.getenv("CO_API_KEY") +def get_gemini_api_key() -> str: + load_dotenv(dotenv_path=f".env.local") + return os.getenv("GEMINI_API_KEY") def get_anyscale_api_key() -> str: load_dotenv(dotenv_path=f".env.local") diff --git a/src/ragnarok/generate/gemini.py b/src/ragnarok/generate/gemini.py new file mode 100644 index 0000000..9d6d660 --- /dev/null +++ b/src/ragnarok/generate/gemini.py @@ -0,0 +1,144 @@ +import time +from enum import Enum +from typing import Any, Dict, List, Tuple, Union + +#import openai +import os +import google.generativeai as genai + +#import tiktoken + +from ragnarok.data import RAGExecInfo, Request +from ragnarok.generate.llm import LLM, PromptMode +from ragnarok.generate.post_processor import gemini_post_processor +from ragnarok.generate.templates.ragnarok_templates import RagnarokTemplates +from ragnarok.data import CitedSentence + +class Gemini(LLM): + def __init__( + self, + model: str, + context_size: int, + citation_length: int, + prompt_mode: PromptMode = PromptMode.GEMINI, + max_output_tokens: int = 1500, + num_few_shot_examples: int = 0, + key: str = None, + ) -> None: + """ + Creates instance of the Gemini class, used to make Gemini models perform generation in RAG pipelines. + The Gemini 1.5, 1.5-flash, and 2.0-flash-experimental models are the only ones implemented so far. + + Parameters: + - model (str): The model identifier for the LLM. + - context_size (int): The maximum number of tokens that the model can handle in a single request. + - citation_length (int): The number of citations used for generation. + - prompt_mode (PromptMode, optional): Specifies the mode of prompt generation, with the default set to GEMINI. + - max_output_tokens (int, optional): Maximum number of tokens that can be generated in a single response. Defaults to 1500. + - num_few_shot_examples (int, optional): Number of few-shot learning examples to include in the prompt, allowing for + the integration of example-based learning to improve model performance. Defaults to 0, indicating no few-shot examples + by default. + - key (str, optional): A single Gemini API key. + + Raises: + - ValueError: If an unsupported prompt mode is provided or if no API key / invalid API key is supplied. + """ + + # Initialize values and check for errors in entered values + super().__init__( + model, context_size, prompt_mode, max_output_tokens, num_few_shot_examples, key, citation_length + ) + if isinstance(str(key), str): + key = key + if not key: + raise ValueError("Please provide Gemini api key to GEMINI_API_KEY env variable.") + if prompt_mode not in [ + PromptMode.GEMINI, + PromptMode.CHATQA + ]: + raise ValueError( + f"unsupported prompt mode for Gemini models: {prompt_mode}, expected one of {PromptMode.GEMINI}, {PromptMode.CHATQA}." + ) + + # Configure model parameters + genai.configure(api_key=str(key)) + generation_config = { + "temperature": 1, + "top_p": 0.95, + "top_k": 40, + "max_output_tokens": max_output_tokens, + "response_mime_type": "text/plain", + } + safety_settings = [ + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"} + ] + system_instruction = "This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context." + + # Initialize model + self.gen_model = genai.GenerativeModel( + model_name=model, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction + ) + + def run_llm( + self, + prompt: Union[str, List[Dict[str, str]]], + logging: bool = False, + ) -> Tuple[List[CitedSentence], RAGExecInfo]: + + chat_session = self.gen_model.start_chat( + history=[ + ] + ) + + response = chat_session.send_message(prompt).text + + answers, rag_exec_response = gemini_post_processor(response, self._citation_length) + + rag_exec_info = RAGExecInfo( + prompt=prompt, + response=rag_exec_response, + input_token_count=None, + output_token_count=None, + candidates=[], + ) + + if logging: + print(f"Prompt: {prompt}") + print(f"Response: {response}") + print(f"Answers: {answers}") + print(f"RAG Exec Info: {rag_exec_info}") + + return answers, rag_exec_info + + def create_prompt( + self, request: Request, topk: int + ) -> Tuple[List[Dict[str, str]], int]: + query = request.query.text + max_length = (self._context_size - 200) // topk + rank = 0 + context = [] + for cand in request.candidates[:topk]: + rank += 1 + content = self.convert_doc_to_prompt_content(cand.doc, max_length) + context.append( + f"[{rank}] {self._replace_number(content)}", + ) + + ragnarok_template = RagnarokTemplates(self._prompt_mode) + messages = ragnarok_template(query, context, "gemini") + + return messages + + def get_num_tokens(self, prompt: Union[str, List[Dict[str, str]]]) -> int: + """Placeholder function. Returns 1.""" + return 1 + + def cost_per_1k_token(self, input_token: bool) -> float: + """Placeholder function. Returns 1""" + return 1 \ No newline at end of file diff --git a/src/ragnarok/generate/llm.py b/src/ragnarok/generate/llm.py index 2e03339..113e648 100644 --- a/src/ragnarok/generate/llm.py +++ b/src/ragnarok/generate/llm.py @@ -20,6 +20,7 @@ class PromptMode(Enum): RAGNAROK_V5_BIOGEN = "ragnarok_v5_biogen" RAGNAROK_V5_BIOGEN_NO_CITE = "ragnarok_v5_biogen_no_cite" RAGNAROK_V4_NO_CITE = "ragnarok_v4_no_cite" + GEMINI = "gemini" def __str__(self): return self.value @@ -193,7 +194,7 @@ def answer_batch( initial_results.append(result) else: for request in requests: - prompt, input_token_count = self.create_prompt(request, topk) + prompt = self.create_prompt(request, topk) answer, rag_exec_summary = self.run_llm(prompt, logging) rag_exec_summary.candidates = [ candidate.__dict__ for candidate in request.candidates[:topk] diff --git a/src/ragnarok/generate/post_processor.py b/src/ragnarok/generate/post_processor.py index 6ed9c40..7665ad5 100644 --- a/src/ragnarok/generate/post_processor.py +++ b/src/ragnarok/generate/post_processor.py @@ -198,3 +198,48 @@ def __call__(self, response) -> List[Dict[str, Any]]: rag_exec_response = {"text": response, "citations": citation_range} return answers, rag_exec_response + +def gemini_post_processor(response: str, citation_length: int) -> List[Dict[str, Any]]: + # Remove all \n then split text into sentences. + text_output = response.replace("\n", "") + sentences = text_output.split(".") + + # Avoids last entry being empty, which can cause errors later. + if not sentences[-1]: + sentences.pop() + + citation_range = list(range(1, citation_length)) + + rag_exec_response = {"text": response, "citations": citation_range} + + answers = [] + for sentence in sentences: + + sentence = sentence + "." + if sentence.startswith(" "): + sentence = sentence[1:] + + citation_list = [] + p1 = sentence.find("[") + p2 = sentence.find("]") + current_citations = [] + + while p1 != -1: + + current_citations.extend(sentence[p1+1:p2].replace(" ", "").split(",")) + for citation in current_citations: + # avoid empty citations + if citation: + # avoid repeated citations + if (not int(citation) in citation_list): + citation_list.append(int(citation)) + sentence = sentence[:p1-1] + sentence[p2+1:] + + p1 = sentence.find("[") + p2 = sentence.find("]") + + answers.append(CitedSentence(text=sentence, citations=citation_list)) + + return answers, rag_exec_response + + diff --git a/src/ragnarok/generate/templates/ragnarok_templates.py b/src/ragnarok/generate/templates/ragnarok_templates.py index caa36af..33c5d16 100644 --- a/src/ragnarok/generate/templates/ragnarok_templates.py +++ b/src/ragnarok/generate/templates/ragnarok_templates.py @@ -179,6 +179,8 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]: return messages elif "chatqa" in model.lower(): prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User: {user_input_context}" + elif "gemini" in model.lower(): + prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User input context: {user_input_context}" else: conv = get_conversation_template(model) system_message = ( diff --git a/src/ragnarok/retrieve_and_generate.py b/src/ragnarok/retrieve_and_generate.py index 22e56c4..a04ac2e 100644 --- a/src/ragnarok/retrieve_and_generate.py +++ b/src/ragnarok/retrieve_and_generate.py @@ -3,10 +3,11 @@ from ragnarok.data import Query, Request # from ragnarok.evaluation.nugget_eval import EvalFunction -from ragnarok.generate.api_keys import get_azure_openai_args, get_openai_api_key +from ragnarok.generate.api_keys import get_azure_openai_args, get_openai_api_key, get_gemini_api_key from ragnarok.generate.cohere import Cohere from ragnarok.generate.generator import RAG from ragnarok.generate.gpt import SafeOpenai +from ragnarok.generate.gemini import Gemini from ragnarok.generate.llm import PromptMode from ragnarok.generate.os_llm import OSLLM from ragnarok.retrieve_and_rerank.restriever import Restriever @@ -79,6 +80,33 @@ def retrieve_and_generate( dict: The generation results in JSON format specified by the TREC 2024 RAG Track. """ + # Retrieve + Rerank + print("Calling reranker API...") + # Only DATASET mode is currently supported. + if retrieval_mode == RetrievalMode.DATASET: + if interactive: + # Calls the host_reranker API to obtain the results after first 2 stages (retrieve+rerank) + requests = [ + Restriever.from_dataset_with_prebuilt_index( + dataset_name=dataset, + retrieval_method=retrieval_method, + host_reranker=host_reranker, + host_retriever=host_retriever, + request=Request(query=Query(text=query, qid=qid)), + k=k, + ) + ] + else: + requests = Retriever.from_dataset_with_prebuilt_index( + dataset_name=dataset, + retrieval_method=retrieval_method, + k=k, + cache_input_format=CacheInputFormat.JSONL, + ) + citation_range = len(requests) + else: + raise ValueError(f"Invalid retrieval mode: {retrieval_mode}") + # Construct Generation Agent model_full_path = "" if "gpt" in generator_path: @@ -112,36 +140,21 @@ def retrieve_and_generate( device=device, num_gpus=num_gpus, ) + elif "gemini" in generator_path.lower(): + print(f"Model: {generator_path}") + api_keys = get_gemini_api_key() + agent = Gemini( + model=generator_path, + context_size=context_size, + prompt_mode=prompt_mode, + max_output_tokens=max_output_tokens, + num_few_shot_examples=num_few_shot_examples, + keys=api_keys, + citation_range=citation_range + ) else: raise ValueError(f"Unsupported model: {generator_path}") - # Retrieve + Rerank - print("Calling reranker API...") - # Only DATASET mode is currently supported. - if retrieval_mode == RetrievalMode.DATASET: - if interactive: - # Calls the host_reranker API to obtain the results after first 2 stages (retrieve+rerank) - requests = [ - Restriever.from_dataset_with_prebuilt_index( - dataset_name=dataset, - retrieval_method=retrieval_method, - host_reranker=host_reranker, - host_retriever=host_retriever, - request=Request(query=Query(text=query, qid=qid)), - k=k, - ) - ] - else: - requests = Retriever.from_dataset_with_prebuilt_index( - dataset_name=dataset, - retrieval_method=retrieval_method, - k=k, - cache_input_format=CacheInputFormat.JSONL, - ) - print() - else: - raise ValueError(f"Invalid retrieval mode: {retrieval_mode}") - # Generation print("Generating...") rag = RAG(agent=agent, run_id=run_id) From fe6e5f7adb38a34a7f241604df9156e8a1c89e10 Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Tue, 31 Dec 2024 21:07:02 -0800 Subject: [PATCH 2/6] Bug fix and cleaning code --- src/ragnarok/generate/gemini.py | 15 ++++++--------- src/ragnarok/retrieve_and_generate.py | 6 +++--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/ragnarok/generate/gemini.py b/src/ragnarok/generate/gemini.py index 9d6d660..816c28a 100644 --- a/src/ragnarok/generate/gemini.py +++ b/src/ragnarok/generate/gemini.py @@ -2,12 +2,9 @@ from enum import Enum from typing import Any, Dict, List, Tuple, Union -#import openai import os import google.generativeai as genai -#import tiktoken - from ragnarok.data import RAGExecInfo, Request from ragnarok.generate.llm import LLM, PromptMode from ragnarok.generate.post_processor import gemini_post_processor @@ -46,12 +43,12 @@ def __init__( # Initialize values and check for errors in entered values super().__init__( - model, context_size, prompt_mode, max_output_tokens, num_few_shot_examples, key, citation_length + model, context_size, prompt_mode, max_output_tokens, num_few_shot_examples ) - if isinstance(str(key), str): - key = key - if not key: - raise ValueError("Please provide Gemini api key to GEMINI_API_KEY env variable.") + self.key = str(key) + self._citation_length = citation_length + if not (key and isinstance(self.key, str)): + raise ValueError(f"Gemini api key not provided or in an invalid format. The key provided (if any) is {key}. Assign the appropriate key to the GEMINI_API_KEY env variable.") if prompt_mode not in [ PromptMode.GEMINI, PromptMode.CHATQA @@ -61,7 +58,7 @@ def __init__( ) # Configure model parameters - genai.configure(api_key=str(key)) + genai.configure(api_key=self.key) generation_config = { "temperature": 1, "top_p": 0.95, diff --git a/src/ragnarok/retrieve_and_generate.py b/src/ragnarok/retrieve_and_generate.py index a04ac2e..1704afa 100644 --- a/src/ragnarok/retrieve_and_generate.py +++ b/src/ragnarok/retrieve_and_generate.py @@ -142,15 +142,15 @@ def retrieve_and_generate( ) elif "gemini" in generator_path.lower(): print(f"Model: {generator_path}") - api_keys = get_gemini_api_key() + api_key = get_gemini_api_key() agent = Gemini( model=generator_path, context_size=context_size, prompt_mode=prompt_mode, max_output_tokens=max_output_tokens, num_few_shot_examples=num_few_shot_examples, - keys=api_keys, - citation_range=citation_range + key=api_key, + citation_length=citation_range ) else: raise ValueError(f"Unsupported model: {generator_path}") From b6a363a84689dcb7b28fba669c0d082a989ccd9b Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Fri, 3 Jan 2025 17:19:13 -0800 Subject: [PATCH 3/6] Fix for RPM limitations --- src/ragnarok/generate/gemini.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ragnarok/generate/gemini.py b/src/ragnarok/generate/gemini.py index 816c28a..ecec8c8 100644 --- a/src/ragnarok/generate/gemini.py +++ b/src/ragnarok/generate/gemini.py @@ -4,6 +4,7 @@ import os import google.generativeai as genai +import google.api_core.exceptions as exceptions from ragnarok.data import RAGExecInfo, Request from ragnarok.generate.llm import LLM, PromptMode @@ -93,7 +94,12 @@ def run_llm( ] ) - response = chat_session.send_message(prompt).text + try: + response = chat_session.send_message(prompt).text + except exceptions.ResourceExhausted: + print("rate limit error encountered, waiting 60 seconds...") + time.sleep(60) + response = chat_session.send_message(prompt).text answers, rag_exec_response = gemini_post_processor(response, self._citation_length) @@ -111,6 +117,8 @@ def run_llm( print(f"Answers: {answers}") print(f"RAG Exec Info: {rag_exec_info}") + time.sleep(0.5) + return answers, rag_exec_info def create_prompt( From 9ec9bbe17579b738609370952c9609d7d6ec04c5 Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:16:40 -0500 Subject: [PATCH 4/6] Post processor bug fix --- src/ragnarok/generate/post_processor.py | 74 ++++++++++++++++--------- 1 file changed, 49 insertions(+), 25 deletions(-) diff --git a/src/ragnarok/generate/post_processor.py b/src/ragnarok/generate/post_processor.py index 7665ad5..7bcd9fe 100644 --- a/src/ragnarok/generate/post_processor.py +++ b/src/ragnarok/generate/post_processor.py @@ -199,6 +199,53 @@ def __call__(self, response) -> List[Dict[str, Any]]: return answers, rag_exec_response +def gemini_find_citations(sentence: str, citation_range: List[int] = list(range(20)) + ) -> tuple[str, List[int]]: + # Regex pattern to find citations + pattern = re.compile(r"\[\d+\](?:,? ?)") + + # Find all citations + citations = pattern.findall(sentence) + if citations: + # Remove citations from text + sentence = pattern.sub("", sentence).strip() + + # Extract indices from citations + indices = [ + int(re.search(r"\d+", citation).group()) - 1 for citation in citations + ] + citations = [index for index in indices if index in citation_range] + else: + matches = re.findall(r"\[[^\]]*\]", sentence) + if not matches: + return sentence, [] + citations = [] + for match in matches: + citation = match[1:-1] + try: + if "," in citation: + flag = False + for cit in citation.split(","): + cit = int(cit) - 1 + if cit in citation_range: + flag = True + citations.append(int(cit)) + if flag: + sentence = sentence.replace(match, "") + else: + citation = int(citation) - 1 + if citation in citation_range: + citations.append(citation) + sentence = sentence.replace(match, "") + except: + print(f"Not a valid citation: {match}") + + sentence = re.sub(" +", " ", sentence) + if len(sentence) > 3: + if sentence[-2] == " ": + sentence = sentence[:-2] + sentence[-1] + return sentence, citations + def gemini_post_processor(response: str, citation_length: int) -> List[Dict[str, Any]]: # Remove all \n then split text into sentences. text_output = response.replace("\n", "") @@ -214,31 +261,8 @@ def gemini_post_processor(response: str, citation_length: int) -> List[Dict[str, answers = [] for sentence in sentences: - - sentence = sentence + "." - if sentence.startswith(" "): - sentence = sentence[1:] - - citation_list = [] - p1 = sentence.find("[") - p2 = sentence.find("]") - current_citations = [] - - while p1 != -1: - - current_citations.extend(sentence[p1+1:p2].replace(" ", "").split(",")) - for citation in current_citations: - # avoid empty citations - if citation: - # avoid repeated citations - if (not int(citation) in citation_list): - citation_list.append(int(citation)) - sentence = sentence[:p1-1] + sentence[p2+1:] - - p1 = sentence.find("[") - p2 = sentence.find("]") - - answers.append(CitedSentence(text=sentence, citations=citation_list)) + sentence_parsed, citations = gemini_find_citations(sentence, citation_range) + answers.append(CitedSentence(text=sentence_parsed, citations=citations)) return answers, rag_exec_response From 7746ec34f4803989e459f2da93062fb90a478b0e Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:26:07 -0500 Subject: [PATCH 5/6] Replace citation length with topk and clean up files --- src/ragnarok/generate/cohere.py | 1 + src/ragnarok/generate/gemini.py | 65 ++++++++----------- src/ragnarok/generate/gpt.py | 3 +- src/ragnarok/generate/llm.py | 2 +- src/ragnarok/generate/os_llm.py | 2 +- src/ragnarok/generate/post_processor.py | 8 +-- .../generate/templates/ragnarok_templates.py | 10 +-- src/ragnarok/retrieve_and_generate.py | 3 +- 8 files changed, 42 insertions(+), 52 deletions(-) diff --git a/src/ragnarok/generate/cohere.py b/src/ragnarok/generate/cohere.py index 96ce5a9..20c4d4b 100644 --- a/src/ragnarok/generate/cohere.py +++ b/src/ragnarok/generate/cohere.py @@ -83,6 +83,7 @@ def run_llm( self, prompt: Union[str, List[Dict[str, Any]]], logging: bool = False, + topk: int = 20, ) -> Tuple[Any, RAGExecInfo]: query, top_k_docs = prompt[0]["query"], prompt[0]["context"] if logging: diff --git a/src/ragnarok/generate/gemini.py b/src/ragnarok/generate/gemini.py index ecec8c8..af4fee4 100644 --- a/src/ragnarok/generate/gemini.py +++ b/src/ragnarok/generate/gemini.py @@ -1,69 +1,65 @@ import time from enum import Enum from typing import Any, Dict, List, Tuple, Union - import os import google.generativeai as genai import google.api_core.exceptions as exceptions - from ragnarok.data import RAGExecInfo, Request from ragnarok.generate.llm import LLM, PromptMode from ragnarok.generate.post_processor import gemini_post_processor from ragnarok.generate.templates.ragnarok_templates import RagnarokTemplates from ragnarok.data import CitedSentence - class Gemini(LLM): def __init__( self, model: str, context_size: int, - citation_length: int, prompt_mode: PromptMode = PromptMode.GEMINI, max_output_tokens: int = 1500, num_few_shot_examples: int = 0, key: str = None, ) -> None: """ - Creates instance of the Gemini class, used to make Gemini models perform generation in RAG pipelines. + Creates instance of the Gemini class, used to make Gemini models perform generation in RAG pipelines. The Gemini 1.5, 1.5-flash, and 2.0-flash-experimental models are the only ones implemented so far. - Parameters: - model (str): The model identifier for the LLM. - context_size (int): The maximum number of tokens that the model can handle in a single request. - - citation_length (int): The number of citations used for generation. - prompt_mode (PromptMode, optional): Specifies the mode of prompt generation, with the default set to GEMINI. - max_output_tokens (int, optional): Maximum number of tokens that can be generated in a single response. Defaults to 1500. - num_few_shot_examples (int, optional): Number of few-shot learning examples to include in the prompt, allowing for the integration of example-based learning to improve model performance. Defaults to 0, indicating no few-shot examples by default. - key (str, optional): A single Gemini API key. - Raises: - ValueError: If an unsupported prompt mode is provided or if no API key / invalid API key is supplied. """ - # Initialize values and check for errors in entered values super().__init__( model, context_size, prompt_mode, max_output_tokens, num_few_shot_examples ) self.key = str(key) - self._citation_length = citation_length if not (key and isinstance(self.key, str)): raise ValueError(f"Gemini api key not provided or in an invalid format. The key provided (if any) is {key}. Assign the appropriate key to the GEMINI_API_KEY env variable.") if prompt_mode not in [ - PromptMode.GEMINI, - PromptMode.CHATQA + PromptMode.CHATQA, + PromptMode.RAGNAROK_V2, + PromptMode.RAGNAROK_V3, + PromptMode.RAGNAROK_V4, + PromptMode.RAGNAROK_V4_BIOGEN, + PromptMode.RAGNAROK_V5_BIOGEN, + PromptMode.RAGNAROK_V5_BIOGEN_NO_CITE, + PromptMode.RAGNAROK_V4_NO_CITE, ]: raise ValueError( - f"unsupported prompt mode for Gemini models: {prompt_mode}, expected one of {PromptMode.GEMINI}, {PromptMode.CHATQA}." + f"unsupported prompt mode for GPT models: {prompt_mode}, expected one of {PromptMode.CHATQA}, {PromptMode.RAGNAROK_V2}, {PromptMode.RAGNAROK_V3}, {PromptMode.RAGNAROK_V4}, {PromptMode.RAGNAROK_V4_NO_CITE}." ) - # Configure model parameters genai.configure(api_key=self.key) generation_config = { - "temperature": 1, + "temperature": 0.1, "top_p": 0.95, - "top_k": 40, + "top_k": 64, "max_output_tokens": max_output_tokens, "response_mime_type": "text/plain", } @@ -73,8 +69,7 @@ def __init__( {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"} ] - system_instruction = "This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context." - + system_instruction = "This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful and detailed answers to the user's question based on the context references. The assistant should also indicate when the answer cannot be found in the context references." # Initialize model self.gen_model = genai.GenerativeModel( model_name=model, @@ -82,27 +77,29 @@ def __init__( safety_settings=safety_settings, system_instruction=system_instruction ) - def run_llm( self, prompt: Union[str, List[Dict[str, str]]], logging: bool = False, + topk: int = 1, ) -> Tuple[List[CitedSentence], RAGExecInfo]: - chat_session = self.gen_model.start_chat( history=[ ] ) - - try: - response = chat_session.send_message(prompt).text - except exceptions.ResourceExhausted: - print("rate limit error encountered, waiting 60 seconds...") - time.sleep(60) - response = chat_session.send_message(prompt).text - - answers, rag_exec_response = gemini_post_processor(response, self._citation_length) - + while True: + try: + prompt_text = prompt[-1]["content"] + response = chat_session.send_message(prompt_text) + response = response.text + break + except exceptions.ResourceExhausted: + print("rate limit error encountered, waiting 2 seconds...") + time.sleep(2) + except: + print("unknown error encountered, waiting 2 seconds...") + time.sleep(2) + answers, rag_exec_response = gemini_post_processor(response, topk) rag_exec_info = RAGExecInfo( prompt=prompt, response=rag_exec_response, @@ -110,17 +107,13 @@ def run_llm( output_token_count=None, candidates=[], ) - if logging: print(f"Prompt: {prompt}") print(f"Response: {response}") print(f"Answers: {answers}") print(f"RAG Exec Info: {rag_exec_info}") - time.sleep(0.5) - return answers, rag_exec_info - def create_prompt( self, request: Request, topk: int ) -> Tuple[List[Dict[str, str]], int]: @@ -134,16 +127,12 @@ def create_prompt( context.append( f"[{rank}] {self._replace_number(content)}", ) - ragnarok_template = RagnarokTemplates(self._prompt_mode) messages = ragnarok_template(query, context, "gemini") - return messages - def get_num_tokens(self, prompt: Union[str, List[Dict[str, str]]]) -> int: """Placeholder function. Returns 1.""" return 1 - def cost_per_1k_token(self, input_token: bool) -> float: """Placeholder function. Returns 1""" return 1 \ No newline at end of file diff --git a/src/ragnarok/generate/gpt.py b/src/ragnarok/generate/gpt.py index 03e03e7..65f40ce 100644 --- a/src/ragnarok/generate/gpt.py +++ b/src/ragnarok/generate/gpt.py @@ -138,6 +138,7 @@ def run_llm( self, prompt: Union[str, List[Dict[str, str]]], logging: bool = False, + topk: int = 20, ) -> Tuple[str, RAGExecInfo]: model_key = "model" if logging: @@ -155,7 +156,7 @@ def run_llm( encoding = tiktoken.get_encoding("cl100k_base") if logging: print(f"Response: {response}") - answers, rag_exec_response = self._post_processor(response) + answers, rag_exec_response = self._post_processor(response, topk) if logging: print(f"Answers: {answers}") rag_exec_info = RAGExecInfo( diff --git a/src/ragnarok/generate/llm.py b/src/ragnarok/generate/llm.py index 113e648..28eb0de 100644 --- a/src/ragnarok/generate/llm.py +++ b/src/ragnarok/generate/llm.py @@ -195,7 +195,7 @@ def answer_batch( else: for request in requests: prompt = self.create_prompt(request, topk) - answer, rag_exec_summary = self.run_llm(prompt, logging) + answer, rag_exec_summary = self.run_llm(prompt, logging, topk) rag_exec_summary.candidates = [ candidate.__dict__ for candidate in request.candidates[:topk] ] diff --git a/src/ragnarok/generate/os_llm.py b/src/ragnarok/generate/os_llm.py index 17dd649..0b51cb3 100644 --- a/src/ragnarok/generate/os_llm.py +++ b/src/ragnarok/generate/os_llm.py @@ -133,7 +133,7 @@ def run_llm_batched( assert False, "Failed run_llm_batched" def run_llm( - self, prompt: str, logging: bool = False, vllm: bool = True + self, prompt: str, logging: bool = False, vllm: bool = True, topk: int = 20, ) -> Tuple[str, int]: if logging: print(f"Prompt: {prompt}") diff --git a/src/ragnarok/generate/post_processor.py b/src/ragnarok/generate/post_processor.py index 7bcd9fe..073e665 100644 --- a/src/ragnarok/generate/post_processor.py +++ b/src/ragnarok/generate/post_processor.py @@ -181,14 +181,14 @@ def _find_sentence_citations( sentence = sentence[:-2] + sentence[-1] return sentence, citations - def __call__(self, response) -> List[Dict[str, Any]]: + def __call__(self, response, topk) -> List[Dict[str, Any]]: text_output = response # Remove all \nNote: and \nReferences: from the text text_output = re.sub(r"\nNote:.*", "", text_output) text_output = re.sub(r"\nReferences:.*", "", text_output) sentences = self.tokenizer.tokenize(text_output) answers = [] - citation_range = list(range(20)) + citation_range = list(range(1, topk)) for sentence in sentences: sentence_parsed, citations = self._find_sentence_citations( sentence, citation_range @@ -246,7 +246,7 @@ def gemini_find_citations(sentence: str, citation_range: List[int] = list(range( sentence = sentence[:-2] + sentence[-1] return sentence, citations -def gemini_post_processor(response: str, citation_length: int) -> List[Dict[str, Any]]: +def gemini_post_processor(response: str, topk: int) -> List[Dict[str, Any]]: # Remove all \n then split text into sentences. text_output = response.replace("\n", "") sentences = text_output.split(".") @@ -255,7 +255,7 @@ def gemini_post_processor(response: str, citation_length: int) -> List[Dict[str, if not sentences[-1]: sentences.pop() - citation_range = list(range(1, citation_length)) + citation_range = list(range(1, topk)) rag_exec_response = {"text": response, "citations": citation_range} diff --git a/src/ragnarok/generate/templates/ragnarok_templates.py b/src/ragnarok/generate/templates/ragnarok_templates.py index 33c5d16..d24ad9f 100644 --- a/src/ragnarok/generate/templates/ragnarok_templates.py +++ b/src/ragnarok/generate/templates/ragnarok_templates.py @@ -131,7 +131,7 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]: + self.sep + f"Instruction: {self.get_instruction()}" ) - elif "gpt" in model: + elif "gpt" in model.lower() or "gemini" in model.lower(): user_input_context = ( f"Instruction: {self.get_instruction()}" + self.sep @@ -153,7 +153,7 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]: + f"Instruction: {self.get_instruction()}" ) - if "gpt" in model: + if "gpt" in model.lower() or "gemini" in model.lower(): messages = [] system_message = ( self.system_message_gpt_no_cite @@ -179,8 +179,8 @@ def __call__(self, query: str, context: List[str], model: str) -> List[str]: return messages elif "chatqa" in model.lower(): prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User: {user_input_context}" - elif "gemini" in model.lower(): - prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User input context: {user_input_context}" + # elif "gemini" in model.lower(): + # prompt = f"{self.system_message_chatqa}{self.sep}{self.input_context.format(context=str_context)}{self.sep}User input context: {user_input_context}" else: conv = get_conversation_template(model) system_message = ( @@ -215,4 +215,4 @@ def get_instruction(self) -> str: elif self.prompt_mode == PromptMode.RAGNAROK_V5_BIOGEN: return self.instruction_ragnarok_v5_biogen else: - return self.instruction_ragnarok + return self.instruction_ragnarok \ No newline at end of file diff --git a/src/ragnarok/retrieve_and_generate.py b/src/ragnarok/retrieve_and_generate.py index 1704afa..2c03c96 100644 --- a/src/ragnarok/retrieve_and_generate.py +++ b/src/ragnarok/retrieve_and_generate.py @@ -149,8 +149,7 @@ def retrieve_and_generate( prompt_mode=prompt_mode, max_output_tokens=max_output_tokens, num_few_shot_examples=num_few_shot_examples, - key=api_key, - citation_length=citation_range + key=api_key ) else: raise ValueError(f"Unsupported model: {generator_path}") From a0eff64ed803bfe5560fcee5a0fc14175064a619 Mon Sep 17 00:00:00 2001 From: Yuv-sue1005 <168255174+Yuv-sue1005@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:30:10 -0500 Subject: [PATCH 6/6] README updated --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a4de870..c8ecdd2 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ Most LLMs supported by VLLM/FastChat should additionally be supported by Ragnar | command-r | `command-r` | | Llama-3 8B Instruct | `meta-llama/Meta-Llama-3-8B-Instruct` | | Llama3-ChatQA-1.5 | `nvidia/Llama3-ChatQA-1.5` | +| Gemini | `gemini-2.0-flash-thinking-exp` | ## ✨ References