Skip to content

Commit 25e0a96

Browse files
authored
Add mistral fine-tuning and examples (#395)
* add mistral, incontext examples, and examples * minor error * remove unnecessary files * changes for clarity * change any to dict * add peft to reqs * ignore BLK100 from flake, since Black is called anyways * ignore W503 * remove test issue regarding peft * formatting changes * change QLora to QLoRA
1 parent 67673b1 commit 25e0a96

File tree

13 files changed

+940
-69
lines changed

13 files changed

+940
-69
lines changed

.flake8

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[flake8]
22
max-line-length = 88
3-
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18
3+
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18,BLK100,W503
4+
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Example to demonstrate how to create synthetic data based on prompt."""
2+
3+
import prompt2model.utils.api_tools as api_tools
4+
from prompt2model.dataset_generator.base import DatasetSplit
5+
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
6+
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
7+
from prompt2model.utils.api_tools import APIAgent
8+
9+
if __name__ == "__main__":
10+
# set API keys and create default API agent.
11+
api_tools.default_api_agent = APIAgent(
12+
model_name="gpt-3.5-turbo-16k", max_tokens=8000
13+
)
14+
15+
# create prompt based on which transform data will be created
16+
prompt = """
17+
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
18+
19+
Here are examples with input questions and context passages, along with their expected outputs:
20+
21+
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
22+
output="Santa Clara"
23+
24+
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
25+
output="Vistula River"
26+
27+
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
28+
output="Europe"
29+
""" # noqa: E501
30+
# parse the prompt to get the instruction and examples
31+
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
32+
prompt_spec.parse_from_prompt(prompt)
33+
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")
34+
35+
# set hyperparams
36+
initial_temperature = 0.4
37+
max_temperature = 1.4
38+
num_samples_total = 20
39+
40+
# run this pipeline to generate data synthetically based on prompt
41+
unlimited_dataset_generator = PromptBasedDatasetGenerator(
42+
initial_temperature=initial_temperature,
43+
max_temperature=max_temperature,
44+
responses_per_request=3,
45+
)
46+
generated_dataset = unlimited_dataset_generator.generate_dataset_split(
47+
prompt_spec, num_samples_total, split=DatasetSplit.TRAIN
48+
)
49+
50+
# save the final generated dataset to disk
51+
generated_dataset.save_to_disk("demo_generated_dataset")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Example of how to create transform data based on a prompt."""
2+
3+
import prompt2model.utils.api_tools as api_tools
4+
from prompt2model.dataset_retriever import DescriptionDatasetRetriever
5+
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
6+
from prompt2model.utils.api_tools import APIAgent
7+
8+
if __name__ == "__main__":
9+
# set API keys and create default API agent.
10+
api_tools.default_api_agent = APIAgent(
11+
model_name="gpt-3.5-turbo-16k", max_tokens=8000
12+
)
13+
14+
# create prompt based on which transform data will be created
15+
prompt = """
16+
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
17+
18+
Here are examples with input questions and context passages, along with their expected outputs:
19+
20+
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
21+
output="Santa Clara"
22+
23+
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
24+
output="Vistula River"
25+
26+
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
27+
output="Europe"
28+
""" # noqa: E501
29+
# parse the prompt to get the instruction and examples
30+
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
31+
prompt_spec.parse_from_prompt(prompt)
32+
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")
33+
34+
# run this pipeline to retrieve relevant datasets, rerank them,
35+
# and transform them based on the prompt
36+
retriever = DescriptionDatasetRetriever()
37+
num_points_to_transform = 20
38+
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
39+
prompt_spec,
40+
auto_transform_data=True,
41+
num_points_to_transform=num_points_to_transform,
42+
)
43+
44+
# save the final dataset to disk
45+
if retrieved_dataset_dict is not None:
46+
retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict")

examples/huggingface_data/huggingface_datasets/dataset_index.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

