From 3e6a8b8305f13f9475ca9ceaf39e0178101704c6 Mon Sep 17 00:00:00 2001 From: Aakriti Kinra <52823721+AakritiKinra@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:03:06 -0400 Subject: [PATCH 1/4] added base_url --- llments/lm/base/api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 8cb7f57..b9631fd 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -47,6 +47,7 @@ def __init__(self, model_name: str) -> None: def generate( self, message: str, + base_url: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -86,6 +87,7 @@ def generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, + base_url = base_url, messages=[{"content": message, "role": "user"}] ) for choice in response['choices']: @@ -96,6 +98,7 @@ def generate( def chat_generate( self, messages: list[str], + base_url: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -134,6 +137,7 @@ def chat_generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, + base_url = base_url, messages=[[{"content": content, "role": "user"}] for content in messages] ) return [response['choices'][0]['message']['content'] for response in responses] From b1d4d287230f29f118de7618fee5470b206d9d1c Mon Sep 17 00:00:00 2001 From: Aakriti Kinra <52823721+AakritiKinra@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:12:35 -0400 Subject: [PATCH 2/4] Updated function descriptions --- llments/lm/base/api.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index b9631fd..5eceaf2 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -47,7 +47,7 @@ def __init__(self, model_name: str) -> None: def generate( self, message: str, - base_url: str | None, + api_base: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -62,6 +62,7 @@ def generate( Args: message (str): The prompt for generating a response. + api_base (str): The API endpoint to call the model. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -87,7 +88,7 @@ def generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - base_url = base_url, + api_base = api_base, messages=[{"content": message, "role": "user"}] ) for choice in response['choices']: @@ -98,7 +99,7 @@ def generate( def chat_generate( self, messages: list[str], - base_url: str | None, + api_base: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -113,6 +114,7 @@ def chat_generate( Args: messages (list): The list of prompts for generating responses. + api_base (str): The API endpoint to call the model. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -137,7 +139,7 @@ def chat_generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - base_url = base_url, + api_base = api_base, messages=[[{"content": content, "role": "user"}] for content in messages] ) return [response['choices'][0]['message']['content'] for response in responses] From 94e7955309f12fd0c010f442fc32315ffdd1a268 Mon Sep 17 00:00:00 2001 From: Aakriti Kinra <52823721+AakritiKinra@users.noreply.github.com> Date: Mon, 16 Sep 2024 09:54:29 -0400 Subject: [PATCH 3/4] Added api_base to the constructor --- llments/lm/base/api.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 5eceaf2..e9424d0 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -6,7 +6,7 @@ from llments.lm.lm import LanguageModel from litellm import completion, batch_completion, ModelResponse -class APIBasedLM(): +class APIBasedLM(LanguageModel): """Base class for API-Based Language Models. Represents a language model that interacts with an API for generating responses. @@ -35,19 +35,20 @@ def calculate_probability(self, condition: str | None, output: str) -> float: raise NotImplementedError @abc.abstractmethod - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str, api_base: str) -> None: """Initialize the APIBasedLM instance. Args: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ self.model_name = model_name + self.api_base = api_base @abc.abstractmethod def generate( self, message: str, - api_base: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -62,7 +63,6 @@ def generate( Args: message (str): The prompt for generating a response. - api_base (str): The API endpoint to call the model. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -88,7 +88,7 @@ def generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - api_base = api_base, + api_base = self.api_base, messages=[{"content": message, "role": "user"}] ) for choice in response['choices']: @@ -99,7 +99,6 @@ def generate( def chat_generate( self, messages: list[str], - api_base: str | None, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -114,7 +113,6 @@ def chat_generate( Args: messages (list): The list of prompts for generating responses. - api_base (str): The API endpoint to call the model. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -139,7 +137,7 @@ def chat_generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - api_base = api_base, + api_base = self.api_base, messages=[[{"content": content, "role": "user"}] for content in messages] ) return [response['choices'][0]['message']['content'] for response in responses] From 4fd0ec50795c866463d887c15f11a6d11d2fab7a Mon Sep 17 00:00:00 2001 From: Aakriti Kinra <52823721+AakritiKinra@users.noreply.github.com> Date: Thu, 19 Sep 2024 21:07:43 -0400 Subject: [PATCH 4/4] matched structure with lm class --- llments/lm/base/api.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index e9424d0..3d54190 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -12,12 +12,13 @@ class APIBasedLM(LanguageModel): Represents a language model that interacts with an API for generating responses. Usage: - - Instantiate this class with the model name. + - Instantiate this class with the model name and the api endpoint - Set the API key of the language model as an environment variable for secure access. Attributes: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ @abc.abstractmethod @@ -48,7 +49,6 @@ def __init__(self, model_name: str, api_base: str) -> None: @abc.abstractmethod def generate( self, - message: str, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -62,7 +62,6 @@ def generate( the generated response. Args: - message (str): The prompt for generating a response. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -73,7 +72,7 @@ def generate( num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - ModelResponse: The generated response object from the language model. + str: Sampled output sequences from the language model. """ if condition is not None: warnings.warn("A non-default value for 'condition' was provided.", UserWarning) @@ -89,7 +88,7 @@ def generate( max_tokens = max_new_tokens, n = num_return_sequences, api_base = self.api_base, - messages=[{"content": message, "role": "user"}] + messages=[{"content": condition, "role": "user"}] ) for choice in response['choices']: responses.append(choice['message']['content']) @@ -98,23 +97,32 @@ def generate( @abc.abstractmethod def chat_generate( self, - messages: list[str], - condition: str | None, + messages: list[dict[str, str]], do_sample: bool = False, max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1 - ) -> list[str]: + ) -> list[list[dict[str, str]]]: """Generate responses to multiple prompts using the batch_completion function. This method sends multiple prompts to the language model API and retrieves the generated response for each of the prompts. Args: - messages (list): The list of prompts for generating responses. - condition (str): The conditioning sequence for the output. - If None, the output is not conditioned. + messages: A list of dictionaries, each representing a message in the chat context. Each dictionary should contain the following keys: + - "role": The role of the entity sending the message. This can be "system", "user", etc. + - "content": The actual content of the message. Example: + [ + { + "role": "system", + "content": "You are a friendly chatbot", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?" + }, + ] do_sample (bool): Whether to use sampling or greedy decoding. max_length (int): The maximum length of the output sequence, (defaults to model max). @@ -123,10 +131,8 @@ def chat_generate( num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - list: List of responses generated by the language model for all the prompts. + list[list[dict[str, str]]]: list of chat contexts with the generated responses. """ - if condition is not None: - warnings.warn("A non-default value for 'condition' was provided.", UserWarning) if do_sample: warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) if max_length is not None: @@ -138,9 +144,9 @@ def chat_generate( max_tokens = max_new_tokens, n = num_return_sequences, api_base = self.api_base, - messages=[[{"content": content, "role": "user"}] for content in messages] + messages=messages ) - return [response['choices'][0]['message']['content'] for response in responses] + return [messages + [{"role": "assistant", "content": r}] for r in responses] @abc.abstractmethod def set_seed(self, seed: int) -> None: