From 99da91bb9dd127e8eed62c23acdb1f76aa71f8f4 Mon Sep 17 00:00:00 2001 From: Maximiliano Marufo da Silva Date: Tue, 11 Oct 2022 11:00:16 -0300 Subject: [PATCH] Extends embed to allow sequence of texts --- bpemb/bpemb.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/bpemb/bpemb.py b/bpemb/bpemb.py index 7a212f9..b33c1a6 100644 --- a/bpemb/bpemb.py +++ b/bpemb/bpemb.py @@ -365,22 +365,30 @@ def _encode(self, texts, fn): texts = map(self.preprocess, texts) return list(map(fn, texts)) - def embed(self, text: str) -> np.ndarray: + def embed(self, texts: Union[str, Sequence[str]]) -> np.ndarray: """Byte-pair encode text and return the corresponding byte-pair embeddings. Parameters ---------- - text: ``str'', required - The text to encode and embed. + texts: ``Union[str, Sequence[str]]'', required + The text or texts to encode and embed. Returns ------- - A matrix of shape (l, d), where l is the length of the byte-pair - encoded text and d the embedding dimension. + If texts is a string, a matrix of shape (l, d), where l is the length + of the byte-pair encoded text and d the embedding dimension. + If texts is a sequence of strings, an array of shape (n,), where n is + the lenght of the text sequence, and each element is a matrix of shape + (l_i, d). """ - ids = self.encode_ids(text) - return self.emb.vectors[ids] + ids = self.encode_ids(texts) + if isinstance(texts, str): + return self.emb.vectors[ids] + return np.asarray([ + self.emb.vectors[text_ids] + for text_ids in ids + ]) def decode( self,