diff --git a/src/muse/annotation/build_notion_concept_tasks.py b/src/muse/annotation/build_notion_concept_tasks.py index 5959a23..6bfe29f 100644 --- a/src/muse/annotation/build_notion_concept_tasks.py +++ b/src/muse/annotation/build_notion_concept_tasks.py @@ -1,12 +1,14 @@ """ This script is used to prepare the input for the Notion concept annotation task with Prodigy. This corpus is built using the Notion parallel sentence corpus -and some number of Notion sentence translation corpora. +and some number of Notion sentence translation corpora. Optionally, this corpus +can be filtered to specific pair ids provided as a file with one id per line. Example Usage: - build_notion_concept_tasks.py out.jsonl notion-parallel-sents.jsonl --mt-corpus mt_corpus.jsonl - build_notion_concept_tasks.py out.jsonl notion-parallel-sents.jsonl --mt-corpus mt1.jsonl mt2.jsonl + build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt_corpus.jsonl + build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt1.jsonl mt2.jsonl + build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt.jsonl --pairfile pair_ids.txt """ import argparse @@ -17,7 +19,10 @@ def build_tasks( - parallel_corpus: pathlib.Path, mt_corpora: list[pathlib.Path], output: pathlib.Path + parallel_corpus: pathlib.Path, + mt_corpora: list[pathlib.Path], + output: pathlib.Path, + pairfile: pathlib.Path | None = None, ) -> None: # Load parallel sentences terms_df = ( @@ -27,6 +32,13 @@ def build_tasks( # Rename id to pair_id for join .rename({"id": "pair_id"}) ) + # Filter by pair ids + if pairfile: + pair_ids = None + with pairfile.open() as f: + pair_ids = {int(line.strip()) for line in f} + if pair_ids: + terms_df = terms_df.filter(pl.col("pair_id").is_in(pair_ids)) # Load machine translations mt_df = ( pl.concat([pl.read_ndjson(corpus) for corpus in mt_corpora]) @@ -36,8 +48,8 @@ def build_tasks( .rename({"tr_text": "text"}) ) - # Join dataframes on pair_id - result_df = mt_df.join(terms_df, "pair_id") + # Join on pair_id then shuffle rows + result_df = mt_df.join(terms_df, "pair_id").sample(fraction=1, shuffle=True) # Write output result_df.write_ndjson(output) @@ -58,9 +70,14 @@ def main(): required=True, help="One or more machine translation corpora", ) + parser.add_argument( + "--pairfile", + type=pathlib.Path, + required=False, + help="File containing a list of pair ids (one per line) to filter to", + ) args = parser.parse_args() - if not args.parallel_corpus: print(f"Error: {args.parallel_corpus} does not exist", sys.stderr) sys.exit(1) @@ -71,11 +88,19 @@ def main(): if args.output.is_file(): print(f"Error: {args.output} exist. Not overwriting.") sys.exit(1) + if args.pairfile: + if not args.pairfile.is_file(): + print(f"Error: pairfile {args.pairfile} does not exist", file=sys.stderr) + sys.exit(1) + elif args.pairfile.stat().st_size == 0: + print(f"Error: pairfile {args.idfile} is zero size", file=sys.stderr) + sys.exit(1) build_tasks( args.parallel_corpus, args.mt_corpus, args.output, + pairfile=args.pairfile, )