Skip to content

Commit fd6a2d0

Browse files
committed
Added sorting option for sentences during encoding
1 parent f17dffa commit fd6a2d0

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

sonar/inference_pipelines/text.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def predict(
143143
max_seq_len: Optional[int] = None,
144144
progress_bar: bool = False,
145145
target_device: Optional[Device] = None,
146+
sort_sent: bool = True,
146147
) -> torch.Tensor:
147148
"""
148149
Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings.
@@ -166,13 +167,18 @@ def truncate(x: torch.Tensor) -> torch.Tensor:
166167
n_truncated += 1
167168
return x[:max_seq_len]
168169

170+
def sort_input(input_sent: Iterable[str]) -> Iterable[str]:
171+
return sorted(input_sent, key=len) if sort_sent else input_sent
172+
173+
if isinstance(input, (str, Path)):
174+
input_sent = read_text(input)
175+
else:
176+
input_sent = read_sequence(input)
177+
178+
sorted_input_sent = sort_input(input_sent=input_sent)
179+
169180
pipeline: Iterable = (
170-
(
171-
read_text(input)
172-
if isinstance(input, (str, Path))
173-
else read_sequence(input)
174-
)
175-
.map(tokenizer_encoder)
181+
sorted_input_sent.map(tokenizer_encoder)
176182
.map(truncate)
177183
.bucket(batch_size)
178184
.map(Collater(self.tokenizer.vocab_info.pad_idx))

0 commit comments

Comments
 (0)