diff --git a/src/sparseml/export/validators.py b/src/sparseml/export/validators.py index 52c9fa05ee0..f513bda21aa 100644 --- a/src/sparseml/export/validators.py +++ b/src/sparseml/export/validators.py @@ -17,8 +17,9 @@ import os.path from collections import OrderedDict from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional from typing import OrderedDict as OrderedDictType +from typing import Union import numpy import onnx diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index 9cc4481c694..48aec7f763a 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -559,7 +559,7 @@ def fetch_recipe_path(target: str): return recipe_path -def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> str: +def generate_mask(string: str, response: str, prompt: str = "") -> str: """ Generate a mask based on provided prompt and response strings to obscure characters in the input string. Prompt will be masked and string in response @@ -576,18 +576,16 @@ def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> s characters and '0' indicates obscured characters. """ - if prompt is None: - prompt = "" mask = ["1"] * len(string) - is_prompt = True + is_prompt = False if string.startswith(response) else True counter = 0 for i, char in enumerate(string): if is_prompt: mask[i] = "0" if counter > 0: - if not is_prompt and char == prompt[counter]: + if not is_prompt and len(prompt) > 1 and char == prompt[counter]: counter += 1 elif is_prompt and char == response[counter]: counter += 1 diff --git a/tests/sparseml/transformers/finetune/test_finetune.py b/tests/sparseml/transformers/finetune/test_finetune.py index b2a437751da..aa9f84e02b1 100644 --- a/tests/sparseml/transformers/finetune/test_finetune.py +++ b/tests/sparseml/transformers/finetune/test_finetune.py @@ -32,6 +32,7 @@ oneshot, train, ) +from sparseml.transformers.utils.helpers import generate_mask def test_oneshot_and_finetune(tmp_path: Path): @@ -319,3 +320,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path): splits=splits, oneshot_device=device, ) + + +def test_finetune_wout_recipe_with_mask(tmp_path: Path): + recipe_str = None + model = "Xenova/llama2.c-stories15M" + device = "cuda:0" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + output_dir = tmp_path + max_steps = 50 + splits = "train" + + def preprocessing_func(example): + example["text"] = "[foo]" + example["text"] + "[bar] mask this" + example["mask"] = generate_mask( + example["text"], response="[bar]", prompt="[foo]" + ) + return example + + train( + model=model, + dataset=dataset, + output_dir=output_dir, + recipe=recipe_str, + max_steps=max_steps, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + preprocessing_func=preprocessing_func, + ) diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index 0d77a6ee7d2..04ebf735d7e 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -206,7 +206,16 @@ def test_save_zoo_directory(tmp_path, stub): None, ("0000000" "1111111111"), ), + ( + ("[bar]world" "[foo]hello\n\n" "[bar]world"), + "[bar]", + "[foo]", + ("1111111111" "000000000000" "1111111111"), + ), ], ) def test_generate_mask(string, response, prompt, expected_mask): - assert generate_mask(string, response, prompt) == expected_mask + if prompt is not None: + assert generate_mask(string, response, prompt) == expected_mask + else: + assert generate_mask(string, response) == expected_mask