Skip to content

Commit

Permalink
BigBench Fixes (#8)
Browse files Browse the repository at this point in the history
* Path fixing

* Give an error

* Working on some paths

* Tweak gitignore

* Tweak ReadMe

* Working on some error messages
  • Loading branch information
riedgar-ms authored Dec 15, 2023
1 parent f43cf97 commit dcafda4
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ src/*.egg-info/*
src/promptbase/generations/*
datasets/*
*.log
*.jsonl
*.jsonl
src/promptbase/datasets/BigBench/**
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ To run evaluations, download these datasets and add them to /src/promptbase/data
- GSM8K: https://github.com/openai/grade-school-math
- MATH: https://huggingface.co/datasets/hendrycks/competition_math
- Big-Bench-Hard: https://github.com/suzgunmirac/BIG-Bench-Hard
The contents of this repo need to be put into a directory called `BigBench` in the `datasets` directory

## Other Resources:

Expand Down
2 changes: 2 additions & 0 deletions src/promptbase/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def main():
elif args.dataset == "bigbench":
bigbench.generate()
bigbench.evaluate()
else:
raise ValueError(f"Bad dataset: {args.dataset}")


if __name__ == "__main__":
Expand Down
77 changes: 57 additions & 20 deletions src/promptbase/bigbench/bigbench_cot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
import logging

import openai
import requests
import os
import json
import pathlib
import time
import argparse
import sys
import threading

_logger = logging.getLogger(pathlib.Path(__file__).name)
_logger.setLevel(logging.INFO)
_sh = logging.StreamHandler(stream=sys.stdout)
_sh.setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s [%(levelname)s] : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
_logger.addHandler(_sh)


my_path = pathlib.Path(__file__).parent.resolve()

bigbench_data_root = "../datasets/BigBench"
cot_prompts_dir = os.path.join(bigbench_data_root, "cot-prompts")
bbh_test_dir = os.path.join(bigbench_data_root, "bbh")
bigbench_data_root = my_path.parent / "datasets" / "BigBench"
cot_prompts_dir = bigbench_data_root / "cot-prompts"
bbh_test_dir = bigbench_data_root / "bbh"

SUBJECTS = [
"boolean_expressions",
Expand Down Expand Up @@ -52,7 +68,7 @@ def extract_chat_qa(few_shot_prompt):


def do_chat_cot(bbh_test_path, cot_prompt_path, test_name, cot_results_path):
print(f"Processing {test_name}")
_logger.info(f"Processing {test_name}")
test_results = []
with open(cot_prompt_path, "r", encoding="utf-8") as file:
cot_prompt_contents = file.read()
Expand All @@ -74,25 +90,31 @@ def do_chat_cot(bbh_test_path, cot_prompt_path, test_name, cot_results_path):
with open(bbh_test_path, "r", encoding="utf-8") as file:
example_data = json.load(file)
for i, example in enumerate(example_data["examples"]):
print(
_logger.info(
f"Processing example {i} of {len(example_data['examples'])} for {test_name}"
)
prompt_messages = few_shot_messages + [
{"role": "user", "content": "Q: " + example["input"]}
]
header = {"Authorization": os.getenv("AZURE_OPENAI_API_KEY")}
# These os.getenv calls shoud probably route to utils.py instead....
header = {"Authorization": os.getenv("AZURE_OPENAI_CHAT_API_KEY")}
data = {
"model": "gpt-4-1106-preview",
"temperature": 0,
"messages": prompt_messages,
"max_tokens": 2000,
}
url = os.getenv("AZURE_OPENAI_API_URL")
while True:
url = os.getenv("AZURE_OPENAI_CHAT_ENDPOINT_URL")

retry_count = 0
max_retries = 5
while retry_count < max_retries:
retry_count += 1
try:
response = requests.post(
url, headers=header, json=data, timeout=600
)
assert response.status_code < 400, f"{response.text}"
completion = json.loads(response.text)
test_results.append(
{
Expand All @@ -106,12 +128,16 @@ def do_chat_cot(bbh_test_path, cot_prompt_path, test_name, cot_results_path):
)
break
except Exception as e:
print("Caught exception: ", e)
print("Retrying in 35 seconds...")
_logger.warning("Caught exception: ", e)
_logger.warning("Retrying in 35 seconds...")
time.sleep(35)
cot_results_filename = os.path.join(cot_results_path, f"{test_name}_chat_cot_results.json")
cot_results_filename = os.path.join(
cot_results_path, f"{test_name}_chat_cot_results.json"
)
json.dump(
cot_results_filename, open(f"{test_name}_chat_cot_results.json", "w"), indent=4
cot_results_filename,
open(f"{test_name}_chat_cot_results.json", "w"),
indent=4,
)


Expand All @@ -132,7 +158,10 @@ def do_completion_cot(bbh_test_path, cot_prompt_path, test_name, cot_results_pat
f"Processing example {i} of {len(example_data['examples'])} for {test_name}"
)
prompt = f"{cot_prompt_contents}\n\nQ: {example['input']}\nA: Let's think step by step.\n"
while True:
retry_count = 0
max_retries = 5
while retry_count < max_retries:
retry_count += 1
try:
completion = openai.Completion.create(
engine="gemini-compete-wus",
Expand All @@ -156,17 +185,21 @@ def do_completion_cot(bbh_test_path, cot_prompt_path, test_name, cot_results_pat
)
break
except Exception as e:
print("Caught exception: ", e)
print("Retrying in 5 seconds...")
_logger.warning("Caught exception: ", e)
_logger.warning("Retrying in 5 seconds...")
time.sleep(5)
cot_results_filename = os.path.join(cot_results_path, f"{test_name}_completion_cot_results.json")
cot_results_filename = os.path.join(
cot_results_path, f"{test_name}_completion_cot_results.json"
)
json.dump(
test_results,
open(cot_results_filename, "w"),
indent=4,
)


def process_cot(test_name: str, api_type="chat"):
_logger.info("Starting process_cot")
if test_name == "all":
subjects = SUBJECTS
elif test_name in SUBJECTS:
Expand All @@ -188,19 +221,23 @@ def process_cot(test_name: str, api_type="chat"):
print(f"COT prompt file {cot_prompt_path} does not exist")

if api_type == "completion":
_logger.info(f"Starting completion thread for {bbh_test_path}")
results_path = os.path.join(".", "results", "cot_results", "completion")
thread = threading.Thread(
target=do_completion_cot, args=(bbh_test_path, cot_prompt_path, subject, results_path)
target=do_completion_cot,
args=(bbh_test_path, cot_prompt_path, subject, results_path),
)
else:
_logger.info(f"Starting chat thread for {bbh_test_path}")
results_path = os.path.join(".", "results", "cot_results", "chat")
thread = threading.Thread(
target=do_chat_cot, args=(bbh_test_path, cot_prompt_path, subject, results_path)
target=do_chat_cot,
args=(bbh_test_path, cot_prompt_path, subject, results_path),
)
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

print("Done!")
print("Done!")
18 changes: 13 additions & 5 deletions src/promptbase/bigbench/bigbench_score.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import datetime
import os
import json
import argparse
import os
import pathlib

my_path = pathlib.Path(__file__).parent.resolve()


def score(api_type="chat"):
ground_truth_dir = os.path.join("..", "datasets", "BigBench", "bbh")
answer_dir = os.path.join(".", "results", "answers")
ground_truth_dir = my_path.parent / "datasets" / "BigBench" / "bbh"
assert ground_truth_dir.exists(), f"Checking for {ground_truth_dir}"
assert ground_truth_dir.is_dir()
answer_dir = my_path / "results" / "answers"

score_dict = {}

Expand All @@ -25,7 +30,10 @@ def score(api_type="chat"):
with open(answer_path) as f:
answer_data = json.load(f)

print("Number of ground truth examples: " + str(len(ground_truth_data["examples"])))
print(
"Number of ground truth examples: "
+ str(len(ground_truth_data["examples"]))
)
print("Number of answer examples: " + str(len(answer_data)))
if len(ground_truth_data["examples"]) != len(answer_data):
print("Number of examples does not match for file: " + filename)
Expand Down

0 comments on commit dcafda4

Please sign in to comment.