-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprecompute_groups.py
131 lines (95 loc) · 3.79 KB
/
precompute_groups.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Precompute groups for DRO
Fasttext code adapted from: https://amitness.com/2019/07/identify-text-language-python/
"""
import os
import json
from collections import defaultdict, Counter
from sklearn.cluster import KMeans
import numpy as np
import fasttext
import torch
from transformers import AutoTokenizer
from datasets import CreoleDataset, SinglishSMSDataset
def test(path, sentences):
with open(path, 'r') as infile:
sentence_dict = json.loads(infile.read())
clusters = []
for i, sent in enumerate(sentences):
cluster = sentence_dict[sent]["cluster"]
clusters.append(cluster)
assert len(clusters) == len(sentences)
print(f"Passes the test!")
def load_creole_and_all(file, creole):
sentences = []
with open(file, "r", encoding="utf-8") as input_file:
entries = json.load(input_file) # list of dicts
for subdict in entries:
for sent, lang in subdict.items():
if lang == creole:
sentences.append(sent)
return sentences
def main(src_dir, src_file, creole, out_file):
"""
:param src_dir:
:param src_file:
:param creole:
:param out_file:
:return: [] list of dicts {"sent": {"en": .8, "zh": .1, "yue": .1}
"""
top_langs_distribution = defaultdict(list)
chosen_langs_distribution = defaultdict(list)
full_src_path = os.path.join(src_dir, src_file)
sentences = load_creole_and_all(full_src_path, creole)
#init out json
out_json = []
# load fasttext mode
pretrained_model_path = "/Users/plq360/Desktop/tmp/lid.176.bin"
model = fasttext.load_model(pretrained_model_path)
creole_LUT = {"singlish": ["en", "zh", "ms", "ta"],
"haitian": ["fr", "yo", "es"],
"naija": ["en", "yo", "pt"]}
sub_language_keys = creole_LUT[creole]
#Now get the language predictions
print(f"* predicting the languages in the examples ... ")
predictions = model.predict(sentences, k=-1) # get ALL the predictions!
langs, scores = predictions
#Build the json dict
for sent, lang_list, score in zip(sentences, langs, scores):
sent_dict = {}
lang_LUT = {}
for i, lang in enumerate(lang_list):
lang_LUT[lang.split("__")[-1]] = i
#chosen ones
for lang in sub_language_keys:
try:
index = lang_LUT[lang]
lang_score = float(score[index])
sent_dict[lang] = lang_score
chosen_langs_distribution[lang].append(lang_score)
except Exception:
sent_dict[lang] = float(0)
chosen_langs_distribution[lang].append(0)
#top 5
for i, lang in enumerate(lang_list):
if i < 5:
code = lang.split("__")[-1]
try:
top_langs_distribution[code].append(score)
except Exception:
top_langs_distribution[code].append(score)
else:
break
out_json.append({sent: sent_dict}) # [sent] = sent_dict
#print(chosen_langs_distribution)
print("###############################")
print(top_langs_distribution.keys())
# print(f"NUM EXAMPLES: {len(out_json)}")
# #Print the json to a file
# new_file = os.path.join(src_dir, out_file)
# with open(new_file, 'w', encoding="utf-8") as o:
# json.dump(out_json, o, indent=1)
#test(new_file, sentences)
main("/Users/plq360/Desktop/data/creoledata/train/singlish", "singlish_and_all.train.json", "singlish", "singlish_only_groups.json")
#main("/Users/plq360/Desktop/data/creoledata/train/naija", "naija_and_all.train.json", "naija","naija_only_groups.json")
#main("/Users/plq360/Desktop/data/creoledata/train/haitian", "haitian_and_all.train.json", "haitian", "haitian_only_groups.json")