File tree Expand file tree Collapse file tree 1 file changed +12
-6
lines changed
sonar/inference_pipelines Expand file tree Collapse file tree 1 file changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -143,6 +143,7 @@ def predict(
143
143
max_seq_len : Optional [int ] = None ,
144
144
progress_bar : bool = False ,
145
145
target_device : Optional [Device ] = None ,
146
+ sort_sent : bool = True ,
146
147
) -> torch .Tensor :
147
148
"""
148
149
Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings.
@@ -166,13 +167,18 @@ def truncate(x: torch.Tensor) -> torch.Tensor:
166
167
n_truncated += 1
167
168
return x [:max_seq_len ]
168
169
170
+ def sort_input (input_sent : Iterable [str ]) -> Iterable [str ]:
171
+ return sorted (input_sent , key = len ) if sort_sent else input_sent
172
+
173
+ if isinstance (input , (str , Path )):
174
+ input_sent = read_text (input )
175
+ else :
176
+ input_sent = read_sequence (input )
177
+
178
+ sorted_input_sent = sort_input (input_sent = input_sent )
179
+
169
180
pipeline : Iterable = (
170
- (
171
- read_text (input )
172
- if isinstance (input , (str , Path ))
173
- else read_sequence (input )
174
- )
175
- .map (tokenizer_encoder )
181
+ sorted_input_sent .map (tokenizer_encoder )
176
182
.map (truncate )
177
183
.bucket (batch_size )
178
184
.map (Collater (self .tokenizer .vocab_info .pad_idx ))
You can’t perform that action at this time.
0 commit comments