Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Transform and Rerank Module #396

Merged
merged 47 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
509f5d9
add top_dataset_info_return
saum7800 Jan 18, 2024
90a4a04
make reranking call from 16k context
saum7800 Jan 18, 2024
7193ec0
add incontext and COT to dataset transformation
saum7800 Jan 19, 2024
1ae1b0d
log selected columns
saum7800 Jan 22, 2024
71e129b
add peft training
saum7800 Jan 23, 2024
faa5b7b
remove import
saum7800 Jan 24, 2024
c0e6ef6
two returns
saum7800 Jan 24, 2024
9e40238
create promptspec
saum7800 Jan 24, 2024
f4298b8
add logging
saum7800 Jan 24, 2024
9c9fa75
add logging
saum7800 Jan 24, 2024
665b546
add logging
saum7800 Jan 24, 2024
e18989a
pass text
saum7800 Jan 24, 2024
5d9761e
change params
saum7800 Jan 24, 2024
b5f094f
change params
saum7800 Jan 24, 2024
1e169eb
change params
saum7800 Jan 24, 2024
bf72344
change params
saum7800 Jan 24, 2024
758a6da
change paths
saum7800 Jan 24, 2024
7bfa104
minor changes
saum7800 Jan 24, 2024
9e9f33e
remove arg
saum7800 Jan 25, 2024
a12e062
clear cache
saum7800 Jan 25, 2024
d9dc1ab
add wandb changes and minor changes
saum7800 Jan 26, 2024
63a4a75
change eval steps
saum7800 Jan 27, 2024
931edf0
change eval steps
saum7800 Jan 27, 2024
a2c2cc6
modify qlora params
saum7800 Jan 27, 2024
41d6fd0
make lr changes
saum7800 Jan 27, 2024
d32dc22
curr changes
ritugala Mar 9, 2024
5a93ae7
delete saumya changes
ritugala Mar 9, 2024
921a98f
initial changes
ritugala Mar 25, 2024
88130ba
merging changes
ritugala Mar 25, 2024
379a503
first pass refactoring
ritugala Mar 26, 2024
acd6e7a
minor bug fixing
ritugala Mar 30, 2024
415aa21
added docstrings
ritugala Mar 30, 2024
0d009b2
completed precommit hooks
ritugala Apr 2, 2024
60051a7
minor refactoring fixes
ritugala Apr 2, 2024
e0c0d0f
fixing pytests
ritugala Apr 2, 2024
6558279
updated p2m_demo.py
ritugala Apr 2, 2024
15a61eb
fix linting
ritugala Apr 5, 2024
775da63
PR changes
ritugala Apr 11, 2024
ad0d502
comment change
ritugala Apr 11, 2024
4b6427b
merged with main
ritugala Apr 18, 2024
95dcea0
PR changes and fixed test
ritugala Apr 18, 2024
b92dca2
fixed linting
ritugala Apr 18, 2024
423510e
Update prompt2model/dataset_retriever/description_dataset_retriever.py
ritugala Apr 18, 2024
92f127a
Update prompt2model/utils/parse_responses.py
ritugala Apr 18, 2024
547d4a3
Update prompt2model/dataset_retriever/description_dataset_retriever.py
ritugala Apr 18, 2024
c2ba8b3
Update prompt2model/dataset_retriever/description_dataset_retriever.py
ritugala Apr 18, 2024
0c81d18
fixed minor linting
ritugala Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
huggingface_data/huggingface_datasets/reranking_dataset_index.json
huggingface_data/huggingface_models/
retrieved_dataset_dict/
result/
checkpoint/
status.yaml

# Outputs generated by the colab demo
trained_model/
trained_tokenizer/
9 changes: 5 additions & 4 deletions examples/create_transform_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@

# run this pipeline to retrieve relevant datasets, rerank them,
# and transform them based on the prompt
retriever = DescriptionDatasetRetriever()
num_points_to_transform = 20
total_num_points_to_transform = 20
retriever = DescriptionDatasetRetriever(
auto_transform_data=True,
total_num_points_to_transform=total_num_points_to_transform,
)
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
)

# save the final dataset to disk
Expand Down

This file was deleted.

This file was deleted.

Binary file not shown.
301 changes: 237 additions & 64 deletions prompt2model/dataset_retriever/description_dataset_retriever.py

Large diffs are not rendered by default.

