diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 4872fbb..eb2be5f 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -6,19 +6,19 @@ from llments.lm.lm import LanguageModel from litellm import completion, batch_completion, ModelResponse - -class APIBasedLM: +class APIBasedLM(LanguageModel): """Base class for API-Based Language Models. Represents a language model that interacts with an API for generating responses. Usage: - - Instantiate this class with the model name. + - Instantiate this class with the model name and the api endpoint - Set the API key of the language model as an environment variable for secure access. Attributes: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ @abc.abstractmethod @@ -36,126 +36,118 @@ def calculate_probability(self, condition: str | None, output: str) -> float: raise NotImplementedError @abc.abstractmethod - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str, api_base: str) -> None: """Initialize the APIBasedLM instance. Args: model_name (str): The name of the language model. + api_base (str): The API endpoint to call the model. """ self.model_name = model_name + self.api_base = api_base @abc.abstractmethod def generate( self, - message: str, condition: str | None, do_sample: bool = False, max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, - num_return_sequences: int = 1, - ) -> list[str]: + num_return_sequences: int = 1 + ) -> list[str]: """Generate a response based on the given prompt. This method sends a prompt to the language model API and retrieves the generated response. Args: - message (str): The prompt for generating a response. condition (str): The conditioning sequence for the output. If None, the output is not conditioned. do_sample (bool): Whether to use sampling or greedy decoding. max_length (int): The maximum length of the output sequence, (defaults to model max). max_new_tokens (float): The maximum number of tokens to generate in the chat completion. - temperature (float): The sampling temperature to be used, between 0 and 2. + temperature (float): The sampling temperature to be used, between 0 and 2. num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - ModelResponse: The generated response object from the language model. + str: Sampled output sequences from the language model. """ if condition is not None: - warnings.warn( - "A non-default value for 'condition' was provided.", UserWarning - ) + warnings.warn("A non-default value for 'condition' was provided.", UserWarning) if do_sample: - warnings.warn( - "A non-default value for 'do_sample' was provided.", UserWarning - ) + warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) if max_length is not None: - warnings.warn( - "A non-default value for 'max_length' was provided.", UserWarning - ) - + warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) + responses = [] response = completion( - model=self.model_name, - temperature=temperature, - max_tokens=max_new_tokens, - n=num_return_sequences, - messages=[{"content": message, "role": "user"}], + model = self.model_name, + temperature = temperature, + max_tokens = max_new_tokens, + n = num_return_sequences, + api_base = self.api_base, + messages=[{"content": condition, "role": "user"}] ) - for choice in response["choices"]: - responses.append(choice["message"]["content"]) + for choice in response['choices']: + responses.append(choice['message']['content']) return responses @abc.abstractmethod def chat_generate( self, - messages: list[list[dict[str, str]]], - condition: str | None = None, + messages: list[dict[str, str]], do_sample: bool = False, max_length: int | None = None, max_new_tokens: int | None = None, temperature: float = 1.0, - num_return_sequences: int = 1, - ) -> list[list[str]]: + num_return_sequences: int = 1 + ) -> list[list[dict[str, str]]]: """Generate responses to multiple prompts using the batch_completion function. This method sends multiple prompts to the language model API and retrieves the generated response for each of the prompts. Args: - messages (list of list): The list of prompts, where each prompt contains a sequence of messages - with roles (either 'system' or 'user') and their corresponding content. - condition (str): The conditioning sequence for the output. - If None, the output is not conditioned. + messages: A list of dictionaries, each representing a message in the chat context. Each dictionary should contain the following keys: + - "role": The role of the entity sending the message. This can be "system", "user", etc. + - "content": The actual content of the message. Example: + [ + { + "role": "system", + "content": "You are a friendly chatbot", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?" + }, + ] do_sample (bool): Whether to use sampling or greedy decoding. max_length (int): The maximum length of the output sequence, (defaults to model max). max_new_tokens (float): The maximum number of tokens to generate in the chat completion. - temperature (float): The sampling temperature to be used, between 0 and 2. + temperature (float): The sampling temperature to be used, between 0 and 2. num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - list: List of responses generated by the language model for all the prompts. + list[list[dict[str, str]]]: list of chat contexts with the generated responses. """ - if condition is not None: - warnings.warn( - "A non-default value for 'condition' was provided.", UserWarning - ) if do_sample: - warnings.warn( - "A non-default value for 'do_sample' was provided.", UserWarning - ) + warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) if max_length is not None: - warnings.warn( - "A non-default value for 'max_length' was provided.", UserWarning - ) - + warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) + responses = batch_completion( - model=self.model_name, - temperature=temperature, - max_tokens=max_new_tokens, - n=num_return_sequences, - messages=messages, + model = self.model_name, + temperature = temperature, + max_tokens = max_new_tokens, + n = num_return_sequences, + api_base = self.api_base, + messages=messages ) - - return [ - [choice["message"]["content"] for choice in response["choices"]] - for response in responses - ] - + return [messages + [{"role": "assistant", "content": r}] for r in responses] + @abc.abstractmethod def set_seed(self, seed: int) -> None: """Set the seed for the language model. @@ -163,4 +155,4 @@ def set_seed(self, seed: int) -> None: Args: seed: The seed to set for the language model. """ - raise NotImplementedError + raise NotImplementedError \ No newline at end of file