Skip to content

Commit

Permalink
Add data transformation capability to dataset retrieval step (#385)
Browse files Browse the repository at this point in the history
* add dataset transformation

* add tests

* make PR revisions

* merge auto into normal demo

* merge reranking and transformation flows

* update test

* verbose line print in demo

* minor grammar changes

---------

Co-authored-by: Graham Neubig <neubig@gmail.com>
  • Loading branch information
saum7800 and neubig authored Jan 15, 2024
1 parent f2eabc1 commit 67673b1
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 29 deletions.
2 changes: 2 additions & 0 deletions prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _split_dataset_into_dataset_dict(
datasets.DatasetDict: A dictionary containing the `train`,
`val`, and `test` datasets.
"""
if "train" in dataset:
dataset = dataset["train"]
num_of_examples = len(dataset)
train_num = int(train_proportion * num_of_examples)
val_num = int(val_proportion * num_of_examples)
Expand Down
110 changes: 83 additions & 27 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations # noqa FI58

import json
import logging
import os
import random
import urllib.request
Expand All @@ -18,13 +17,14 @@
from prompt2model.dataset_retriever.reranking_prompt import (
construct_prompt_for_dataset_reranking,
)
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects
from prompt2model.utils import encode_text, get_formatted_logger, retrieve_objects
from prompt2model.utils.dataset_utils import get_dataset_size
from prompt2model.utils.parse_responses import parse_prompt_to_fields

datasets.utils.logging.disable_progress_bar()
logger = logging.getLogger(__name__)
logger = get_formatted_logger("DescriptionDatasetRetriever")


class DescriptionDatasetRetriever(DatasetRetriever):
Expand Down Expand Up @@ -206,6 +206,7 @@ def get_all_dataset_infos(self, dataset_list: list[str]) -> dict:
The keys are dataset names and the values are dictionaries
with dataset information.
"""
dataset_info_dict = {}
for dataset_name in dataset_list:
if dataset_name not in self.reranking_datasets_infos:
continue
Expand All @@ -219,11 +220,8 @@ def get_all_dataset_infos(self, dataset_list: list[str]) -> dict:
curr_dataset["configs"] = dict(
random.sample(list(curr_dataset["configs"].items()), 5)
)
dataset_info_dict = {
dataset_name: self.reranking_datasets_infos[dataset_name]
for dataset_name in dataset_list
if dataset_name in self.reranking_datasets_infos
}
dataset_info_dict[dataset_name] = curr_dataset

return dataset_info_dict

@staticmethod
Expand Down Expand Up @@ -312,11 +310,12 @@ def canonicalize_dataset_by_cli(
)
self._print_divider()

dataset_info = self.get_all_dataset_infos([dataset_name])[dataset_name]
dataset_info = self.get_all_dataset_infos([dataset_name])
if len(dataset_info.keys()) == 0:
return None
dataset_info = dataset_info[dataset_name]["configs"][chosen_config]
if dataset_info is None:
return None
dataset_info = dataset_info["configs"][chosen_config]
assert dataset_info is not None
try:
input_columns, output_column = self.automatic_column_selection(
prompt_spec.instruction,
Expand Down Expand Up @@ -437,7 +436,11 @@ def rerank_datasets(self, dataset_list: list[str], prompt_spec: PromptSpec):
return dataset_info_dict[dataset_name]["configs"][config_name]

def canonicalize_dataset_automatically(
self, top_dataset_info: dict, task_instruction: str
self,
top_dataset_info: dict,
prompt_spec: PromptSpec,
auto_transform_data: bool = False,
num_points_to_transform: int = 10,
):
"""Automatically canonicalize dataset (instead of cli).
Expand All @@ -446,18 +449,29 @@ def canonicalize_dataset_automatically(
the top dataset information exists. If so, it proceeds to automatically
select the input and output columns based on the task instruction. The
dataset is then loaded, flattened, and renamed according to the columns
mapping. Finally, the dataset is canonicalized using the selected columns.
mapping. If auto_transform_data is true, num_points_to_transform points
from the dataset are transformed by an LLM to desired format according
to the prompt_spec, and transformed dataset is returned. If
auto_transform_data is false, the dataset is canonicalized using the
selected columns.
Args:
top_dataset_info: Contains info about the top-ranked dataset.
task_instruction: A string representing the instruction for the task,
used to guide column selection.
prompt_spec: prompt object storing the original task and examples.
auto_transform_data: Specifies whether a dataset is to be
transformed. Samples from the original dataset will be transformed
by an LLM to match a desired format as specified by prompt_spec.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. ignored
if data_transform is False.
Returns:
The canonicalized dataset, or None if the dataset is invalid or
if column selection fails, or if any other error occurs
during the process.
"""
task_instruction = prompt_spec.instruction
if top_dataset_info is None:
logger.warning("None of the retrieved datasets were relevant.")
return None
Expand All @@ -472,34 +486,76 @@ def canonicalize_dataset_automatically(
except Exception as e:
logger.warning("Column selection failed: ", e)
return None
full_dataset = datasets.load_dataset(
top_dataset_info["dataset_name"], top_dataset_info["config_name"]
).flatten()
full_dataset = full_dataset.rename_columns(top_dataset_info["columns_mapping"])
canonicalized_dataset = self.canonicalize_dataset_using_columns(
full_dataset, input_columns, output_column
logger.info("Column selection completed")
full_dataset = (
datasets.load_dataset(
top_dataset_info["dataset_name"], top_dataset_info["config_name"]
)
.shuffle()
.flatten()
)
logger.info(f"Using dataset {top_dataset_info['dataset_name']}")
full_dataset = full_dataset.rename_columns(top_dataset_info["columns_mapping"])
logger.info("Dataset loaded")

if auto_transform_data:
# remove columns not selected by automatic column selection
full_dataset = full_dataset.remove_columns(
[
col_name
for col_name in full_dataset["train"].column_names
if col_name not in input_columns + [output_column]
]
)
logger.info("Unnecessary columns removed")

return canonicalized_dataset
dataset_transformer = PromptBasedDatasetTransformer()
canonicalized_dataset = dataset_transformer.transform_data(
prompt_spec=prompt_spec,
dataset=full_dataset["train"],
num_points_to_transform=num_points_to_transform,
)
logger.info("Data transformation completed")

example_rows = json.dumps(canonicalized_dataset["train"][0], indent=4)

logger.info(f"Transformed dataset. Example row:\n{example_rows}\n")

return canonicalized_dataset
else:
canonicalized_dataset = self.canonicalize_dataset_using_columns(
full_dataset, input_columns, output_column
)
logger.info(
f"No transformation. Using dataset {top_dataset_info['dataset_name']}"
) # noqa E501
return canonicalized_dataset

def retrieve_dataset_dict(
self,
prompt_spec: PromptSpec,
auto_transform_data: bool = False,
num_points_to_transform: int = 10,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.
Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.
prompt_spec: prompt object storing the original task and examples.
auto_transform_data: Specifies whether a dataset is to be
transformed. Samples from the original dataset will be transformed
by an LLM to match a desired format as specified by prompt_spec.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. ignored
if data_transform is False.
Return:
The most relevant dataset, canonicalized;
or None if there are no relevant datasets.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)

logger.info(f"Top datasets retrieved. Top datasets: {sorted_list}")
top_dataset_info = self.rerank_datasets(sorted_list, prompt_spec)
print("Datasets Reranked. ")
logger.info(f"Rerank completed. Top dataset info: {top_dataset_info}")
return self.canonicalize_dataset_automatically(
top_dataset_info, prompt_spec.instruction
top_dataset_info, prompt_spec, auto_transform_data, num_points_to_transform
)
8 changes: 8 additions & 0 deletions prompt2model/dataset_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Import DatasetGenerator classes."""
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer

__all__ = (
"PromptBasedDatasetTransformer",
"DatasetTransformer",
)
34 changes: 34 additions & 0 deletions prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""An interface for dataset transformation."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod

import datasets

from prompt2model.prompt_parser import PromptSpec


class DatasetTransformer(ABC):
"""A class for transforming a given dataset to a desired format."""

@abstractmethod
def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.Dataset:
"""Transform a split of data.
Args:
prompt_spec: A prompt spec (containing a system description).
dataset: A dataset split.
num_points_to_transform: Number of data points you wish to
transform. Number must be greater than zero. If number is greater
than size of dataset, whole dataset will be transformed. Ignored
if data_transform is False.
Returns:
A single dataset split.
"""
134 changes: 134 additions & 0 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""A simple dataset transformer that uses a plan prompt and transform prompt."""
from __future__ import annotations

import asyncio
from collections.abc import Callable

import datasets

from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
construct_prompt_for_transform_data,
)
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import (
API_ERRORS,
api_tools,
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_responses import make_single_api_request, parse_json

logger = get_formatted_logger("DatasetTransformer")


class PromptBasedDatasetTransformer(DatasetTransformer):
"""Transform data based on a transform prompt."""

def __init__(
self,
plan_prompt_fn: Callable[
[str, list[dict], str], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
] = construct_prompt_for_transform_data,
):
"""Initialize the class."""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""

def make_dataset_from_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Given a list of inputs and outputs, make a dataset.
This function takes in inputs and outputs, both as list of strings,
and returns a DatasetDict object with a single split, "train". It has
two columns, "input_col" and "output_col".
Args:
inputs: A list of inputs, each input is a string.
outputs: A list of outputs, each output is a string.
Returns:
A DatasetDict object with a single split, "train". It has two
columns, "input_col" and "output_col".
"""
if len(inputs) <= 0 or len(inputs) != len(outputs):
raise ValueError("Length of inputs and outputs must be >0 and equal.")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
)
return datasets.DatasetDict(dataset_dict)

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset according to the prompt_spec and dataset."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
dataset,
prompt_spec.examples,
)
self.plan = make_single_api_request(plan_prompt)

logger.info(f"Plan created. Plan: {self.plan}")

inputs = []
outputs = []

max_len = min(num_points_to_transform, len(dataset))
len_count = 0
transform_prompts = []
for row in dataset:
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
row,
prompt_spec.examples,
self.plan,
)
transform_prompts.append(transform_prompt)

len_count += 1
if len_count >= max_len:
break

async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
except API_ERRORS as e:
handle_api_error(e)

for response in responses:
try:
extraction = parse_json(response, ["input", "output"], [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")

return self.make_dataset_from_samples(inputs, outputs)
Loading

0 comments on commit 67673b1

Please sign in to comment.