Skip to content

Commit

Permalink
Asynchronous implementation of API LM (#47)
Browse files Browse the repository at this point in the history
* Asynchronous implementation

* Inherit LanguageModel class

* LanguageModel inheritance 2

* Ruff changes

* Added generate_batch function and new input parameters

* mypy changes

* Updated input parameters

* matched the LanguageModel interface

* ruff and mypy changes

* mypy changes input arguments

* Update api.py

mypy interface changes

* Update api.py

ruff changes

* Warnings for non-default values

* ruff changes
  • Loading branch information
AakritiKinra authored Jun 6, 2024
1 parent b0fcabd commit 24adfee
Showing 1 changed file with 118 additions and 12 deletions.
130 changes: 118 additions & 12 deletions llments/lm/base/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
"""Base class for API-Based Language Models."""

import os
from litellm import completion
import abc
import warnings
from llments.lm.lm import LanguageModel
from litellm import completion, batch_completion, ModelResponse


class APIBasedLM:
class APIBasedLM():
"""Base class for API-Based Language Models.
This class represents an API-based language model that generates responses
using the API key of the model and the model name. The user sets the API Key
as an environment variable, and the model name is passed while creating
an instance of the class.
Represents a language model that interacts with an API for generating responses.
Usage:
- Instantiate this class with the model name.
- Set the API key of the language model as an environment
variable for secure access.
Attributes:
model_name (str): The name of the language model.
"""

@abc.abstractmethod
def calculate_probability(self, condition: str | None, output: str) -> float:
"""Calculate the probability of an output given the language model.
Args:
condition: The conditioning sequence for the output.
If None, the output is not conditioned.
output: The output sequence for which the probability is calculated.
Returns:
float: The probability of output x given the language model.
"""
raise NotImplementedError

@abc.abstractmethod
def __init__(self, model_name: str) -> None:
"""Initialize the APIBasedLM instance.
Expand All @@ -24,19 +43,106 @@ def __init__(self, model_name: str) -> None:
"""
self.model_name = model_name

def generate_response(self, prompt: str) -> str:
@abc.abstractmethod
def generate(
self,
message: str,
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1
) -> list[str]:
"""Generate a response based on the given prompt.
This method sends a prompt to the language model API and retrieves
the generated response.
Args:
prompt (str): The prompt for generating a response.
message (str): The prompt for generating a response.
condition (str): The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample (bool): Whether to use sampling or greedy decoding.
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
str: The generated response from the language model.
ModelResponse: The generated response object from the language model.
"""
if condition is not None:
warnings.warn("A non-default value for 'condition' was provided.", UserWarning)
if do_sample:
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
if max_length is not None:
warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

responses = []
response = completion(
model=self.model_name, messages=[{"content": prompt, "role": "user"}]
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
messages=[{"content": message, "role": "user"}]
)
return str(response["choices"][0]["message"]["content"])
for choice in response['choices']:
responses.append(choice['message']['content'])
return responses

@abc.abstractmethod
def chat_generate(
self,
messages: list[str],
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1
) -> list[str]:
"""Generate responses to multiple prompts using the batch_completion function.
This method sends multiple prompts to the language model API and retrieves
the generated response for each of the prompts.
Args:
messages (list): The list of prompts for generating responses.
condition (str): The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample (bool): Whether to use sampling or greedy decoding.
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
list: List of responses generated by the language model for all the prompts.
"""
if condition is not None:
warnings.warn("A non-default value for 'condition' was provided.", UserWarning)
if do_sample:
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
if max_length is not None:
warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

responses = batch_completion(
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
messages=[[{"content": content, "role": "user"}] for content in messages]
)
return [response['choices'][0]['message']['content'] for response in responses]

@abc.abstractmethod
def set_seed(self, seed: int) -> None:
"""Set the seed for the language model.
Args:
seed: The seed to set for the language model.
"""
raise NotImplementedError

0 comments on commit 24adfee

Please sign in to comment.