diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 8cb7f57..3d54190 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -6,18 +6,19 @@ from llments.lm.lm import LanguageModel from litellm import completion, batch_completion, ModelResponse -class APIBasedLM(): +class APIBasedLM(LanguageModel): """Base class for API-Based Language Models. Represents a language model that interacts with an API for generating responses. Usage: - - Instantiate this class with the model name. + - Instantiate this class with the model name and the api endpoint - Set the API key of the language model as an environment variable for secure access. Attributes: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ @abc.abstractmethod @@ -35,18 +36,19 @@ def calculate_probability(self, condition: str | None, output: str) -> float: raise NotImplementedError @abc.abstractmethod - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str, api_base: str) -> None: """Initialize the APIBasedLM instance. Args: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ self.model_name = model_name + self.api_base = api_base @abc.abstractmethod def generate( self, - message: str, condition: str | None, do_sample: bool = False, max_length: int | None = None, @@ -60,7 +62,6 @@ def generate( the generated response. Args: - message (str): The prompt for generating a response. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. @@ -71,7 +72,7 @@ def generate( num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - ModelResponse: The generated response object from the language model. + str: Sampled output sequences from the language model. """ if condition is not None: warnings.warn("A non-default value for 'condition' was provided.", UserWarning) @@ -86,7 +87,8 @@ def generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - messages=[{"content": message, "role": "user"}] + api_base = self.api_base, + messages=[{"content": condition, "role": "user"}] ) for choice in response['choices']: responses.append(choice['message']['content']) @@ -95,23 +97,32 @@ def generate( @abc.abstractmethod def chat_generate( self, - messages: list[str], - condition: str | None, + messages: list[dict[str, str]], do_sample: bool = False, max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1 - ) -> list[str]: + ) -> list[list[dict[str, str]]]: """Generate responses to multiple prompts using the batch_completion function. This method sends multiple prompts to the language model API and retrieves the generated response for each of the prompts. Args: - messages (list): The list of prompts for generating responses. - condition (str): The conditioning sequence for the output. - If None, the output is not conditioned. + messages: A list of dictionaries, each representing a message in the chat context. Each dictionary should contain the following keys: + - "role": The role of the entity sending the message. This can be "system", "user", etc. + - "content": The actual content of the message. Example: + [ + { + "role": "system", + "content": "You are a friendly chatbot", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?" + }, + ] do_sample (bool): Whether to use sampling or greedy decoding. max_length (int): The maximum length of the output sequence, (defaults to model max). @@ -120,10 +131,8 @@ def chat_generate( num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - list: List of responses generated by the language model for all the prompts. + list[list[dict[str, str]]]: list of chat contexts with the generated responses. """ - if condition is not None: - warnings.warn("A non-default value for 'condition' was provided.", UserWarning) if do_sample: warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) if max_length is not None: @@ -134,9 +143,10 @@ def chat_generate( temperature = temperature, max_tokens = max_new_tokens, n = num_return_sequences, - messages=[[{"content": content, "role": "user"}] for content in messages] + api_base = self.api_base, + messages=messages ) - return [response['choices'][0]['message']['content'] for response in responses] + return [messages + [{"role": "assistant", "content": r}] for r in responses] @abc.abstractmethod def set_seed(self, seed: int) -> None: