Skip to content

Commit ddf6668

Browse files
authored
Closes SEACrowd#211 | Implement dataloader for SEAHORSE (SEACrowd#407)
* implement seahorse dataloader * update * update * incorporate the latest comments though tensorflow still needed for tfds * update * update
1 parent ca8e109 commit ddf6668

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

seacrowd/sea_datasets/seahorse/__init__.py

Whitespace-only changes.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from pathlib import Path
2+
3+
import datasets
4+
import pandas as pd
5+
6+
from seacrowd.utils import schemas
7+
from seacrowd.utils.configs import SEACrowdConfig
8+
from seacrowd.utils.constants import Licenses, Tasks
9+
10+
_CITATION = """
11+
@inproceedings{clark-etal-2023-seahorse,
12+
title = "{SEAHORSE}: A Multilingual, Multifaceted Dataset for Summarization Evaluation",
13+
author = "Clark, Elizabeth and
14+
Rijhwani, Shruti and
15+
Gehrmann, Sebastian and
16+
Maynez, Joshua and
17+
Aharoni, Roee and
18+
Nikolaev, Vitaly and
19+
Sellam, Thibault and
20+
Siddhant, Aditya and
21+
Das, Dipanjan and
22+
Parikh, Ankur",
23+
editor = "Bouamor, Houda and
24+
Pino, Juan and
25+
Bali, Kalika",
26+
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
27+
month = dec,
28+
year = "2023",
29+
address = "Singapore",
30+
publisher = "Association for Computational Linguistics",
31+
url = "https://aclanthology.org/2023.emnlp-main.584",
32+
doi = "10.18653/v1/2023.emnlp-main.584",
33+
pages = "9397--9413",
34+
}
35+
"""
36+
37+
_DATASETNAME = "seahorse"
38+
39+
_DESCRIPTION = """
40+
SEAHORSE is a dataset for multilingual, multifaceted summarization evaluation. It consists of 96K summaries with human
41+
ratings along 6 quality dimensions: comprehensibility, repetition, grammar, attribution, main idea(s), and conciseness,
42+
covering 6 languages, 9 systems and 4 datasets.
43+
"""
44+
45+
_HOMEPAGE = "https://github.com/google-research-datasets/seahorse"
46+
47+
_LANGUAGES = ["vie"]
48+
49+
_LICENSE = Licenses.CC_BY_4_0.value
50+
51+
_LOCAL = False
52+
53+
_URLS = "https://storage.googleapis.com/seahorse-public/seahorse_data.zip"
54+
55+
_SUPPORTED_TASKS = [Tasks.SUMMARIZATION]
56+
57+
_SOURCE_VERSION = "1.0.0"
58+
59+
_SEACROWD_VERSION = "1.0.0"
60+
61+
62+
# The original dataset only contaions gem_id, we need to retrieve the article following https://github.com/google-research-datasets/seahorse?tab=readme-ov-file#retrieving-articles-from-gem
63+
def get_wikilingual_data(lang, split):
64+
ds = datasets.load_dataset("gem", name=f"wiki_lingua_{lang}", split=split)
65+
df = ds.to_pandas()
66+
return dict(zip(*[df[col] for col in ["gem_id", "source"]]))
67+
68+
69+
def get_xlsum_data(lang, split):
70+
df = datasets.load_dataset("GEM/xlsum", lang)
71+
return {item["gem_id"]: item["text"] for item in df[split]}
72+
73+
74+
# Both train and validation splits in seahorse are taken from the validation split from the original dataset
75+
_WIKILINGUAL_DATA = {split: get_wikilingual_data("vietnamese_vi", split) for split in ["test", "validation"]}
76+
_XLSUM_DATA = {split: get_xlsum_data("vietnamese", split) for split in ["test", "validation"]}
77+
78+
79+
def get_article(gem_id, split):
80+
if "wiki_lingua" in gem_id:
81+
data = _WIKILINGUAL_DATA
82+
elif "xlsum" in gem_id:
83+
data = _XLSUM_DATA
84+
else:
85+
raise AssertionError("gem_id should either from wiki_lingua or xlsum.")
86+
return data[split if split == "test" else "validation"][gem_id]
87+
88+
89+
class SeahorseDataset(datasets.GeneratorBasedBuilder):
90+
"""Seahorse is a dataset for multilingual, multifaceted summarization evaluation."""
91+
92+
SOURCE_VERSION = datasets.Version(_SOURCE_VERSION)
93+
SEACROWD_VERSION = datasets.Version(_SEACROWD_VERSION)
94+
95+
BUILDER_CONFIGS = [
96+
SEACrowdConfig(
97+
name=f"{_DATASETNAME}_source",
98+
version=datasets.Version(_SOURCE_VERSION),
99+
description=f"{_DATASETNAME} source schema",
100+
schema="source",
101+
subset_id=_DATASETNAME,
102+
),
103+
SEACrowdConfig(
104+
name=f"{_DATASETNAME}_seacrowd_t2t",
105+
version=datasets.Version(_SEACROWD_VERSION),
106+
description=f"{_DATASETNAME} SEACrowd schema",
107+
schema="seacrowd_t2t",
108+
subset_id=_DATASETNAME,
109+
),
110+
]
111+
112+
DEFAULT_CONFIG_NAME = f"{_DATASETNAME}_source"
113+
114+
def _info(self) -> datasets.DatasetInfo:
115+
if self.config.schema == "source":
116+
features = datasets.Features(
117+
{
118+
"gem_id": datasets.Value("string"),
119+
"summary": datasets.Value("string"),
120+
"model": datasets.Value("string"),
121+
"question1": datasets.Value("string"),
122+
"question2": datasets.Value("string"),
123+
"question3": datasets.Value("string"),
124+
"question4": datasets.Value("string"),
125+
"question5": datasets.Value("string"),
126+
"question6": datasets.Value("string"),
127+
}
128+
)
129+
130+
elif self.config.schema == "seacrowd_t2t":
131+
features = schemas.text2text_features
132+
133+
return datasets.DatasetInfo(
134+
description=_DESCRIPTION,
135+
features=features,
136+
homepage=_HOMEPAGE,
137+
license=_LICENSE,
138+
citation=_CITATION,
139+
)
140+
141+
def _split_generators(self, dl_manager: datasets.DownloadManager) -> list[datasets.SplitGenerator]:
142+
data_dir = dl_manager.download_and_extract(_URLS)
143+
144+
return [
145+
datasets.SplitGenerator(
146+
name=datasets.Split.TRAIN,
147+
gen_kwargs={
148+
"filepath": f"{data_dir}/seahorse_data/train.tsv",
149+
"split": "train",
150+
},
151+
),
152+
datasets.SplitGenerator(
153+
name=datasets.Split.VALIDATION,
154+
gen_kwargs={
155+
"filepath": f"{data_dir}/seahorse_data/validation.tsv",
156+
"split": "dev",
157+
},
158+
),
159+
datasets.SplitGenerator(
160+
name=datasets.Split.TEST,
161+
gen_kwargs={
162+
"filepath": f"{data_dir}/seahorse_data/test.tsv",
163+
"split": "test",
164+
},
165+
),
166+
]
167+
168+
def _generate_examples(self, filepath: Path, split: str) -> tuple[int, dict]:
169+
df = pd.read_csv(filepath, sep="\t")
170+
mask = df["worker_lang"] == "vi"
171+
df_vi = df[mask]
172+
if self.config.schema == "source":
173+
for i, row in df_vi.iterrows():
174+
yield i, {
175+
"gem_id": row["gem_id"],
176+
"summary": row["summary"],
177+
"model": row["model"],
178+
"question1": row["question1"],
179+
"question2": row["question2"],
180+
"question3": row["question3"],
181+
"question4": row["question4"],
182+
"question5": row["question5"],
183+
"question6": row["question6"],
184+
}
185+
186+
elif self.config.schema == "seacrowd_t2t":
187+
for i, row in df_vi.iterrows():
188+
yield i, {
189+
"id": str(i),
190+
"text_1": get_article(row["gem_id"], split),
191+
"text_2": row["summary"],
192+
"text_1_name": "article",
193+
"text_2_name": "summary",
194+
}

0 commit comments

Comments
 (0)