From 0dec2d53a75361ccfe5b81e89991ba45ea138e1f Mon Sep 17 00:00:00 2001 From: Peilin Yu Date: Thu, 28 Dec 2023 17:14:39 -0500 Subject: [PATCH] adding prompt template-based solution for scoring candidates with API-based m odels --- alfred/fm/ai21.py | 36 +++++++++++++++++- alfred/fm/anthropic.py | 73 +++++++++++++++++++++++++---------- alfred/fm/cohere.py | 36 +++++++++++++++++- alfred/fm/google.py | 76 ++++++++++++++++++++++++------------- alfred/fm/huggingface.py | 6 ++- alfred/fm/model.py | 4 +- alfred/fm/openai.py | 31 ++++++++++++++- alfred/fm/utils.py | 5 ++- alfred/fm/vllm.py | 4 +- docs/alfred/fm/anthropic.md | 2 +- docs/alfred/fm/google.md | 43 +++++++++++++++++++++ docs/alfred/fm/model.md | 14 +++---- docs/alfred/fm/openai.md | 2 +- docs/alfred/fm/utils.md | 20 +++++----- 14 files changed, 273 insertions(+), 79 deletions(-) create mode 100644 docs/alfred/fm/google.md diff --git a/alfred/fm/ai21.py b/alfred/fm/ai21.py index 5ff4e80..69558d6 100644 --- a/alfred/fm/ai21.py +++ b/alfred/fm/ai21.py @@ -1,10 +1,10 @@ import logging -from typing import Optional, List +from typing import Optional, List, Union, Tuple import requests from .model import APIAccessFoundationModel -from .response import CompletionResponse +from .response import CompletionResponse, RankedResponse logger = logging.getLogger(__name__) @@ -113,3 +113,35 @@ def _generate_batch( ) ) return output + + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Tentative solution for scoring candidates. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._ai21_query( + _scoring_prompt, model=self.model_string, **kwargs + ), + scores={}, + ) + ) + return output diff --git a/alfred/fm/anthropic.py b/alfred/fm/anthropic.py index 1f03f7e..dac2bbf 100644 --- a/alfred/fm/anthropic.py +++ b/alfred/fm/anthropic.py @@ -7,7 +7,7 @@ import readline from .model import APIAccessFoundationModel -from .response import CompletionResponse +from .response import CompletionResponse, RankedResponse from .utils import colorize_str, type_print logger = logging.getLogger(__name__) @@ -38,12 +38,12 @@ class AnthropicModel(APIAccessFoundationModel): """ def _anthropic_query( - self, - query: Union[str, List], - temperature: float = 0.0, - max_tokens: int = 3, - model: str = "claude-instant-1", - **kwargs: Any, + self, + query: Union[str, List], + temperature: float = 0.0, + max_tokens: int = 3, + model: str = "claude-instant-1", + **kwargs: Any, ) -> str: """ Run a single query through the foundation model @@ -85,7 +85,7 @@ def _anthropic_query( return response["completion"] def __init__( - self, model_string: str = "claude-instant-1", api_key: Optional[str] = None + self, model_string: str = "claude-instant-1", api_key: Optional[str] = None ): """ Initialize the Anthropic API wrapper. @@ -100,7 +100,7 @@ def __init__( :type api_key: Optional[str] """ assert ( - model_string in ANTHROPIC_MODELS + model_string in ANTHROPIC_MODELS ), f"Model {model_string} not found. Please choose from {ANTHROPIC_MODELS}" if "ANTHROPIC_API_KEY" in os.environ: @@ -124,9 +124,9 @@ def __init__( super().__init__(model_string, {"api_key": api_key}) def _generate_batch( - self, - batch_instance: List[str], - **kwargs, + self, + batch_instance: List[str], + **kwargs, ) -> List[CompletionResponse]: """ Generate completions for a batch of prompts using the anthropic API. @@ -152,6 +152,38 @@ def _generate_batch( ) return output + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Tentative solution for scoring candidates. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._anthropic_query( + _scoring_prompt, model=self.model_string, **kwargs + ), + scores={}, + ) + ) + return output + def chat(self, **kwargs: Any): """ Launch an interactive chat session with the Anthropic API. @@ -165,9 +197,10 @@ def _feedback(feedback: str, no_newline=False, override=False): end="", ) type_print(feedback) - print("", - end="\n" if not no_newline else "", - ) + print( + "", + end="\n" if not no_newline else "", + ) model = kwargs.get("model", self.model_string) c_title = colorize_str("Alfred's Anthropic Chat", "BLUE") @@ -205,11 +238,11 @@ def _feedback(feedback: str, no_newline=False, override=False): message_log.append({"role": "user", "content": query}) response = [] for resp in self._anthropic_query( - query, - chat=True, - model=model, - temperature=temperature, - max_tokens=max_tokens, + query, + chat=True, + model=model, + temperature=temperature, + max_tokens=max_tokens, ): if resp["stop_reason"] in ["stop", "stop_sequence"]: break diff --git a/alfred/fm/cohere.py b/alfred/fm/cohere.py index 94c2cc0..4b80859 100644 --- a/alfred/fm/cohere.py +++ b/alfred/fm/cohere.py @@ -1,10 +1,10 @@ import logging -from typing import Optional, List, Any +from typing import Optional, List, Any, Union, Tuple import torch from .model import APIAccessFoundationModel -from .response import CompletionResponse +from .response import CompletionResponse, RankedResponse logger = logging.getLogger(__name__) @@ -47,6 +47,38 @@ def _cohere_query( ) return response.generations[0].text + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Tentative solution for scoring candidates. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._cohere_query( + _scoring_prompt, model=self.model_string, **kwargs + ), + scores={}, + ) + ) + return output + def _cohere_embedding_query( self, query_string: str, diff --git a/alfred/fm/google.py b/alfred/fm/google.py index 0180003..2533382 100644 --- a/alfred/fm/google.py +++ b/alfred/fm/google.py @@ -8,7 +8,7 @@ import torch from .model import APIAccessFoundationModel -from .response import CompletionResponse +from .response import CompletionResponse, RankedResponse from .utils import colorize_str, retry, type_print logger = logging.getLogger(__name__) @@ -23,17 +23,12 @@ "Google GenAI module not found. Please install google-generativeai to use the Google model." ) -GOOGLE_GENAI_MODELS = ( - "gemini-pro", -) +GOOGLE_GENAI_MODELS = ("gemini-pro",) -GOOGLE_GENAI_VISION_MODELS = ( - "gemini-pro-vision", -) +GOOGLE_GENAI_VISION_MODELS = ("gemini-pro-vision",) + +GOOGLE_GENAI_EMBEDDING_MODELS = ("embedding-001",) -GOOGLE_GENAI_EMBEDDING_MODELS = ( - "embedding-001", -) class GoogleModel(APIAccessFoundationModel): """ @@ -45,9 +40,7 @@ class GoogleModel(APIAccessFoundationModel): @retry( num_retries=3, wait_time=0.1, - exceptions=( - Exception - ), + exceptions=(Exception), ) def _google_genai_query( self, @@ -73,27 +66,25 @@ def _google_genai_query( if self.model_string in GOOGLE_GENAI_VISION_MODELS: img, prompt = query[0], query[1] if not isinstance(img, PIL.Image.Image): - raise ValueError(f"Image type {type(img)} not supported. Please use PIL.Image!") - query = [ - prompt, img - ] if len(prompt) > 0 else [img] + raise ValueError( + f"Image type {type(img)} not supported. Please use PIL.Image!" + ) + query = [prompt, img] if len(prompt) > 0 else [img] response = self.model.generate_content( query, generation_config=genai.types.GenerationConfig( candidate_count=1, - stop_sequences=['x'], + stop_sequences=["x"], max_output_tokens=max_tokens, temperature=temperature, - ) + ), ) return response.text @retry( num_retries=3, wait_time=0.1, - exceptions=( - Exception - ), + exceptions=(Exception), ) def _google_genai_embedding_query( self, @@ -114,12 +105,11 @@ def _google_genai_embedding_query( model=f"models/{self.model_string}", content=query_string, task_type="retrieval_document", - title="Embedding of single string") + title="Embedding of single string", + ) ) - def __init__( - self, model_string: str = "gemini-pro", api_key: Optional[str] = None - ): + def __init__(self, model_string: str = "gemini-pro", api_key: Optional[str] = None): """ Initialize the Google API wrapper. @@ -138,7 +128,9 @@ def __init__( raise RuntimeError("Google GenAI requires Python 3.9+") assert ( model_string - in GOOGLE_GENAI_MODELS + GOOGLE_GENAI_EMBEDDING_MODELS + GOOGLE_GENAI_VISION_MODELS + in GOOGLE_GENAI_MODELS + + GOOGLE_GENAI_EMBEDDING_MODELS + + GOOGLE_GENAI_VISION_MODELS ), ( f"Model {model_string} not found. " f"Please choose from {GOOGLE_GENAI_MODELS + GOOGLE_GENAI_EMBEDDING_MODELS + GOOGLE_GENAI_VISION_MODELS}" @@ -189,6 +181,36 @@ def _generate_batch( ) return output + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Tentative solution for scoring candidates. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._google_genai_query(_scoring_prompt, **kwargs), + scores={}, + ) + ) + return output + def _encode_batch( self, batch_instance: [List[str]], diff --git a/alfred/fm/huggingface.py b/alfred/fm/huggingface.py index 398c922..9b27fa6 100644 --- a/alfred/fm/huggingface.py +++ b/alfred/fm/huggingface.py @@ -249,7 +249,6 @@ def _score_batch( max_length=self.max_position_embeddings, ) - end_device = list(self.model.hf_device_map.values())[-1] logger.log(logging.INFO, f"Ranking {len(candidate)} instances") @@ -276,7 +275,10 @@ def _score_batch( padding=True, truncation=True, max_length=self.max_position_embeddings, - add_special_tokens=not (isinstance(self.model, LlamaPreTrainedModel) or isinstance(self.model, MistralPreTrainedModel)), + add_special_tokens=not ( + isinstance(self.model, LlamaPreTrainedModel) + or isinstance(self.model, MistralPreTrainedModel) + ), return_tensors="pt", ) candidate_token_ids = candidate_tokens.input_ids.to(end_device) diff --git a/alfred/fm/model.py b/alfred/fm/model.py index 1e0a6c8..9446bbf 100644 --- a/alfred/fm/model.py +++ b/alfred/fm/model.py @@ -206,7 +206,9 @@ def forward( logging.info(f"New batch size: {batch_size}") elif batch_policy == "dynamic": DB = DynamicBatcher( - queries, tokenizer=tokenizer, max_batch_size=int(DB.max_batch_size * 0.9) + queries, + tokenizer=tokenizer, + max_batch_size=int(DB.max_batch_size * 0.9), ) DB.limit_size = int(DB.limit_size * 0.9) batched_queries = DB.batch() diff --git a/alfred/fm/openai.py b/alfred/fm/openai.py index fbc1dab..b125508 100644 --- a/alfred/fm/openai.py +++ b/alfred/fm/openai.py @@ -8,7 +8,7 @@ import readline from .model import APIAccessFoundationModel -from .response import CompletionResponse +from .response import CompletionResponse, RankedResponse from .utils import colorize_str, retry, encode_image, type_print logger = logging.getLogger(__name__) @@ -285,6 +285,35 @@ def _encode_batch( output.append(self._openai_embedding_query(query, **kwargs)) return output + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Tentative solution for scoring candidates. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._openai_query(_scoring_prompt, **kwargs), scores={} + ) + ) + return output + def chat(self, **kwargs: Any): """ Launch an interactive chat session with the OpenAI API. diff --git a/alfred/fm/utils.py b/alfred/fm/utils.py index 9ae75c1..5e89a0d 100644 --- a/alfred/fm/utils.py +++ b/alfred/fm/utils.py @@ -166,7 +166,7 @@ def check_pkg_available(pkg_name: str) -> bool: raise ImportError(f"Please install {pkg_name} to use this feature") -def type_print(string, interval=.07, newline=False): +def type_print(string, interval=0.07, newline=False): """ Print a string word by word to simulate typing """ @@ -174,7 +174,8 @@ def type_print(string, interval=.07, newline=False): print(word, end=" ", flush=True) time.sleep(interval) print("\b", end="", flush=True) - if newline: print("") + if newline: + print("") def retry(num_retries=3, wait_time=0.1, exceptions=(Exception,)): diff --git a/alfred/fm/vllm.py b/alfred/fm/vllm.py index 7b5beae..de84e4d 100644 --- a/alfred/fm/vllm.py +++ b/alfred/fm/vllm.py @@ -22,9 +22,7 @@ class vLLMModel(LocalAccessFoundationModel): source: https://github.com/vllm-project/vllm """ - def __init__( - self, model: str, local_dir: str = None, **kwargs: Any - ): + def __init__(self, model: str, local_dir: str = None, **kwargs: Any): """ Initialize a VLLM with MultiGPU. diff --git a/docs/alfred/fm/anthropic.md b/docs/alfred/fm/anthropic.md index 1160250..8ff00cd 100644 --- a/docs/alfred/fm/anthropic.md +++ b/docs/alfred/fm/anthropic.md @@ -31,7 +31,7 @@ class AnthropicModel(APIAccessFoundationModel): ### AnthropicModel().chat -[Show source in anthropic.py:155](../../../alfred/fm/anthropic.py#L155) +[Show source in anthropic.py:187](../../../alfred/fm/anthropic.py#L187) Launch an interactive chat session with the Anthropic API. diff --git a/docs/alfred/fm/google.md b/docs/alfred/fm/google.md new file mode 100644 index 0000000..64986b8 --- /dev/null +++ b/docs/alfred/fm/google.md @@ -0,0 +1,43 @@ +# Google + +[Alfred Index](../../README.md#alfred-index) / +[Alfred](../index.md#alfred) / +[Fm](./index.md#fm) / +Google + +> Auto-generated documentation for [alfred.fm.google](../../../alfred/fm/google.py) module. + +- [Google](#google) + - [GoogleModel](#googlemodel) + - [GoogleModel().chat](#googlemodel()chat) + +## GoogleModel + +[Show source in google.py:33](../../../alfred/fm/google.py#L33) + +A wrapper for the Google API. + +This class provides a wrapper for the Google API for generating completions. + +#### Signature + +```python +class GoogleModel(APIAccessFoundationModel): + def __init__(self, model_string: str = "gemini-pro", api_key: Optional[str] = None): + ... +``` + +### GoogleModel().chat + +[Show source in google.py:246](../../../alfred/fm/google.py#L246) + +Launch an interactive chat session with the Google API. + +#### Signature + +```python +def chat(self, **kwargs: Any): + ... +``` + + diff --git a/docs/alfred/fm/model.md b/docs/alfred/fm/model.md index cc5468b..e494305 100644 --- a/docs/alfred/fm/model.md +++ b/docs/alfred/fm/model.md @@ -20,7 +20,7 @@ Model ## APIAccessFoundationModel -[Show source in model.py:380](../../../alfred/fm/model.py#L380) +[Show source in model.py:382](../../../alfred/fm/model.py#L382) #### Signature @@ -51,7 +51,7 @@ class FoundationModel(abc.ABC): ### FoundationModel().__call__ -[Show source in model.py:358](../../../alfred/fm/model.py#L358) +[Show source in model.py:360](../../../alfred/fm/model.py#L360) This function returns the output of the run function when the model is called as a function. It can be used as model(queries), @@ -84,7 +84,7 @@ def __call__( ### FoundationModel().encode -[Show source in model.py:275](../../../alfred/fm/model.py#L275) +[Show source in model.py:277](../../../alfred/fm/model.py#L277) This function is a wrapper around the forward function @@ -168,7 +168,7 @@ def forward( ### FoundationModel().generate -[Show source in model.py:224](../../../alfred/fm/model.py#L224) +[Show source in model.py:226](../../../alfred/fm/model.py#L226) This function is a wrapper around the forward function for running CompletionQuery objects through the foundation model. It returns a list @@ -205,7 +205,7 @@ def generate( ### FoundationModel().run -[Show source in model.py:306](../../../alfred/fm/model.py#L306) +[Show source in model.py:308](../../../alfred/fm/model.py#L308) This function is the main entry point for users to run queries through the foundation model. It accepts raw query content and automatically converts it into query objects. @@ -239,7 +239,7 @@ def run( ### FoundationModel().score -[Show source in model.py:249](../../../alfred/fm/model.py#L249) +[Show source in model.py:251](../../../alfred/fm/model.py#L251) This function is a wrapper around the forward function for running RankedQuery objects through the foundation model. @@ -278,7 +278,7 @@ def score( ## LocalAccessFoundationModel -[Show source in model.py:395](../../../alfred/fm/model.py#L395) +[Show source in model.py:397](../../../alfred/fm/model.py#L397) #### Signature diff --git a/docs/alfred/fm/openai.md b/docs/alfred/fm/openai.md index 622e9c2..0899723 100644 --- a/docs/alfred/fm/openai.md +++ b/docs/alfred/fm/openai.md @@ -31,7 +31,7 @@ class OpenAIModel(APIAccessFoundationModel): ### OpenAIModel().chat -[Show source in openai.py:288](../../../alfred/fm/openai.py#L288) +[Show source in openai.py:317](../../../alfred/fm/openai.py#L317) Launch an interactive chat session with the OpenAI API. diff --git a/docs/alfred/fm/utils.md b/docs/alfred/fm/utils.md index 7422ac5..7fb787c 100644 --- a/docs/alfred/fm/utils.md +++ b/docs/alfred/fm/utils.md @@ -29,7 +29,7 @@ Utils ## DynamicBatcher -[Show source in utils.py:318](../../../alfred/fm/utils.py#L318) +[Show source in utils.py:319](../../../alfred/fm/utils.py#L319) Dynamic Batching Utility Maximize GPU Utilization by batching queries of similar sizes @@ -50,7 +50,7 @@ class DynamicBatcher: ### DynamicBatcher().batch -[Show source in utils.py:447](../../../alfred/fm/utils.py#L447) +[Show source in utils.py:448](../../../alfred/fm/utils.py#L448) Batch a list of instances into a list of batches. If the instances are of different sizes, they will be sorted by size @@ -70,7 +70,7 @@ def batch(self) -> List: ### DynamicBatcher().merge_rank_response -[Show source in utils.py:363](../../../alfred/fm/utils.py#L363) +[Show source in utils.py:364](../../../alfred/fm/utils.py#L364) Merge a list of responses with raw logit into a single RankedResponse Assumption: Candidate Order is the same across all ranked queries @@ -98,7 +98,7 @@ def merge_rank_response( ### DynamicBatcher().reorder -[Show source in utils.py:406](../../../alfred/fm/utils.py#L406) +[Show source in utils.py:407](../../../alfred/fm/utils.py#L407) Reordering the responses according to the original order of the queries @@ -125,7 +125,7 @@ def reorder(self, inst: List, offset: Optional[int] = None) -> List: ## EmbeddingCache -[Show source in utils.py:241](../../../alfred/fm/utils.py#L241) +[Show source in utils.py:242](../../../alfred/fm/utils.py#L242) A simple embedding cache for VLM models @@ -139,7 +139,7 @@ class EmbeddingCache: ### EmbeddingCache().get -[Show source in utils.py:267](../../../alfred/fm/utils.py#L267) +[Show source in utils.py:268](../../../alfred/fm/utils.py#L268) Process the inputs and retrieve from the cache/embed the inputs @@ -168,7 +168,7 @@ def get( ## TokenizedBatch -[Show source in utils.py:307](../../../alfred/fm/utils.py#L307) +[Show source in utils.py:308](../../../alfred/fm/utils.py#L308) #### Signature @@ -182,7 +182,7 @@ class TokenizedBatch: ## bcolors -[Show source in utils.py:215](../../../alfred/fm/utils.py#L215) +[Show source in utils.py:216](../../../alfred/fm/utils.py#L216) #### Signature @@ -264,7 +264,7 @@ def clear_cuda_cache(): ## colorize_str -[Show source in utils.py:227](../../../alfred/fm/utils.py#L227) +[Show source in utils.py:228](../../../alfred/fm/utils.py#L228) #### Signature @@ -359,7 +359,7 @@ def reorder_array( ## retry -[Show source in utils.py:180](../../../alfred/fm/utils.py#L180) +[Show source in utils.py:181](../../../alfred/fm/utils.py#L181) A decorator to retry a function call if it raises an exception.