diff --git a/alfred/client/client.py b/alfred/client/client.py index 0da0a62..67308f1 100644 --- a/alfred/client/client.py +++ b/alfred/client/client.py @@ -80,6 +80,7 @@ def __init__( "google", "groq", "torch", + "openllm", "dummy", ], f"Invalid model type: {self.model_type}" else: @@ -99,7 +100,7 @@ def __init__( self.run = self.cache.cached_query(self.run) self.grpcClient = None - if end_point: + if end_point and model_type not in ["dummy", "openllm", ]: end_point_pieces = end_point.split(":") self.end_point_ip, self.end_point_port = ( "".join(end_point_pieces[:-1]), @@ -180,6 +181,11 @@ def __init__( from ..fm.openai import OpenAIModel self.model = OpenAIModel(self.model, **kwargs) + elif self.model_type == "openllm": + from ..fm.openllm import OpenLLMModel + + base_url = kwargs.get("base_url", end_point) + self.model = OpenLLMModel(self.model, base_url=base_url, **kwargs) elif self.model_type == "cohere": from ..fm.cohere import CohereModel diff --git a/alfred/fm/openllm.py b/alfred/fm/openllm.py new file mode 100644 index 0000000..6c83f6e --- /dev/null +++ b/alfred/fm/openllm.py @@ -0,0 +1,148 @@ +import json +import logging +from typing import Optional, List, Any, Union, Tuple + +import openai +from openai._exceptions import ( + AuthenticationError, + APIError, + APITimeoutError, + RateLimitError, + BadRequestError, + APIConnectionError, + APIStatusError, +) + +from .model import APIAccessFoundationModel +from .response import CompletionResponse, RankedResponse +from .utils import retry + +logger = logging.getLogger(__name__) + +class OpenLLMModel(APIAccessFoundationModel): + """ + A wrapper for the OpenLLM Models using OpenAI's Python package + """ + + @retry( + num_retries=3, + wait_time=0.1, + exceptions=( + AuthenticationError, + APIConnectionError, + APITimeoutError, + RateLimitError, + APIError, + BadRequestError, + APIStatusError, + ), + ) + def _api_query( + self, + query: Union[str, List, Tuple], + temperature: float = 0.0, + max_tokens: int = 64, + **kwargs: Any, + ) -> str: + """ + Run a single query through the foundation model using OpenAI's Python package + + :param query: The prompt to be used for the query + :type query: Union[str, List, Tuple] + :param temperature: The temperature of the model + :type temperature: float + :param max_tokens: The maximum number of tokens to be returned + :type max_tokens: int + :param kwargs: Additional keyword arguments + :type kwargs: Any + :return: The generated completion + :rtype: str + """ + chat = kwargs.get("chat", False) + + if chat: + messages = query if isinstance(query, list) else [{"role": "user", "content": query}] + response = self.openai_client.chat.completions.create( + model=self.model_string, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + return response.choices[0].message.content + else: + prompt = query[0]['content'] if isinstance(query, list) else query + response = self.openai_client.completions.create( + model=self.model_string, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + ) + return response.choices[0].text + + def __init__( + self, model_string: str = "", api_key: Optional[str] = None, **kwargs: Any + ): + """ + Initialize the OpenLLM API wrapper. + + :param model_string: The model to be used for generating completions. + :type model_string: str + :param api_key: The API key to be used for the OpenAI API. + :type api_key: Optional[str] + """ + self.model_string = model_string + base_url = kwargs.get("base_url", None) + api_key = api_key or "na" + self.openai_client = openai.OpenAI(base_url=base_url, api_key=api_key) + super().__init__(model_string, {"api_key": api_key, "base_url": base_url}) + + def _generate_batch( + self, + batch_instance: Union[List[str], Tuple], + **kwargs, + ) -> List[CompletionResponse]: + """ + Generate completions for a batch of prompts using the OpenAI API. + + :param batch_instance: A list of prompts for which to generate completions. + :type batch_instance: List[str] or List[Tuple] + :param kwargs: Additional keyword arguments to pass to the API. + :type kwargs: Any + :return: A list of `CompletionResponse` objects containing the generated completions. + :rtype: List[CompletionResponse] + """ + output = [] + for query in batch_instance: + output.append( + CompletionResponse(prediction=self._api_query(query, **kwargs)) + ) + return output + + def _score_batch( + self, + batch_instance: Union[List[Tuple[str, str]], List[str]], + scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n", + **kwargs, + ) -> List[RankedResponse]: + """ + Score candidates using the OpenAI API. + + :param batch_instance: A list of prompts for which to generate candidate preferences. + :type batch_instance: List[str] or List[Tuple] + :param scoring_instruction: The instruction prompt for scoring + :type scoring_instruction: str + """ + output = [] + for query in batch_instance: + _scoring_prompt = ( + scoring_instruction.replace( + "[[label_space]]", ",".join(query.candidates) + ) + + query.prompt + ) + output.append( + RankedResponse( + prediction=self._api_query(_scoring_prompt, **kwargs), scores={} + ) + ) + return output \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 73a0881..ebe555f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -32,6 +32,7 @@ A full list of `Alfred` project modules. - [Model](alfred/fm/model.md#model) - [Onnx](alfred/fm/onnx.md#onnx) - [Openai](alfred/fm/openai.md#openai) + - [Openllm](alfred/fm/openllm.md#openllm) - [Query](alfred/fm/query/index.md#query) - [CompletionQuery](alfred/fm/query/completion_query.md#completionquery) - [Query](alfred/fm/query/query.md#query) diff --git a/docs/alfred/client/client.md b/docs/alfred/client/client.md index 59b57d1..6b24c54 100644 --- a/docs/alfred/client/client.md +++ b/docs/alfred/client/client.md @@ -44,7 +44,7 @@ class Client: ### Client().__call__ -[Show source in client.py:313](../../../alfred/client/client.py#L313) +[Show source in client.py:319](../../../alfred/client/client.py#L319) __call__() function to run the model on the queries. Equivalent to run() function. @@ -71,7 +71,7 @@ def __call__( ### Client().calibrate -[Show source in client.py:329](../../../alfred/client/client.py#L329) +[Show source in client.py:335](../../../alfred/client/client.py#L335) calibrate are used to calibrate foundation models contextually given the template. A voter class may be passed to calibrate the model with a specific voter. @@ -115,7 +115,7 @@ def calibrate( ### Client().chat -[Show source in client.py:427](../../../alfred/client/client.py#L427) +[Show source in client.py:433](../../../alfred/client/client.py#L433) Chat with the model APIs. Currently, Alfred supports Chat APIs from Anthropic and OpenAI @@ -133,7 +133,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any): ... ### Client().encode -[Show source in client.py:401](../../../alfred/client/client.py#L401) +[Show source in client.py:407](../../../alfred/client/client.py#L407) embed() function to embed the queries. @@ -155,7 +155,7 @@ def encode( ### Client().generate -[Show source in client.py:272](../../../alfred/client/client.py#L272) +[Show source in client.py:278](../../../alfred/client/client.py#L278) Wrapper function to generate the response(s) from the model. (For completion) @@ -183,7 +183,7 @@ def generate( ### Client().remote_run -[Show source in client.py:246](../../../alfred/client/client.py#L246) +[Show source in client.py:252](../../../alfred/client/client.py#L252) Wrapper function for running the model on the queries thru a gRPC Server. @@ -209,7 +209,7 @@ def remote_run( ### Client().run -[Show source in client.py:226](../../../alfred/client/client.py#L226) +[Show source in client.py:232](../../../alfred/client/client.py#L232) Run the model on the queries. @@ -235,7 +235,7 @@ def run( ### Client().score -[Show source in client.py:289](../../../alfred/client/client.py#L289) +[Show source in client.py:295](../../../alfred/client/client.py#L295) Wrapper function to score the response(s) from the model. (For ranking) diff --git a/docs/alfred/fm/index.md b/docs/alfred/fm/index.md index 16305a8..a40970b 100644 --- a/docs/alfred/fm/index.md +++ b/docs/alfred/fm/index.md @@ -22,6 +22,7 @@ - [Model](./model.md) - [Onnx](./onnx.md) - [Openai](./openai.md) +- [Openllm](./openllm.md) - [Query](query/index.md) - [Remote](remote/index.md) - [Response](response/index.md)