-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenization.py
33 lines (26 loc) · 1 KB
/
tokenization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
from typing import List, Optional
import sentencepiece
class Tokenizer:
def __init__(self, model_path: Optional[str]):
# Reload tokenizer.
assert os.path.isfile(model_path), model_path
self.sp_model = sentencepiece.SentencePieceProcessor()
self.sp_model.Load(model_path)
# BOS / EOS token IDs.
self.n_words: int = self.sp_model.GetPieceSize()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
"""Converts a string into a list of tokens."""
assert isinstance(s, str)
t = self.sp_model.EncodeAsIds(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
"""Converts a list of tokens into a string."""
return self.sp_model.DecodeIds(t)