Skip to content

Commit

Permalink
PR changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala committed Apr 11, 2024
1 parent 15a61eb commit 775da63
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 35 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ retrieved_dataset_dict/
result/
checkpoint/
status.yaml
dump.txt
# Outputs generated by the colab demo
trained_model/
trained_tokenizer/
40 changes: 20 additions & 20 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
max_number_of_dataset_rows=3000,
allow_gated_datasets=False,
auto_transform_data: bool = False,
num_points_to_transform: int = 3000,
total_num_points_to_transform: int = 3000,
max_allowed_failed_transforms: int = None,
max_datasets_to_choose: int = 3,
num_votes=5,
Expand All @@ -66,8 +66,10 @@ def __init__(
max_number_of_dataset_rows: Limit the number of rows for large datasets.
allow_gated_datasets: Use only if the user explicitly wants gated datasets
auto_transform_data: Automatically transform data to match the prompt.
num_points_to_transform: Number of data points to transform.
max_allowed_failed_transforms: Maximum number of failed transforms allowed.
total_num_points_to_transform: Number of data points to transform across all datasets.
max_allowed_failed_transforms: Maximum number of failed transforms allowed
for a given dataset. Skip the dataset if it exceed the
maximum number of allowed transforms.
max_datasets_to_choose: Maximum number of datasets to choose from.
num_votes: Number of votes to consider for reranking.
"""
Expand Down Expand Up @@ -424,7 +426,9 @@ def make_dataset_from_samples(
columns, "input_col" and "output_col".
"""
updated_inputs, updated_outputs = [], []
if len(inputs) < 0 or len(inputs) != len(outputs):
dataset_dict = {}

if len(inputs) <= 0 or len(inputs) != len(outputs):
logger.error("Length of inputs and outputs must be >0 and equal.")
else:
for i, o in zip(inputs, outputs):
Expand All @@ -434,10 +438,9 @@ def make_dataset_from_samples(
else:
logger.warning(f"Input or output is None: {i} {o}")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": updated_inputs, "output_col": updated_outputs}
)
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": updated_inputs, "output_col": updated_outputs}
)
return datasets.DatasetDict(dataset_dict)

def get_rerank_with_highest_votes(self, prompt, infos_dict):
Expand All @@ -459,11 +462,8 @@ def get_rerank_with_highest_votes(self, prompt, infos_dict):
if curr_name["name"] not in infos_dict:
logger.warning("LLM hallucinated dataset/config name: %s", curr_name)
voting.append(None)
continue
voting.append(curr_name["name"])
if len(voting) == 0:
logger.warning("Voting resulted in no dataset/config.")
return None
else:
voting.append(curr_name["name"])
chosen_one = max(set(voting), key=voting.count)
return chosen_one

