-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added datasets and models for text generation evaluation #291
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import numpy as np | ||
import torch.cuda | ||
from datasets import load_dataset | ||
from sacrebleu import corpus_bleu | ||
from transformers import pipeline | ||
|
@@ -11,31 +12,54 @@ def sacrebleu_score(hypotheses, references): | |
return corpus_bleu(hypotheses, [references]).score | ||
|
||
|
||
def _process_data(dataset_name, split): | ||
'''Function for extracting expected columns and create a dataset.''' | ||
if dataset_name == "xsum": | ||
hf_dataset = load_dataset(dataset_name, "3.0.0", split=split) | ||
dataset = KeyValueDataset.from_huggingface( | ||
hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["document", "summary"] | ||
) | ||
return dataset | ||
elif dataset_name == "cnn_dailymail": | ||
hf_dataset = load_dataset(dataset_name,"3.0.0", split=split) | ||
dataset = KeyValueDataset.from_huggingface( | ||
hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["article", "highlights"] | ||
) | ||
return dataset | ||
elif dataset_name == "big_patent": | ||
hf_dataset = load_dataset(dataset_name, split) | ||
dataset = KeyValueDataset.from_huggingface( | ||
hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["description", "abstract"] | ||
) | ||
return dataset | ||
elif dataset_name == "billsum": | ||
hf_dataset = load_dataset(dataset_name, split) | ||
dataset = KeyValueDataset.from_huggingface( | ||
hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["text", "summary"] | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I would suggest adding the 'else' block and raising exceptions with the proper message. |
||
|
||
def evaluate( | ||
operation, evaluate_filter, model_name, dataset_name, split="test[:20%]" | ||
): | ||
operation, evaluate_filter, model_name, | ||
dataset_name, split="test[:20%]", is_cuda=torch.cuda.is_available()): | ||
# load model | ||
if model_name is None: | ||
model_name = "sshleifer/distilbart-xsum-12-6" | ||
if model_name is None: model_name = "sshleifer/distilbart-xsum-12-6" # default model | ||
# load test set | ||
if dataset_name is None: | ||
dataset_name = "xsum" | ||
if dataset_name is None: dataset_name = "xsum" # default dataset | ||
|
||
print( | ||
f"Loading <{dataset_name}> dataset to evaluate <{model_name}> model." | ||
) | ||
hf_dataset = ( | ||
load_dataset(dataset_name, "3.0.0", split=split) | ||
if dataset_name == "xsum" | ||
else load_dataset(dataset_name, split=split) | ||
) | ||
|
||
dataset = KeyValueDataset.from_huggingface( | ||
hf_dataset, TaskType.TEXT_TO_TEXT_GENERATION, ["document", "summary"] | ||
) | ||
summarization_pipeline = pipeline( | ||
"summarization", model=model_name, tokenizer=model_name | ||
"summarization", model=model_name, tokenizer=model_name, device=0 if is_cuda else -1) | ||
#percent = f"[{split.split('[')[-1]}" if "[" in split else "" | ||
#if dataset_name == "wikihow": split = "all[:1%]" # f"all{percent}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove this commented code. |
||
|
||
dataset = _process_data(dataset_name, split) | ||
print( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate print statement. |
||
f"Here is the performance of the model {model_name} on the {split} split of the {dataset_name} dataset" | ||
) | ||
|
||
print( | ||
f"Here is the performance of the model {model_name} on the {split} split of the {dataset_name} dataset" | ||
) | ||
|
@@ -55,20 +79,16 @@ def evaluate( | |
|
||
|
||
def filter_performance(dataset, summarization_pipeline, filter): | ||
'''Evaluate performance on filtered dataset.''' | ||
print("Here is the performance of the model on the filtered set") | ||
filtered_dataset = dataset.apply_filter(filter, subfields=["document"]) | ||
return performance_on_dataset(filtered_dataset, summarization_pipeline) | ||
|
||
|
||
""" | ||
Evaluates performance on the original set | ||
and on the perturbed set. | ||
""" | ||
|
||
|
||
def transformation_performance( | ||
dataset, summarization_pipeline, transformation | ||
): | ||
'''Evaluates performance on the original set and on the perturbed set.''' | ||
performance = performance_on_dataset( | ||
dataset, summarization_pipeline | ||
) # 15.989 BLEU | ||
|
@@ -83,11 +103,13 @@ def transformation_performance( | |
|
||
|
||
def performance_on_dataset(dataset, summarization_pipeline): | ||
'''Evaluate performance on a given dataset.''' | ||
references = [] | ||
raw_hypotheses = [] | ||
print(f"Length of Evaluation dataset is {len(dataset)}") | ||
|
||
for example in dataset: | ||
for i,example in enumerate(dataset): | ||
print(i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this print statement? |
||
article, gold_summary = example | ||
max_len = ( | ||
len(gold_summary.split(" ")) + 10 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return statement for "billsum".