diff --git a/examples/create_transform_data_example.py b/examples/create_transform_data_example.py index 8900e547d..ed33af765 100644 --- a/examples/create_transform_data_example.py +++ b/examples/create_transform_data_example.py @@ -33,10 +33,10 @@ # run this pipeline to retrieve relevant datasets, rerank them, # and transform them based on the prompt - num_points_to_transform = 20 + total_num_points_to_transform = 20 retriever = DescriptionDatasetRetriever( auto_transform_data=True, - num_points_to_transform=num_points_to_transform, + total_num_points_to_transform=total_num_points_to_transform, ) retrieved_dataset_dict = retriever.retrieve_dataset_dict( prompt_spec, diff --git a/prompt2model/dataset_retriever/description_dataset_retriever.py b/prompt2model/dataset_retriever/description_dataset_retriever.py index 3332f6bda..3fede74f8 100644 --- a/prompt2model/dataset_retriever/description_dataset_retriever.py +++ b/prompt2model/dataset_retriever/description_dataset_retriever.py @@ -66,9 +66,10 @@ def __init__( max_number_of_dataset_rows: Limit the number of rows for large datasets. allow_gated_datasets: Use only if the user explicitly wants gated datasets auto_transform_data: Automatically transform data to match the prompt. - total_num_points_to_transform: Number of data points to transform across all datasets. - max_allowed_failed_transforms: Maximum number of failed transforms allowed - for a given dataset. Skip the dataset if it exceed the + total_num_points_to_transform: Number of data points to transform across + all datasets. + max_allowed_failed_transforms: Maximum number of failed transforms allowed + for a given dataset. Skip the dataset if it exceed the maximum number of allowed transforms. max_datasets_to_choose: Maximum number of datasets to choose from. num_votes: Number of votes to consider for reranking. @@ -83,9 +84,9 @@ def __init__( self.max_number_of_dataset_rows = max_number_of_dataset_rows self.allow_gated_datasets = allow_gated_datasets self.auto_transform_data = auto_transform_data - self.num_points_to_transform = num_points_to_transform + self.total_num_points_to_transform = total_num_points_to_transform if max_allowed_failed_transforms is None: - self.max_allowed_failed_transforms: int = num_points_to_transform // 3 + self.max_allowed_failed_transforms: int = total_num_points_to_transform // 3 else: self.max_allowed_failed_transforms = max_allowed_failed_transforms @@ -521,9 +522,9 @@ def rerank_datasets( config_name = self.get_rerank_with_highest_votes( config_selection_prompt, curr_dataset["configs"] ) - + logger.info(f"Chosen dataset and config: {dataset_name=} {config_name=}") - #config name being None gets handled in calling function + # config name being None gets handled in calling function return dataset_name, config_name def canonicalize_dataset_automatically( @@ -609,7 +610,6 @@ def canonicalize_dataset_automatically( logger.info(f"Transformed dataset. Example row:\n{example_rows}\n") - return canonicalized_dataset else: canonicalized_dataset = self.canonicalize_dataset_using_columns( @@ -626,7 +626,7 @@ def get_datasets_of_required_size( """Combine multiple datasets to get the required size. Args: - dataset_list (list[dict]): A list of dictionaries representing the datasets. + datasets_info_dict (dict): A list of dictionaries representing the datasets. prompt_spec (PromptSpec): An object representing the prompt specification. Returns: @@ -640,7 +640,7 @@ def get_datasets_of_required_size( dataset_contributions = {} number_of_chosen_datasets = 0 while ( - curr_datasets_size < self.num_points_to_transform + curr_datasets_size < self.total_num_points_to_transform and len(datasets_info_dict.keys()) > 0 and number_of_chosen_datasets <= self.max_datasets_to_choose ): @@ -654,16 +654,14 @@ def get_datasets_of_required_size( number_of_chosen_datasets += 1 if config_name is None: - del datasets_info_dict[ - dataset_name - ] + del datasets_info_dict[dataset_name] continue # If it couldn't find a relevant config, delete the entire dataset. # noqa: E501 dataset_info = datasets_info_dict[dataset_name]["configs"][config_name] canonicalized_dataset = self.canonicalize_dataset_automatically( dataset_info, prompt_spec, - self.num_points_to_transform - curr_datasets_size, + self.total_num_points_to_transform - curr_datasets_size, ) curr_datasets_size += len(canonicalized_dataset["train"]["input_col"]) inputs += canonicalized_dataset["train"]["input_col"] diff --git a/prompt2model/dataset_transformer/prompt_based.py b/prompt2model/dataset_transformer/prompt_based.py index 1f2b5e0b8..cad2790de 100644 --- a/prompt2model/dataset_transformer/prompt_based.py +++ b/prompt2model/dataset_transformer/prompt_based.py @@ -92,12 +92,29 @@ def generate_transform_prompts( transform_prompts.append(transform_prompt) return transform_prompts - def generate_responses(self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo") -> list[str]: - """Generate responses for the transform prompts. Use gpt 3.5 for transformation as it is cheaper.""" + def generate_responses( + self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo" + ) -> list[str]: + """Generate responses for the given transform prompts. + + Args: + transform_prompts_batch (list[str]): A list of transform prompts. + model_name (str, optional): The name of the model to use. Defaults to + "gpt-3.5-turbo" to save costs. + + Returns: + list[str]: A list of generated responses. + + Raises: + API_ERRORS: If there is an error with the API. + + """ async def generate_responses_async(transform_prompts): """Generate responses asynchronously using the specified model.""" - responses = await api_tools.APIAgent(model_name=model_name).generate_batch_completion( + responses = await api_tools.APIAgent( + model_name=model_name + ).generate_batch_completion( transform_prompts, temperature=0, responses_per_request=1, @@ -154,7 +171,6 @@ def process_responses( if self.curr_failed_transforms > self.max_allowed_failed_transforms: break - return inputs, outputs def transform_data( @@ -192,7 +208,6 @@ def transform_data( ) self.max_allowed_failed_transforms = 0 break - logger.info( f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501 diff --git a/prompt2model_demo.py b/prompt2model_demo.py index a2c881af6..b20021c71 100644 --- a/prompt2model_demo.py +++ b/prompt2model_demo.py @@ -206,18 +206,18 @@ def main(): ) line = input() try: - num_points_to_transform = int(line) + total_num_points_to_transform = int(line) except ValueError: line_print("Invalid input. Please enter a number.") continue - if num_points_to_transform <= 0: + if total_num_points_to_transform <= 0: line_print("Invalid input. Please enter a number greater than 0.") continue - status["num_transform"] = num_points_to_transform + status["num_transform"] = total_num_points_to_transform break retriever = DescriptionDatasetRetriever( auto_transform_data=auto_transform_data, - num_points_to_transform=num_points_to_transform, + total_num_points_to_transform=total_num_points_to_transform, ) retrieved_dataset_dict = retriever.retrieve_dataset_dict(prompt_spec)