diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index ee4373a..8cb7f57 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -1,21 +1,40 @@ """Base class for API-Based Language Models.""" import os -from litellm import completion +import abc +import warnings +from llments.lm.lm import LanguageModel +from litellm import completion, batch_completion, ModelResponse - -class APIBasedLM: +class APIBasedLM(): """Base class for API-Based Language Models. - This class represents an API-based language model that generates responses - using the API key of the model and the model name. The user sets the API Key - as an environment variable, and the model name is passed while creating - an instance of the class. + Represents a language model that interacts with an API for generating responses. + + Usage: + - Instantiate this class with the model name. + - Set the API key of the language model as an environment + variable for secure access. Attributes: model_name (str): The name of the language model. """ + @abc.abstractmethod + def calculate_probability(self, condition: str | None, output: str) -> float: + """Calculate the probability of an output given the language model. + + Args: + condition: The conditioning sequence for the output. + If None, the output is not conditioned. + output: The output sequence for which the probability is calculated. + + Returns: + float: The probability of output x given the language model. + """ + raise NotImplementedError + + @abc.abstractmethod def __init__(self, model_name: str) -> None: """Initialize the APIBasedLM instance. @@ -24,19 +43,106 @@ def __init__(self, model_name: str) -> None: """ self.model_name = model_name - def generate_response(self, prompt: str) -> str: + @abc.abstractmethod + def generate( + self, + message: str, + condition: str | None, + do_sample: bool = False, + max_length: int | None = None, + max_new_tokens: int | None = None, + temperature: float = 1.0, + num_return_sequences: int = 1 + ) -> list[str]: """Generate a response based on the given prompt. This method sends a prompt to the language model API and retrieves the generated response. Args: - prompt (str): The prompt for generating a response. + message (str): The prompt for generating a response. + condition (str): The conditioning sequence for the output. + If None, the output is not conditioned. + do_sample (bool): Whether to use sampling or greedy decoding. + max_length (int): The maximum length of the output sequence, + (defaults to model max). + max_new_tokens (float): The maximum number of tokens to generate in the chat completion. + temperature (float): The sampling temperature to be used, between 0 and 2. + num_return_sequences (int): The number of chat completion choices to generate for each input message. Returns: - str: The generated response from the language model. + ModelResponse: The generated response object from the language model. """ + if condition is not None: + warnings.warn("A non-default value for 'condition' was provided.", UserWarning) + if do_sample: + warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) + if max_length is not None: + warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) + + responses = [] response = completion( - model=self.model_name, messages=[{"content": prompt, "role": "user"}] + model = self.model_name, + temperature = temperature, + max_tokens = max_new_tokens, + n = num_return_sequences, + messages=[{"content": message, "role": "user"}] ) - return str(response["choices"][0]["message"]["content"]) + for choice in response['choices']: + responses.append(choice['message']['content']) + return responses + + @abc.abstractmethod + def chat_generate( + self, + messages: list[str], + condition: str | None, + do_sample: bool = False, + max_length: int | None = None, + max_new_tokens: int | None = None, + temperature: float = 1.0, + num_return_sequences: int = 1 + ) -> list[str]: + """Generate responses to multiple prompts using the batch_completion function. + + This method sends multiple prompts to the language model API and retrieves + the generated response for each of the prompts. + + Args: + messages (list): The list of prompts for generating responses. + condition (str): The conditioning sequence for the output. + If None, the output is not conditioned. + do_sample (bool): Whether to use sampling or greedy decoding. + max_length (int): The maximum length of the output sequence, + (defaults to model max). + max_new_tokens (float): The maximum number of tokens to generate in the chat completion. + temperature (float): The sampling temperature to be used, between 0 and 2. + num_return_sequences (int): The number of chat completion choices to generate for each input message. + + Returns: + list: List of responses generated by the language model for all the prompts. + """ + if condition is not None: + warnings.warn("A non-default value for 'condition' was provided.", UserWarning) + if do_sample: + warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning) + if max_length is not None: + warnings.warn("A non-default value for 'max_length' was provided.", UserWarning) + + responses = batch_completion( + model = self.model_name, + temperature = temperature, + max_tokens = max_new_tokens, + n = num_return_sequences, + messages=[[{"content": content, "role": "user"}] for content in messages] + ) + return [response['choices'][0]['message']['content'] for response in responses] + + @abc.abstractmethod + def set_seed(self, seed: int) -> None: + """Set the seed for the language model. + + Args: + seed: The seed to set for the language model. + """ + raise NotImplementedError