Skip to content

Commit

Permalink
Fix for #540, issues (+ support for multiple translations) with exper…
Browse files Browse the repository at this point in the history
…iment --test
  • Loading branch information
Ben King committed Oct 2, 2024
1 parent fd12193 commit 11e5baa
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 31 deletions.
1 change: 1 addition & 0 deletions silnlp/nmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test(self):
best=self.config.model_dir.exists(),
by_book=self.score_by_book,
scorers=self.scorers,
produce_multiple_translations=self.produce_multiple_translations,
)
SIL_NLP_ENV.copy_experiment_to_bucket(
self.name, patterns=("scores-*.csv", "test.*trg-predictions.*"), overwrite=True
Expand Down
107 changes: 86 additions & 21 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

from ..common.corpus import Term, count_lines, get_terms
from ..common.environment import SIL_NLP_ENV, download_if_s3_paths
from ..common.translator import TranslationGroup
from ..common.translator import DraftGroup, TranslationGroup
from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, merge_dict
from .config import CheckpointType, Config, DataFile, NMTModel
from .tokenizer import NullTokenizer, Tokenizer
Expand Down Expand Up @@ -1008,23 +1008,64 @@ def translate_test_files(
with ExitStack() as stack:
src_file = stack.enter_context(input_path.open("r", encoding="utf-8-sig"))
sentences = (line.strip().split() for line in src_file)
out_file = stack.enter_context(translation_path.open("w", encoding="utf-8", newline="\n"))
vrefs: Optional[Iterable[VerseRef]] = None
if vref_path is not None:
vref_file = stack.enter_context(vref_path.open("r", encoding="utf-8-sig"))
vrefs = (VerseRef.from_string(line.strip(), ORIGINAL_VERSIFICATION) for line in vref_file)

for output_group in tqdm(
self._translate_sentences(
tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True
),
total=length,
unit="ex",
):
ids = to_py_obj(output_group.get_token_ids()[0])
ids = [id for id in ids[1:] if id != tokenizer.pad_token_id]
tokens = tokenizer.convert_ids_to_tokens(ids)
out_file.write(" ".join(tokens) + "\n")
draft_group = DraftGroup(
list(
self._translate_test_sentences(
tokenizer, pipeline, sentences, vrefs, length, produce_multiple_translations
)
)
)

for draft_index, translated_draft in enumerate(draft_group.get_drafts(), 1):

if produce_multiple_translations:
translation_draft_path = translation_path.with_suffix(
f".{draft_index}{translation_path.suffix}"
)
else:
translation_draft_path = translation_path
out_file = stack.enter_context(translation_draft_path.open("w", encoding="utf-8", newline="\n"))

out_file.write("\n".join(translated_draft) + "\n")

def _translate_test_sentences(
self,
tokenizer: PreTrainedTokenizer,
pipeline: TranslationPipeline,
sentences: Iterable[List[str]],
vrefs: Iterable[VerseRef],
length: int,
produce_multiple_translations: bool = False,
) -> Iterable[TranslationGroup]:
num_drafts = self.get_num_drafts()
if produce_multiple_translations and num_drafts > 1:
LOGGER.info("Producing %i translated drafts", num_drafts)
elif produce_multiple_translations and num_drafts <= 1:
LOGGER.warning(
"num_drafts must be greater than 1 when using --multiple-translations. "
"Falling back to a single translation."
)

for output_group in tqdm(
self._translate_sentences(
tokenizer, pipeline, sentences, vrefs, produce_multiple_translations, return_tensors=True
),
total=length,
unit="ex",
):
ids = to_py_obj(output_group.get_token_ids())
ids = [[id for id in output[1:] if id != tokenizer.pad_token_id] for output in ids]
tokens = [tokenizer.convert_ids_to_tokens(id_group) for id_group in ids]
yield [" ".join(token_group) for token_group in tokens]

def get_num_drafts(self) -> int:
num_drafts = self._config.infer.get("num_drafts")
return num_drafts

def translate(
self,
Expand All @@ -1048,7 +1089,7 @@ def translate(
device=0,
)

num_drafts = self._config.infer.get("num_drafts")
num_drafts = self.get_num_drafts()
if produce_multiple_translations and num_drafts > 1:
LOGGER.info("Producing %i translated drafts", num_drafts)
elif produce_multiple_translations and num_drafts <= 1:
Expand Down Expand Up @@ -1300,10 +1341,13 @@ def _translate_sentence_helper(
force_words_ids: List[List[List[int]]] = None,
produce_multiple_translations: bool = False,
) -> Iterable[OutputGroup]:
num_drafts = self._config.infer.get("num_drafts")

num_drafts = self.get_num_drafts()
if produce_multiple_translations and num_drafts > 1:
multiple_translations_method: str = self._config.infer.get("multiple_translations_method")

sentences = list(sentences)

if multiple_translations_method == "hybrid":
beam_search_results: List[dict] = self._translate_with_beam_search(
pipeline,
Expand All @@ -1325,7 +1369,7 @@ def _translate_sentence_helper(

# concatenate the beam search results with the sampling results
yield from [
OutputGroup([beam_search_results[i]] + sampling_results[i]) for i in range(len(beam_search_results))
OutputGroup(beam_search_results[i] + sampling_results[i]) for i in range(len(beam_search_results))
]

elif multiple_translations_method == "sampling":
Expand Down Expand Up @@ -1371,7 +1415,7 @@ def _translate_sentence_helper(

else:
yield from [
OutputGroup([translated_sentence])
OutputGroup([translated_sentence[0]])
for translated_sentence in self._translate_with_beam_search(
pipeline,
sentences,
Expand All @@ -1382,6 +1426,12 @@ def _translate_sentence_helper(
)
]

# When translating tokenized sentences, for some reason the Huggingface pipeline
# returns List[List[dict]] instead of List[dict]. Each nested list is a
# singleton. This function flattens the structure.
def _flatten_tokenized_translations(self, pipeline_output) -> List[dict]:
return [[i if isinstance(i, dict) else i[0] for i in translation] for translation in pipeline_output]

def _translate_with_beam_search(
self,
pipeline: TranslationPipeline,
Expand All @@ -1390,12 +1440,12 @@ def _translate_with_beam_search(
return_tensors: bool,
num_return_sequences: int = 1,
force_words_ids: List[List[List[int]]] = None,
) -> List[dict]:
) -> List[List[dict]]:
num_beams: Optional[int] = self._config.infer.get("num_beams")
if num_beams is None:
num_beams = self._config.params.get("generation_num_beams")

return pipeline(
translations = pipeline(
sentences,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
Expand All @@ -1405,6 +1455,11 @@ def _translate_with_beam_search(
return_tensors=return_tensors,
)

if num_return_sequences == 1:
translations = [[t] for t in translations]

return self._flatten_tokenized_translations(translations)

def _translate_with_sampling(
self,
pipeline: TranslationPipeline,
Expand All @@ -1417,7 +1472,7 @@ def _translate_with_sampling(

temperature: Optional[int] = self._config.infer.get("temperature")

return pipeline(
translations = pipeline(
sentences,
do_sample=True,
temperature=temperature,
Expand All @@ -1428,6 +1483,11 @@ def _translate_with_sampling(
return_tensors=return_tensors,
)

if num_return_sequences == 1:
translations = [[t] for t in translations]

return self._flatten_tokenized_translations(translations)

def _translate_with_diverse_beam_search(
self,
pipeline: TranslationPipeline,
Expand All @@ -1442,7 +1502,7 @@ def _translate_with_diverse_beam_search(
num_beams = self._config.params.get("generation_num_beams")
diversity_penalty: Optional[float] = self._config.infer.get("diversity_penalty")

return pipeline(
translations = pipeline(
sentences,
num_beams=num_beams,
num_beam_groups=num_beams,
Expand All @@ -1454,6 +1514,11 @@ def _translate_with_diverse_beam_search(
return_tensors=return_tensors,
)

if num_return_sequences == 1:
translations = [[t] for t in translations]

return self._flatten_tokenized_translations(translations)

def _create_inference_model(
self, ckpt: Union[CheckpointType, str, int], tokenizer: PreTrainedTokenizer
) -> PreTrainedModel:
Expand Down
Loading

0 comments on commit 11e5baa

Please sign in to comment.