Skip to content

Commit

Permalink
add _accept_string as opposed to _accept_prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
Saibo-creator committed Feb 29, 2024
1 parent 258de17 commit 9a5d840
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 7 deletions.
8 changes: 6 additions & 2 deletions docs/debugging_custom_grammars.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ If the grammar can be parsed, it means that it is syntactically correct.

After you have checked that the grammar can be parsed, you can test it with a simple input to see if it can generate the expected output.
We provide a simple script to do this:

```python
from transformers_cfg.parser import parse_ebnf
from transformers_cfg.recognizer import GrammarRecognizer
Expand All @@ -100,22 +101,25 @@ recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)

# Test the grammar with a simple input
json_input = '{"foo": "bar", "baz": "bat"}'
is_accepted = recognizer._accept_string(json_input, recognizer.stacks)
is_accepted = recognizer._accept_prefix(json_input, recognizer.stacks)
print(is_accepted)
```

If the above script returns `True`, it means that the grammar can recognize the input string.
If it returns `False`, it means that the grammar cannot recognize the input string.
In this case, you need to check in which step the input string is rejected.
N.B. the recognizer can accept partial input, so you can try the following:

```python
json_input = '{"foo": "bar"'
is_accepted = recognizer._accept_string(json_input, recognizer.stacks)
is_accepted = recognizer._accept_prefix(json_input, recognizer.stacks)
print(is_accepted)
```

This helps you to see where the grammar fails to recognize the input string.

If you want to check if the sentence is complete or not, you can use `_accept_string` method, which returns `True` if the input string is complete and `False` otherwise.

## DEBUG mode

You can enable the DEBUG mode to see the parsing process of the input string.
Expand Down
12 changes: 11 additions & 1 deletion tests/test_string_recognizer/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,27 @@ def test_minimal_json_object(self):

# accept_state = AcceptState.empty_state()

self.assertEqual(
is_json_parsable(json),
self.recognizer._accept_prefix(json),
)

self.assertEqual(
is_json_parsable(json),
self.recognizer._accept_string(json),
)

prefix_json = json[: len(json) // 2]
self.assertTrue(self.recognizer._accept_prefix(prefix_json))

self.assertFalse(self.recognizer._accept_string(prefix_json))

def test_systematic_examples(self):

for name, json_object in json_examples.items():
# accept_state = AcceptState.empty_state()
self.assertEqual(
is_json_parsable(json_object),
self.recognizer._accept_string(json_object),
self.recognizer._accept_prefix(json_object),
msg=f"Failed on {name}, {json_object}",
)
2 changes: 1 addition & 1 deletion tests/test_string_recognizer/test_json_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ def test_minimal_json_array(self):
# accept_state = AcceptState.empty_state()
self.assertEqual(
is_json_parsable(json),
recognizer._accept_string(json),
recognizer._accept_prefix(json),
f"Failed on {json}",
)
4 changes: 2 additions & 2 deletions tests/test_string_recognizer/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_accept_japanese(self):

# accept_state = AcceptState.empty_state()

self.assertTrue(recognizer._accept_string(japanese))
self.assertTrue(recognizer._accept_prefix(japanese))

def test_emoji(self):
"""
Expand All @@ -42,4 +42,4 @@ def test_emoji(self):

recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id)

self.assertTrue(recognizer._accept_string(emoji))
self.assertTrue(recognizer._accept_prefix(emoji))
11 changes: 10 additions & 1 deletion transformers_cfg/recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,21 @@ def _consume_string(self, string: str, accept_state: AcceptState):
stacks = self._consume_code_points(code_points, accept_state.stacks)
return AcceptState(stacks, accept_state.partial_utf8)

def _accept_string(self, string: str, accept_state: AcceptState = None):
def _accept_prefix(self, string: str, accept_state: AcceptState = None):
if accept_state is None:
accept_state = self.get_initial_accept_state()
new_accept_state = self._consume_string(string, accept_state)
return len(new_accept_state.stacks) > 0

def _accept_string(self, string: str, accept_state: AcceptState = None):
if accept_state is None:
accept_state = self.get_initial_accept_state()
new_accept_state = self._consume_string(string, accept_state)
at_least_one_stack_is_empty = any(
len(stack) == 0 for stack in new_accept_state.stacks
)
return at_least_one_stack_is_empty

def _can_stop(self, stacks: List[List[int]]):
# This happens in practice, but maybe it shouldn't? TODO
if len(stacks) == 0:
Expand Down

0 comments on commit 9a5d840

Please sign in to comment.