205 changes: 78 additions & 127 deletions prompt2model/dataset_retriever/reranking_prompt.py
ritugala marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions prompt2model/dataset_retriever/task_expansion_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""This module contains the functions to construct the prompt for task expansion."""
METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501

TASK = """
Task Description: {task_description}

Task Examples: {examples}
"""


def construct_prompt_for_task_explanation(instruction: str, demonstrations: str):
"""Constructs prompt for task explanation.

This is useful for clarifying the requirements of a task,
and providing a clearer description of the task.

Args:
instruction (str): The task instruction.
demonstrations (str): The task demonstrations.

Returns:
str: The constructed prompt.
"""
task = TASK.format(task_description=instruction, examples=demonstrations)
prompt = "\n--------\n".join([METAPROMPT_BASE, task])
return prompt
1 change: 0 additions & 1 deletion prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.Dataset:
"""Transform a split of data.

Expand Down
219 changes: 141 additions & 78 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import datasets

from prompt2model.dataset_retriever.task_expansion_prompt import (
construct_prompt_for_task_explanation,
)
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
Expand All @@ -31,99 +34,87 @@ class PromptBasedDatasetTransformer(DatasetTransformer):

def __init__(
self,
num_points_to_transform: int = 10,
max_allowed_failed_transforms: int = 3,
plan_prompt_fn: Callable[
[str, str, list[dict], int], str
[str, str, list[dict]], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, str, str, dict], str
[str, str, str, str], str
] = construct_prompt_for_transform_data,
):
"""Initialize the class.
"""Initializes an instance of the PromptBasedDatasetTransformer class.

Args:
plan_prompt_fn: A function that takes in a description of the target task,
example of the target task,
list of dictionaries where each dictionary is a row from a potentially
relevant dataset,
and the number of rows to use from this potentially relevant dataset,
and returns a plan prompt.

transform_prompt_fn: A function that takes in a description of the target
task, an example of the target task,
plan for dataset transformation,
and the row from a potentially relevant dataset to be transformed.
num_points_to_transform: The number of points to transform.
max_allowed_failed_transforms: The maximum number of
failed transforms allowed.
plan_prompt_fn: The function to construct the prompt for plan
transform_prompt_fn: The function to construct the prompt
ritugala marked this conversation as resolved.
Show resolved Hide resolved
for transform data.
"""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""

def make_dataset_from_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Given a list of inputs and outputs, make a dataset.

This function takes in inputs and outputs, both as list of strings,
and returns a DatasetDict object with a single split, "train". It has
two columns, "input_col" and "output_col".


Args:
inputs: A list of inputs, each input is a string.
outputs: A list of outputs, each output is a string.

Returns:
A DatasetDict object with a single split, "train". It has two
columns, "input_col" and "output_col".
"""
if len(inputs) <= 0 or len(inputs) != len(outputs):
raise ValueError("Length of inputs and outputs must be >0 and equal.")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
self.num_points_to_transform = num_points_to_transform
self.curr_failed_transforms = 0
self.max_allowed_failed_transforms = max_allowed_failed_transforms

def generate_task_explanation(self, prompt_spec: PromptSpec) -> str:
"""Generate task explanation."""
task_explanation_prompt = construct_prompt_for_task_explanation(
prompt_spec.instruction, prompt_spec.examples
)
return datasets.DatasetDict(dataset_dict)
return make_single_api_request(task_explanation_prompt, max_api_calls=10)
ritugala marked this conversation as resolved.
Show resolved Hide resolved

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset according to the prompt_spec and dataset."""
def generate_plan(
self, task_explanation: str, dataset: datasets.Dataset, prompt_spec: PromptSpec
) -> str:
"""Generate plan for the task."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
prompt_spec.examples,
dataset,
min(5, len(dataset)),
task_explanation, prompt_spec.examples, dataset
)
self.plan = make_single_api_request(plan_prompt)

logger.info(f"Plan created. Plan: {self.plan}")

inputs = []
outputs = []
return make_single_api_request(plan_prompt, max_api_calls=10)
ritugala marked this conversation as resolved.
Show resolved Hide resolved

