Skip to content

Commit

Permalink
add kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed May 19, 2023
1 parent 053a7a6 commit d712f19
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions hf_hub_ctranslate2/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def tokenize_encode(self, text, *args, **kwargs):
def tokenize_decode(self, tokens_out, *args, **kwargs):
raise NotImplementedError

def generate(self, text: Union[str, List[str]], encode_kwargs={}, *forward_args, **forward_kwds: Any):
def generate(self, text: Union[str, List[str]], encode_kwargs={}, decode_kwargs={}, *forward_args, **forward_kwds: Any):
orig_type = list
if isinstance(text, str):
orig_type = str
text = [text]
token_list = self.tokenize_encode(text, **encode_kwargs)
tokens_out = self._forward(token_list, *forward_args, **forward_kwds)
texts_out = self.tokenize_decode(tokens_out)
texts_out = self.tokenize_decode(tokens_out, **decode_kwargs)
if orig_type == str:
return texts_out[0]
else:
Expand Down Expand Up @@ -122,11 +122,13 @@ def tokenize_decode(self, tokens_out, *args):
for i in range(len(tokens_out))
]

def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any):
def generate(self, text: Union[str, List[str]], encode_tok_kwargs={}, decode_tok_kwargs={}, *forward_args, **forward_kwds: Any):
"""_summary_
Args:
text (Union[str, List[str]]): Input texts
encode_tok_kwargs (dict, optional): additional kwargs for tokenizer
decode_tok_kwargs (dict, optional): additional kwargs for tokenizer
max_batch_size (int, optional): Batch size. Defaults to 0.
batch_type (str, optional): _. Defaults to "examples".
asynchronous (bool, optional): Only False supported. Defaults to False.
Expand Down Expand Up @@ -160,7 +162,7 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A
Returns:
Union[str, List[str]]: text as output, if list, same len as input
"""
return super().generate(text, *forward_args, **forward_kwds)
return super().generate(text, encode_kwargs=encode_tok_kwargs, decode_kwargs=decode_tok_kwargs, *forward_args, **forward_kwds)

class MultiLingualTranslatorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
def __init__(
Expand Down Expand Up @@ -298,11 +300,13 @@ def tokenize_decode(self, tokens_out, *args):
]


def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: Any):
def generate(self, text: Union[str, List[str]], encode_tok_kwargs={}, decode_tok_kwargs={}, *forward_args, **forward_kwds: Any):
"""_summary_
Args:
text (str | List[str]): Input texts
encode_tok_kwargs (dict, optional): additional kwargs for tokenizer
decode_tok_kwargs (dict, optional): additional kwargs for tokenizer
max_batch_size (int, optional): _. Defaults to 0.
batch_type (str, optional): _. Defaults to 'examples'.
asynchronous (bool, optional): _. Defaults to False.
Expand Down Expand Up @@ -330,5 +334,5 @@ def generate(self, text: Union[str, List[str]], *forward_args, **forward_kwds: A
Returns:
str | List[str]: text as output, if list, same len as input
"""
return super().generate(text, *forward_args, **forward_kwds)
return super().generate(text, encode_kwargs=encode_tok_kwargs, decode_kwargs=decode_tok_kwargs, *forward_args, **forward_kwds)

0 comments on commit d712f19

Please sign in to comment.