Skip to content

Commit

Permalink
Added API Proxy functionality to API LM (#68)
Browse files Browse the repository at this point in the history
* added base_url

* Updated function descriptions

* Added api_base to the constructor

* matched structure with lm class

Pull latest changes
  • Loading branch information
AakritiKinra authored and rohanmodi2810 committed Sep 26, 2024
1 parent 2e4782b commit bac169f
Showing 1 changed file with 52 additions and 60 deletions.
112 changes: 52 additions & 60 deletions llments/lm/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from llments.lm.lm import LanguageModel
from litellm import completion, batch_completion, ModelResponse


class APIBasedLM:
class APIBasedLM(LanguageModel):
"""Base class for API-Based Language Models.
Represents a language model that interacts with an API for generating responses.
Usage:
- Instantiate this class with the model name.
- Instantiate this class with the model name and the api endpoint
- Set the API key of the language model as an environment
variable for secure access.
Attributes:
model_name (str): The name of the language model.
api_base (str): The API endpoint to call the model.
"""

@abc.abstractmethod
Expand All @@ -36,131 +36,123 @@ def calculate_probability(self, condition: str | None, output: str) -> float:
raise NotImplementedError

@abc.abstractmethod
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str, api_base: str) -> None:
"""Initialize the APIBasedLM instance.
Args:
model_name (str): The name of the language model.
api_base (str): The API endpoint to call the model.
"""
self.model_name = model_name
self.api_base = api_base

@abc.abstractmethod
def generate(
self,
message: str,
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
) -> list[str]:
num_return_sequences: int = 1
) -> list[str]:
"""Generate a response based on the given prompt.
This method sends a prompt to the language model API and retrieves
the generated response.
Args:
message (str): The prompt for generating a response.
condition (str): The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample (bool): Whether to use sampling or greedy decoding.
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
ModelResponse: The generated response object from the language model.
str: Sampled output sequences from the language model.
"""
if condition is not None:
warnings.warn(
"A non-default value for 'condition' was provided.", UserWarning
)
warnings.warn("A non-default value for 'condition' was provided.", UserWarning)
if do_sample:
warnings.warn(
"A non-default value for 'do_sample' was provided.", UserWarning
)
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
if max_length is not None:
warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)

warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

responses = []
response = completion(
model=self.model_name,
temperature=temperature,
max_tokens=max_new_tokens,
n=num_return_sequences,
messages=[{"content": message, "role": "user"}],
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
api_base = self.api_base,
messages=[{"content": condition, "role": "user"}]
)
for choice in response["choices"]:
responses.append(choice["message"]["content"])
for choice in response['choices']:
responses.append(choice['message']['content'])
return responses

@abc.abstractmethod
def chat_generate(
self,
messages: list[list[dict[str, str]]],
condition: str | None = None,
messages: list[dict[str, str]],
do_sample: bool = False,
max_length: int | None = None,
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
) -> list[list[str]]:
num_return_sequences: int = 1
) -> list[list[dict[str, str]]]:
"""Generate responses to multiple prompts using the batch_completion function.
This method sends multiple prompts to the language model API and retrieves
the generated response for each of the prompts.
Args:
messages (list of list): The list of prompts, where each prompt contains a sequence of messages
with roles (either 'system' or 'user') and their corresponding content.
condition (str): The conditioning sequence for the output.
If None, the output is not conditioned.
messages: A list of dictionaries, each representing a message in the chat context. Each dictionary should contain the following keys:
- "role": The role of the entity sending the message. This can be "system", "user", etc.
- "content": The actual content of the message. Example:
[
{
"role": "system",
"content": "You are a friendly chatbot",
},
{
"role": "user",
"content": "How many helicopters can a human eat in one sitting?"
},
]
do_sample (bool): Whether to use sampling or greedy decoding.
max_length (int): The maximum length of the output sequence,
(defaults to model max).
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
Returns:
list: List of responses generated by the language model for all the prompts.
list[list[dict[str, str]]]: list of chat contexts with the generated responses.
"""
if condition is not None:
warnings.warn(
"A non-default value for 'condition' was provided.", UserWarning
)
if do_sample:
warnings.warn(
"A non-default value for 'do_sample' was provided.", UserWarning
)
warnings.warn("A non-default value for 'do_sample' was provided.", UserWarning)
if max_length is not None:
warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)

warnings.warn("A non-default value for 'max_length' was provided.", UserWarning)

responses = batch_completion(
model=self.model_name,
temperature=temperature,
max_tokens=max_new_tokens,
n=num_return_sequences,
messages=messages,
model = self.model_name,
temperature = temperature,
max_tokens = max_new_tokens,
n = num_return_sequences,
api_base = self.api_base,
messages=messages
)

return [
[choice["message"]["content"] for choice in response["choices"]]
for response in responses
]

return [messages + [{"role": "assistant", "content": r}] for r in responses]

@abc.abstractmethod
def set_seed(self, seed: int) -> None:
"""Set the seed for the language model.
Args:
seed: The seed to set for the language model.
"""
raise NotImplementedError
raise NotImplementedError

0 comments on commit bac169f

Please sign in to comment.