-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper_code.py
114 lines (89 loc) · 3.56 KB
/
helper_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import string
import pickle as pkl
import pandas as pd
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
title = 'Confusion Matrix'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=classes)
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
# ax.figure.colorbar(im, ax=ax)
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=0)
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.grid(b=None)
# https://github.com/mwaskom/seaborn/issues/1773
# fix for mpl bug that cuts off top/bottom of seaborn viz
b, t = plt.ylim() # discover the values for bottom and top
b += 0.5 # Add 0.5 to the bottom
t -= 0.5 # Subtract 0.5 from the top
plt.ylim(b, t) # update the ylim(bottom, top) values
plt.show() # ta-da!
def do_law_of_zipf(data):
'''
Convert a dictionary (keys are language, and values are lists of sentences)
into separate Pandas DataFrames for each language, and plots log scales of
Ranks vs Frequencies, to visualize Zipf's Law.
'''
languages = list(data.keys())
words_df_dict = dict()
for language in languages:
words_df_dict[language] = pd.DataFrame()
words = []
for sentence in data[language]:
words.extend(sentence.split())
words_df_dict[language]['word'] = words
for language in languages:
freqs = words_df_dict[language]['word'].value_counts().values
ranks = range(1, len(freqs)+1)
plt.plot(freqs, ranks, label=language)
plt.ylabel('Frequency')
plt.xlabel('Rank')
plt.yscale('log')
plt.xscale('log')
plt.title('Zipf\'s Law')
plt.legend()
def split_into_subwords(text):
merges = pkl.load(open('Data/Auxiliary/merge_ordered.pkl', 'rb'))
subwords = []
for word in text.split():
for subword in merges:
subword_count = word.count(subword)
if subword_count > 0:
word = word.replace(subword, ' ')
subwords.extend([subword]*subword_count)
return ' '.join(subwords)
def preprocess_function(text):
'''
removes punctuation from a string, and converts all characters to lowercase
'''
punctuation_without_hyphen = ''.join([x for x in string.punctuation if x != '-'])
translation_table = str.maketrans('\n-', ' ', punctuation_without_hyphen+string.digits)
return text.translate(translation_table).lower()