Skip to content

Commit

Permalink
added docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala committed Mar 30, 2024
1 parent acd6e7a commit 415aa21
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 44 deletions.
1 change: 1 addition & 0 deletions litellm_uuid.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
d936d070-a20e-445e-b721-d668b1c1706e
26 changes: 16 additions & 10 deletions prompt2model/dataset_retriever/reranking_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" This module contains the functions to generate the prompt for dataset reranking. """
from __future__ import annotations # noqa FI58

METAPROMPT_BASE_DATASET = """Your objective is to choose the most relevant dataset for a given a task (and few examples of the task). For each dataset, you will be provided with the dataset description, and tags related to the dataset which provide meta-information about the dataset. Please return the most relevant dataset, e.g. squad """ # noqa: E501
Expand All @@ -20,15 +21,19 @@
CONFIG_TEMPLATE = """\t[{counter}] **{config_name}**\n: The columns in this config are {dataset_columns}.\n An example row from this config is {sample_row}.\n """ # noqa: E501


def build_datasets_prompt(instruction: str, examples: str, datasets_infos):
"""Constructs a prompt that describes each dataset.
def build_datasets_prompt(instruction: str, examples: str, datasets_infos: dict):
"""
Builds the prompt for dataset reranking.
Args:
datasets_infos (dict): Dictionary with dataset information.
instruction (str): Task instructions
examples (str): Task Examples
datasets_infos (dict): A dictionary containing information about all datasets.
Returns:
str: A string that lists each dataset with its description and tags.
str: The input prompt for dataset retrieval.
"""

dataset_string = ""
for i, (dataset_name, dataset_info) in enumerate(datasets_infos.items(), start=1):
dataset_string += f"""{DATASET_TEMPLATE.format(
Expand All @@ -48,15 +53,16 @@ def build_datasets_prompt(instruction: str, examples: str, datasets_infos):


def build_configs_prompt(instruction: str, examples: str, dataset_info: dict):
"""Constructs a prompt for selecting relevant configurations from a given dataset.
"""
Builds the prompt for config reranking.
Args:
instruction (str): Instruction of the task.
examples (str): Examples of the task.
dataset_info (dict): Information about the dataset and its configurations.
instruction (str): Task instructions
examples (str): Task Examples
datasets_infos (dict): A dictionary containing information about the specific dataset, which includes config information.
Returns:
str: A string that lists each configuration with its details for the specified dataset.
str: The input prompt for dataset retrieval.
"""
configs_string = ""
for j, config in dataset_info["configs"].items():
Expand All @@ -78,7 +84,7 @@ def build_configs_prompt(instruction: str, examples: str, dataset_info: dict):
return input_prompt


def construct_prompt_for_dataset_reranking(instruction: str, examples: str, datasets_infos,is_config:bool=False):
def construct_prompt_for_dataset_reranking(instruction: str, examples: str, datasets_infos: dict,is_config:bool=False):
"""Generate the full prompt for dataset reranking based on the given parameters.
Args:
Expand Down
15 changes: 14 additions & 1 deletion prompt2model/dataset_retriever/task_expansion_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" This module contains the functions to construct the prompt for task expansion. """
METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples."

TASK = """
Expand All @@ -6,7 +7,19 @@
Task Examples: {examples}
"""

def construct_prompt_for_task_explanation(instruction, demonstrations):
def construct_prompt_for_task_explanation(instruction: str, demonstrations: str):
"""
Constructs prompt for task explanation.
This is useful for clarifying the requirements of a task, and providing a clearer description of the task.
Args:
instruction (str): The task instruction.
demonstrations (str): The task demonstrations.
Returns:
str: The constructed prompt.
"""
task = TASK.format(task_description=instruction, examples=demonstrations)
prompt = "\n--------\n".join([METAPROMPT_BASE, task])
return prompt
77 changes: 49 additions & 28 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,48 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
"""Transform data based on a transform prompt."""

def __init__(
self,
num_points_to_transform: int ,
max_allowed_failed_transforms: int,
plan_prompt_fn: Callable[
[str, list[dict], str], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
] = construct_prompt_for_transform_data,

):
"""Initialize the class."""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""
self.num_points_to_transform = num_points_to_transform
self.curr_failed_transforms = 0
self.max_allowed_failed_transforms = max_allowed_failed_transforms


def generate_task_explanation(self, prompt_spec) -> str:
self,
num_points_to_transform: int ,
max_allowed_failed_transforms: int,
plan_prompt_fn: Callable[
[str, list[dict], str], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
] = construct_prompt_for_transform_data,

):
"""
Initializes an instance of the PromptBasedDatasetTransformer class.
Args:
num_points_to_transform (int): The number of points to transform.
max_allowed_failed_transforms (int): The maximum number of failed transforms allowed.
plan_prompt_fn (Callable[[str, list[dict], str], str], optional): The function to construct the prompt for plan data. Defaults to construct_prompt_for_plan.
transform_prompt_fn (Callable[[str, dict, str, str], str], optional): The function to construct the prompt for transform data. Defaults to construct_prompt_for_transform_data.
"""

self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""
self.num_points_to_transform = num_points_to_transform
self.curr_failed_transforms = 0
self.max_allowed_failed_transforms = max_allowed_failed_transforms


def generate_task_explanation(self, prompt_spec: PromptSpec) -> str:
""" Generate task explanation"""
task_explanation_prompt = construct_prompt_for_task_explanation(prompt_spec.instruction, prompt_spec.examples)
return make_single_api_request(task_explanation_prompt, max_api_calls=10)

def generate_plan(self, task_explanation, dataset, prompt_spec) -> str:
def generate_plan(self, task_explanation:str, dataset:datasets.Dataset, prompt_spec: PromptSpec) -> str:
""" Generate plan for the task"""
plan_prompt = self.plan_prompt_fn(task_explanation, dataset, prompt_spec.examples)
return make_single_api_request(plan_prompt, max_api_calls=100)
return make_single_api_request(plan_prompt, max_api_calls=10)


def generate_transform_prompts(self, task_explanation:str, dataset:datasets.Dataset, prompt_spec:PromptSpec,) -> List[str]:
""" Get transform prompts for each row in the dataset."""
transform_prompts = []
for i in range(min(self.num_points_to_transform, len(dataset))):
row = dataset[i]
Expand All @@ -72,8 +84,8 @@ def generate_transform_prompts(self, task_explanation:str, dataset:datasets.Dat
return transform_prompts


def generate_responses(self, transform_prompts_batch) -> List[str]:

def generate_responses(self, transform_prompts_batch:List[str]) -> List[str]:
""" Generate responses for the transform prompts."""
async def generate_responses_async(transform_prompts):
"""
Generate responses asynchronously using the specified model.
Expand All @@ -86,7 +98,6 @@ async def generate_responses_async(transform_prompts):
)
return responses


try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses_async(transform_prompts_batch))
Expand All @@ -96,7 +107,7 @@ async def generate_responses_async(transform_prompts):
return responses


def process_responses(self, responses:list, prompt_spec) -> Tuple[List[str], List[str]]:
def process_responses(self, responses:list, prompt_spec: PromptSpec) -> Tuple[List[str], List[str]]:
"""
Process the responses received from the API. Also write the current set of inputs and outputs to a file.
Expand Down Expand Up @@ -143,7 +154,17 @@ def process_responses(self, responses:list, prompt_spec) -> Tuple[List[str], Lis

return inputs, outputs

def transform_data(self, prompt_spec, dataset: datasets.Dataset) -> datasets.DatasetDict:
def transform_data(self, prompt_spec:PromptSpec, dataset: datasets.Dataset) -> tuple[list[str], list[str]]:
"""
Transforms the given dataset based on the provided prompt specification.
Args:
prompt_spec (PromptSpec): The prompt specification object that defines the transformation rules.
dataset (datasets.Dataset): The dataset to be transformed.
Returns:
A tuple containing two lists: inputs and outputs.
"""
task_explanation = self.generate_task_explanation(prompt_spec)
self.plan = self.generate_plan(task_explanation, dataset, prompt_spec)
logger.info(f"Plan created. Plan: {self.plan}")
Expand Down
4 changes: 2 additions & 2 deletions prompt2model/dataset_transformer/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def truncate_row(example_row: dict, max_length=200) -> str:


def construct_prompt_for_plan(
task_description: str, dataset: list[dict], example: str, num_rows: int = 5
task_description: str,example: str, dataset: list[dict], num_rows: int = 5
) -> str:
"""Construct prompt for plan.
Expand Down Expand Up @@ -439,7 +439,7 @@ def construct_prompt_for_plan(


def construct_prompt_for_transform_data(
task_description: str, dataset_row: dict, plan: str, example: str
task_description: str, dataset_row: str, plan: str, example: str
) -> str:
"""Construct prompt for dataset transformation.
Expand Down
2 changes: 1 addition & 1 deletion prompt2model/prompt_parser/instr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def parse_from_prompt(self, prompt: str) -> None:
self._instruction = extraction["Instruction"]
self._examples = extraction["Demonstrations"]

def set_instruction_and_examples(self, instruction="", examples=""):
def set_instruction_and_examples(self, instruction:str="", examples:str="")->None:
"""Set the instruction and examples directly."""
self._instruction = instruction
self._examples = examples
1 change: 1 addition & 0 deletions prompt2model/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Util functions for datasets."""


import datasets
import requests

Expand Down
5 changes: 3 additions & 2 deletions prompt2model/utils/parse_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def find_and_parse_json(
return final_response


def find_rightmost_brackets(text):
def find_rightmost_brackets(text:str)->str|None:
"""Find the rightmost complete set of brackets in a string."""
stack = []
for i, char in enumerate(reversed(text)):
if char == '}':
Expand All @@ -71,7 +72,7 @@ def find_rightmost_brackets(text):

import re

def parse_dataset_config_responses(response):
def parse_dataset_config_responses(response:dict)->str:
"""
Parses the response to extract relevant information from dataset/configuration.
Expand Down
1 change: 1 addition & 0 deletions scripts/dataset_index/retrieve_dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datasets
import requests


def parse_arguments():
"""Parse command line arguments for the script.
Expand Down

0 comments on commit 415aa21

Please sign in to comment.