Skip to content

Commit 7223bf8

Browse files
neubigsaum7800
andauthored
Make API-based model configurable (#344)
* Make models configurable * Additional description on colab * Revert unrelated change * Remove unused argument * Remove output from notebook * Remove noqa * Fix test * Update prompt2model/utils/api_tools.py Co-authored-by: Saumya Gandhi <gandhisaumya8@gmail.com> * Update prompt2model/utils/api_tools.py Co-authored-by: Saumya Gandhi <gandhisaumya8@gmail.com> * Change test comments * Default to maximum tokens for model --------- Co-authored-by: Saumya Gandhi <gandhisaumya8@gmail.com>
1 parent 9714923 commit 7223bf8

File tree

10 files changed

+171
-14
lines changed

10 files changed

+171
-14
lines changed

prompt2model/dataset_generator/prompt_based.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from prompt2model.utils import (
2121
API_ERRORS,
2222
APIAgent,
23+
api_tools,
2324
count_tokens_from_string,
2425
get_formatted_logger,
2526
handle_api_error,
@@ -415,7 +416,7 @@ def generate_dataset_split(
415416
generated_examples: list[Example] = []
416417

417418
pbar = tqdm(total=num_examples, desc="Generating examples")
418-
chat_api = APIAgent()
419+
chat_api = api_tools.default_api_agent
419420

420421
while len(generated_examples) < num_examples:
421422
if self.max_api_calls and self.api_call_counter >= self.max_api_calls:

prompt2model/model_retriever/generate_hypothetical_document.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66

77
from prompt2model.prompt_parser import PromptSpec
8-
from prompt2model.utils import API_ERRORS, APIAgent, handle_api_error
8+
from prompt2model.utils import API_ERRORS, api_tools, handle_api_error
99

1010
PROMPT_PREFIX = '''HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
1111
@@ -427,7 +427,7 @@ def generate_hypothetical_model_description(
427427
api_call_counter = 0
428428

429429
instruction = prompt.instruction
430-
api_agent = APIAgent("gpt-3.5-turbo-16k")
430+
api_agent = api_tools.default_api_agent
431431
chatgpt_prompt = (
432432
PROMPT_PREFIX
433433
+ "\n"

prompt2model/prompt_parser/instr_parser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
construct_prompt_for_instruction_parsing,
1414
)
1515

16-
from prompt2model.utils import APIAgent, get_formatted_logger
16+
from prompt2model.utils import api_tools, get_formatted_logger
1717
from prompt2model.utils.api_tools import API_ERRORS, handle_api_error
1818

1919
logger = get_formatted_logger("PromptParser")
@@ -61,7 +61,7 @@ def extract_response(self, response: openai.Completion) -> tuple[str, str] | Non
6161
try:
6262
response_json = json.loads(response_text, strict=False)
6363
except json.decoder.JSONDecodeError:
64-
logger.warning("API response was not a valid JSON")
64+
logger.warning(f"API response was not a valid JSON: {response_text}")
6565
return None
6666

6767
required_keys = ["Instruction", "Demonstrations"]
@@ -85,7 +85,7 @@ def parse_from_prompt(self, prompt: str) -> None:
8585
"""
8686
parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt)
8787

88-
chat_api = APIAgent()
88+
chat_api = api_tools.default_api_agent
8989
last_error = None
9090
while True:
9191
self.api_call_counter += 1

prompt2model/utils/api_tools.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99

1010
import aiolimiter
11+
import litellm.utils
1112
import openai
1213
import openai.error
1314
import tiktoken
@@ -40,13 +41,25 @@
4041
class APIAgent:
4142
"""A class for accessing API-based models."""
4243

43-
def __init__(self, model_name: str = "gpt-3.5-turbo"):
44-
"""Initialize APIAgent with an API key.
44+
def __init__(
45+
self,
46+
model_name: str = "gpt-3.5-turbo",
47+
max_tokens: int | None = None,
48+
):
49+
"""Initialize APIAgent with model_name and max_tokens.
4550
4651
Args:
4752
model_name: Name fo the model to use (by default, gpt-3.5-turbo).
53+
max_tokens: The maximum number of tokens to generate. Defaults to the max
54+
value for the model if available through litellm.
4855
"""
4956
self.model_name = model_name
57+
self.max_tokens = max_tokens
58+
if max_tokens is None:
59+
try:
60+
self.max_tokens = litellm.utils.get_max_tokens(model_name)
61+
except Exception:
62+
pass
5063

5164
def generate_one_completion(
5265
self,
@@ -73,6 +86,7 @@ def generate_one_completion(
7386
An OpenAI-like response object if there were no errors in generation.
7487
In case of API-specific error, Exception object is captured and returned.
7588
"""
89+
max_tokens = self.max_tokens or 4 * count_tokens_from_string(prompt)
7690
response = completion( # completion gets the key from os.getenv
7791
model=self.model_name,
7892
messages=[
@@ -81,6 +95,7 @@ def generate_one_completion(
8195
temperature=temperature,
8296
presence_penalty=presence_penalty,
8397
frequency_penalty=frequency_penalty,
98+
max_tokens=max_tokens,
8499
)
85100
return response
86101

@@ -154,14 +169,17 @@ async def _throttled_completion_acreate(
154169
await asyncio.sleep(10)
155170
return {"choices": [{"message": {"content": ""}}]}
156171

172+
max_tokens = self.max_tokens or 4 * max(
173+
count_tokens_from_string(prompt) for prompt in prompts
174+
)
157175
async_responses = [
158176
_throttled_completion_acreate(
159-
model="gpt-3.5-turbo",
177+
model=self.model_name,
160178
messages=[
161179
{"role": "user", "content": f"{prompt}"},
162180
],
163181
temperature=temperature,
164-
max_tokens=500,
182+
max_tokens=max_tokens,
165183
n=responses_per_request,
166184
top_p=1,
167185
limiter=limiter,
@@ -205,3 +223,8 @@ def count_tokens_from_string(string: str, encoding_name: str = "cl100k_base") ->
205223
encoding = tiktoken.get_encoding(encoding_name)
206224
num_tokens = len(encoding.encode(string))
207225
return num_tokens
226+
227+
228+
# This is the default API agent that is used everywhere if a different agent is not
229+
# specified
230+
default_api_agent = APIAgent()

prompt2model_demo.ipynb

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"cell_type": "markdown",
4343
"metadata": {},
4444
"source": [
45-
"Set your OpenAI API key as an environment variable. A good way to do this is to create a `.env` file with a single line.\n",
45+
"prompt2model requires that you use a base LLM to help out with various parts of the training process. The default is OpenAI's `gpt-3.5-turbo`, but you can use any model supported by [litellm](https://github.com/BerriAI/litellm). Set the appropriate API key as an environment variable. A good way to do this is to create a `.env` file with a single line, like below if you're using OpenAI.\n",
4646
"\n",
4747
"```text\n",
4848
"OPENAI_API_KEY=<your key here>\n",
@@ -81,6 +81,24 @@
8181
"os.environ['OPENAI_API_KEY'][:3]"
8282
]
8383
},
84+
{
85+
"cell_type": "markdown",
86+
"metadata": {},
87+
"source": [
88+
"Finally, we specify the base model that we want to use here."
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"from prompt2model.utils import api_tools\n",
98+
"# CHANGE THIS if you want to try a different model\n",
99+
"api_tools.default_api_agent = api_tools.APIAgent(model_name=\"gpt-3.5-turbo\")"
100+
]
101+
},
84102
{
85103
"cell_type": "markdown",
86104
"metadata": {},

test_helpers/mock_api.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Tools for mocking API responses (for testing purposes)."""
22

3-
from __future__ import annotations # noqa FI58
3+
from __future__ import annotations
4+
5+
import openai
6+
7+
from prompt2model.utils.api_tools import APIAgent
48

59

610
class MockCompletion:
@@ -177,6 +181,38 @@ def mock_batch_api_response_identical_completions(
177181
return mock_completions
178182

179183

184+
class MockAPIAgent(APIAgent):
185+
"""A mock API agent that always returns the same content."""
186+
187+
def __init__(self, default_content):
188+
"""Initialize the API agent."""
189+
self.generate_one_call_counter = 0
190+
self.generate_batch_call_counter = 0
191+
self.default_content = default_content
192+
193+
def generate_one_completion(
194+
self,
195+
prompt: str,
196+
temperature: float = 0,
197+
presence_penalty: float = 0,
198+
frequency_penalty: float = 0,
199+
) -> openai.Completion:
200+
"""Return a mocked object and increment the counter."""
201+
self.generate_one_call_counter += 1
202+
return MockCompletion(content=self.default_content)
203+
204+
async def generate_batch_completion(
205+
self,
206+
prompts: list[str],
207+
temperature: float = 1,
208+
responses_per_request: int = 5,
209+
requests_per_minute: int = 80,
210+
) -> list[openai.Completion]:
211+
"""Return a mocked object and increment the counter."""
212+
self.generate_batch_call_counter += 1
213+
return [MockCompletion(content=self.default_content) for _ in prompts]
214+
215+
180216
class UnknownGpt3Exception(Exception):
181217
"""This is a newly-defined exception for testing purposes."""
182218

test_helpers/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Utility functions for testing."""
2+
from contextlib import contextmanager
3+
4+
5+
@contextmanager
6+
def temp_setattr(obj, attr, value):
7+
"""Temporarily set an attribute on an object."""
8+
original = getattr(obj, attr, None)
9+
setattr(obj, attr, value)
10+
try:
11+
yield
12+
finally:
13+
if original is not None:
14+
setattr(obj, attr, original)
15+
else:
16+
delattr(obj, attr)

tests/dataset_generator_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
PromptBasedDatasetGenerator,
1717
)
1818
from prompt2model.prompt_parser import MockPromptSpec, TaskType
19+
from prompt2model.utils import api_tools
1920
from test_helpers import (
2021
MockCompletion,
2122
UnknownGpt3Exception,
2223
mock_batch_api_response_identical_completions,
2324
)
24-
from test_helpers.mock_api import MockBatchDifferentCompletions
25+
from test_helpers.mock_api import MockAPIAgent, MockBatchDifferentCompletions
26+
from test_helpers.test_utils import temp_setattr
2527

2628
logger = logging.getLogger("DatasetGenerator")
2729

@@ -946,3 +948,25 @@ def test_dataset_generator_terminates(mocked_generate_example):
946948
generated_df = generated_dataset.to_pandas()
947949
assert len(generated_dataset) == 100
948950
assert list(generated_df.columns) == ["input_col", "output_col"]
951+
952+
953+
def test_generate_dataset_agent_switch():
954+
"""Test if dataset generation can use a user-set API agent."""
955+
my_agent = MockAPIAgent(
956+
default_content='{"input": "This is input.", "output": "This is an output."}'
957+
)
958+
with temp_setattr(api_tools, "default_api_agent", my_agent):
959+
prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION)
960+
dataset_generator = PromptBasedDatasetGenerator(
961+
initial_temperature=0.3,
962+
max_temperature=1.4,
963+
responses_per_request=1,
964+
max_api_calls=100,
965+
requests_per_minute=80,
966+
filter_duplicated_examples=False,
967+
)
968+
dataset_generator.generate_dataset_split(
969+
prompt_spec, 100, split=DatasetSplit.TRAIN
970+
)
971+
# 100 outputs, and each batch has 5 outputs so 20 api calls
972+
assert my_agent.generate_batch_call_counter == 20

tests/model_retriever_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111
import torch
1212

1313
from prompt2model.model_retriever import DescriptionModelRetriever
14+
from prompt2model.model_retriever.generate_hypothetical_document import (
15+
generate_hypothetical_model_description,
16+
)
1417
from prompt2model.prompt_parser import MockPromptSpec, TaskType
18+
from prompt2model.utils import api_tools
1519
from test_helpers import create_test_search_index
20+
from test_helpers.mock_api import MockAPIAgent
21+
from test_helpers.test_utils import temp_setattr
1622

1723
TINY_MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
1824

@@ -238,3 +244,12 @@ def test_retrieve_bm25_when_no_index_exists():
238244
assert top_model_names[0] == "t5-base"
239245
# Clear search index from disk.
240246
shutil.rmtree(retriever.search_index_path)
247+
248+
249+
def test_generate_hypothetical_document_agent_switch():
250+
"""Test if generate_hypothetical_document can use a user-set API agent."""
251+
my_agent = MockAPIAgent(default_content="test response")
252+
with temp_setattr(api_tools, "default_api_agent", my_agent):
253+
prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION)
254+
generate_hypothetical_model_description(prompt_spec, max_api_calls=3)
255+
assert my_agent.generate_one_call_counter == 1

tests/prompt_parser_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import pytest
99

1010
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
11+
from prompt2model.prompt_parser.mock import MockPromptSpec
12+
from prompt2model.utils import api_tools
1113
from test_helpers import MockCompletion, UnknownGpt3Exception
14+
from test_helpers.mock_api import MockAPIAgent
15+
from test_helpers.test_utils import temp_setattr
1216

1317
logger = logging.getLogger("PromptParser")
1418
GPT3_RESPONSE_WITH_DEMONSTRATIONS = MockCompletion(
@@ -116,7 +120,13 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method):
116120
prompt_spec.parse_from_prompt(prompt)
117121
mock_info.assert_not_called()
118122
warning_list = [each.args[0] for each in mock_warning.call_args_list]
119-
assert warning_list == ["API response was not a valid JSON"] * 3
123+
assert (
124+
warning_list
125+
== [
126+
'API response was not a valid JSON: {"Instruction": "A", "Demonstrations": "B}' # noqa: E501
127+
]
128+
* 3
129+
) # noqa: E501
120130
assert mocked_parsing_method.call_count == 3
121131
assert prompt_spec._instruction is None
122132
assert prompt_spec._examples is None
@@ -179,3 +189,17 @@ def test_instruction_parser_with_unexpected_error(mocked_parsing_method):
179189
# Check that we only tried calling the API once.
180190
assert mocked_parsing_method.call_count == 1
181191
gc.collect()
192+
193+
194+
def test_prompt_parser_agent_switch():
195+
"""Test if prompt parser can use a user-set API agent."""
196+
my_agent = MockAPIAgent(
197+
default_content='{"Instruction": "test response", "Demonstrations": "test response"}' # noqa: E501
198+
)
199+
with temp_setattr(api_tools, "default_api_agent", my_agent):
200+
prompt_parser = PromptBasedInstructionParser(
201+
TaskType.CLASSIFICATION, max_api_calls=3
202+
)
203+
prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION)
204+
prompt_parser.parse_from_prompt(prompt_spec)
205+
assert my_agent.generate_one_call_counter == 1

0 commit comments

Comments
 (0)