Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Apr 5, 2024
1 parent 1392a35 commit 68f4eac
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/sparseml/export/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os.path
from collections import OrderedDict
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union

import numpy
import onnx
Expand Down
8 changes: 3 additions & 5 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def fetch_recipe_path(target: str):
return recipe_path


def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> str:
def generate_mask(string: str, response: str, prompt: str = "") -> str:
"""
Generate a mask based on provided prompt and response strings to obscure
characters in the input string. Prompt will be masked and string in response
Expand All @@ -576,18 +576,16 @@ def generate_mask(string: str, response: str, prompt: Optional[str] = None) -> s
characters and '0' indicates obscured characters.
"""
if prompt is None:
prompt = ""

mask = ["1"] * len(string)
is_prompt = True
is_prompt = False if string.startswith(response) else True
counter = 0
for i, char in enumerate(string):
if is_prompt:
mask[i] = "0"

if counter > 0:
if not is_prompt and char == prompt[counter]:
if not is_prompt and len(prompt) > 1 and char == prompt[counter]:
counter += 1
elif is_prompt and char == response[counter]:
counter += 1
Expand Down
33 changes: 33 additions & 0 deletions tests/sparseml/transformers/finetune/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
oneshot,
train,
)
from sparseml.transformers.utils.helpers import generate_mask


def test_oneshot_and_finetune(tmp_path: Path):
Expand Down Expand Up @@ -319,3 +320,35 @@ def test_oneshot_with_modifier_object(tmp_path: Path):
splits=splits,
oneshot_device=device,
)


def test_finetune_wout_recipe_with_mask(tmp_path: Path):
recipe_str = None
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
output_dir = tmp_path
max_steps = 50
splits = "train"

def preprocessing_func(example):
example["text"] = "[foo]" + example["text"] + "[bar] mask this"
example["mask"] = generate_mask(
example["text"], response="[bar]", prompt="[foo]"
)
return example

train(
model=model,
dataset=dataset,
output_dir=output_dir,
recipe=recipe_str,
max_steps=max_steps,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
preprocessing_func=preprocessing_func,
)
11 changes: 10 additions & 1 deletion tests/sparseml/transformers/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,16 @@ def test_save_zoo_directory(tmp_path, stub):
None,
("0000000" "1111111111"),
),
(
("[bar]world" "[foo]hello\n\n" "[bar]world"),
"[bar]",
"[foo]",
("1111111111" "000000000000" "1111111111"),
),
],
)
def test_generate_mask(string, response, prompt, expected_mask):
assert generate_mask(string, response, prompt) == expected_mask
if prompt is not None:
assert generate_mask(string, response, prompt) == expected_mask
else:
assert generate_mask(string, response) == expected_mask

0 comments on commit 68f4eac

Please sign in to comment.