Skip to content

Commit

Permalink
Add mistral fine-tuning and examples (#395)
Browse files Browse the repository at this point in the history
* add mistral, incontext examples, and examples

* minor error

* remove unnecessary files

* changes for clarity

* change any to dict

* add peft to reqs

* ignore BLK100 from flake, since Black is called anyways

* ignore W503

* remove test issue regarding peft

* formatting changes

* change QLora to QLoRA
  • Loading branch information
saum7800 authored Mar 11, 2024
1 parent 67673b1 commit 25e0a96
Show file tree
Hide file tree
Showing 13 changed files with 940 additions and 69 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[flake8]
max-line-length = 88
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18,BLK100,W503
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501
51 changes: 51 additions & 0 deletions examples/create_synthetic_data_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Example to demonstrate how to create synthetic data based on prompt."""

import prompt2model.utils.api_tools as api_tools
from prompt2model.dataset_generator.base import DatasetSplit
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
from prompt2model.utils.api_tools import APIAgent

if __name__ == "__main__":
# set API keys and create default API agent.
api_tools.default_api_agent = APIAgent(
model_name="gpt-3.5-turbo-16k", max_tokens=8000
)

# create prompt based on which transform data will be created
prompt = """
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
Here are examples with input questions and context passages, along with their expected outputs:
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
output="Santa Clara"
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
output="Vistula River"
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
output="Europe"
""" # noqa: E501
# parse the prompt to get the instruction and examples
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
prompt_spec.parse_from_prompt(prompt)
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")

# set hyperparams
initial_temperature = 0.4
max_temperature = 1.4
num_samples_total = 20

# run this pipeline to generate data synthetically based on prompt
unlimited_dataset_generator = PromptBasedDatasetGenerator(
initial_temperature=initial_temperature,
max_temperature=max_temperature,
responses_per_request=3,
)
generated_dataset = unlimited_dataset_generator.generate_dataset_split(
prompt_spec, num_samples_total, split=DatasetSplit.TRAIN
)

# save the final generated dataset to disk
generated_dataset.save_to_disk("demo_generated_dataset")
46 changes: 46 additions & 0 deletions examples/create_transform_data_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Example of how to create transform data based on a prompt."""

import prompt2model.utils.api_tools as api_tools
from prompt2model.dataset_retriever import DescriptionDatasetRetriever
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
from prompt2model.utils.api_tools import APIAgent

if __name__ == "__main__":
# set API keys and create default API agent.
api_tools.default_api_agent = APIAgent(
model_name="gpt-3.5-turbo-16k", max_tokens=8000
)

# create prompt based on which transform data will be created
prompt = """
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
Here are examples with input questions and context passages, along with their expected outputs:
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
output="Santa Clara"
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
output="Vistula River"
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
output="Europe"
""" # noqa: E501
# parse the prompt to get the instruction and examples
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
prompt_spec.parse_from_prompt(prompt)
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")

# run this pipeline to retrieve relevant datasets, rerank them,
# and transform them based on the prompt
retriever = DescriptionDatasetRetriever()
num_points_to_transform = 20
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
)

# save the final dataset to disk
if retrieved_dataset_dict is not None:
retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict")

Large diffs are not rendered by default.

Large diffs are not rendered by default.

61 changes: 61 additions & 0 deletions examples/mistral_qlora_finetune_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Example of how to fine-tune a model using the QLoRATrainer class."""

import os

from datasets import load_from_disk

from prompt2model.model_trainer.qlora_trainer import QLoRATrainer
from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets

if __name__ == "__main__":
# First, we load in the datasets we want to fine-tune on.
retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict")
retrieved_dataset = retrieved_dataset_dict["train"]
generated_dataset = load_from_disk("demo_generated_dataset")
dataset_list = [retrieved_dataset, generated_dataset]

# Next, we combine datasets and create train and eval splits.
train_dataset = make_combined_datasets(dataset_list)
splits = train_dataset.train_test_split(test_size=0.1)
train_dataset = splits["train"]
eval_dataset = splits["test"]

# At this point, both train_dataset and eval_dataset are datasets with two
# columns: "input_col" and "output_col".
# We need to format them into a single column, "text", for the QLoRATrainer to use.
formatted_train_dataset = format_train_data(train_dataset)
formatted_eval_dataset = format_train_data(eval_dataset)

# Next, we define the hyperparameters for the QLoRATrainer.
num_epochs = 1
qlora_alpha = 8
qlora_r = 16
qlora_lr = 1e-5
save_folder_path = "qlora_finetuned_model"
load_best_model_at_end = False

# Next, we create a QLoRATrainer object and call the train_model method.
trainer = QLoRATrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512)

# `formatted_eval_dataset` contains just one column: "text",
# and is used to calculate eval loss, by checking loss for each next token.
# `eval_dataset` contains two columns: "input_col" and "output_col",
# and is used to calculate eval accuracy, by checking whether the generated output
# exactly matches the expected output.
trained_model, trained_tokenizer = trainer.train_model(
formatted_train_dataset,
formatted_eval_dataset,
eval_dataset,
num_epochs=1,
alpha=qlora_alpha,
r=qlora_r,
lr=qlora_lr,
save_folder_path=save_folder_path,
load_best_model_at_end=load_best_model_at_end,
)

# Finally, we save the trained model and tokenizer to disk.
trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model"))
trained_tokenizer.save_pretrained(
os.path.join(save_folder_path, "demo_final_tokenizer")
)
32 changes: 25 additions & 7 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
get_formatted_logger,
handle_api_error,
)
from prompt2model.utils.parse_responses import make_single_api_request, parse_json
from prompt2model.utils.parse_responses import (
find_and_parse_json,
make_single_api_request,
)

logger = get_formatted_logger("DatasetTransformer")

Expand All @@ -29,13 +32,27 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
def __init__(
self,
plan_prompt_fn: Callable[
[str, list[dict], str], str
[str, str, list[dict], int], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, dict, str, str], str
[str, str, str, dict], str
] = construct_prompt_for_transform_data,
):
"""Initialize the class."""
"""Initialize the class.
Args:
plan_prompt_fn: A function that takes in a description of the target task,
example of the target task,
list of dictionaries where each dictionary is a row from a potentially
relevant dataset,
and the number of rows to use from this potentially relevant dataset,
and returns a plan prompt.
transform_prompt_fn: A function that takes in a description of the target
task, an example of the target task,
plan for dataset transformation,
and the row from a potentially relevant dataset to be transformed.
"""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""
Expand Down Expand Up @@ -78,8 +95,9 @@ def transform_data(
"""Transform the dataset according to the prompt_spec and dataset."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
dataset,
prompt_spec.examples,
dataset,
min(5, len(dataset)),
)
self.plan = make_single_api_request(plan_prompt)

Expand All @@ -94,9 +112,9 @@ def transform_data(
for row in dataset:
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
row,
prompt_spec.examples,
self.plan,
row,
)
transform_prompts.append(transform_prompt)

Expand All @@ -121,7 +139,7 @@ async def generate_responses(transform_prompts):

for response in responses:
try:
extraction = parse_json(response, ["input", "output"], [])
extraction = find_and_parse_json(response, ["input", "output"], [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
Expand Down
Loading

0 comments on commit 25e0a96

Please sign in to comment.