Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Jul 8, 2021
2 parents 0c0f47d + 71d238f commit a024c17
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions transformer_embedder/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def __init__(
self.config = tr.AutoConfig.from_pretrained(model)
else:
self.huggingface_tokenizer = model
self.config = tr.AutoConfig.from_pretrained(self.huggingface_tokenizer.name_or_path)
self.config = tr.AutoConfig.from_pretrained(
self.huggingface_tokenizer.name_or_path
)
# spacy tokenizer, lazy load. None at first
self.spacy_tokenizer = None
# default multilingual model
Expand Down Expand Up @@ -129,10 +131,16 @@ def __call__(
)

# if text is str or a list of str and they are not split, then text needs to be tokenized
if isinstance(text, str) or (not is_split_into_words and isinstance(text[0], str)):
if isinstance(text, str) or (
not is_split_into_words and isinstance(text[0], str)
):
if not is_batched:
text = self.pretokenize(text, use_spacy=use_spacy)
text_pair = self.pretokenize(text_pair, use_spacy=use_spacy) if text_pair else None
text_pair = (
self.pretokenize(text_pair, use_spacy=use_spacy)
if text_pair
else None
)
else:
text = [self.pretokenize(t, use_spacy=use_spacy) for t in text]
text_pair = (
Expand Down Expand Up @@ -216,13 +224,17 @@ def build_tokens(
Returns:
a dictionary with A and B encoded
"""
words, input_ids, token_type_ids, offsets = self._build_tokens(text, max_len=max_len)
words, input_ids, token_type_ids, offsets = self._build_tokens(
text, max_len=max_len
)
if text_pair:
words_b, input_ids_b, token_type_ids_b, offsets_b = self._build_tokens(
text_pair, True, max_len
)
# align offsets of sentence b
offsets_b = [(o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b]
offsets_b = [
(o[0] + len(input_ids), o[1] + len(input_ids)) for o in offsets_b
]
offsets = offsets + offsets_b
input_ids += input_ids_b
token_type_ids += token_type_ids_b
Expand Down Expand Up @@ -290,7 +302,9 @@ def _build_tokens(
token_type_ids += [token_type_id]
return words, input_ids, token_type_ids, offsets

def pad_batch(self, batch: Dict[str, list], max_length: int = None) -> Dict[str, list]:
def pad_batch(
self, batch: Dict[str, list], max_length: int = None
) -> Dict[str, list]:
"""
Pad the batch to its maximum length.
Expand Down Expand Up @@ -376,7 +390,9 @@ def pretokenize(self, text: str, use_spacy: bool = False) -> List[str]:
return [t.text for t in text]
return text.split(" ")

def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]]) -> int:
def add_special_tokens(
self, special_tokens_dict: Dict[str, Union[str, tr.AddedToken]]
) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder.
If special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last
Expand Down Expand Up @@ -442,7 +458,8 @@ def to_tensor(self, batch: Union[List[dict], dict]) -> Dict[str, "torch.Tensor"]
"""
# convert to tensor
batch = {
k: torch.as_tensor(v) if k in self.to_tensor_inputs else v for k, v in batch.items()
k: torch.as_tensor(v) if k in self.to_tensor_inputs else v
for k, v in batch.items()
}
return batch

Expand All @@ -457,7 +474,9 @@ def _load_spacy(self) -> "spacy.tokenizer.Tokenizer":
try:
spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"])
except OSError:
logger.info(f"Spacy model '{self.language}' not found. Downloading and installing.")
logger.info(
f"Spacy model '{self.language}' not found. Downloading and installing."
)
spacy_download(self.language)
spacy_tagger = spacy.load(self.language, exclude=["ner", "parser"])
self.spacy_tokenizer = spacy_tagger.tokenizer
Expand Down Expand Up @@ -564,9 +583,9 @@ def num_special_tokens(self) -> int:
int: the number of special tokens
"""
if isinstance(self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP) and isinstance(
self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN
):
if isinstance(
self.huggingface_tokenizer, MODELS_WITH_DOUBLE_SEP
) and isinstance(self.huggingface_tokenizer, MODELS_WITH_STARTING_TOKEN):
return 4
if isinstance(
self.huggingface_tokenizer,
Expand Down

0 comments on commit a024c17

Please sign in to comment.