diff --git a/.github/codecov.yml b/.github/codecov.yml
index 854f75b..504dfe7 100644
--- a/.github/codecov.yml
+++ b/.github/codecov.yml
@@ -23,6 +23,4 @@ comment:
ignore:
- "tests/**"
- - "test_*.py"
- - "**/__main__.py"
- - "**/__init__.py"
\ No newline at end of file
+ - "test_*.py"
\ No newline at end of file
diff --git a/README.md b/README.md
index 94d72a5..7e36452 100644
--- a/README.md
+++ b/README.md
@@ -10,19 +10,50 @@
### Augmented Recurrent Neural G2P with Inflectional Orthography
-Grapheme-to-phoneme (G2P) conversion is the process of converting the written form of words (Graphemes) to their
-pronunciations (Phonemes). Deep learning models for text-to-speech (TTS) synthesis using phoneme / mixed symbols
-typically require a G2P conversion method for both training and inference.
-
-Aquila Resolve presents a new approach for accurate and efficient English G2P resolution.
-Input text graphemes are translated into their phonetic pronunciations,
-using [ARPAbet](https://wikipedia.org/wiki/ARPABET) as the [phoneme symbol set](#Symbol-Set).
+Aquila Resolve presents a new approach for accurate and efficient English to
+[ARPAbet](https://wikipedia.org/wiki/ARPABET) G2P resolution.
The pipeline employs a context layer, multiple transformer and n-gram morpho-orthographical search layers,
-and an autoregressive recurrent neural transformer base.
-
-The current implementation offers state-of-the-art accuracy for out-of-vocabulary (OOV) words, as well as contextual
+and an autoregressive recurrent neural transformer base. The current implementation offers state-of-the-art accuracy for out-of-vocabulary (OOV) words, as well as contextual
analysis for correct inferencing of [English Heteronyms](https://en.wikipedia.org/wiki/Heteronym_(linguistics)).
+The package is offered in a pre-trained state that is ready for use as a dependency or in
+notebook environments. There are no additional resources needed, other than the model checkpoint which is
+automatically downloaded on the first usage. See [Installation](#Installation) more information.
+
+### 1. Dynamic Word Mappings based on context:
+
+```pycon
+g2p.convert('I read the book, did you read it?')
+# >> '{AY1} {R EH1 D} {DH AH0} {B UH1 K}, {D IH1 D} {Y UW1} {R IY1 D} {IH1 T}?'
+```
+```pycon
+g2p.convert('The researcher was to subject the subject to a test.')
+# >> '{DH AH0} {R IY1 S ER0 CH ER0} {W AA1 Z} {T UW1} {S AH0 B JH EH1 K T} {DH AH0} {S AH1 B JH IH0 K T} {T UW1} {AH0} {T EH1 S T}.'
+```
+
+| | 'The subject was told to read. Eight records were read in total.' |
+|--------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------|
+| *Ground Truth* | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were `R EH1 D` in total. |
+| Aquila Resolve | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were `R EH1 D` in total. |
+| [Deep Phonemizer](https://github.com/as-ideas/DeepPhonemizer)
([en_us_cmudict_forward.pt](https://github.com/as-ideas/DeepPhonemizer#pretrained-models)) | The **S AH B JH EH K T** was told to **R EH D**. Eight **R AH K AO R D Z** were `R EH D` in total. |
+| [CMUSphinx Seq2Seq](https://github.com/cmusphinx/g2p-seq2seq)
([checkpoint](https://github.com/cmusphinx/g2p-seq2seq#running-g2p)) | The `S AH1 B JH IH0 K T` was told to `R IY1 D`. Eight **R IH0 K AO1 R D Z** were **R IY1 D** in total. |
+| [ESpeakNG](https://github.com/espeak-ng/espeak-ng)
(with [phonecodes](https://github.com/jhasegaw/phonecodes)) | The **S AH1 B JH EH K T** was told to `R IY1 D`. Eight `R EH1 K ER0 D Z` were **R IY1 D** in total. |
+
+### 2. Leading Accuracy for unseen words:
+
+```pycon
+g2p.convert('Did you kalpe the Hevinet?')
+# >> '{AY1} {R EH1 D} {DH AH0} {B UH1 K}, {D IH1 D} {Y UW1} {R IY1 D} {IH1 T}?'
+```
+
+| | "tensorflow" | "agglomerative" | "necrophages" |
+|--------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------|------------------------------------|----------------------------------|
+| Aquila Resolve | `T EH1 N S ER0 F L OW2` | `AH0 G L AA1 M ER0 EY2 T IH0 V` | `N EH1 K R OW0 F EY2 JH IH0 Z` |
+| [Deep Phonemizer](https://github.com/as-ideas/DeepPhonemizer)
([en_us_cmudict_forward.pt](https://github.com/as-ideas/DeepPhonemizer#pretrained-models)) | `T EH N S ER F L OW` | **AH G L AA M ER AH T IH V** | `N EH K R OW F EY JH IH Z` |
+| [CMUSphinx Seq2Seq](https://github.com/cmusphinx/g2p-seq2seq)
([checkpoint](https://github.com/cmusphinx/g2p-seq2seq#running-g2p)) | **T EH1 N S ER0 L OW0 F** | **AH0 G L AA1 M ER0 T IH0 V** | **N AE1 K R AH0 F IH0 JH IH0 Z** |
+| [ESpeakNG](https://github.com/espeak-ng/espeak-ng)
(with [phonecodes](https://github.com/jhasegaw/phonecodes)) | **T EH1 N S OW0 R F L OW2** | **AA G L AA1 M ER0 R AH0 T IH2 V** | **N EH1 K R AH0 F IH JH EH0 Z** |
+
+
## Installation
```bash
@@ -32,8 +63,8 @@ pip install aquila-resolve
> automatically downloaded on the first use of relevant public methods that require inferencing. For example,
> when [instantiating `G2p`](#Usage). You can also start this download manually by calling `Aquila_Resolve.download()`.
>
-> If you are in an environment where remote file downloads are not possible, you can also download the checkpoint
-> manually and instantiate `G2p` with the flag: `G2p(custom_checkpoint='path/model.pt')`
+> If you are in an environment where remote file downloads are not possible, you can also transfer the checkpoint
+> manually, placing `model.pt` within the `Aquila_Resolve.data` module folder.
## Usage
@@ -48,10 +79,10 @@ g2p.convert('The book costs $5, will you read it?')
> Additional optional parameters are available when defining a `G2p` instance:
-| Parameter | Default | Description |
-|--------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| `device` | `'cpu'` | Device for Pytorch inference model |
-| `process_numbers` | `True` | Toggles conversion of some numbers and symbols to their spoken pronunciation forms. See [numbers.py](src/Aquila_Resolve/text/numbers.py) for details on what is covered. |
+| Parameter | Default | Description |
+|-------------------|---------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| `device` | `'cpu'` | Device for Pytorch inference model. GPU is supported using `'cuda'` |
+| `process_numbers` | `True` | Toggles conversion of some numbers and symbols to their spoken pronunciation forms. See [numbers.py](src/Aquila_Resolve/text/numbers.py) for details on what is covered. |
## Model Architecture
diff --git a/setup.cfg b/setup.cfg
index f4e2dd4..97c127a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,12 +1,12 @@
[metadata]
name = Aquila-Resolve
-version = 0.1.2-dev1
+version = 0.1.2
author = ionite
author_email = dev@ionite.io
description = Augmented Recurrent Neural Grapheme-to-Phoneme conversion with Inflectional Orthography.
long_description = file: README.md
long_description_content_type = text/markdown
-url = https://github.com/ionite34/Aquila-Resolve'
+url = https://github.com/ionite34/Aquila-Resolve
license = Apache 2.0
license_file = LICENSE
classifiers =
diff --git a/src/Aquila_Resolve/__init__.py b/src/Aquila_Resolve/__init__.py
index 16af5b0..fb63580 100644
--- a/src/Aquila_Resolve/__init__.py
+++ b/src/Aquila_Resolve/__init__.py
@@ -4,7 +4,7 @@
Grapheme to Phoneme Resolver
"""
-__version__ = "0.1.2-dev1"
+__version__ = "0.1.2"
from .g2p import G2p
from .data.remote import download
diff --git a/src/Aquila_Resolve/data/__init__.py b/src/Aquila_Resolve/data/__init__.py
index fd5de86..8bff045 100644
--- a/src/Aquila_Resolve/data/__init__.py
+++ b/src/Aquila_Resolve/data/__init__.py
@@ -2,7 +2,7 @@
if sys.version_info < (3, 9):
# In Python versions below 3.9, this is needed
- from importlib_resources import files
+ from importlib_resources import files # pragma: no cover
else:
# Since python 3.9+, importlib.resources.files is built-in
from importlib.resources import files
diff --git a/src/Aquila_Resolve/g2p.py b/src/Aquila_Resolve/g2p.py
index ec9ca6b..4b99000 100644
--- a/src/Aquila_Resolve/g2p.py
+++ b/src/Aquila_Resolve/g2p.py
@@ -9,15 +9,14 @@
from nltk.stem.snowball import SnowballStemmer
from .h2p import H2p
-from .h2p import replace_first
+from .text.replace import replace_first
from .format_ph import with_cb
-# from .dict_reader import DictReader
from .static_dict import get_cmudict
from .text.numbers import normalize_numbers
from .filter import filter_text
from .processors import Processor
from .infer import Infer
-from .symbols import contains_alpha, brackets_match
+from .symbols import contains_alpha, valid_braces
re_digit = re.compile(r"\((\d+)\)")
re_bracket_with_digit = re.compile(r"\(.*\)")
@@ -143,13 +142,9 @@ def convert(self, text: str, convert_num: bool = True) -> str | None:
:param convert_num: True to convert numbers to words
"""
- # Check that every {} bracket is paired
- check = brackets_match(text)
- if check is not None:
- raise ValueError(check)
-
- # Normalize numbers, if enabled
+ # Convert numbers, if enabled
if convert_num:
+ valid_braces(text, raise_on_invalid=True)
text = normalize_numbers(text)
# Filter and Tokenize
diff --git a/src/Aquila_Resolve/h2p.py b/src/Aquila_Resolve/h2p.py
index e013a53..4be9df6 100644
--- a/src/Aquila_Resolve/h2p.py
+++ b/src/Aquila_Resolve/h2p.py
@@ -1,27 +1,18 @@
-import nltk
-import re
from nltk.tokenize import TweetTokenizer
from nltk import pos_tag
from nltk import pos_tag_sents
from .dictionary import Dictionary
from .filter import filter_text as ft
+from .text.replace import replace_first
from . import format_ph as ph
-# Check that the nltk data is downloaded, if not, download it
+# Check required nltk data exists, if not, download it
try:
- nltk.data.find('taggers/averaged_perceptron_tagger.zip')
-except LookupError:
- nltk.download('averaged_perceptron_tagger')
-
-
-# Method to use Regex to replace the first instance of a word with its phonemes
-def replace_first(target, replacement, text):
- # Skip if target invalid
- if target is None or target == '':
- return text
- # Replace the first instance of a word with its phonemes
- # return re.sub(r'(?i)\b' + target + r'\b', replacement, text, 1)
- return re.sub(r'(? list[str]:
+ def __call__(self, text: list[str]) -> list[str]:
"""
Infers phonemes for a list of words.
- :param words: list of words
+ :param text: list of words
:return: dict of {word: phonemes}
"""
- res = self.model.phonemise_list(words, lang=self.lang, batch_size=self.batch_size).phonemes
+ res = self.model.phonemise_list(text, lang=self.lang, batch_size=self.batch_size).phonemes
# Replace all occurrences of '][' with spaces, remove remaining brackets
res = [r.replace('][', ' ').replace('[', '').replace(']', '') for r in res]
return res
diff --git a/src/Aquila_Resolve/models/__init__.py b/src/Aquila_Resolve/models/__init__.py
index e3d4f0e..ac7d693 100644
--- a/src/Aquila_Resolve/models/__init__.py
+++ b/src/Aquila_Resolve/models/__init__.py
@@ -2,7 +2,7 @@
if sys.version_info < (3, 9):
# In Python versions below 3.9, this is needed
- from importlib_resources import files
+ from importlib_resources import files # pragma: no cover
else:
# Since python 3.9+, importlib.resources.files is built-in
from importlib.resources import files
diff --git a/src/Aquila_Resolve/models/dp/model/model.py b/src/Aquila_Resolve/models/dp/model/model.py
index 46d4d18..5b66235 100644
--- a/src/Aquila_Resolve/models/dp/model/model.py
+++ b/src/Aquila_Resolve/models/dp/model/model.py
@@ -4,8 +4,7 @@
import torch
import torch.nn as nn
-from torch.nn import TransformerEncoderLayer, LayerNorm, TransformerEncoder
-from .utils import get_dedup_tokens, _make_len_mask, _generate_square_subsequent_mask, PositionalEncoding
+from .utils import _make_len_mask, _generate_square_subsequent_mask, PositionalEncoding
from ..preprocessing.text import Preprocessor
@@ -17,7 +16,7 @@ def is_autoregressive(self) -> bool:
"""
Returns: bool: Whether the model is autoregressive.
"""
- return self in {ModelType.AUTOREG_TRANSFORMER}
+ return self in {ModelType.AUTOREG_TRANSFORMER} # pragma: no cover
class Model(torch.nn.Module, ABC):
@@ -39,91 +38,7 @@ def generate(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.
Tuple[torch.Tensor, torch.Tensor]: The predictions. The first element is a tensor (phoneme tokens)
and the second element is a tensor (phoneme token probabilities)
"""
- pass
-
-
-class ForwardTransformer(Model):
-
- def __init__(self,
- encoder_vocab_size: int,
- decoder_vocab_size: int,
- d_model=512,
- d_fft=1024,
- layers=4,
- dropout=0.1,
- heads=1) -> None:
- super().__init__()
-
- self.d_model = d_model
-
- self.embedding = nn.Embedding(encoder_vocab_size, d_model)
- self.pos_encoder = PositionalEncoding(d_model, dropout)
-
- encoder_layer = TransformerEncoderLayer(d_model=d_model,
- nhead=heads,
- dim_feedforward=d_fft,
- dropout=dropout,
- activation='relu')
- encoder_norm = LayerNorm(d_model)
- self.encoder = TransformerEncoder(encoder_layer=encoder_layer,
- num_layers=layers,
- norm=encoder_norm)
-
- self.fc_out = nn.Linear(d_model, decoder_vocab_size)
-
- def forward(self,
- batch: Dict[str, torch.Tensor]) -> torch.Tensor: # shape: [N, T]
- """
- Forward pass of the model on a data batch.
-
- Args:
- batch (Dict[str, torch.Tensor]): Input batch entry 'text' (text tensor).
-
- Returns:
- Tensor: Predictions.
- """
-
- x = batch['text']
- x = x.transpose(0, 1) # shape: [T, N]
- src_pad_mask = _make_len_mask(x).to(x.device)
- x = self.embedding(x)
- x = self.pos_encoder(x)
- x = self.encoder(x, src_key_padding_mask=src_pad_mask)
- x = self.fc_out(x)
- x = x.transpose(0, 1)
- return x
-
- @torch.jit.export
- def generate(self,
- batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Inference pass on a batch of tokenized texts.
-
- Args:
- batch (Dict[str, torch.Tensor]): Input batch with entry 'text' (text tensor).
-
- Returns:
- Tuple: The first element is a Tensor (phoneme tokens) and the second element
- is a tensor (phoneme token probabilities).
- """
-
- with torch.no_grad():
- x = self.forward(batch)
- tokens, logits = get_dedup_tokens(x)
- return tokens, logits
-
- @classmethod
- def from_config(cls, config: dict) -> 'ForwardTransformer':
- preprocessor = Preprocessor.from_config(config)
- return ForwardTransformer(
- encoder_vocab_size=preprocessor.text_tokenizer.vocab_size,
- decoder_vocab_size=preprocessor.phoneme_tokenizer.vocab_size,
- d_model=config['model']['d_model'],
- d_fft=config['model']['d_fft'],
- layers=config['model']['layers'],
- dropout=config['model']['dropout'],
- heads=config['model']['heads']
- )
+ pass # pragma: no cover
class AutoregressiveTransformer(Model):
@@ -151,42 +66,6 @@ def __init__(self,
dropout=dropout, activation='relu')
self.fc_out = nn.Linear(d_model, decoder_vocab_size)
- def forward(self, batch: Dict[str, torch.Tensor]): # shape: [N, T]
- """
- Foward pass of the model on a data batch.
-
- Args:
- batch (Dict[str, torch.Tensor]): Input batch with entries 'text' (text tensor) and 'phonemes'
- (phoneme tensor for teacher forcing).
-
- Returns:
- Tensor: Predictions.
- """
-
- src = batch['text']
- trg = batch['phonemes'][:, :-1]
-
- src = src.transpose(0, 1) # shape: [T, N]
- trg = trg.transpose(0, 1)
-
- trg_mask = _generate_square_subsequent_mask(len(trg)).to(trg.device)
-
- src_pad_mask = _make_len_mask(src).to(trg.device)
- trg_pad_mask = _make_len_mask(trg).to(trg.device)
-
- src = self.encoder(src)
- src = self.pos_encoder(src)
-
- trg = self.decoder(trg)
- trg = self.pos_decoder(trg)
-
- output = self.transformer(src, trg, src_mask=None, tgt_mask=trg_mask,
- memory_mask=None, src_key_padding_mask=src_pad_mask,
- tgt_key_padding_mask=trg_pad_mask, memory_key_padding_mask=src_pad_mask)
- output = self.fc_out(output)
- output = output.transpose(0, 1)
- return output
-
@torch.jit.export
def generate(self,
batch: Dict[str, torch.Tensor],
@@ -278,15 +157,10 @@ def create_model(model_type: ModelType, config: Dict[str, Any]) -> Model:
Returns: Model: Model object.
"""
-
- if model_type is ModelType.TRANSFORMER:
- model = ForwardTransformer.from_config(config)
- elif model_type is ModelType.AUTOREG_TRANSFORMER:
- model = AutoregressiveTransformer.from_config(config)
- else:
+ if model_type is not ModelType.AUTOREG_TRANSFORMER: # pragma: no cover
raise ValueError(f'Unsupported model type: {model_type}. '
- f'Supported types: {[t.value for t in ModelType]}')
- return model
+ 'Supported type: AUTOREG_TRANSFORMER')
+ return AutoregressiveTransformer.from_config(config)
def load_checkpoint(checkpoint_path: str, device: str = 'cpu') -> Tuple[Model, Dict[str, Any]]:
diff --git a/src/Aquila_Resolve/models/dp/model/utils.py b/src/Aquila_Resolve/models/dp/model/utils.py
index 1ba58b7..3518390 100644
--- a/src/Aquila_Resolve/models/dp/model/utils.py
+++ b/src/Aquila_Resolve/models/dp/model/utils.py
@@ -1,8 +1,5 @@
import math
-from typing import Tuple
-
import torch
-from torch.nn.utils.rnn import pad_sequence
class PositionalEncoding(torch.nn.Module):
@@ -30,48 +27,11 @@ def __init__(self, d_model: int, dropout=0.1, max_len=5000) -> None:
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
- def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N]
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N]
x = x + self.scale * self.pe[:x.size(0), :]
return self.dropout(x)
-def get_dedup_tokens(logits_batch: torch.Tensor) \
- -> Tuple[torch.Tensor, torch.Tensor]:
- """Converts a batch of logits into the batch most probable tokens and their probabilities.
-
- Args:
- logits_batch (Tensor): Batch of logits (N x T x V).
-
- Returns:
- Tuple: Deduplicated tokens. The first element is a tensor (token indices) and the second element
- is a tensor (token probabilities)
-
- """
-
- logits_batch = logits_batch.softmax(-1)
- out_tokens, out_probs = [], []
- for i in range(logits_batch.size(0)):
- logits = logits_batch[i]
- max_logits, max_indices = torch.max(logits, dim=-1)
- max_logits = max_logits[max_indices!=0]
- max_indices = max_indices[max_indices!=0]
- cons_tokens, counts = torch.unique_consecutive(
- max_indices, return_counts=True)
- out_probs_i = torch.zeros(len(counts), device=logits.device)
- ind = 0
- for i, c in enumerate(counts):
- max_logit = max_logits[ind:ind + c].max()
- out_probs_i[i] = max_logit
- ind = ind + c
- out_tokens.append(cons_tokens)
- out_probs.append(out_probs_i)
-
- out_tokens = pad_sequence(out_tokens, batch_first=True, padding_value=0.).long()
- out_probs = pad_sequence(out_probs, batch_first=True, padding_value=0.)
-
- return out_tokens, out_probs
-
-
def _generate_square_subsequent_mask(sz: int) -> torch.Tensor:
mask = torch.triu(torch.ones(sz, sz), 1)
mask = mask.masked_fill(mask == 1, float('-inf'))
@@ -86,9 +46,4 @@ def _get_len_util_stop(sequence: torch.Tensor, end_index: int) -> int:
for i, val in enumerate(sequence):
if val == end_index:
return i + 1
- return len(sequence)
-
-
-def _trim_util_stop(sequence: torch.Tensor, end_index: int) -> torch.Tensor:
- seq_len = _get_len_util_stop(sequence, end_index)
- return sequence[:seq_len]
+ return len(sequence) # pragma: no cover
diff --git a/src/Aquila_Resolve/symbols.py b/src/Aquila_Resolve/symbols.py
index 0acbb77..71ccb4a 100644
--- a/src/Aquila_Resolve/symbols.py
+++ b/src/Aquila_Resolve/symbols.py
@@ -98,35 +98,53 @@ def get_parent_pos(pos: str) -> str | None:
return None
-def contains_alpha(s: str) -> bool:
- # Check if a word contains an alpha character
- return s.upper().isupper()
-
-
-def is_phoneme(s: str) -> bool:
- # Check if a word is a phoneme, detect brackets
- return s.startswith('{') and s.endswith('}')
-
-
-def brackets_match(s: str) -> str | None:
- # Check if string contains brackets at all
- if not ('{' in s or '}' in s):
- return None # Valid
- index_opened = -1
- in_bracket = False
- for i in range(len(s)):
- if not in_bracket:
- if s[i] == '{':
- in_bracket = True
- index_opened = i
- elif s[i] == '}':
- return f'Unexpected close bracket at index {i} without open.'
+def contains_alpha(text: str) -> bool:
+ """
+ Method to check if a string contains alphabetic characters.
+ :param text:
+ :return:
+ """
+ return text.upper().isupper()
+
+
+def is_braced(word: str) -> bool:
+ """
+ Check if a word is surrounded by brace-markings {}.
+
+ :param word: Word
+ :return: True if word is braced-marked.
+ """
+ return word.startswith('{') and word.endswith('}')
+
+
+def valid_braces(text: str, raise_on_invalid: bool = False) -> bool:
+ """
+ Check if a text is valid braced-marked.
+
+ :param text: Text to check.
+ :param raise_on_invalid: Raises ValueError if invalid.
+ :return: True if text is valid braced-marked.
+ """
+ def invalid(msg: str) -> bool:
+ if raise_on_invalid:
+ raise ValueError(f'Invalid braced-marked text ({msg}) in "{text}"')
else:
- if s[i] == '}':
- in_bracket = False
- elif s[i] == '{':
- return f'Unexpected nested open bracket at index {i}.'
- if in_bracket:
- return f'Bracket opened at index {index_opened} but was never closed.'
- return None # Valid
+ return False
+ if not any(c in text for c in {'{', '}'}):
+ return True # No braces, so valid.
+ in_braces = False
+ for char in text:
+ if char == '{':
+ if not in_braces:
+ in_braces = True
+ else:
+ return invalid('Nested braces')
+ elif char == '}':
+ if in_braces:
+ in_braces = False
+ else:
+ return invalid('Closing brace without opening')
+ if in_braces:
+ return invalid('Opening brace without closing')
+ return True
diff --git a/src/Aquila_Resolve/text/numbers.py b/src/Aquila_Resolve/text/numbers.py
index 8905072..c3050b0 100644
--- a/src/Aquila_Resolve/text/numbers.py
+++ b/src/Aquila_Resolve/text/numbers.py
@@ -22,7 +22,8 @@
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_currency_re = re.compile(r'([$€£₩])([0-9.,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]|$))?'.format("|".join(_magnitudes)),
re.IGNORECASE)
-_measurement_re = re.compile(r'([0-9.,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
+# _measurement_re = re.compile(r'([0-9.,]*[0-9]+(\s)?{}\b)'.format(_measurements), re.IGNORECASE)
+_measurement_re = re.compile(r'(? str:
+ """
+ Use Regex to replace the first instance of a word
+
+ Words within braces are ignored (e.g. '{word} is ignored')
+
+ :param target: The word to be replaced
+ :param replacement: Replacement word
+ :param text: Text to be searched
+ :return: Text with the first instance of the word replaced
+ """
+ if not target or not text:
+ return text # Return original if no target or text
+ # Replace the first instance of a word with its phonemes
+ # return re.sub(r'(?i)\b' + target + r'\b', replacement, text, 1)
+ return re.sub(r'(? G2p:
g2p = G2p()
- assert isinstance(g2p, G2p)
yield g2p
# Test for lookup method
@pytest.mark.parametrize("word, phoneme", [
- ('cat', ['K', 'AE1', 'T']),
- ('CaT', ['K', 'AE1', 'T']),
- ('CAT', ['K', 'AE1', 'T']),
- ('test', ['T', 'EH1', 'S', 'T']),
- ('testers', ['T', 'EH1', 'S', 'T', 'ER0', 'Z']),
- ('testers(2)', ['T', 'EH1', 'S', 'T', 'AH0', 'Z']),
+ ('cat', 'K AE1 T'),
+ ('CaT', 'K AE1 T'),
+ ('CAT', 'K AE1 T'),
+ ('test', 'T EH1 S T'),
+ ('testers', 'T EH1 S T ER0 Z'),
+ ('testers(2)', 'T EH1 S T AH0 Z'),
])
def test_lookup(g2p, word, phoneme):
- assert g2p.lookup(word) == ' '.join(phoneme)
+ assert g2p.lookup(word) == phoneme
# Test for convert method
@pytest.mark.parametrize("line, ph_line", zip(cde_lines, cde_expected_results))
def test_convert(g2p, line, ph_line):
assert g2p.convert(line) == ph_line
+
+
+# Test for convert format exception
+@pytest.mark.parametrize("case", [
+ "The cat {R {IY1 D} the} book.",
+ "The cat {R {IY1 D the} book.",
+ "The cat {R IY1 D} the} book.",
+ "The cat {R IY1 D the book.",
+ "The cat R IY1 D} the book.",
+])
+def test_convert_ex_format(g2p, case):
+ with pytest.raises(ValueError):
+ g2p.convert(case)
diff --git a/tests/test_h2p.py b/tests/test_h2p.py
index aaaeeac..858ff54 100644
--- a/tests/test_h2p.py
+++ b/tests/test_h2p.py
@@ -1,6 +1,4 @@
import pytest
-import random
-from Aquila_Resolve.h2p import replace_first
# List of lines
ex_lines = [
@@ -17,40 +15,14 @@
]
-# Function to generate test lines with n sentences, randomly chosen from the list of lines
-# The number of heteronyms would be 2n
-def gen_line(n):
- # List of lines
- lines = [
- "The cat read the book. It was a good book to read.",
- "You should absent yourself from the meeting. Then you would be absent.",
- "The machine would automatically reject products. These were the reject products.",
- ]
- test_line = ""
- # Loop through n sentences
- for i in range(n):
- # Add space if not the first part
- if not i == 0:
- test_line += " "
- test_line += (random.choice(lines))
- return test_line
-
-
-# Test Data, for contains_het()
-# List of tuples, each tuple contains:
-# - Line to test
-# - Expected result (True/False)
-contains_het_data = [
+# Test the contains_het function
+@pytest.mark.parametrize("line, expected", [
("The cat read the book. It was a good book to read.", True),
("The effect was absent.", True),
("Symbols like !, ?, and ;", False),
("The product was a reject.", True),
("", False), (" ", False), ("\n", False), ("\t", False)
-]
-
-
-# Test the contains_het function
-@pytest.mark.parametrize("line, expected", contains_het_data)
+])
def test_contains_het(h2p, line, expected):
assert h2p.contains_het(line) == expected
@@ -66,16 +38,3 @@ def test_replace_het_list(h2p):
results = h2p.replace_het_list(ex_lines)
for result, expected in zip(results, ex_expected_results):
assert expected == result
-
-
-replace_first_data = [
- ("the", "re", "The cat read the book.", "re cat read the book."),
- ("the", "{re mult}", "The effect was absent.", "{re mult} effect was absent."),
- ("the", "re", "Symbols !, ?, and ;", "Symbols !, ?, and ;")
-]
-
-
-# Test for the test_replace_first function
-@pytest.mark.parametrize("search, replace, line, expected", replace_first_data)
-def test_replace_first(search, replace, line, expected):
- assert replace_first(search, replace, line) == expected
diff --git a/tests/test_infer.py b/tests/test_infer.py
new file mode 100644
index 0000000..77dc5eb
--- /dev/null
+++ b/tests/test_infer.py
@@ -0,0 +1,19 @@
+import pytest
+from Aquila_Resolve.infer import Infer
+
+
+@pytest.fixture(scope="module")
+def infer():
+ yield Infer()
+
+
+# noinspection SpellCheckingInspection
+@pytest.mark.parametrize("case, exp", [
+ ([""], [""]),
+ (["a"], ["AH0"]),
+ (["a", "a"], ["AH0", "AH0"]), # Test De-duplication
+ (["a", "b"], ["AH0", "B IY1"]),
+ (["ioniformi"], ["IY0 AA2 N IH0 F AO1 R M IY0"]), # OOV word
+])
+def test_infer(infer, case, exp):
+ assert infer(case) == exp
diff --git a/tests/test_replace.py b/tests/test_replace.py
new file mode 100644
index 0000000..8997cd1
--- /dev/null
+++ b/tests/test_replace.py
@@ -0,0 +1,16 @@
+import pytest
+from Aquila_Resolve.text.replace import replace_first
+
+
+# Test for the test_replace_first function
+@pytest.mark.parametrize("search, replace, line, expected", [
+ (None, "re", "Text.", "Text."),
+ ("the", "re", "", ""),
+ ("the", "re", None, None),
+ ("the", "re", "Thesis.", "Thesis."),
+ ("the", "re", "The cat read the book.", "re cat read the book."),
+ ("the", "{re mult}", "The effect was absent.", "{re mult} effect was absent."),
+ ("the", "re", "Symbols !, ?, and ;", "Symbols !, ?, and ;")
+])
+def test_replace_first(search, replace, line, expected):
+ assert replace_first(search, replace, line) == expected
diff --git a/tests/test_symbols.py b/tests/test_symbols.py
index 9d11d6e..a13aa0c 100644
--- a/tests/test_symbols.py
+++ b/tests/test_symbols.py
@@ -36,3 +36,25 @@ def test_get_parent_pos_verb(tag, expected):
def test_get_parent_pos_invalid_tag():
# If the pos tag is not in the list, expect None
assert symbols.get_parent_pos('XYZ') is None
+
+
+@pytest.mark.parametrize('case, exp', [
+ ('ABC', True),
+ ('abc', True),
+ ('0', False),
+ ('1A', True),
+ ('1@$%&', False),
+ ('@a$%&', True),
+])
+def test_contains_alpha(case, exp):
+ assert symbols.contains_alpha(case) == exp
+
+
+@pytest.mark.parametrize('case, exp', [
+ ('word', False),
+ ('{AH0}', True),
+ ('{C AH0 T}', True),
+ ('In {AH0} line.', False),
+])
+def test_is_phoneme(case, exp):
+ assert symbols.is_braced(case) == exp