Skip to content

Commit

Permalink
fixing pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala committed Apr 2, 2024
1 parent 60051a7 commit e0c0d0f
Show file tree
Hide file tree
Showing 16 changed files with 348 additions and 125,408 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
huggingface_data/huggingface_datasets/reranking_dataset_index.json
huggingface_data/huggingface_models/
retrieved_dataset_dict/
result/
checkpoint/
status.yaml

dump.txt
# Outputs generated by the colab demo
trained_model/
trained_tokenizer/

This file was deleted.

This file was deleted.

124,960 changes: 0 additions & 124,960 deletions examples/huggingface_data/huggingface_datasets/reranking_dataset_index_v2.json

This file was deleted.

20 changes: 10 additions & 10 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""An dual-encoder dataset retriever using HuggingFace dataset descriptions."""

from __future__ import annotations
from __future__ import annotations # noqa FI58

import asyncio # noqa FI58
import json
import os
import random
Expand Down Expand Up @@ -264,7 +263,9 @@ def automatic_column_selection(
input_columns = response["input"]
output_column = response["output"]
if len(input_columns) < 1 or len(output_column) != 1:
raise RuntimeError(f"Incorrect number of cols: {input_columns}, {output_column} ") # noqa: E501
raise RuntimeError(
f"Incorrect number of cols: {input_columns}, {output_column} "
) # noqa: E501

dataset_columns = dataset_columns
incorrect_columns = [
Expand Down Expand Up @@ -459,8 +460,7 @@ def get_rerank_with_highest_votes(self, prompt, infos_dict):
logger.warning("LLM hallucinated dataset/config name: %s", curr_name)
voting.append(None)
continue
voting.append(curr_name)

voting.append(curr_name["name"])
if len(voting) == 0:
logger.warning("Voting resulted in no dataset/config.")
return None
Expand All @@ -482,7 +482,7 @@ def rerank_datasets(
recommendation.
Args:
dataset_list: A list of dataset names to be reranked.
datasets_info_dict: The datasets to be considered
prompt_spec: An object containing the prompt specification,
ncluding instruction and examples, used for reranking datasets.
Expand All @@ -494,10 +494,10 @@ def rerank_datasets(
dataset_selection_prompt = construct_prompt_for_dataset_reranking(
prompt_spec.instruction, prompt_spec.examples, datasets_info_dict
)

dataset_name = self.get_rerank_with_highest_votes(
dataset_selection_prompt, datasets_info_dict
prompt=dataset_selection_prompt, infos_dict=datasets_info_dict
)

if dataset_name is None:
return None, None

Expand All @@ -521,7 +521,7 @@ def rerank_datasets(
config_name = self.get_rerank_with_highest_votes(
config_selection_prompt, curr_dataset["configs"]
)

logger.info(f"Chosen dataset and config: {dataset_name=} {config_name=}")
return dataset_name, config_name

def canonicalize_dataset_automatically(
Expand Down Expand Up @@ -731,6 +731,6 @@ def retrieve_dataset_dict(
or None if there are no relevant datasets.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)
logger.info(f"Top datasets retrieved. Top datasets: {sorted_list}")
logger.info("Top datasets retrieved.")

return self.create_dataset(prompt_spec, sorted_list)
12 changes: 6 additions & 6 deletions prompt2model/dataset_retriever/reranking_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ def build_configs_prompt(instruction: str, examples: str, dataset_info: dict):
str: The input prompt for dataset retrieval.
"""
configs_string = ""
for j, config in dataset_info["configs"].items():
for j, (config_name, config_info) in enumerate(dataset_info["configs"].items()):
configs_string += f"""{CONFIG_TEMPLATE.format(
counter = chr(ord('a')+j),
config_name = config["config_name"],
dataset_columns = config["columns"],
sample_row = config["sample_row"]
config_name = config_name,
dataset_columns = config_info["columns"],
sample_row = config_info["sample_row"]
)}\n""" # noqa: E501

