-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mistral fine-tuning and examples (#395)
* add mistral, incontext examples, and examples * minor error * remove unnecessary files * changes for clarity * change any to dict * add peft to reqs * ignore BLK100 from flake, since Black is called anyways * ignore W503 * remove test issue regarding peft * formatting changes * change QLora to QLoRA
- Loading branch information
Showing
13 changed files
with
940 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
[flake8] | ||
max-line-length = 88 | ||
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18 | ||
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18,BLK100,W503 | ||
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Example to demonstrate how to create synthetic data based on prompt.""" | ||
|
||
import prompt2model.utils.api_tools as api_tools | ||
from prompt2model.dataset_generator.base import DatasetSplit | ||
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator | ||
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType | ||
from prompt2model.utils.api_tools import APIAgent | ||
|
||
if __name__ == "__main__": | ||
# set API keys and create default API agent. | ||
api_tools.default_api_agent = APIAgent( | ||
model_name="gpt-3.5-turbo-16k", max_tokens=8000 | ||
) | ||
|
||
# create prompt based on which transform data will be created | ||
prompt = """ | ||
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. | ||
Here are examples with input questions and context passages, along with their expected outputs: | ||
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." | ||
output="Santa Clara" | ||
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." | ||
output="Vistula River" | ||
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." | ||
output="Europe" | ||
""" # noqa: E501 | ||
# parse the prompt to get the instruction and examples | ||
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) | ||
prompt_spec.parse_from_prompt(prompt) | ||
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") | ||
|
||
# set hyperparams | ||
initial_temperature = 0.4 | ||
max_temperature = 1.4 | ||
num_samples_total = 20 | ||
|
||
# run this pipeline to generate data synthetically based on prompt | ||
unlimited_dataset_generator = PromptBasedDatasetGenerator( | ||
initial_temperature=initial_temperature, | ||
max_temperature=max_temperature, | ||
responses_per_request=3, | ||
) | ||
generated_dataset = unlimited_dataset_generator.generate_dataset_split( | ||
prompt_spec, num_samples_total, split=DatasetSplit.TRAIN | ||
) | ||
|
||
# save the final generated dataset to disk | ||
generated_dataset.save_to_disk("demo_generated_dataset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Example of how to create transform data based on a prompt.""" | ||
|
||
import prompt2model.utils.api_tools as api_tools | ||
from prompt2model.dataset_retriever import DescriptionDatasetRetriever | ||
from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType | ||
from prompt2model.utils.api_tools import APIAgent | ||
|
||
if __name__ == "__main__": | ||
# set API keys and create default API agent. | ||
api_tools.default_api_agent = APIAgent( | ||
model_name="gpt-3.5-turbo-16k", max_tokens=8000 | ||
) | ||
|
||
# create prompt based on which transform data will be created | ||
prompt = """ | ||
Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. | ||
Here are examples with input questions and context passages, along with their expected outputs: | ||
input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." | ||
output="Santa Clara" | ||
input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." | ||
output="Vistula River" | ||
input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." | ||
output="Europe" | ||
""" # noqa: E501 | ||
# parse the prompt to get the instruction and examples | ||
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) | ||
prompt_spec.parse_from_prompt(prompt) | ||
print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") | ||
|
||
# run this pipeline to retrieve relevant datasets, rerank them, | ||
# and transform them based on the prompt | ||
retriever = DescriptionDatasetRetriever() | ||
num_points_to_transform = 20 | ||
retrieved_dataset_dict = retriever.retrieve_dataset_dict( | ||
prompt_spec, | ||
auto_transform_data=True, | ||
num_points_to_transform=num_points_to_transform, | ||
) | ||
|
||
# save the final dataset to disk | ||
if retrieved_dataset_dict is not None: | ||
retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict") |
1 change: 1 addition & 0 deletions
1
examples/huggingface_data/huggingface_datasets/dataset_index.json
Large diffs are not rendered by default.
Oops, something went wrong.
1 change: 1 addition & 0 deletions
1
examples/huggingface_data/huggingface_datasets/reranking_dataset_index.json
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Example of how to fine-tune a model using the QLoRATrainer class.""" | ||
|
||
import os | ||
|
||
from datasets import load_from_disk | ||
|
||
from prompt2model.model_trainer.qlora_trainer import QLoRATrainer | ||
from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets | ||
|
||
if __name__ == "__main__": | ||
# First, we load in the datasets we want to fine-tune on. | ||
retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict") | ||
retrieved_dataset = retrieved_dataset_dict["train"] | ||
generated_dataset = load_from_disk("demo_generated_dataset") | ||
dataset_list = [retrieved_dataset, generated_dataset] | ||
|
||
# Next, we combine datasets and create train and eval splits. | ||
train_dataset = make_combined_datasets(dataset_list) | ||
splits = train_dataset.train_test_split(test_size=0.1) | ||
train_dataset = splits["train"] | ||
eval_dataset = splits["test"] | ||
|
||
# At this point, both train_dataset and eval_dataset are datasets with two | ||
# columns: "input_col" and "output_col". | ||
# We need to format them into a single column, "text", for the QLoRATrainer to use. | ||
formatted_train_dataset = format_train_data(train_dataset) | ||
formatted_eval_dataset = format_train_data(eval_dataset) | ||
|
||
# Next, we define the hyperparameters for the QLoRATrainer. | ||
num_epochs = 1 | ||
qlora_alpha = 8 | ||
qlora_r = 16 | ||
qlora_lr = 1e-5 | ||
save_folder_path = "qlora_finetuned_model" | ||
load_best_model_at_end = False | ||
|
||
# Next, we create a QLoRATrainer object and call the train_model method. | ||
trainer = QLoRATrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512) | ||
|
||
# `formatted_eval_dataset` contains just one column: "text", | ||
# and is used to calculate eval loss, by checking loss for each next token. | ||
# `eval_dataset` contains two columns: "input_col" and "output_col", | ||
# and is used to calculate eval accuracy, by checking whether the generated output | ||
# exactly matches the expected output. | ||
trained_model, trained_tokenizer = trainer.train_model( | ||
formatted_train_dataset, | ||
formatted_eval_dataset, | ||
eval_dataset, | ||
num_epochs=1, | ||
alpha=qlora_alpha, | ||
r=qlora_r, | ||
lr=qlora_lr, | ||
save_folder_path=save_folder_path, | ||
load_best_model_at_end=load_best_model_at_end, | ||
) | ||
|
||
# Finally, we save the trained model and tokenizer to disk. | ||
trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model")) | ||
trained_tokenizer.save_pretrained( | ||
os.path.join(save_folder_path, "demo_final_tokenizer") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.