max_len = min(num_points_to_transform, len(dataset))
len_count = 0
def generate_transform_prompts(
self,
task_explanation: str,
dataset: datasets.Dataset,
prompt_spec: PromptSpec,
) -> list[str]:
"""Get transform prompts for each row in the dataset."""
transform_prompts = []
for row in dataset:
for i in range(min(self.num_points_to_transform, len(dataset))):
row = dataset[i]
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
prompt_spec.examples,
self.plan,
row,
task_explanation, row, self.plan, prompt_spec.examples
)
transform_prompts.append(transform_prompt)
return transform_prompts

len_count += 1
if len_count >= max_len:
break
def generate_responses(
self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo"
) -> list[str]:
"""Generate responses for the given transform prompts.

Args:
transform_prompts_batch (list[str]): A list of transform prompts.
ritugala marked this conversation as resolved.
Show resolved Hide resolved
model_name (str, optional): The name of the model to use. Defaults to
"gpt-3.5-turbo" to save costs.

async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
Returns:
list[str]: A list of generated responses.
ritugala marked this conversation as resolved.
Show resolved Hide resolved

Raises:
API_ERRORS: If there is an error with the API.

"""

async def generate_responses_async(transform_prompts):
"""Generate responses asynchronously using the specified model."""
responses = await api_tools.APIAgent(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we expose APIAgent as a function parameter with this APIAgent being the default choice? This is currently too restrictive.

Copy link
Collaborator Author

@ritugala ritugala Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viswavi let me know if the generate_responses() function should have a parameter of
model_name passed in (current change) or
is it better to pass in the entire api_agent?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that either is probably fine, so I'm ok with the current design

model_name=model_name
).generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
Expand All @@ -133,20 +124,92 @@ async def generate_responses(transform_prompts):

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
responses = loop.run_until_complete(
generate_responses_async(transform_prompts_batch)
)
except API_ERRORS as e:
handle_api_error(e)
# TODO: What to return here?
ritugala marked this conversation as resolved.
Show resolved Hide resolved
return responses

def process_responses(
self, responses: list, prompt_spec: PromptSpec
) -> tuple[list[str], list[str]]:
"""Process the responses received from the API.

Args:
responses: A list of response strings from the API.
prompt_spec: The PromptSpec object containing the instruction and examples.

Returns:
A tuple containing two lists: inputs and outputs.
- inputs: A list of transformed input strings.
- outputs: A list of transformed output strings.
"""
inputs, outputs = [], []
show_sample_flag = True
ritugala marked this conversation as resolved.
Show resolved Hide resolved
for response in responses:
try:
extraction = find_and_parse_json(response, ["input", "output"], [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
if extraction["input"] is None or extraction["output"] is None:
raise ValueError("Input or output is None")
input = str(extraction["input"]).strip()
output = str(extraction["output"]).strip()
if input in prompt_spec.examples:
raise ValueError("Repeated Task Examples from prompt")

inputs.append(input)
outputs.append(output)
if show_sample_flag:
logger.info(f"inputs\n{input}\n\nouputs\n{output}")
show_sample_flag = False

except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue
logger.error(f"Error extracting from response: {e}")
self.curr_failed_transforms += 1
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
break

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")
return inputs, outputs

return self.make_dataset_from_samples(inputs, outputs)
def transform_data(
self, prompt_spec: PromptSpec, dataset: datasets.Dataset
) -> tuple[list[str], list[str]]:
"""Transforms the given dataset based on the provided prompt specification.

Args:
prompt_spec (PromptSpec): The prompt specification object that defines
the transformation rules.
dataset (datasets.Dataset): The dataset to be transformed.
ritugala marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A tuple containing two lists: inputs and outputs.
"""
task_explanation = self.generate_task_explanation(prompt_spec)
self.plan = self.generate_plan(task_explanation, dataset, prompt_spec)
logger.info(f"Plan created. Plan: {self.plan}")

transform_prompts = self.generate_transform_prompts(
task_explanation, dataset, prompt_spec
)
inputs, outputs = [], []
for batch_indices in range(0, len(transform_prompts), 100):
transform_prompt_batch = transform_prompts[
batch_indices : batch_indices + 100
]
responses = self.generate_responses(transform_prompt_batch)
curr_inputs, curr_outputs = self.process_responses(responses, prompt_spec)
inputs.extend(curr_inputs)
outputs.extend(curr_outputs)
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
ritugala marked this conversation as resolved.
Show resolved Hide resolved
logger.error(
f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}" # noqa: E501
)
self.max_allowed_failed_transforms = 0
break

logger.info(
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501
)
return inputs, outputs
Loading
Loading