From 405b56a8b451e0cf298a76b4b60e93623ec9686d Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 15 Mar 2024 22:45:56 +0000 Subject: [PATCH 1/8] mask, offset_mapping and filtering --- _scratch/mask.py | 72 +++++++++++++++++++ .../transformers/finetune/data/base.py | 19 ++++- 2 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 _scratch/mask.py diff --git a/_scratch/mask.py b/_scratch/mask.py new file mode 100644 index 00000000000..8d2cc6a1b69 --- /dev/null +++ b/_scratch/mask.py @@ -0,0 +1,72 @@ +data = { + "instruction":"Improve the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.", + "output":"The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments.", + "prompt":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:", + "text":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments." +} + +prompt = "[fooo]:\nname me some cold blooded animals\n\n[Response]:\n" +bs = "jdlsaj;fdsa dls;ajfd;a dfjdaklfd jsak; fsajfiajeiga sdkga idgdhangia n" + +prompts = prompt + bs + prompt + bs + + +def mask_prompt(target: str): + if not target.startswith("["): + target = "[" + target + if not target.endswith("]"): + target += "]" + + +def generate_mask(string:str, prompt: str, censor: str): + mask = ['1'] * len(string) + is_prompt = True + counter = 0 + for i, char in enumerate(string): + if not is_prompt: + mask[i] = "0" + + if counter > 0: + if not is_prompt and char == prompt[counter]: + counter += 1 + elif is_prompt and char == censor[counter]: + counter += 1 + else: + counter = 0 + + if counter == len(prompt) and not is_prompt: + mask[i-counter+1:i+1] = ["1"] * counter + counter = 0 + is_prompt = True + + if counter == len(censor) and is_prompt: + mask[i-counter+1:i+1] = ["0"] * counter + counter = 0 + is_prompt = False + + if prompt.startswith(char) or censor.startswith(char): + counter = 1 + + return ''.join(mask) + +def filter_string(mask:str, string:str): + filtered_list = [] + current_chunk = '' + for char, m in zip(string, mask): + if m == '1': + current_chunk += char + elif current_chunk: + filtered_list.append(current_chunk) + current_chunk = '' + if current_chunk: + filtered_list.append(current_chunk) + return filtered_list + +# Example usage: +string = "[prompt] hey there\n\n[censor]hello back" +string += string +prompt = "[prompt]" +censor = "[censor]" +mask = generate_mask(string, prompt, censor) +filtered_list = filter_string(mask, string) + diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 6f34bc352d5..322949af66f 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional, Union +from typing import Optional, Union, Tuple from datasets import Dataset, IterableDataset from transformers import AutoTokenizer @@ -125,6 +125,7 @@ def tokenize_fn(data): padding=self.padding, max_length=self.max_seq_length, truncation=True, + return_offsets_mapping=True, ) # store unpadded prompt so we can mask out correct number of elements @@ -156,6 +157,18 @@ def group_text_fn(data): def label_fn(data): # if the dataset uses prompts, mask them out so they don't contribute # to the loss calculation + + if "offset_mapping" in data: + offset_mapping = data["offset_mapping"] + input_ids = data["input_ids"] + # get the character level mask + mask = ""# str of 01 + mask = "1" * len(input_ids) + for i, (start, end) in enumerate(offset_mapping): + # if any char is to be filtered + if '0' in mask[start:end]: + input_ids[i] = LABELS_MASK_VALUE + prompt_len = 0 if self.PROMPT_KEY in data: prompt_len = len(data[self.PROMPT_KEY]) @@ -166,6 +179,7 @@ def label_fn(data): padding = len(data["attention_mask"]) - sum(data["attention_mask"]) if padding > 0: data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding + return data dataset = self.map( @@ -195,6 +209,7 @@ def label_fn(data): column_names = dataset.column_names if isinstance(column_names, dict): column_names = column_names[list(column_names)[0]] + breakpoint() dataset = self.map( dataset, function=label_fn, @@ -206,7 +221,6 @@ def label_fn(data): load_from_cache_file=not self.data_args.overwrite_cache, desc="Adding labels", ) - print(dataset.column_names) return dataset @@ -226,5 +240,4 @@ def map( kwargs.pop("num_proc", None) kwargs.pop("load_from_cache_file", None) kwargs.pop("desc", None) - return dataset.map(**kwargs) From 7ffc0b51f35a3f1403a87ccd40261281702cffeb Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 18 Mar 2024 22:43:45 +0000 Subject: [PATCH 2/8] bug cuda bug on applying mask --- .../transformers/finetune/data/base.py | 21 ++++----- .../transformers/finetune/data/custom.py | 7 ++- src/sparseml/transformers/utils/helpers.py | 47 +++++++++++++++++++ .../utils/preprocessing_functions.py | 4 ++ .../transformers/utils/test_helpers.py | 27 +++++++++++ 5 files changed, 93 insertions(+), 13 deletions(-) diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 322949af66f..49d244f5c44 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union from datasets import Dataset, IterableDataset from transformers import AutoTokenizer @@ -161,14 +161,15 @@ def label_fn(data): if "offset_mapping" in data: offset_mapping = data["offset_mapping"] input_ids = data["input_ids"] - # get the character level mask - mask = ""# str of 01 + # get the character level mask mask = "1" * len(input_ids) - for i, (start, end) in enumerate(offset_mapping): - # if any char is to be filtered - if '0' in mask[start:end]: - input_ids[i] = LABELS_MASK_VALUE - + mask = data.get('mask') + if mask is not None: + for i, (start, end) in enumerate(offset_mapping): + # if any char is to be filtered + if "0" in mask[start:end]: + input_ids[i] = LABELS_MASK_VALUE + prompt_len = 0 if self.PROMPT_KEY in data: prompt_len = len(data[self.PROMPT_KEY]) @@ -179,7 +180,7 @@ def label_fn(data): padding = len(data["attention_mask"]) - sum(data["attention_mask"]) if padding > 0: data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding - + return data dataset = self.map( @@ -209,7 +210,6 @@ def label_fn(data): column_names = dataset.column_names if isinstance(column_names, dict): column_names = column_names[list(column_names)[0]] - breakpoint() dataset = self.map( dataset, function=label_fn, @@ -221,7 +221,6 @@ def label_fn(data): load_from_cache_file=not self.data_args.overwrite_cache, desc="Adding labels", ) - return dataset def map( diff --git a/src/sparseml/transformers/finetune/data/custom.py b/src/sparseml/transformers/finetune/data/custom.py index f1bdcb6085f..c86b4ed5057 100644 --- a/src/sparseml/transformers/finetune/data/custom.py +++ b/src/sparseml/transformers/finetune/data/custom.py @@ -88,10 +88,11 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: raw_dataset, batched=True, remove_columns=self.remove_columns, - num_proc=self.data_args.preprocessing_num_workers, + # num_proc=self.data_args.preprocessing_num_workers, + num_proc=1, + desc="Removing unneeded columns", ) - return raw_dataset def get_remove_columns_from_dataset( @@ -108,5 +109,7 @@ def get_remove_columns_from_dataset( remove_columns.remove(self.text_column) if self.PROMPT_KEY in remove_columns: remove_columns.remove(self.PROMPT_KEY) + if "mask" in remove_columns: + remove_columns.remove("mask") return list(remove_columns) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index d797afbacc0..554ca55466c 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -52,6 +52,7 @@ "ALL_TASK_NAMES", "create_fake_dataloader", "POSSIBLE_TOKENIZER_FILES", + "generate_mask", ] @@ -554,3 +555,49 @@ def fetch_recipe_path(target: str): recipe_path = hf_hub_download(repo_id=target, filename=DEFAULT_RECIPE_NAME) return recipe_path + + +def generate_mask(string: str, prompt: str, censor: str) -> str: + """ + Generate a mask based on provided prompt and censor strings to obscure + characters in the input string. + + Args: + :param string: The input string to be masked. + :param prompt: The prompt string to identify characters to keep visible. + :param censor: The censor string to identify characters to obscure. + + Returns: + str: A string representing the mask where '1' indicates visible + characters and '0' indicates obscured characters. + + """ + mask = ["1"] * len(string) + is_prompt = True + counter = 0 + for i, char in enumerate(string): + if not is_prompt: + mask[i] = "0" + + if counter > 0: + if not is_prompt and char == prompt[counter]: + counter += 1 + elif is_prompt and char == censor[counter]: + counter += 1 + else: + counter = 0 + + if counter == len(prompt) and not is_prompt: + mask[i - counter + 1 : i + 1] = ["1"] * counter + counter = 0 + is_prompt = True + + if counter == len(censor) and is_prompt: + mask[i - counter + 1 : i + 1] = ["0"] * counter + counter = 0 + is_prompt = False + + if prompt.startswith(char) or censor.startswith(char): + counter = 1 + + return "".join(mask) diff --git a/src/sparseml/transformers/utils/preprocessing_functions.py b/src/sparseml/transformers/utils/preprocessing_functions.py index 8b019094377..b26bc3f587c 100644 --- a/src/sparseml/transformers/utils/preprocessing_functions.py +++ b/src/sparseml/transformers/utils/preprocessing_functions.py @@ -14,6 +14,7 @@ from typing import Dict +from sparseml.transformers.utils.helpers import generate_mask from sparsezoo.utils.registry import RegistryMixin @@ -26,4 +27,7 @@ def custom_evolved_codealpaca_dataset(data: Dict): PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:""" data["prompt"] = PROMPT_DICT.format_map(data) data["text"] = data["prompt"] + data["output"] + data["mask"] = generate_mask( + data["text"], prompt="[Instruction]", censor="[Response]" + ) return data diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 5e0ab5e93da..7ab256fba39 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -21,6 +21,7 @@ from accelerate import init_empty_weights from sparseml.transformers.utils.helpers import ( create_fake_dataloader, + generate_mask, infer_recipe_from_model_path, is_transformer_model, resolve_recipe_file, @@ -166,3 +167,29 @@ def test_save_zoo_directory(tmp_path, stub): assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False) shutil.rmtree(path_to_training_outputs) shutil.rmtree(save_dir) + + +@pytest.mark.parametrize( + "string, prompt, censor, expected_mask", + [ + ("[foo]hello\n\n[bar]world", "[foo]", "[bar]", "1111111111110000000000"), + ( + ( + "[Instruction]python is\n\n" # 24 + "[Response]great\n\n" # 17 + "[Instruction]What about Java" # 28 + "[Response]Meh" # 13 + ), + "[Instruction]", + "[Response]", + ( + "111111111111111111111111" # 24 + "00000000000000000" # 17 + "1111111111111111111111111111" # 28 + "0000000000000" # 13 + ), + ), + ], +) +def test_generate_mask(string, prompt, censor, expected_mask): + assert generate_mask(string, prompt, censor) == expected_mask From 0a81315cd0136421b0ad047b9cf4bb96f527d1b6 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 20 Mar 2024 13:56:28 +0000 Subject: [PATCH 3/8] fix cuda bug --- src/sparseml/transformers/finetune/data/base.py | 13 ++++++------- src/sparseml/transformers/finetune/data/custom.py | 4 +--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 49d244f5c44..935a0c5f78b 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional, Tuple, Union +from typing import Optional, Union from datasets import Dataset, IterableDataset from transformers import AutoTokenizer @@ -157,23 +157,22 @@ def group_text_fn(data): def label_fn(data): # if the dataset uses prompts, mask them out so they don't contribute # to the loss calculation - + labels = data["input_ids"].copy() if "offset_mapping" in data: offset_mapping = data["offset_mapping"] - input_ids = data["input_ids"] # get the character level mask - mask = "1" * len(input_ids) - mask = data.get('mask') + mask = data.get("mask") if mask is not None: for i, (start, end) in enumerate(offset_mapping): # if any char is to be filtered if "0" in mask[start:end]: - input_ids[i] = LABELS_MASK_VALUE + labels[i] = LABELS_MASK_VALUE prompt_len = 0 if self.PROMPT_KEY in data: prompt_len = len(data[self.PROMPT_KEY]) - data["labels"] = data["input_ids"].copy() + + data["labels"] = labels data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len # mask out padding in the labels as well diff --git a/src/sparseml/transformers/finetune/data/custom.py b/src/sparseml/transformers/finetune/data/custom.py index c86b4ed5057..8af2fd674f6 100644 --- a/src/sparseml/transformers/finetune/data/custom.py +++ b/src/sparseml/transformers/finetune/data/custom.py @@ -88,9 +88,7 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]: raw_dataset, batched=True, remove_columns=self.remove_columns, - # num_proc=self.data_args.preprocessing_num_workers, - num_proc=1, - + num_proc=self.data_args.preprocessing_num_workers, desc="Removing unneeded columns", ) return raw_dataset From cd6634095d98069af8d84833c0bbc92ad91c737f Mon Sep 17 00:00:00 2001 From: George Date: Wed, 20 Mar 2024 09:56:54 -0400 Subject: [PATCH 4/8] Delete _scratch/mask.py --- _scratch/mask.py | 72 ------------------------------------------------ 1 file changed, 72 deletions(-) delete mode 100644 _scratch/mask.py diff --git a/_scratch/mask.py b/_scratch/mask.py deleted file mode 100644 index 8d2cc6a1b69..00000000000 --- a/_scratch/mask.py +++ /dev/null @@ -1,72 +0,0 @@ -data = { - "instruction":"Improve the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.", - "output":"The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments.", - "prompt":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:", - "text":"[Instruction]:\nImprove the following piece of code that throws multiple errors. You should also write a function test to make sure your code runs correctly.\n\nInitial code:\ndef my_function(x, y, z):\n a = x + 5y - z\n print x, y, z, a\n return a, b, c, d\n\nYour task is not only to resolve the errors but also to improve the code, making it efficient and well-structured. End your task by writing a function test to ensure your code is running as expected. \n\nNote: The context of this coding exercise is language-agnostic. Therefore, be aware of potential differences in syntax and functionalities among different programming languages, and adjust your answer with respect to the programming language of your choice.\n\n[Response]:The initial code contains some errors which include:\n\n1. The syntax for the arithmetic operation is wrong, it should be 5*y instead of 5y.\n2. Variables \\'b\\',\\'c\\', and \\'d\\' are used without being declared.\n3. Using commas in print function may cause errors in some programming languages.\n\nBelow is the corrected code in Python:\n\n```python\ndef my_function(x, y, z):\n a = x + 5*y - z\n print (x, y, z, a)\n return a\n```\n\nThe above code will calculate the expression x + 5*y - z and return the result as \\'a\\'. Then it will print the values of x, y, z, and a.\n\nThe function test could look something like this:\n\n```python\ndef test_function():\n result = my_function(1, 2, 3)\n assert(result == -2), \"The values seem to be incorrect.\"\n print(\"my_function passed.\")\n```\n\nAbove code will test my_function using the inputs (1, 2, 3). If the output is not -2, it will print \"The values seem to be incorrect.\" If the output is correct, it will print \"my_function passed.\"\n\nSo, when we run the test function:\n\n```python\ntest_function()\n```\n\nIt will test our my_function to see if it\\'s working as expected. The \"assert\" statement is used to let the program test if the condition is true. If the condition is false, the program throws an AssertionError exception.\n\nNote: The above codes are in Python, if these are converted in other language, kindly do the necessary syntax adjustments." -} - -prompt = "[fooo]:\nname me some cold blooded animals\n\n[Response]:\n" -bs = "jdlsaj;fdsa dls;ajfd;a dfjdaklfd jsak; fsajfiajeiga sdkga idgdhangia n" - -prompts = prompt + bs + prompt + bs - - -def mask_prompt(target: str): - if not target.startswith("["): - target = "[" + target - if not target.endswith("]"): - target += "]" - - -def generate_mask(string:str, prompt: str, censor: str): - mask = ['1'] * len(string) - is_prompt = True - counter = 0 - for i, char in enumerate(string): - if not is_prompt: - mask[i] = "0" - - if counter > 0: - if not is_prompt and char == prompt[counter]: - counter += 1 - elif is_prompt and char == censor[counter]: - counter += 1 - else: - counter = 0 - - if counter == len(prompt) and not is_prompt: - mask[i-counter+1:i+1] = ["1"] * counter - counter = 0 - is_prompt = True - - if counter == len(censor) and is_prompt: - mask[i-counter+1:i+1] = ["0"] * counter - counter = 0 - is_prompt = False - - if prompt.startswith(char) or censor.startswith(char): - counter = 1 - - return ''.join(mask) - -def filter_string(mask:str, string:str): - filtered_list = [] - current_chunk = '' - for char, m in zip(string, mask): - if m == '1': - current_chunk += char - elif current_chunk: - filtered_list.append(current_chunk) - current_chunk = '' - if current_chunk: - filtered_list.append(current_chunk) - return filtered_list - -# Example usage: -string = "[prompt] hey there\n\n[censor]hello back" -string += string -prompt = "[prompt]" -censor = "[censor]" -mask = generate_mask(string, prompt, censor) -filtered_list = filter_string(mask, string) - From 5b7626741fd93ac7701a485436c92ce217f2ae41 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 27 Mar 2024 17:56:25 +0000 Subject: [PATCH 5/8] fix non prompt case --- .../transformers/finetune/data/base.py | 1 + .../transformers/finetune/data/custom.py | 4 +-- src/sparseml/transformers/utils/helpers.py | 25 +++++++++++-------- .../transformers/utils/test_helpers.py | 22 ++++++++++++---- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/sparseml/transformers/finetune/data/base.py b/src/sparseml/transformers/finetune/data/base.py index 935a0c5f78b..4a4f386ba4c 100644 --- a/src/sparseml/transformers/finetune/data/base.py +++ b/src/sparseml/transformers/finetune/data/base.py @@ -41,6 +41,7 @@ class TextGenerationDataset(RegistryMixin): """ PROMPT_KEY = "prompt" + MASK_KEY = "mask" def __init__( self, diff --git a/src/sparseml/transformers/finetune/data/custom.py b/src/sparseml/transformers/finetune/data/custom.py index 8af2fd674f6..55586e7d562 100644 --- a/src/sparseml/transformers/finetune/data/custom.py +++ b/src/sparseml/transformers/finetune/data/custom.py @@ -107,7 +107,7 @@ def get_remove_columns_from_dataset( remove_columns.remove(self.text_column) if self.PROMPT_KEY in remove_columns: remove_columns.remove(self.PROMPT_KEY) - if "mask" in remove_columns: - remove_columns.remove("mask") + if self.MASK_KEY in remove_columns: + remove_columns.remove(self.MASK_KEY) return list(remove_columns) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index 554ca55466c..afb1944293b 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -557,15 +557,15 @@ def fetch_recipe_path(target: str): return recipe_path -def generate_mask(string: str, prompt: str, censor: str) -> str: +def generate_mask(string: str, prompt: str, respomse: str = "") -> str: """ - Generate a mask based on provided prompt and censor strings to obscure + Generate a mask based on provided prompt and respomse strings to obscure characters in the input string. Args: :param string: The input string to be masked. - :param prompt: The prompt string to identify characters to keep visible. - :param censor: The censor string to identify characters to obscure. + :param prompt: The prompt string to identify characters to obscure. + :param respomse: The respomse string to identify characters to keep visible. Returns: str: A string representing the mask where '1' indicates visible @@ -576,28 +576,33 @@ def generate_mask(string: str, prompt: str, censor: str) -> str: is_prompt = True counter = 0 for i, char in enumerate(string): - if not is_prompt: + # if not is_prompt: + if is_prompt: mask[i] = "0" if counter > 0: if not is_prompt and char == prompt[counter]: counter += 1 - elif is_prompt and char == censor[counter]: + elif is_prompt and char == respomse[counter]: counter += 1 else: counter = 0 if counter == len(prompt) and not is_prompt: - mask[i - counter + 1 : i + 1] = ["1"] * counter + # mask[i - counter + 1 : i + 1] = ["1"] * counter + mask[i - counter + 1 : i + 1] = ["0"] * counter + counter = 0 is_prompt = True - if counter == len(censor) and is_prompt: - mask[i - counter + 1 : i + 1] = ["0"] * counter + if counter == len(respomse) and is_prompt: + # mask[i - counter + 1 : i + 1] = ["0"] * counter + mask[i - counter + 1 : i + 1] = ["1"] * counter + counter = 0 is_prompt = False - if prompt.startswith(char) or censor.startswith(char): + if prompt.startswith(char) or respomse.startswith(char): counter = 1 return "".join(mask) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 7ab256fba39..94f7c0205e5 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -172,7 +172,12 @@ def test_save_zoo_directory(tmp_path, stub): @pytest.mark.parametrize( "string, prompt, censor, expected_mask", [ - ("[foo]hello\n\n[bar]world", "[foo]", "[bar]", "1111111111110000000000"), + ( + ("[foo]hello\n\n" "[bar]world"), + "[foo]", + "[bar]", + ("000000000000" "1111111111"), + ), ( ( "[Instruction]python is\n\n" # 24 @@ -183,12 +188,19 @@ def test_save_zoo_directory(tmp_path, stub): "[Instruction]", "[Response]", ( - "111111111111111111111111" # 24 - "00000000000000000" # 17 - "1111111111111111111111111111" # 28 - "0000000000000" # 13 + "000000000000000000000000" # 24 + "11111111111111111" # 17 + "0000000000000000000000000000" # 28 + "1111111111111" # 13 ), ), + ( + ("hello\n\n" "[bar]world"), + "[foo]", + "[bar]", + ("000000000" "1111111111"), + ( 1111111111) + ), ], ) def test_generate_mask(string, prompt, censor, expected_mask): From 2ca4be7f2d5d91a42aee301af09ebea8792f51cf Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 1 Apr 2024 17:44:43 +0000 Subject: [PATCH 6/8] fix bug for no prompt --- src/sparseml/transformers/utils/helpers.py | 24 ++++++++++--------- .../transformers/utils/test_helpers.py | 17 +++++++------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index afb1944293b..f8f4b67fad0 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -557,52 +557,54 @@ def fetch_recipe_path(target: str): return recipe_path -def generate_mask(string: str, prompt: str, respomse: str = "") -> str: +def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> str: """ - Generate a mask based on provided prompt and respomse strings to obscure - characters in the input string. + Generate a mask based on provided prompt and response strings to obscure + characters in the input string. Prompt will be masked and string in response + will be kept represented by 0 - remove and 1 - keep. + By default, non-reponse wrapped strings will be matched with 0 Args: :param string: The input string to be masked. :param prompt: The prompt string to identify characters to obscure. - :param respomse: The respomse string to identify characters to keep visible. + :param response: The response string to identify characters to keep visible. Returns: str: A string representing the mask where '1' indicates visible characters and '0' indicates obscured characters. """ + if prompt is None: + prompt = "" + mask = ["1"] * len(string) is_prompt = True counter = 0 for i, char in enumerate(string): - # if not is_prompt: if is_prompt: mask[i] = "0" if counter > 0: if not is_prompt and char == prompt[counter]: counter += 1 - elif is_prompt and char == respomse[counter]: + elif is_prompt and char == response[counter]: counter += 1 else: counter = 0 - if counter == len(prompt) and not is_prompt: - # mask[i - counter + 1 : i + 1] = ["1"] * counter + if len(prompt) > 0 and counter == len(prompt) and not is_prompt: mask[i - counter + 1 : i + 1] = ["0"] * counter counter = 0 is_prompt = True - if counter == len(respomse) and is_prompt: - # mask[i - counter + 1 : i + 1] = ["0"] * counter + if counter == len(response) and is_prompt: mask[i - counter + 1 : i + 1] = ["1"] * counter counter = 0 is_prompt = False - if prompt.startswith(char) or respomse.startswith(char): + if prompt.startswith(char) or response.startswith(char): counter = 1 return "".join(mask) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 94f7c0205e5..980fc9f2dbb 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -170,12 +170,12 @@ def test_save_zoo_directory(tmp_path, stub): @pytest.mark.parametrize( - "string, prompt, censor, expected_mask", + "string, response, prompt, expected_mask", [ ( ("[foo]hello\n\n" "[bar]world"), - "[foo]", "[bar]", + "[foo]", ("000000000000" "1111111111"), ), ( @@ -185,8 +185,8 @@ def test_save_zoo_directory(tmp_path, stub): "[Instruction]What about Java" # 28 "[Response]Meh" # 13 ), - "[Instruction]", "[Response]", + "[Instruction]", ( "000000000000000000000000" # 24 "11111111111111111" # 17 @@ -195,13 +195,12 @@ def test_save_zoo_directory(tmp_path, stub): ), ), ( - ("hello\n\n" "[bar]world"), - "[foo]", + ("[foo]hello\n\n" "[bar]world"), "[bar]", - ("000000000" "1111111111"), - ( 1111111111) + None, + ("000000000000" "1111111111"), ), ], ) -def test_generate_mask(string, prompt, censor, expected_mask): - assert generate_mask(string, prompt, censor) == expected_mask +def test_generate_mask(string, response, prompt, expected_mask): + assert generate_mask(string, response, prompt) == expected_mask From d819f8d890dbdc2e1cf5eef35a49c6d045a6c43b Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 1 Apr 2024 17:50:56 +0000 Subject: [PATCH 7/8] test case with no starting prompt tag --- src/sparseml/transformers/utils/helpers.py | 1 - tests/sparseml/transformers/utils/test_helpers.py | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index f8f4b67fad0..4177d6b09a5 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -606,5 +606,4 @@ def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> s if prompt.startswith(char) or response.startswith(char): counter = 1 - return "".join(mask) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 980fc9f2dbb..0d77a6ee7d2 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -200,6 +200,12 @@ def test_save_zoo_directory(tmp_path, stub): None, ("000000000000" "1111111111"), ), + ( + ("hello\n\n" "[bar]world"), + "[bar]", + None, + ("0000000" "1111111111"), + ), ], ) def test_generate_mask(string, response, prompt, expected_mask): From 68f4eacf7fdbaba86d39b33dc0b4f8710ac91d35 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 5 Apr 2024 17:17:40 +0000 Subject: [PATCH 8/8] comments --- src/sparseml/export/validators.py | 3 +- src/sparseml/transformers/utils/helpers.py | 8 ++--- .../transformers/finetune/test_finetune.py | 33 +++++++++++++++++++ .../transformers/utils/test_helpers.py | 11 ++++++- 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/sparseml/export/validators.py b/src/sparseml/export/validators.py index 52c9fa05ee0..f513bda21aa 100644 --- a/src/sparseml/export/validators.py +++ b/src/sparseml/export/validators.py @@ -17,8 +17,9 @@ import os.path from collections import OrderedDict from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional from typing import OrderedDict as OrderedDictType +from typing import Union import numpy import onnx diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index 9cc4481c694..48aec7f763a 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -559,7 +559,7 @@ def fetch_recipe_path(target: str): return recipe_path -def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> str: +def generate_mask(string: str, response: str, prompt: str = "") -> str: """ Generate a mask based on provided prompt and response strings to obscure characters in the input string. Prompt will be masked and string in response @@ -576,18 +576,16 @@ def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> s characters and '0' indicates obscured characters. """ - if prompt is None: - prompt = "" mask = ["1"] * len(string) - is_prompt = True + is_prompt = False if string.startswith(response) else True counter = 0 for i, char in enumerate(string): if is_prompt: mask[i] = "0" if counter > 0: - if not is_prompt and char == prompt[counter]: + if not is_prompt and len(prompt) > 1 and char == prompt[counter]: counter += 1 elif is_prompt and char == response[counter]: counter += 1 diff --git a/tests/sparseml/transformers/finetune/test_finetune.py b/tests/sparseml/transformers/finetune/test_finetune.py index b2a437751da..aa9f84e02b1 100644 --- a/tests/sparseml/transformers/finetune/test_finetune.py +++ b/tests/sparseml/transformers/finetune/test_finetune.py @@ -32,6 +32,7 @@ oneshot, train, ) +from sparseml.transformers.utils.helpers import generate_mask def test_oneshot_and_finetune(tmp_path: Path): @@ -319,3 +320,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path): splits=splits, oneshot_device=device, ) + + +def test_finetune_wout_recipe_with_mask(tmp_path: Path): + recipe_str = None + model = "Xenova/llama2.c-stories15M" + device = "cuda:0" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + output_dir = tmp_path + max_steps = 50 + splits = "train" + + def preprocessing_func(example): + example["text"] = "[foo]" + example["text"] + "[bar] mask this" + example["mask"] = generate_mask( + example["text"], response="[bar]", prompt="[foo]" + ) + return example + + train( + model=model, + dataset=dataset, + output_dir=output_dir, + recipe=recipe_str, + max_steps=max_steps, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + preprocessing_func=preprocessing_func, + ) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 0d77a6ee7d2..04ebf735d7e 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -206,7 +206,16 @@ def test_save_zoo_directory(tmp_path, stub): None, ("0000000" "1111111111"), ), + ( + ("[bar]world" "[foo]hello\n\n" "[bar]world"), + "[bar]", + "[foo]", + ("1111111111" "000000000000" "1111111111"), + ), ], ) def test_generate_mask(string, response, prompt, expected_mask): - assert generate_mask(string, response, prompt) == expected_mask + if prompt is not None: + assert generate_mask(string, response, prompt) == expected_mask + else: + assert generate_mask(string, response) == expected_mask