Skip to content

Commit

Permalink
Extends embed to allow sequence of texts
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximiliano Marufo da Silva committed Oct 11, 2022
1 parent 1c63035 commit 99da91b
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions bpemb/bpemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,22 +365,30 @@ def _encode(self, texts, fn):
texts = map(self.preprocess, texts)
return list(map(fn, texts))

def embed(self, text: str) -> np.ndarray:
def embed(self, texts: Union[str, Sequence[str]]) -> np.ndarray:
"""Byte-pair encode text and return the corresponding byte-pair
embeddings.
Parameters
----------
text: ``str'', required
The text to encode and embed.
texts: ``Union[str, Sequence[str]]'', required
The text or texts to encode and embed.
Returns
-------
A matrix of shape (l, d), where l is the length of the byte-pair
encoded text and d the embedding dimension.
If texts is a string, a matrix of shape (l, d), where l is the length
of the byte-pair encoded text and d the embedding dimension.
If texts is a sequence of strings, an array of shape (n,), where n is
the lenght of the text sequence, and each element is a matrix of shape
(l_i, d).
"""
ids = self.encode_ids(text)
return self.emb.vectors[ids]
ids = self.encode_ids(texts)
if isinstance(texts, str):
return self.emb.vectors[ids]
return np.asarray([
self.emb.vectors[text_ids]
for text_ids in ids
])

def decode(
self,
Expand Down

0 comments on commit 99da91b

Please sign in to comment.