Expand Down Expand Up @@ -521,7 +521,9 @@ def rerank_datasets(
config_name = self.get_rerank_with_highest_votes(
config_selection_prompt, curr_dataset["configs"]
)

logger.info(f"Chosen dataset and config: {dataset_name=} {config_name=}")
#config name being None gets handled in calling function
return dataset_name, config_name

def canonicalize_dataset_automatically(
Expand All @@ -543,7 +545,7 @@ def canonicalize_dataset_automatically(
Args:
top_dataset_info: Contains info about the top-ranked dataset.
prompt_spec: prompt object storing the original task and examples.
num_points_to_transform: Number of points to transform for a given dataset
Returns:
The canonicalized dataset, or None if the dataset is invalid or
if column selection fails, or if any other error occurs
Expand Down Expand Up @@ -606,9 +608,7 @@ def canonicalize_dataset_automatically(
example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4)

logger.info(f"Transformed dataset. Example row:\n{example_rows}\n")
else:
dataset_name = top_dataset_info["dataset_name"]
logger.info(f"{dataset_name} exceed max allowed transforms..")


return canonicalized_dataset
else:
Expand Down Expand Up @@ -648,15 +648,15 @@ def get_datasets_of_required_size(
datasets_info_dict, prompt_spec
)
if dataset_name is None:
# If it couldn't find a relevant dataset (even after voting)
# stop trying to find more datasets.
# If it couldn't find a relevant dataset from reranking
# (even after voting) stop trying to find more datasets.
return None
number_of_chosen_datasets += 1

if config_name is None:
del datasets_info_dict[
dataset_name
] # TODO: Is this deleting the right thing?
]
continue # If it couldn't find a relevant config, delete the entire dataset. # noqa: E501

dataset_info = datasets_info_dict[dataset_name]["configs"][config_name]
Expand Down
15 changes: 5 additions & 10 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ def generate_transform_prompts(
transform_prompts.append(transform_prompt)
return transform_prompts

def generate_responses(self, transform_prompts_batch: list[str]) -> list[str]:
"""Generate responses for the transform prompts."""
def generate_responses(self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo") -> list[str]:
"""Generate responses for the transform prompts. Use gpt 3.5 for transformation as it is cheaper."""

async def generate_responses_async(transform_prompts):
"""Generate responses asynchronously using the specified model."""
responses = await api_tools.APIAgent(
model_name="azure/GPT-3-5-turbo-chat", max_tokens=4000
).generate_batch_completion(
responses = await api_tools.APIAgent(model_name=model_name).generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
Expand All @@ -122,8 +120,6 @@ def process_responses(
) -> tuple[list[str], list[str]]:
"""Process the responses received from the API.
Also write the current set of inputs and outputs to a dump text just in case.
Args:
responses: A list of response strings from the API.
prompt_spec: The PromptSpec object containing the instruction and examples.
Expand Down Expand Up @@ -158,9 +154,6 @@ def process_responses(
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
break

with open("dump.txt", "a") as file:
file.write("Input: " + ", ".join(map(str, inputs)) + "\n")
file.write("Output: " + ", ".join(map(str, outputs)) + "\n")

return inputs, outputs

Expand Down Expand Up @@ -197,7 +190,9 @@ def transform_data(
logger.error(
f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}" # noqa: E501
)
self.max_allowed_failed_transforms = 0
break


logger.info(
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501
Expand Down
9 changes: 5 additions & 4 deletions prompt2model/utils/api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,22 @@
openai.error.ServiceUnavailableError: "API service unavailable error: {e}",
openai.error.APIError: "API error: {e}",
}
BUFFER_DURATION = 2


class APIAgent:
"""A class for accessing API-based models."""

def __init__(
self,
model_name: str = "gpt-3.5-turbo",
model_name: str = "gpt-4",
max_tokens: int | None = 4000,
api_base: str | None = None,
):
"""Initialize APIAgent with model_name and max_tokens.
Args:
model_name: Name fo the model to use (by default, gpt-3.5-turbo).
model_name: Name fo the model to use (by default, gpt-4).
max_tokens: The maximum number of tokens to generate. Defaults to the max
value for the model if available through litellm.
api_base: Custom endpoint for Hugging Face's inference API.
Expand Down Expand Up @@ -225,7 +226,8 @@ def handle_api_error(e, backoff_duration=1) -> None:
Sleeps incase error is some type of timeout, else throws error.
Args:
e: The API error raised.
e: The error to handle. This could be an OpenAI error or a related
non-fatal error, such as JSONDecodeError or AssertionError.
backoff_duration: The duration to wait before retrying the API call.
Raises:
Expand All @@ -246,7 +248,6 @@ def handle_api_error(e, backoff_duration=1) -> None:
match = re.search(r"Please retry after (\d+) seconds", str(e))
# If openai mentions how long to sleep use that, else do exponential backoff
if match is not None:
BUFFER_DURATION = 2
backoff_duration = int(match.group(1)) + BUFFER_DURATION

logging.info(f"Retrying in {backoff_duration} seconds...")
Expand Down

0 comments on commit 775da63

Please sign in to comment.