examples/huggingface_data/huggingface_datasets/reranking_dataset_index.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Example of how to fine-tune a model using the QLoRATrainer class."""
2+
3+
import os
4+
5+
from datasets import load_from_disk
6+
7+
from prompt2model.model_trainer.qlora_trainer import QLoRATrainer
8+
from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets
9+
10+
if __name__ == "__main__":
11+
# First, we load in the datasets we want to fine-tune on.
12+
retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict")
13+
retrieved_dataset = retrieved_dataset_dict["train"]
14+
generated_dataset = load_from_disk("demo_generated_dataset")
15+
dataset_list = [retrieved_dataset, generated_dataset]
16+
17+
# Next, we combine datasets and create train and eval splits.
18+
train_dataset = make_combined_datasets(dataset_list)
19+
splits = train_dataset.train_test_split(test_size=0.1)
20+
train_dataset = splits["train"]
21+
eval_dataset = splits["test"]
22+
23+
# At this point, both train_dataset and eval_dataset are datasets with two
24+
# columns: "input_col" and "output_col".
25+
# We need to format them into a single column, "text", for the QLoRATrainer to use.
26+
formatted_train_dataset = format_train_data(train_dataset)
27+
formatted_eval_dataset = format_train_data(eval_dataset)
28+
29+
# Next, we define the hyperparameters for the QLoRATrainer.
30+
num_epochs = 1
31+
qlora_alpha = 8
32+
qlora_r = 16
33+
qlora_lr = 1e-5
34+
save_folder_path = "qlora_finetuned_model"
35+
load_best_model_at_end = False
36+
37+
# Next, we create a QLoRATrainer object and call the train_model method.
38+
trainer = QLoRATrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512)
39+
40+
# `formatted_eval_dataset` contains just one column: "text",
41+
# and is used to calculate eval loss, by checking loss for each next token.
42+
# `eval_dataset` contains two columns: "input_col" and "output_col",
43+
# and is used to calculate eval accuracy, by checking whether the generated output
44+
# exactly matches the expected output.
45+
trained_model, trained_tokenizer = trainer.train_model(
46+
formatted_train_dataset,
47+
formatted_eval_dataset,
48+
eval_dataset,
49+
num_epochs=1,
50+
alpha=qlora_alpha,
51+
r=qlora_r,
52+
lr=qlora_lr,
53+
save_folder_path=save_folder_path,
54+
load_best_model_at_end=load_best_model_at_end,
55+
)
56+
57+
# Finally, we save the trained model and tokenizer to disk.
58+
trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model"))
59+
trained_tokenizer.save_pretrained(
60+
os.path.join(save_folder_path, "demo_final_tokenizer")
61+
)

prompt2model/dataset_transformer/prompt_based.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
get_formatted_logger,
1919
handle_api_error,
2020
)
21-
from prompt2model.utils.parse_responses import make_single_api_request, parse_json
21+
from prompt2model.utils.parse_responses import (
22+
find_and_parse_json,
23+
make_single_api_request,
24+
)
2225

2326
logger = get_formatted_logger("DatasetTransformer")
2427

@@ -29,13 +32,27 @@ class PromptBasedDatasetTransformer(DatasetTransformer):
2932
def __init__(
3033
self,
3134
plan_prompt_fn: Callable[
32-
[str, list[dict], str], str
35+
[str, str, list[dict], int], str
3336
] = construct_prompt_for_plan,
3437
transform_prompt_fn: Callable[
35-
[str, dict, str, str], str
38+
[str, str, str, dict], str
3639
] = construct_prompt_for_transform_data,
3740
):
38-
"""Initialize the class."""
41+
"""Initialize the class.
42+
43+
Args:
44+
plan_prompt_fn: A function that takes in a description of the target task,
45+
example of the target task,
46+
list of dictionaries where each dictionary is a row from a potentially
47+
relevant dataset,
48+
and the number of rows to use from this potentially relevant dataset,
49+
and returns a plan prompt.
50+
51+
transform_prompt_fn: A function that takes in a description of the target
52+
task, an example of the target task,
53+
plan for dataset transformation,
54+
and the row from a potentially relevant dataset to be transformed.
55+
"""
3956
self.plan_prompt_fn = plan_prompt_fn
4057
self.transform_prompt_fn = transform_prompt_fn
4158
self.plan: str = ""
@@ -78,8 +95,9 @@ def transform_data(
7895
"""Transform the dataset according to the prompt_spec and dataset."""
7996
plan_prompt = self.plan_prompt_fn(
8097
prompt_spec.instruction,
81-
dataset,
8298
prompt_spec.examples,
99+
dataset,
100+
min(5, len(dataset)),
83101
)
84102
self.plan = make_single_api_request(plan_prompt)
85103

@@ -94,9 +112,9 @@ def transform_data(
94112
for row in dataset:
95113
transform_prompt = self.transform_prompt_fn(
96114
prompt_spec.instruction,
97-
row,
98115
prompt_spec.examples,
99116
self.plan,
117+
row,
100118
)
101119
transform_prompts.append(transform_prompt)
102120

@@ -121,7 +139,7 @@ async def generate_responses(transform_prompts):
121139

122140
for response in responses:
123141
try:
124-
extraction = parse_json(response, ["input", "output"], [])
142+
extraction = find_and_parse_json(response, ["input", "output"], [])
125143
if extraction is not None:
126144
inputs.append(str(extraction["input"]))
127145
outputs.append(str(extraction["output"]))

0 commit comments

Comments
 (0)