diff --git a/.github/codecov.yml b/.github/codecov.yml index 854f75b..504dfe7 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -23,6 +23,4 @@ comment: ignore: - "tests/**" - - "test_*.py" - - "**/__main__.py" - - "**/__init__.py" \ No newline at end of file + - "test_*.py" \ No newline at end of file diff --git a/README.md b/README.md index 94d72a5..7e36452 100644 --- a/README.md +++ b/README.md @@ -10,19 +10,50 @@ ### Augmented Recurrent Neural G2P with Inflectional Orthography -Grapheme-to-phoneme (G2P) conversion is the process of converting the written form of words (Graphemes) to their -pronunciations (Phonemes). Deep learning models for text-to-speech (TTS) synthesis using phoneme / mixed symbols -typically require a G2P conversion method for both training and inference. - -Aquila Resolve presents a new approach for accurate and efficient English G2P resolution. -Input text graphemes are translated into their phonetic pronunciations, -using [ARPAbet](https://wikipedia.org/wiki/ARPABET) as the [phoneme symbol set](#Symbol-Set). +Aquila Resolve presents a new approach for accurate and efficient English to +[ARPAbet](https://wikipedia.org/wiki/ARPABET) G2P resolution. The pipeline employs a context layer, multiple transformer and n-gram morpho-orthographical search layers, -and an autoregressive recurrent neural transformer base. - -The current implementation offers state-of-the-art accuracy for out-of-vocabulary (OOV) words, as well as contextual +and an autoregressive recurrent neural transformer base. The current implementation offers state-of-the-art accuracy for out-of-vocabulary (OOV) words, as well as contextual analysis for correct inferencing of [English Heteronyms](https://en.wikipedia.org/wiki/Heteronym_(linguistics)). +The package is offered in a pre-trained state that is ready for use as a dependency or in +notebook environments. There are no additional resources needed, other than the model checkpoint which is +automatically downloaded on the first usage. See [Installation](#Installation) more information. + +### 1. Dynamic Word Mappings based on context: + +```pycon +g2p.convert('I read the book, did you read it?') +# >> '{AY1} {R EH1 D} {DH AH0} {B UH1 K}, {D IH1 D} {Y UW1} {R IY1 D} {IH1 T}?' +``` +```pycon +g2p.convert('The researcher was to subject the subject to a test.') +# >> '{DH AH0} {R IY1 S ER0 CH ER0} {W AA1 Z} {T UW1} {S AH0 B JH EH1 K T} {DH AH0} {S AH1 B JH IH0 K T} {T UW1} {AH0} {T EH1 S T}.' +``` + +| | 'The subject was told to read. Eight records were read in total.' | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------| +| *Ground Truth* | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were `R EH1 D` in total. | +| Aquila Resolve | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were `R EH1 D` in total. | +| [Deep Phonemizer](https://github.com/as-ideas/DeepPhonemizer)
([en_us_cmudict_forward.pt](https://github.com/as-ideas/DeepPhonemizer#pretrained-models)) | The **S AH B JH EH K T** was told to **R EH D**. Eight **R AH K AO R D Z** were `R EH D` in total. | +| [CMUSphinx Seq2Seq](https://github.com/cmusphinx/g2p-seq2seq)
([checkpoint](https://github.com/cmusphinx/g2p-seq2seq#running-g2p)) | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight **R IH0 K AO1 R D Z** were **R IY1 D** in total. | +| [ESpeakNG](https://github.com/espeak-ng/espeak-ng)
(with [phonecodes](https://github.com/jhasegaw/phonecodes)) | The **S AH1 B JH EH K T** was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were **R IY1 D** in total. | + +### 2. Leading Accuracy for unseen words: + +```pycon +g2p.convert('Did you kalpe the Hevinet?') +# >> '{AY1} {R EH1 D} {DH AH0} {B UH1 K}, {D IH1 D} {Y UW1} {R IY1 D} {IH1 T}?' +``` + +| | "tensorflow" | "agglomerative" | "necrophages" | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------|------------------------------------|----------------------------------| +| Aquila Resolve | `T EH1 N S ER0 F L OW2` | `AH0 G L AA1 M ER0 EY2 T IH0 V` | `N EH1 K R OW0 F EY2 JH IH0 Z` | +| [Deep Phonemizer](https://github.com/as-ideas/DeepPhonemizer)
([en_us_cmudict_forward.pt](https://github.com/as-ideas/DeepPhonemizer#pretrained-models)) | `T EH N S ER F L OW` | **AH G L AA M ER AH T IH V** | `N EH K R OW F EY JH IH Z` | +| [CMUSphinx Seq2Seq](https://github.com/cmusphinx/g2p-seq2seq)
([checkpoint](https://github.com/cmusphinx/g2p-seq2seq#running-g2p)) | **T EH1 N S ER0 L OW0 F** | **AH0 G L AA1 M ER0 T IH0 V** | **N AE1 K R AH0 F IH0 JH IH0 Z** | +| [ESpeakNG](https://github.com/espeak-ng/espeak-ng)
(with [phonecodes](https://github.com/jhasegaw/phonecodes)) | **T EH1 N S OW0 R F L OW2** | **AA G L AA1 M ER0 R AH0 T IH2 V** | **N EH1 K R AH0 F IH JH EH0 Z** | + + ## Installation ```bash @@ -32,8 +63,8 @@ pip install aquila-resolve > automatically downloaded on the first use of relevant public methods that require inferencing. For example, > when [instantiating `G2p`](#Usage). You can also start this download manually by calling `Aquila_Resolve.download()`. > -> If you are in an environment where remote file downloads are not possible, you can also download the checkpoint -> manually and instantiate `G2p` with the flag: `G2p(custom_checkpoint='path/model.pt')` +> If you are in an environment where remote file downloads are not possible, you can also transfer the checkpoint +> manually, placing `model.pt` within the `Aquila_Resolve.data` module folder. ## Usage @@ -48,10 +79,10 @@ g2p.convert('The book costs $5, will you read it?') > Additional optional parameters are available when defining a `G2p` instance: -| Parameter | Default | Description | -|--------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `device` | `'cpu'` | Device for Pytorch inference model | -| `process_numbers` | `True` | Toggles conversion of some numbers and symbols to their spoken pronunciation forms. See [numbers.py](src/Aquila_Resolve/text/numbers.py) for details on what is covered. | +| Parameter | Default | Description | +|-------------------|---------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `device` | `'cpu'` | Device for Pytorch inference model. GPU is supported using `'cuda'` | +| `process_numbers` | `True` | Toggles conversion of some numbers and symbols to their spoken pronunciation forms. See [numbers.py](src/Aquila_Resolve/text/numbers.py) for details on what is covered. | ## Model Architecture diff --git a/setup.cfg b/setup.cfg index f4e2dd4..97c127a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,12 +1,12 @@ [metadata] name = Aquila-Resolve -version = 0.1.2-dev1 +version = 0.1.2 author = ionite author_email = dev@ionite.io description = Augmented Recurrent Neural Grapheme-to-Phoneme conversion with Inflectional Orthography. long_description = file: README.md long_description_content_type = text/markdown -url = https://github.com/ionite34/Aquila-Resolve' +url = https://github.com/ionite34/Aquila-Resolve license = Apache 2.0 license_file = LICENSE classifiers = diff --git a/src/Aquila_Resolve/__init__.py b/src/Aquila_Resolve/__init__.py index 16af5b0..fb63580 100644 --- a/src/Aquila_Resolve/__init__.py +++ b/src/Aquila_Resolve/__init__.py @@ -4,7 +4,7 @@ Grapheme to Phoneme Resolver """ -__version__ = "0.1.2-dev1" +__version__ = "0.1.2" from .g2p import G2p from .data.remote import download diff --git a/src/Aquila_Resolve/data/__init__.py b/src/Aquila_Resolve/data/__init__.py index fd5de86..8bff045 100644 --- a/src/Aquila_Resolve/data/__init__.py +++ b/src/Aquila_Resolve/data/__init__.py @@ -2,7 +2,7 @@ if sys.version_info < (3, 9): # In Python versions below 3.9, this is needed - from importlib_resources import files + from importlib_resources import files # pragma: no cover else: # Since python 3.9+, importlib.resources.files is built-in from importlib.resources import files diff --git a/src/Aquila_Resolve/g2p.py b/src/Aquila_Resolve/g2p.py index ec9ca6b..4b99000 100644 --- a/src/Aquila_Resolve/g2p.py +++ b/src/Aquila_Resolve/g2p.py @@ -9,15 +9,14 @@ from nltk.stem.snowball import SnowballStemmer from .h2p import H2p -from .h2p import replace_first +from .text.replace import replace_first from .format_ph import with_cb -# from .dict_reader import DictReader from .static_dict import get_cmudict from .text.numbers import normalize_numbers from .filter import filter_text from .processors import Processor from .infer import Infer -from .symbols import contains_alpha, brackets_match +from .symbols import contains_alpha, valid_braces re_digit = re.compile(r"\((\d+)\)") re_bracket_with_digit = re.compile(r"\(.*\)") @@ -143,13 +142,9 @@ def convert(self, text: str, convert_num: bool = True) -> str | None: :param convert_num: True to convert numbers to words """ - # Check that every {} bracket is paired - check = brackets_match(text) - if check is not None: - raise ValueError(check) - - # Normalize numbers, if enabled + # Convert numbers, if enabled if convert_num: + valid_braces(text, raise_on_invalid=True) text = normalize_numbers(text) # Filter and Tokenize diff --git a/src/Aquila_Resolve/h2p.py b/src/Aquila_Resolve/h2p.py index e013a53..4be9df6 100644 --- a/src/Aquila_Resolve/h2p.py +++ b/src/Aquila_Resolve/h2p.py @@ -1,27 +1,18 @@ -import nltk -import re from nltk.tokenize import TweetTokenizer from nltk import pos_tag from nltk import pos_tag_sents from .dictionary import Dictionary from .filter import filter_text as ft +from .text.replace import replace_first from . import format_ph as ph -# Check that the nltk data is downloaded, if not, download it +# Check required nltk data exists, if not, download it try: - nltk.data.find('taggers/averaged_perceptron_tagger.zip') -except LookupError: - nltk.download('averaged_perceptron_tagger') - - -# Method to use Regex to replace the first instance of a word with its phonemes -def replace_first(target, replacement, text): - # Skip if target invalid - if target is None or target == '': - return text - # Replace the first instance of a word with its phonemes - # return re.sub(r'(?i)\b' + target + r'\b', replacement, text, 1) - return re.sub(r'(? list[str]: + def __call__(self, text: list[str]) -> list[str]: """ Infers phonemes for a list of words. - :param words: list of words + :param text: list of words :return: dict of {word: phonemes} """ - res = self.model.phonemise_list(words, lang=self.lang, batch_size=self.batch_size).phonemes + res = self.model.phonemise_list(text, lang=self.lang, batch_size=self.batch_size).phonemes # Replace all occurrences of '][' with spaces, remove remaining brackets res = [r.replace('][', ' ').replace('[', '').replace(']', '') for r in res] return res diff --git a/src/Aquila_Resolve/models/__init__.py b/src/Aquila_Resolve/models/__init__.py index e3d4f0e..ac7d693 100644 --- a/src/Aquila_Resolve/models/__init__.py +++ b/src/Aquila_Resolve/models/__init__.py @@ -2,7 +2,7 @@ if sys.version_info < (3, 9): # In Python versions below 3.9, this is needed - from importlib_resources import files + from importlib_resources import files # pragma: no cover else: # Since python 3.9+, importlib.resources.files is built-in from importlib.resources import files diff --git a/src/Aquila_Resolve/models/dp/model/model.py b/src/Aquila_Resolve/models/dp/model/model.py index 46d4d18..5b66235 100644 --- a/src/Aquila_Resolve/models/dp/model/model.py +++ b/src/Aquila_Resolve/models/dp/model/model.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from torch.nn import TransformerEncoderLayer, LayerNorm, TransformerEncoder -from .utils import get_dedup_tokens, _make_len_mask, _generate_square_subsequent_mask, PositionalEncoding +from .utils import _make_len_mask, _generate_square_subsequent_mask, PositionalEncoding from ..preprocessing.text import Preprocessor @@ -17,7 +16,7 @@ def is_autoregressive(self) -> bool: """ Returns: bool: Whether the model is autoregressive. """ - return self in {ModelType.AUTOREG_TRANSFORMER} + return self in {ModelType.AUTOREG_TRANSFORMER} # pragma: no cover class Model(torch.nn.Module, ABC): @@ -39,91 +38,7 @@ def generate(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch. Tuple[torch.Tensor, torch.Tensor]: The predictions. The first element is a tensor (phoneme tokens) and the second element is a tensor (phoneme token probabilities) """ - pass - - -class ForwardTransformer(Model): - - def __init__(self, - encoder_vocab_size: int, - decoder_vocab_size: int, - d_model=512, - d_fft=1024, - layers=4, - dropout=0.1, - heads=1) -> None: - super().__init__() - - self.d_model = d_model - - self.embedding = nn.Embedding(encoder_vocab_size, d_model) - self.pos_encoder = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer(d_model=d_model, - nhead=heads, - dim_feedforward=d_fft, - dropout=dropout, - activation='relu') - encoder_norm = LayerNorm(d_model) - self.encoder = TransformerEncoder(encoder_layer=encoder_layer, - num_layers=layers, - norm=encoder_norm) - - self.fc_out = nn.Linear(d_model, decoder_vocab_size) - - def forward(self, - batch: Dict[str, torch.Tensor]) -> torch.Tensor: # shape: [N, T] - """ - Forward pass of the model on a data batch. - - Args: - batch (Dict[str, torch.Tensor]): Input batch entry 'text' (text tensor). - - Returns: - Tensor: Predictions. - """ - - x = batch['text'] - x = x.transpose(0, 1) # shape: [T, N] - src_pad_mask = _make_len_mask(x).to(x.device) - x = self.embedding(x) - x = self.pos_encoder(x) - x = self.encoder(x, src_key_padding_mask=src_pad_mask) - x = self.fc_out(x) - x = x.transpose(0, 1) - return x - - @torch.jit.export - def generate(self, - batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Inference pass on a batch of tokenized texts. - - Args: - batch (Dict[str, torch.Tensor]): Input batch with entry 'text' (text tensor). - - Returns: - Tuple: The first element is a Tensor (phoneme tokens) and the second element - is a tensor (phoneme token probabilities). - """ - - with torch.no_grad(): - x = self.forward(batch) - tokens, logits = get_dedup_tokens(x) - return tokens, logits - - @classmethod - def from_config(cls, config: dict) -> 'ForwardTransformer': - preprocessor = Preprocessor.from_config(config) - return ForwardTransformer( - encoder_vocab_size=preprocessor.text_tokenizer.vocab_size, - decoder_vocab_size=preprocessor.phoneme_tokenizer.vocab_size, - d_model=config['model']['d_model'], - d_fft=config['model']['d_fft'], - layers=config['model']['layers'], - dropout=config['model']['dropout'], - heads=config['model']['heads'] - ) + pass # pragma: no cover class AutoregressiveTransformer(Model): @@ -151,42 +66,6 @@ def __init__(self, dropout=dropout, activation='relu') self.fc_out = nn.Linear(d_model, decoder_vocab_size) - def forward(self, batch: Dict[str, torch.Tensor]): # shape: [N, T] - """ - Foward pass of the model on a data batch. - - Args: - batch (Dict[str, torch.Tensor]): Input batch with entries 'text' (text tensor) and 'phonemes' - (phoneme tensor for teacher forcing). - - Returns: - Tensor: Predictions. - """ - - src = batch['text'] - trg = batch['phonemes'][:, :-1] - - src = src.transpose(0, 1) # shape: [T, N] - trg = trg.transpose(0, 1) - - trg_mask = _generate_square_subsequent_mask(len(trg)).to(trg.device) - - src_pad_mask = _make_len_mask(src).to(trg.device) - trg_pad_mask = _make_len_mask(trg).to(trg.device) - - src = self.encoder(src) - src = self.pos_encoder(src) - - trg = self.decoder(trg) - trg = self.pos_decoder(trg) - - output = self.transformer(src, trg, src_mask=None, tgt_mask=trg_mask, - memory_mask=None, src_key_padding_mask=src_pad_mask, - tgt_key_padding_mask=trg_pad_mask, memory_key_padding_mask=src_pad_mask) - output = self.fc_out(output) - output = output.transpose(0, 1) - return output - @torch.jit.export def generate(self, batch: Dict[str, torch.Tensor], @@ -278,15 +157,10 @@ def create_model(model_type: ModelType, config: Dict[str, Any]) -> Model: Returns: Model: Model object. """ - - if model_type is ModelType.TRANSFORMER: - model = ForwardTransformer.from_config(config) - elif model_type is ModelType.AUTOREG_TRANSFORMER: - model = AutoregressiveTransformer.from_config(config) - else: + if model_type is not ModelType.AUTOREG_TRANSFORMER: # pragma: no cover raise ValueError(f'Unsupported model type: {model_type}. ' - f'Supported types: {[t.value for t in ModelType]}') - return model + 'Supported type: AUTOREG_TRANSFORMER') + return AutoregressiveTransformer.from_config(config) def load_checkpoint(checkpoint_path: str, device: str = 'cpu') -> Tuple[Model, Dict[str, Any]]: diff --git a/src/Aquila_Resolve/models/dp/model/utils.py b/src/Aquila_Resolve/models/dp/model/utils.py index 1ba58b7..3518390 100644 --- a/src/Aquila_Resolve/models/dp/model/utils.py +++ b/src/Aquila_Resolve/models/dp/model/utils.py @@ -1,8 +1,5 @@ import math -from typing import Tuple - import torch -from torch.nn.utils.rnn import pad_sequence class PositionalEncoding(torch.nn.Module): @@ -30,48 +27,11 @@ def __init__(self, d_model: int, dropout=0.1, max_len=5000) -> None: pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) - def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N] + def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N] x = x + self.scale * self.pe[:x.size(0), :] return self.dropout(x) -def get_dedup_tokens(logits_batch: torch.Tensor) \ - -> Tuple[torch.Tensor, torch.Tensor]: - """Converts a batch of logits into the batch most probable tokens and their probabilities. - - Args: - logits_batch (Tensor): Batch of logits (N x T x V). - - Returns: - Tuple: Deduplicated tokens. The first element is a tensor (token indices) and the second element - is a tensor (token probabilities) - - """ - - logits_batch = logits_batch.softmax(-1) - out_tokens, out_probs = [], [] - for i in range(logits_batch.size(0)): - logits = logits_batch[i] - max_logits, max_indices = torch.max(logits, dim=-1) - max_logits = max_logits[max_indices!=0] - max_indices = max_indices[max_indices!=0] - cons_tokens, counts = torch.unique_consecutive( - max_indices, return_counts=True) - out_probs_i = torch.zeros(len(counts), device=logits.device) - ind = 0 - for i, c in enumerate(counts): - max_logit = max_logits[ind:ind + c].max() - out_probs_i[i] = max_logit - ind = ind + c - out_tokens.append(cons_tokens) - out_probs.append(out_probs_i) - - out_tokens = pad_sequence(out_tokens, batch_first=True, padding_value=0.).long() - out_probs = pad_sequence(out_probs, batch_first=True, padding_value=0.) - - return out_tokens, out_probs - - def _generate_square_subsequent_mask(sz: int) -> torch.Tensor: mask = torch.triu(torch.ones(sz, sz), 1) mask = mask.masked_fill(mask == 1, float('-inf')) @@ -86,9 +46,4 @@ def _get_len_util_stop(sequence: torch.Tensor, end_index: int) -> int: for i, val in enumerate(sequence): if val == end_index: return i + 1 - return len(sequence) - - -def _trim_util_stop(sequence: torch.Tensor, end_index: int) -> torch.Tensor: - seq_len = _get_len_util_stop(sequence, end_index) - return sequence[:seq_len] + return len(sequence) # pragma: no cover diff --git a/src/Aquila_Resolve/symbols.py b/src/Aquila_Resolve/symbols.py index 0acbb77..71ccb4a 100644 --- a/src/Aquila_Resolve/symbols.py +++ b/src/Aquila_Resolve/symbols.py @@ -98,35 +98,53 @@ def get_parent_pos(pos: str) -> str | None: return None -def contains_alpha(s: str) -> bool: - # Check if a word contains an alpha character - return s.upper().isupper() - - -def is_phoneme(s: str) -> bool: - # Check if a word is a phoneme, detect brackets - return s.startswith('{') and s.endswith('}') - - -def brackets_match(s: str) -> str | None: - # Check if string contains brackets at all - if not ('{' in s or '}' in s): - return None # Valid - index_opened = -1 - in_bracket = False - for i in range(len(s)): - if not in_bracket: - if s[i] == '{': - in_bracket = True - index_opened = i - elif s[i] == '}': - return f'Unexpected close bracket at index {i} without open.' +def contains_alpha(text: str) -> bool: + """ + Method to check if a string contains alphabetic characters. + :param text: + :return: + """ + return text.upper().isupper() + + +def is_braced(word: str) -> bool: + """ + Check if a word is surrounded by brace-markings {}. + + :param word: Word + :return: True if word is braced-marked. + """ + return word.startswith('{') and word.endswith('}') + + +def valid_braces(text: str, raise_on_invalid: bool = False) -> bool: + """ + Check if a text is valid braced-marked. + + :param text: Text to check. + :param raise_on_invalid: Raises ValueError if invalid. + :return: True if text is valid braced-marked. + """ + def invalid(msg: str) -> bool: + if raise_on_invalid: + raise ValueError(f'Invalid braced-marked text ({msg}) in "{text}"') else: - if s[i] == '}': - in_bracket = False - elif s[i] == '{': - return f'Unexpected nested open bracket at index {i}.' - if in_bracket: - return f'Bracket opened at index {index_opened} but was never closed.' - return None # Valid + return False + if not any(c in text for c in {'{', '}'}): + return True # No braces, so valid. + in_braces = False + for char in text: + if char == '{': + if not in_braces: + in_braces = True + else: + return invalid('Nested braces') + elif char == '}': + if in_braces: + in_braces = False + else: + return invalid('Closing brace without opening') + if in_braces: + return invalid('Opening brace without closing') + return True diff --git a/src/Aquila_Resolve/text/numbers.py b/src/Aquila_Resolve/text/numbers.py index 8905072..c3050b0 100644 --- a/src/Aquila_Resolve/text/numbers.py +++ b/src/Aquila_Resolve/text/numbers.py @@ -22,7 +22,8 @@ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _currency_re = re.compile(r'([$€£₩])([0-9.,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]|$))?'.format("|".join(_magnitudes)), re.IGNORECASE) -_measurement_re = re.compile(r'([0-9.,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE) +# _measurement_re = re.compile(r'([0-9.,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE) +_measurement_re = re.compile(r'(? str: + """ + Use Regex to replace the first instance of a word + + Words within braces are ignored (e.g. '{word} is ignored') + + :param target: The word to be replaced + :param replacement: Replacement word + :param text: Text to be searched + :return: Text with the first instance of the word replaced + """ + if not target or not text: + return text # Return original if no target or text + # Replace the first instance of a word with its phonemes + # return re.sub(r'(?i)\b' + target + r'\b', replacement, text, 1) + return re.sub(r'(? G2p: g2p = G2p() - assert isinstance(g2p, G2p) yield g2p # Test for lookup method @pytest.mark.parametrize("word, phoneme", [ - ('cat', ['K', 'AE1', 'T']), - ('CaT', ['K', 'AE1', 'T']), - ('CAT', ['K', 'AE1', 'T']), - ('test', ['T', 'EH1', 'S', 'T']), - ('testers', ['T', 'EH1', 'S', 'T', 'ER0', 'Z']), - ('testers(2)', ['T', 'EH1', 'S', 'T', 'AH0', 'Z']), + ('cat', 'K AE1 T'), + ('CaT', 'K AE1 T'), + ('CAT', 'K AE1 T'), + ('test', 'T EH1 S T'), + ('testers', 'T EH1 S T ER0 Z'), + ('testers(2)', 'T EH1 S T AH0 Z'), ]) def test_lookup(g2p, word, phoneme): - assert g2p.lookup(word) == ' '.join(phoneme) + assert g2p.lookup(word) == phoneme # Test for convert method @pytest.mark.parametrize("line, ph_line", zip(cde_lines, cde_expected_results)) def test_convert(g2p, line, ph_line): assert g2p.convert(line) == ph_line + + +# Test for convert format exception +@pytest.mark.parametrize("case", [ + "The cat {R {IY1 D} the} book.", + "The cat {R {IY1 D the} book.", + "The cat {R IY1 D} the} book.", + "The cat {R IY1 D the book.", + "The cat R IY1 D} the book.", +]) +def test_convert_ex_format(g2p, case): + with pytest.raises(ValueError): + g2p.convert(case) diff --git a/tests/test_h2p.py b/tests/test_h2p.py index aaaeeac..858ff54 100644 --- a/tests/test_h2p.py +++ b/tests/test_h2p.py @@ -1,6 +1,4 @@ import pytest -import random -from Aquila_Resolve.h2p import replace_first # List of lines ex_lines = [ @@ -17,40 +15,14 @@ ] -# Function to generate test lines with n sentences, randomly chosen from the list of lines -# The number of heteronyms would be 2n -def gen_line(n): - # List of lines - lines = [ - "The cat read the book. It was a good book to read.", - "You should absent yourself from the meeting. Then you would be absent.", - "The machine would automatically reject products. These were the reject products.", - ] - test_line = "" - # Loop through n sentences - for i in range(n): - # Add space if not the first part - if not i == 0: - test_line += " " - test_line += (random.choice(lines)) - return test_line - - -# Test Data, for contains_het() -# List of tuples, each tuple contains: -# - Line to test -# - Expected result (True/False) -contains_het_data = [ +# Test the contains_het function +@pytest.mark.parametrize("line, expected", [ ("The cat read the book. It was a good book to read.", True), ("The effect was absent.", True), ("Symbols like !, ?, and ;", False), ("The product was a reject.", True), ("", False), (" ", False), ("\n", False), ("\t", False) -] - - -# Test the contains_het function -@pytest.mark.parametrize("line, expected", contains_het_data) +]) def test_contains_het(h2p, line, expected): assert h2p.contains_het(line) == expected @@ -66,16 +38,3 @@ def test_replace_het_list(h2p): results = h2p.replace_het_list(ex_lines) for result, expected in zip(results, ex_expected_results): assert expected == result - - -replace_first_data = [ - ("the", "re", "The cat read the book.", "re cat read the book."), - ("the", "{re mult}", "The effect was absent.", "{re mult} effect was absent."), - ("the", "re", "Symbols !, ?, and ;", "Symbols !, ?, and ;") -] - - -# Test for the test_replace_first function -@pytest.mark.parametrize("search, replace, line, expected", replace_first_data) -def test_replace_first(search, replace, line, expected): - assert replace_first(search, replace, line) == expected diff --git a/tests/test_infer.py b/tests/test_infer.py new file mode 100644 index 0000000..77dc5eb --- /dev/null +++ b/tests/test_infer.py @@ -0,0 +1,19 @@ +import pytest +from Aquila_Resolve.infer import Infer + + +@pytest.fixture(scope="module") +def infer(): + yield Infer() + + +# noinspection SpellCheckingInspection +@pytest.mark.parametrize("case, exp", [ + ([""], [""]), + (["a"], ["AH0"]), + (["a", "a"], ["AH0", "AH0"]), # Test De-duplication + (["a", "b"], ["AH0", "B IY1"]), + (["ioniformi"], ["IY0 AA2 N IH0 F AO1 R M IY0"]), # OOV word +]) +def test_infer(infer, case, exp): + assert infer(case) == exp diff --git a/tests/test_replace.py b/tests/test_replace.py new file mode 100644 index 0000000..8997cd1 --- /dev/null +++ b/tests/test_replace.py @@ -0,0 +1,16 @@ +import pytest +from Aquila_Resolve.text.replace import replace_first + + +# Test for the test_replace_first function +@pytest.mark.parametrize("search, replace, line, expected", [ + (None, "re", "Text.", "Text."), + ("the", "re", "", ""), + ("the", "re", None, None), + ("the", "re", "Thesis.", "Thesis."), + ("the", "re", "The cat read the book.", "re cat read the book."), + ("the", "{re mult}", "The effect was absent.", "{re mult} effect was absent."), + ("the", "re", "Symbols !, ?, and ;", "Symbols !, ?, and ;") +]) +def test_replace_first(search, replace, line, expected): + assert replace_first(search, replace, line) == expected diff --git a/tests/test_symbols.py b/tests/test_symbols.py index 9d11d6e..a13aa0c 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -36,3 +36,25 @@ def test_get_parent_pos_verb(tag, expected): def test_get_parent_pos_invalid_tag(): # If the pos tag is not in the list, expect None assert symbols.get_parent_pos('XYZ') is None + + +@pytest.mark.parametrize('case, exp', [ + ('ABC', True), + ('abc', True), + ('0', False), + ('1A', True), + ('1@$%&', False), + ('@a$%&', True), +]) +def test_contains_alpha(case, exp): + assert symbols.contains_alpha(case) == exp + + +@pytest.mark.parametrize('case, exp', [ + ('word', False), + ('{AH0}', True), + ('{C AH0 T}', True), + ('In {AH0} line.', False), +]) +def test_is_phoneme(case, exp): + assert symbols.is_braced(case) == exp