input_prompt = INPUT_PROMPT_CONFIG_TEMPLATE.format(
instruction=instruction,
examples=examples,
dataset_name=config["dataset_name"],
dataset_description=config["dataset_description"],
dataset_name=dataset_info["dataset_name"],
dataset_description=dataset_info["description"],
configs=configs_string,
num=len(dataset_info["configs"]),
)
Expand Down
23 changes: 10 additions & 13 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class PromptBasedDatasetTransformer(DatasetTransformer):

def __init__(
self,
num_points_to_transform: int,
max_allowed_failed_transforms: int,
num_points_to_transform: int = 10,
max_allowed_failed_transforms: int = 3,
plan_prompt_fn: Callable[
[str, str, list[dict]], str
] = construct_prompt_for_plan,
Expand Down Expand Up @@ -122,7 +122,7 @@ def process_responses(
) -> tuple[list[str], list[str]]:
"""Process the responses received from the API.
Also write the current set of inputs and outputs to a file.
Also write the current set of inputs and outputs to a dump text just in case.
Args:
responses: A list of response strings from the API.
Expand All @@ -134,26 +134,23 @@ def process_responses(
- outputs: A list of transformed output strings.
"""
inputs, outputs = [], []
counter = 0
show_sample_flag = True
for response in responses:
try:
extraction = find_and_parse_json(response, ["input", "output"], [])
if extraction is not None:
if extraction["input"] is None or extraction["output"] is None:
raise ValueError("Input or output is None")
input = str(extraction["input"]).strip()
output = str(extraction["output"]).strip()
if input in prompt_spec.examples:
raise ValueError("Repeated Task Examples from prompt")

str1 = str("Q: " + input + "\nA:")
str2 = str(extraction["output"]).strip()

inputs.append(str1)
outputs.append(str2)
if counter < 2:
logger.info(f"inputs\n{str1}\n\nouputs\n{str2}")
counter += 1
counter += 1
inputs.append(input)
outputs.append(output)
if show_sample_flag:
logger.info(f"inputs\n{input}\n\nouputs\n{output}")
show_sample_flag = False

except Exception as e:
logger.error(f"Error extracting from response: {e}")
Expand Down
4 changes: 3 additions & 1 deletion prompt2model/dataset_transformer/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def truncate_row(example_row: dict, max_length=200) -> str:


def construct_prompt_for_plan(
task_description: str, example: str, dataset: list[dict], num_rows: int = 5
task_description: str, example: str, dataset: list[dict], num_rows: int = None
) -> str:
"""Construct prompt for plan.
Expand All @@ -413,6 +413,8 @@ def construct_prompt_for_plan(
Returns:
str: Prompt for creating plan. Plan will be used for dataset transformation
"""
if not num_rows:
num_rows = min(len(dataset), 5)
incontext_tasks = [VITAMINC] # using one is enough for now
incontext_examples = []

Expand Down
19 changes: 8 additions & 11 deletions prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@ def parse_from_prompt(self, prompt: str) -> None:
"""
parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt)
required_keys = ["Instruction", "Demonstrations"]
try:
extraction = parse_prompt_to_fields(
parsing_prompt_for_chatgpt,
required_keys,
max_api_calls=self.max_api_calls,
)
self._instruction = extraction["Instruction"]
self._examples = extraction["Demonstrations"]
except Exception as e:
print(e)
extraction = None

extraction = parse_prompt_to_fields(
parsing_prompt_for_chatgpt,
required_keys,
max_api_calls=self.max_api_calls,
)
self._instruction = extraction["Instruction"]
self._examples = extraction["Demonstrations"]

def set_instruction_and_examples(
self, instruction: str = "", examples: str = ""
Expand Down
9 changes: 5 additions & 4 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import json
import logging
import re
import time

import aiolimiter
Expand Down Expand Up @@ -219,14 +220,15 @@ async def _throttled_completion_acreate(


def handle_api_error(e, backoff_duration=1) -> None:
"""Handles API errors raised during API calls.
"""Handle OpenAI errors or related errors that the API may raise.
Sleeps incase error is some type of timeout, else throws error.
Args:
e: The API error raised.
backoff_duration: The duration to wait before retrying the API call.
Raises:
openai.error.OpenAIError: If the error is not an instance of OpenAIError.
e: If the error is not an instance of APIError, Timeout, or RateLimitError.
Returns:
Expand All @@ -240,7 +242,6 @@ def handle_api_error(e, backoff_duration=1) -> None:
if isinstance(
e, (openai.error.APIError, openai.error.Timeout, openai.error.RateLimitError)
):
import re

match = re.search(r"Please retry after (\d+) seconds", str(e))
# If openai mentions how long to sleep use that, else do exponential backoff
Expand Down Expand Up @@ -269,4 +270,4 @@ def count_tokens_from_string(string: str, encoding_name: str = "cl100k_base") ->

# This is the default API agent that is used everywhere if a different agent is not
# specified
default_api_agent = APIAgent(max_tokens=4000)
default_api_agent = APIAgent(max_tokens=4000)
16 changes: 10 additions & 6 deletions prompt2model/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ def get_formatted_logger(logger_name: str):
A logger object.
"""
logger = logging.getLogger(logger_name)
ch = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
# Check if the logger already has a StreamHandler to prevent adding another
if not any(
isinstance(handler, logging.StreamHandler) for handler in logger.handlers
):
ch = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
15 changes: 9 additions & 6 deletions prompt2model/utils/parse_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def find_and_parse_json(
final response as a Dictionary
Else returns None.
"""
if type(response) != str and "choices" in response:
if type(response) != str and hasattr(response, "choices"):
response = response.choices[0]["message"]["content"]
correct_json = find_rightmost_brackets(response)

Expand Down Expand Up @@ -88,17 +88,18 @@ def parse_dataset_config_responses(response: openai.ChatCompletion) -> dict:
Returns:
dict: The extracted relevant information from the dataset configuration.
"""
if "choices" in response:
response_str = response["choices"][0]["message"]["content"]
if type(response) != str and hasattr(response, "choices"):
response_str = response.choices[0]["message"]["content"]
else:
response_str = response

pattern = r"\*\*(.*?)\*\*"

match = re.search(pattern, response_str)
dataset_config = ""
if match:
dataset_config = match.group(1)
elif len(response_str.split()) > 1:
elif len(response_str.split()) >= 1:
dataset_config = response_str.split()[-1].replace(".", "")
return {"name": dataset_config}

Expand Down Expand Up @@ -131,7 +132,6 @@ def parse_prompt_to_fields(
Raises:
ValueError: If max_api_calls is not greater than 0.
RuntimeError: If the maximum number of API calls is reached.
Other exceptions as appropriate for other error conditions.
"""
chat_api = api_tools.default_api_agent
Expand All @@ -146,6 +146,9 @@ def parse_prompt_to_fields(
response: openai.ChatCompletion | Exception = (
chat_api.generate_one_completion(
prompt,
temperature=0.01,
presence_penalty=0,
frequency_penalty=0,
)
)
extraction: dict[str, Any] | None = None
Expand Down Expand Up @@ -188,7 +191,7 @@ def make_single_api_request(prompt: str, max_api_calls: int = 10) -> str:
api_call_counter += 1
try:
response: openai.ChatCompletion = chat_api.generate_one_completion(
prompt=prompt, temperature=0, presence_penalty=0, frequency_penalty=0
prompt=prompt, temperature=0.01, presence_penalty=0, frequency_penalty=0
)
if response is not None:
return response.choices[0]["message"]["content"]
Expand Down
2 changes: 2 additions & 0 deletions prompt2model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def main():
line_print("Prompt parsed.")

if propmt_has_been_parsed and not dataset_has_been_retrieved:
retriever_logger = get_formatted_logger("DescriptionDatasetRetriever")
retriever_logger.setLevel(logging.INFO)
prompt_spec = MockPromptSpec(
TaskType.TEXT_GENERATION, status["instruction"], status["examples"]
)
Expand Down
Loading

0 comments on commit e0c0d0f

Please sign in to comment.