diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 0d8dd00..f51c4c0 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -143,6 +143,7 @@ def predict( max_seq_len: Optional[int] = None, progress_bar: bool = False, target_device: Optional[Device] = None, + sort_sent: bool = True, ) -> torch.Tensor: """ Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings. @@ -165,12 +166,19 @@ def truncate(x: torch.Tensor) -> torch.Tensor: nonlocal n_truncated n_truncated += 1 return x[:max_seq_len] + + def sort_input(input_sent) -> Sequence[str]: + return sorted(input_sent, key=len) if sort_sent else input_sent + + if isinstance(input, (str, Path)): + input_data = read_sequence(sort_input(read_text(input))) + input_data = read_text(input).and_return() + else: + input_data = read_sequence(sort_input(input)) pipeline: Iterable = ( ( - read_text(input) - if isinstance(input, (str, Path)) - else read_sequence(input) + input_data ) .map(tokenizer_encoder) .map(truncate)