diff --git a/utils_nlp/models/transformers/abstractive_summarization_bertsum.py b/utils_nlp/models/transformers/abstractive_summarization_bertsum.py index 03beb3e24..0984f7c65 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bertsum.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bertsum.py @@ -318,7 +318,7 @@ def preprocess(self, story_lines, summary_lines=None): if len(line) <= 0: continue story_lines_token_ids.append( - self.tokenizer.encode(line, max_length=self.max_src_len) + self.tokenizer.encode(line, truncation=True, max_length=self.max_src_len) ) except: print(line) @@ -333,7 +333,7 @@ def preprocess(self, story_lines, summary_lines=None): if len(line) <= 0: continue summary_lines_token_ids.append( - self.tokenizer.encode(line, max_length=self.max_tgt_len) + self.tokenizer.encode(line, truncation=True, max_length=self.max_tgt_len) ) except: print(line) diff --git a/utils_nlp/models/transformers/bertsum/predictor.py b/utils_nlp/models/transformers/bertsum/predictor.py index 20a30b895..81b6beb76 100644 --- a/utils_nlp/models/transformers/bertsum/predictor.py +++ b/utils_nlp/models/transformers/bertsum/predictor.py @@ -256,7 +256,7 @@ def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0): topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. - topk_beam_index = topk_ids.div(vocab_size) + topk_beam_index = topk_ids.true_divide(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. @@ -267,7 +267,7 @@ def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0): # Append last prediction. alive_seq = torch.cat( - [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1 + [alive_seq.index_select(0, select_indices.view(-1).long()), topk_ids.view(-1, 1)], -1 ) is_finished = topk_ids.eq(self.end_token) @@ -310,9 +310,9 @@ def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0): ) # Reorder states. select_indices = batch_index.view(-1) - src_features = src_features.index_select(0, select_indices) + src_features = src_features.index_select(0, select_indices.view(-1).long()) dec_states.map_batch_fn( - lambda state, dim: state.index_select(dim, select_indices) + lambda state, dim: state.index_select(dim, select_indices.view(-1).long()) ) empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset]