Skip to content

Commit

Permalink
comment change
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala committed Apr 11, 2024
1 parent 775da63 commit ad0d502
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 25 deletions.
4 changes: 2 additions & 2 deletions examples/create_transform_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@

# run this pipeline to retrieve relevant datasets, rerank them,
# and transform them based on the prompt
num_points_to_transform = 20
total_num_points_to_transform = 20
retriever = DescriptionDatasetRetriever(
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
total_num_points_to_transform=total_num_points_to_transform,
)
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
Expand Down
26 changes: 12 additions & 14 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def __init__(
max_number_of_dataset_rows: Limit the number of rows for large datasets.
allow_gated_datasets: Use only if the user explicitly wants gated datasets
auto_transform_data: Automatically transform data to match the prompt.
total_num_points_to_transform: Number of data points to transform across all datasets.
max_allowed_failed_transforms: Maximum number of failed transforms allowed
for a given dataset. Skip the dataset if it exceed the
total_num_points_to_transform: Number of data points to transform across
all datasets.
max_allowed_failed_transforms: Maximum number of failed transforms allowed
for a given dataset. Skip the dataset if it exceed the
maximum number of allowed transforms.
max_datasets_to_choose: Maximum number of datasets to choose from.
num_votes: Number of votes to consider for reranking.
Expand All @@ -83,9 +84,9 @@ def __init__(
self.max_number_of_dataset_rows = max_number_of_dataset_rows
self.allow_gated_datasets = allow_gated_datasets
self.auto_transform_data = auto_transform_data
self.num_points_to_transform = num_points_to_transform
self.total_num_points_to_transform = total_num_points_to_transform
if max_allowed_failed_transforms is None:
self.max_allowed_failed_transforms: int = num_points_to_transform // 3
self.max_allowed_failed_transforms: int = total_num_points_to_transform // 3
else:
self.max_allowed_failed_transforms = max_allowed_failed_transforms

Expand Down Expand Up @@ -521,9 +522,9 @@ def rerank_datasets(
config_name = self.get_rerank_with_highest_votes(
config_selection_prompt, curr_dataset["configs"]
)

logger.info(f"Chosen dataset and config: {dataset_name=} {config_name=}")
#config name being None gets handled in calling function
# config name being None gets handled in calling function
return dataset_name, config_name

def canonicalize_dataset_automatically(
Expand Down Expand Up @@ -609,7 +610,6 @@ def canonicalize_dataset_automatically(

logger.info(f"Transformed dataset. Example row:\n{example_rows}\n")


return canonicalized_dataset
else:
canonicalized_dataset = self.canonicalize_dataset_using_columns(
Expand All @@ -626,7 +626,7 @@ def get_datasets_of_required_size(
"""Combine multiple datasets to get the required size.
Args:
dataset_list (list[dict]): A list of dictionaries representing the datasets.
datasets_info_dict (dict): A list of dictionaries representing the datasets.
prompt_spec (PromptSpec): An object representing the prompt specification.
Returns:
Expand All @@ -640,7 +640,7 @@ def get_datasets_of_required_size(
dataset_contributions = {}
number_of_chosen_datasets = 0
while (
curr_datasets_size < self.num_points_to_transform
curr_datasets_size < self.total_num_points_to_transform
and len(datasets_info_dict.keys()) > 0
and number_of_chosen_datasets <= self.max_datasets_to_choose
):
Expand All @@ -654,16 +654,14 @@ def get_datasets_of_required_size(
number_of_chosen_datasets += 1

if config_name is None:
del datasets_info_dict[
dataset_name
]
del datasets_info_dict[dataset_name]
continue # If it couldn't find a relevant config, delete the entire dataset. # noqa: E501

dataset_info = datasets_info_dict[dataset_name]["configs"][config_name]
canonicalized_dataset = self.canonicalize_dataset_automatically(
dataset_info,
prompt_spec,
self.num_points_to_transform - curr_datasets_size,
self.total_num_points_to_transform - curr_datasets_size,
)
curr_datasets_size += len(canonicalized_dataset["train"]["input_col"])
inputs += canonicalized_dataset["train"]["input_col"]
Expand Down
25 changes: 20 additions & 5 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,29 @@ def generate_transform_prompts(
transform_prompts.append(transform_prompt)
return transform_prompts

def generate_responses(self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo") -> list[str]:
"""Generate responses for the transform prompts. Use gpt 3.5 for transformation as it is cheaper."""
def generate_responses(
self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo"
) -> list[str]:
"""Generate responses for the given transform prompts.
Args:
transform_prompts_batch (list[str]): A list of transform prompts.
model_name (str, optional): The name of the model to use. Defaults to
"gpt-3.5-turbo" to save costs.
Returns:
list[str]: A list of generated responses.
Raises:
API_ERRORS: If there is an error with the API.
"""

async def generate_responses_async(transform_prompts):
"""Generate responses asynchronously using the specified model."""
responses = await api_tools.APIAgent(model_name=model_name).generate_batch_completion(
responses = await api_tools.APIAgent(
model_name=model_name
).generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
Expand Down Expand Up @@ -154,7 +171,6 @@ def process_responses(
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
break


return inputs, outputs

def transform_data(
Expand Down Expand Up @@ -192,7 +208,6 @@ def transform_data(
)
self.max_allowed_failed_transforms = 0
break


logger.info(
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501
Expand Down
8 changes: 4 additions & 4 deletions prompt2model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,18 @@ def main():
)
line = input()
try:
num_points_to_transform = int(line)
total_num_points_to_transform = int(line)
except ValueError:
line_print("Invalid input. Please enter a number.")
continue
if num_points_to_transform <= 0:
if total_num_points_to_transform <= 0:
line_print("Invalid input. Please enter a number greater than 0.")
continue
status["num_transform"] = num_points_to_transform
status["num_transform"] = total_num_points_to_transform
break
retriever = DescriptionDatasetRetriever(
auto_transform_data=auto_transform_data,
num_points_to_transform=num_points_to_transform,
total_num_points_to_transform=total_num_points_to_transform,
)
retrieved_dataset_dict = retriever.retrieve_dataset_dict(prompt_spec)

Expand Down

0 comments on commit ad0d502

Please sign in to comment.