From 11e5baae89b212a9c4b60dcb6cd3bdcaae30ee90 Mon Sep 17 00:00:00 2001 From: Ben King Date: Wed, 2 Oct 2024 16:36:38 -0400 Subject: [PATCH] Fix for #540, issues (+ support for multiple translations) with experiment --test --- silnlp/nmt/experiment.py | 1 + silnlp/nmt/hugging_face_config.py | 107 ++++++++++++++++++++++++------ silnlp/nmt/test.py | 73 +++++++++++++++++--- 3 files changed, 150 insertions(+), 31 deletions(-) diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 591e7f8d..da697a9d 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -83,6 +83,7 @@ def test(self): best=self.config.model_dir.exists(), by_book=self.score_by_book, scorers=self.scorers, + produce_multiple_translations=self.produce_multiple_translations, ) SIL_NLP_ENV.copy_experiment_to_bucket( self.name, patterns=("scores-*.csv", "test.*trg-predictions.*"), overwrite=True diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 3b30ae52..56fa11f7 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -70,7 +70,7 @@ from ..common.corpus import Term, count_lines, get_terms from ..common.environment import SIL_NLP_ENV, download_if_s3_paths -from ..common.translator import TranslationGroup +from ..common.translator import DraftGroup, TranslationGroup from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, merge_dict from .config import CheckpointType, Config, DataFile, NMTModel from .tokenizer import NullTokenizer, Tokenizer @@ -1008,23 +1008,64 @@ def translate_test_files( with ExitStack() as stack: src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig")) sentences = (line.strip().split() for line in src_file) - out_file = stack.enter_context(translation_path.open("w", encoding="utf-8", newline="\n")) vrefs: Optional[Iterable[VerseRef]] = None if vref_path is not None: vref_file = stack.enter_context(vref_path.open("r", encoding="utf-8-sig")) vrefs = (VerseRef.from_string(line.strip(), ORIGINAL_VERSIFICATION) for line in vref_file) - for output_group in tqdm( - self._translate_sentences( - tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True - ), - total=length, - unit="ex", - ): - ids = to_py_obj(output_group.get_token_ids()[0]) - ids = [id for id in ids[1:] if id != tokenizer.pad_token_id] - tokens = tokenizer.convert_ids_to_tokens(ids) - out_file.write(" ".join(tokens) + "\n") + draft_group = DraftGroup( + list( + self._translate_test_sentences( + tokenizer, pipeline, sentences, vrefs, length, produce_multiple_translations + ) + ) + ) + + for draft_index, translated_draft in enumerate(draft_group.get_drafts(), 1): + + if produce_multiple_translations: + translation_draft_path = translation_path.with_suffix( + f".{draft_index}{translation_path.suffix}" + ) + else: + translation_draft_path = translation_path + out_file = stack.enter_context(translation_draft_path.open("w", encoding="utf-8", newline="\n")) + + out_file.write("\n".join(translated_draft) + "\n") + + def _translate_test_sentences( + self, + tokenizer: PreTrainedTokenizer, + pipeline: TranslationPipeline, + sentences: Iterable[List[str]], + vrefs: Iterable[VerseRef], + length: int, + produce_multiple_translations: bool = False, + ) -> Iterable[TranslationGroup]: + num_drafts = self.get_num_drafts() + if produce_multiple_translations and num_drafts > 1: + LOGGER.info("Producing %i translated drafts", num_drafts) + elif produce_multiple_translations and num_drafts <= 1: + LOGGER.warning( + "num_drafts must be greater than 1 when using --multiple-translations. " + "Falling back to a single translation." + ) + + for output_group in tqdm( + self._translate_sentences( + tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True + ), + total=length, + unit="ex", + ): + ids = to_py_obj(output_group.get_token_ids()) + ids = [[id for id in output[1:] if id != tokenizer.pad_token_id] for output in ids] + tokens = [tokenizer.convert_ids_to_tokens(id_group) for id_group in ids] + yield [" ".join(token_group) for token_group in tokens] + + def get_num_drafts(self) -> int: + num_drafts = self._config.infer.get("num_drafts") + return num_drafts def translate( self, @@ -1048,7 +1089,7 @@ def translate( device=0, ) - num_drafts = self._config.infer.get("num_drafts") + num_drafts = self.get_num_drafts() if produce_multiple_translations and num_drafts > 1: LOGGER.info("Producing %i translated drafts", num_drafts) elif produce_multiple_translations and num_drafts <= 1: @@ -1300,10 +1341,13 @@ def _translate_sentence_helper( force_words_ids: List[List[List[int]]] = None, produce_multiple_translations: bool = False, ) -> Iterable[OutputGroup]: - num_drafts = self._config.infer.get("num_drafts") + + num_drafts = self.get_num_drafts() if produce_multiple_translations and num_drafts > 1: multiple_translations_method: str = self._config.infer.get("multiple_translations_method") + sentences = list(sentences) + if multiple_translations_method == "hybrid": beam_search_results: List[dict] = self._translate_with_beam_search( pipeline, @@ -1325,7 +1369,7 @@ def _translate_sentence_helper( # concatenate the beam search results with the sampling results yield from [ - OutputGroup([beam_search_results[i]] + sampling_results[i]) for i in range(len(beam_search_results)) + OutputGroup(beam_search_results[i] + sampling_results[i]) for i in range(len(beam_search_results)) ] elif multiple_translations_method == "sampling": @@ -1371,7 +1415,7 @@ def _translate_sentence_helper( else: yield from [ - OutputGroup([translated_sentence]) + OutputGroup([translated_sentence[0]]) for translated_sentence in self._translate_with_beam_search( pipeline, sentences, @@ -1382,6 +1426,12 @@ def _translate_sentence_helper( ) ] + # When translating tokenized sentences, for some reason the Huggingface pipeline + # returns List[List[dict]] instead of List[dict]. Each nested list is a + # singleton. This function flattens the structure. + def _flatten_tokenized_translations(self, pipeline_output) -> List[dict]: + return [[i if isinstance(i, dict) else i[0] for i in translation] for translation in pipeline_output] + def _translate_with_beam_search( self, pipeline: TranslationPipeline, @@ -1390,12 +1440,12 @@ def _translate_with_beam_search( return_tensors: bool, num_return_sequences: int = 1, force_words_ids: List[List[List[int]]] = None, - ) -> List[dict]: + ) -> List[List[dict]]: num_beams: Optional[int] = self._config.infer.get("num_beams") if num_beams is None: num_beams = self._config.params.get("generation_num_beams") - return pipeline( + translations = pipeline( sentences, num_beams=num_beams, num_return_sequences=num_return_sequences, @@ -1405,6 +1455,11 @@ def _translate_with_beam_search( return_tensors=return_tensors, ) + if num_return_sequences == 1: + translations = [[t] for t in translations] + + return self._flatten_tokenized_translations(translations) + def _translate_with_sampling( self, pipeline: TranslationPipeline, @@ -1417,7 +1472,7 @@ def _translate_with_sampling( temperature: Optional[int] = self._config.infer.get("temperature") - return pipeline( + translations = pipeline( sentences, do_sample=True, temperature=temperature, @@ -1428,6 +1483,11 @@ def _translate_with_sampling( return_tensors=return_tensors, ) + if num_return_sequences == 1: + translations = [[t] for t in translations] + + return self._flatten_tokenized_translations(translations) + def _translate_with_diverse_beam_search( self, pipeline: TranslationPipeline, @@ -1442,7 +1502,7 @@ def _translate_with_diverse_beam_search( num_beams = self._config.params.get("generation_num_beams") diversity_penalty: Optional[float] = self._config.infer.get("diversity_penalty") - return pipeline( + translations = pipeline( sentences, num_beams=num_beams, num_beam_groups=num_beams, @@ -1454,6 +1514,11 @@ def _translate_with_diverse_beam_search( return_tensors=return_tensors, ) + if num_return_sequences == 1: + translations = [[t] for t in translations] + + return self._flatten_tokenized_translations(translations) + def _create_inference_model( self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer ) -> PreTrainedModel: diff --git a/silnlp/nmt/test.py b/silnlp/nmt/test.py index 6f04343d..42324f87 100644 --- a/silnlp/nmt/test.py +++ b/silnlp/nmt/test.py @@ -35,6 +35,7 @@ def __init__( sent_len: int, projects: Set[str], other_scores: Dict[str, float] = {}, + draft_index: int = 1, ) -> None: self.src_iso = src_iso self.trg_iso = trg_iso @@ -44,10 +45,11 @@ def __init__( self.refs = "_".join(sorted(projects)) self.other_scores = other_scores self.book = book + self.draft_index = draft_index def writeHeader(self, file: IO) -> None: header = ( - "book,src_iso,trg_iso,num_refs,references,sent_len" + "book,draft_index,src_iso,trg_iso,num_refs,references,sent_len" + ( ",BLEU,BLEU_1gram_prec,BLEU_2gram_prec,BLEU_3gram_prec,BLEU_4gram_prec,BLEU_brevity_penalty,BLEU_total_sys_len,BLEU_total_ref_len" if self.bleu is not None @@ -60,7 +62,10 @@ def writeHeader(self, file: IO) -> None: file.write(header) def write(self, file: IO) -> None: - file.write(f"{self.book},{self.src_iso},{self.trg_iso}," f"{self.num_refs},{self.refs},{self.sent_len:d}") + file.write( + f"{self.book},{self.draft_index},{self.src_iso},{self.trg_iso}," + f"{self.num_refs},{self.refs},{self.sent_len:d}" + ) if self.bleu is not None: file.write( f",{self.bleu.score:.2f},{self.bleu.precisions[0]:.2f},{self.bleu.precisions[1]:.2f}" @@ -82,6 +87,7 @@ def score_pair( scorers: Set[str], config: Config, ref_projects: Set[str], + draft_index: int = 1, ) -> PairScore: bleu_score = None if "bleu" in scorers: @@ -142,7 +148,7 @@ def score_pair( if ter_score.score >= 0: other_scores["TER"] = ter_score.score - return PairScore(book, src_iso, trg_iso, bleu_score, len(pair_sys), ref_projects, other_scores) + return PairScore(book, src_iso, trg_iso, bleu_score, len(pair_sys), ref_projects, other_scores, draft_index) def score_individual_books( @@ -361,6 +367,7 @@ def test_checkpoint( step: int, scorers: Set[str], books: Dict[int, List[int]], + produce_multiple_translations: bool = False, ) -> List[PairScore]: config.set_seed() vref_file_names: List[str] = [] @@ -413,17 +420,48 @@ def test_checkpoint( model.translate_test_files( source_paths, translation_paths, - produce_multiple_translations=False, - vref_paths=vref_paths, - ckpt=step if checkpoint_type is CheckpointType.OTHER else checkpoint_type, + produce_multiple_translations, + vref_paths, + step if checkpoint_type is CheckpointType.OTHER else checkpoint_type, ) + if produce_multiple_translations: + num_drafts = model.get_num_drafts() + vref_file_names = num_drafts * vref_file_names + source_file_names = num_drafts * source_file_names + translation_file_names = [ + str(Path(file_name).with_suffix(f".{draft_index}{Path(file_name).suffix}")) + for draft_index in range(1, num_drafts + 1) + for file_name in translation_file_names + ] + refs_patterns = num_drafts * refs_patterns + translation_detok_file_names = [ + str(Path(file_name).with_suffix(f".{draft_index}{Path(file_name).suffix}")) + for draft_index in range(1, num_drafts + 1) + for file_name in translation_detok_file_names + ] + draft_indices = num_drafts * list(range(1, num_drafts + 1)) + else: + draft_indices = len(source_file_names) * [1] + LOGGER.info(f"Scoring {checkpoint_name}") scores: List[PairScore] = [] overall_sys: List[str] = [] overall_refs: List[List[str]] = [] - for vref_file_name, features_file_name, predictions_file_name, refs_pattern, predictions_detok_file_name in zip( - vref_file_names, source_file_names, translation_file_names, refs_patterns, translation_detok_file_names + for ( + vref_file_name, + features_file_name, + predictions_file_name, + refs_pattern, + predictions_detok_file_name, + draft_index, + ) in zip( + vref_file_names, + source_file_names, + translation_file_names, + refs_patterns, + translation_detok_file_names, + draft_indices, ): src_iso = config.default_test_src_iso trg_iso = config.default_test_trg_iso @@ -456,7 +494,16 @@ def test_checkpoint( scores.append( score_pair( - pair_sys, pair_refs, "ALL", src_iso, trg_iso, predictions_detok_file_name, scorers, config, ref_projects + pair_sys, + pair_refs, + "ALL", + src_iso, + trg_iso, + predictions_detok_file_name, + scorers, + config, + ref_projects, + draft_index, ) ) @@ -495,6 +542,7 @@ def test( ref_projects: Set[str] = set(), books: List[str] = [], by_book: bool = False, + produce_multiple_translations: bool = False, ): exp_name = experiment SIL_NLP_ENV.copy_experiment_from_bucket(exp_name) @@ -527,6 +575,7 @@ def test( step, scorers, books_nums, + produce_multiple_translations, ) if avg: @@ -543,6 +592,7 @@ def test( step, scorers, books_nums, + produce_multiple_translations, ) except ValueError: LOGGER.warn("No average checkpoint available.") @@ -563,6 +613,7 @@ def test( step, scorers, books_nums, + produce_multiple_translations, ) if last or (not best and checkpoint is None and not avg and config.model_dir.exists()): @@ -579,6 +630,7 @@ def test( step, scorers, books_nums, + produce_multiple_translations, ) if not config.model_dir.exists(): @@ -593,6 +645,7 @@ def test( 0, scorers, books_nums, + produce_multiple_translations, ) SIL_NLP_ENV.copy_experiment_to_bucket( @@ -611,7 +664,7 @@ def test( checkpoint_name = f"checkpoint {step}" books_str = "ALL" if len(books_nums) == 0 else ", ".join(sorted(str(num) for num in books_nums.keys())) LOGGER.info(f"Test results for {checkpoint_name} ({num_refs} reference(s), books: {books_str})") - header = "book,src_iso,trg_iso,num_refs,references,sent_len" + header = "book,draft_index,src_iso,trg_iso,num_refs,references,sent_len" if len(results[step]) > 0: pair_score = results[step][0] header += (