-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtasks.py
142 lines (113 loc) · 5.09 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import requests
import json
import concurrent.futures
from abc import ABC, abstractmethod
from typing import List, Dict, Callable
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from datasets import load_dataset
import os
import gc
######
# GlobalVars
######
file_path = os.path.abspath(os.path.dirname(__file__))
class DataProcessor(ABC):
def __init__(self, data_dir, max_threads=1):
self.data_dir = data_dir
self.max_threads = max_threads
@abstractmethod
def get_train_examples(self):
pass
@abstractmethod
def get_test_examples(self):
pass
@abstractmethod
def evaluate(self, predictor, test_exs):
pass
@abstractmethod
def stringify_prediction(self, pred):
pass
class ClassificationTask(DataProcessor):
def run_evaluate(self, predictor, prompt, test_exs, n=100):
labels = []
preds = []
texts = []
for test in test_exs[:n]:
pred = predictor.inference(test, prompt)
texts.append(test['text'])
labels.append(test['label'])
preds.append(pred)
accuracy = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average='micro')
cfmat = confusion_matrix(labels, preds)
return f1, accuracy, cfmat, texts, labels, preds
def evaluate(self, predictor, prompt, test_exs, n=100):
while True:
try:
f1, accuracy, cfmat, texts, labels, preds = self.run_evaluate(predictor, prompt, test_exs, n=n)
break
except (concurrent.futures.process.BrokenProcessPool, requests.exceptions.SSLError):
pass
return f1, accuracy, cfmat, texts, labels, preds
class ClimateBinaryTask(ClassificationTask):
categories = ['Supports', 'Refutes']
raw_dataset = load_dataset("climate_fever", split='test', cache_dir=file_path+"/../../cache/datasets")
def get_train_examples(self):
exs = []
true_statements = self.raw_dataset.filter(lambda x: x["claim_label"] == 0)
false_statements = self.raw_dataset.filter(lambda x: x["claim_label"] == 1)
for i, row in enumerate(true_statements.select(range(100,200))):
exs.append({'id': f'train-{i}', 'label': row['claim_label'], 'text': row['claim']})
for i, row in enumerate(false_statements.select(range(100,200))):
exs.append({'id': f'train-{i}', 'label': row['claim_label'], 'text': row['claim']})
return exs
def get_test_examples(self):
exs = []
true_statements = self.raw_dataset.filter(lambda x: x["claim_label"] == 0)
false_statements = self.raw_dataset.filter(lambda x: x["claim_label"] == 1)
for i, row in enumerate(true_statements.select(range(0,100))):
exs.append({'id': f'test-{i}', 'label': row['claim_label'], 'text': row['claim']})
for i, row in enumerate(false_statements.select(range(0,100))):
exs.append({'id': f'test-{i}', 'label': row['claim_label'], 'text': row['claim']})
return exs
def stringify_prediction(self, pred):
return ClimateBinaryTask.categories[pred]
class PolitifactBinaryTask(ClassificationTask):
categories = ['Supports', 'Refutes']
def __init__(self, data_dir, max_threads=1):
self.data_dir = data_dir
self.max_threads = max_threads
# Initialize an empty list to store the JSON objects
self.data = []
# Open the file and read it line by line
with open(file_path+"/datasets/politifact_factcheck_data.json", 'r') as file:
for line in file:
# Load each line as a JSON object
json_object = json.loads(line.strip())
if json_object['verdict'] in {'true', 'false'}:
# Append the JSON object to the list
self.data.append(json_object)
# exclude first 100 true and first 100 false statements
def get_train_examples(self):
exs = []
true_statements = [d for d in self.data if d['verdict'] == "true"]
false_statements = [d for d in self.data if d['verdict'] == "false"]
for i, row in enumerate(true_statements[100:200]):
exs.append({'id': f'train-{i}', 'label': 0, 'text': row['statement']})
for i, row in enumerate(false_statements[100:200]):
exs.append({'id': f'train-{i}', 'label': 1, 'text': row['statement']})
return exs
# first 100 true and first 100 false statements as test
def get_test_examples(self):
exs = []
true_statements = [d for d in self.data if d['verdict'] == "true"]
false_statements = [d for d in self.data if d['verdict'] == "false"]
for i, row in enumerate(true_statements[:100]):
exs.append({'id': f'test-{i}', 'label': 0, 'text': row['statement']})
for i, row in enumerate(false_statements[:100]):
exs.append({'id': f'test-{i}', 'label': 1, 'text': row['statement']})
return exs
def stringify_prediction(self, pred):
return PolitifactBinaryTask.categories[pred]