Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions src/muse/annotation/build_notion_concept_tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
This script is used to prepare the input for the Notion concept annotation task
with Prodigy. This corpus is built using the Notion parallel sentence corpus
and some number of Notion sentence translation corpora.
and some number of Notion sentence translation corpora. Optionally, this corpus
can be filtered to specific pair ids provided as a file with one id per line.

Example Usage:

build_notion_concept_tasks.py out.jsonl notion-parallel-sents.jsonl --mt-corpus mt_corpus.jsonl
build_notion_concept_tasks.py out.jsonl notion-parallel-sents.jsonl --mt-corpus mt1.jsonl mt2.jsonl
build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt_corpus.jsonl
build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt1.jsonl mt2.jsonl
build_notion_concept_tasks.py out.jsonl parallel-sents.jsonl --mt-corpus mt.jsonl --pairfile pair_ids.txt
"""

import argparse
Expand All @@ -17,7 +19,10 @@


def build_tasks(
parallel_corpus: pathlib.Path, mt_corpora: list[pathlib.Path], output: pathlib.Path
parallel_corpus: pathlib.Path,
mt_corpora: list[pathlib.Path],
output: pathlib.Path,
pairfile: pathlib.Path | None = None,
) -> None:
# Load parallel sentences
terms_df = (
Expand All @@ -27,6 +32,13 @@ def build_tasks(
# Rename id to pair_id for join
.rename({"id": "pair_id"})
)
# Filter by pair ids
if pairfile:
pair_ids = None
with pairfile.open() as f:
pair_ids = {int(line.strip()) for line in f}
if pair_ids:
terms_df = terms_df.filter(pl.col("pair_id").is_in(pair_ids))
# Load machine translations
mt_df = (
pl.concat([pl.read_ndjson(corpus) for corpus in mt_corpora])
Expand All @@ -36,8 +48,8 @@ def build_tasks(
.rename({"tr_text": "text"})
)

# Join dataframes on pair_id
result_df = mt_df.join(terms_df, "pair_id")
# Join on pair_id then shuffle rows
result_df = mt_df.join(terms_df, "pair_id").sample(fraction=1, shuffle=True)

# Write output
result_df.write_ndjson(output)
Expand All @@ -58,9 +70,14 @@ def main():
required=True,
help="One or more machine translation corpora",
)
parser.add_argument(
"--pairfile",
type=pathlib.Path,
required=False,
help="File containing a list of pair ids (one per line) to filter to",
)

args = parser.parse_args()

if not args.parallel_corpus:
print(f"Error: {args.parallel_corpus} does not exist", sys.stderr)
sys.exit(1)
Expand All @@ -71,11 +88,19 @@ def main():
if args.output.is_file():
print(f"Error: {args.output} exist. Not overwriting.")
sys.exit(1)
if args.pairfile:
if not args.pairfile.is_file():
print(f"Error: pairfile {args.pairfile} does not exist", file=sys.stderr)
sys.exit(1)
elif args.pairfile.stat().st_size == 0:
print(f"Error: pairfile {args.idfile} is zero size", file=sys.stderr)
sys.exit(1)

build_tasks(
args.parallel_corpus,
args.mt_corpus,
args.output,
pairfile=args.pairfile,
)


Expand Down
Loading