Skip to content

Commit 22d5c8a

Browse files
authored
implement chi square test (#56)
1 parent 112161b commit 22d5c8a

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

src/tests/test_kinderminer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ def test_fisher_exact_test():
2424
sort_ratio = km.get_sort_ratio(table)
2525
assert sort_ratio == pytest.approx(15 / 59)
2626

27+
def test_chisq_pvalue():
28+
table = [[10, 3000], [2000, 10000000]]
29+
pvalue = km.chi_square(table)
30+
assert pvalue == pytest.approx(2.583e-30, abs=1e-30)
31+
32+
table = [[1, 3000], [2000, 10000000]]
33+
pvalue = km.chi_square(table)
34+
assert pvalue == 1
35+
36+
table = [[0, 100], [0, 10000000]]
37+
pvalue = km.chi_square(table)
38+
assert pvalue == 1
39+
2740
def test_text_sanitation():
2841
text = 'Testing123****.'
2942
sanitized_text = index.sanitize_term(text)

src/workers/kinderminer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,24 @@ def get_contingency_table(a_term_set: set, b_term_set: set, total_n: int):
2020
def fisher_exact(table) -> float:
2121
return scipy.stats.fisher_exact(table, fet_sided)[1]
2222

23+
def chi_square(table) -> float:
24+
try:
25+
return scipy.stats.chi2_contingency(table, fet_sided)[1]
26+
except ValueError:
27+
# default to a p-value of 1.0
28+
# this happens if the sum of a row or column is 0
29+
return 1.0
30+
2331
def get_sort_ratio(table) -> float:
2432
denom = (table[0][0] + table[1][0])
2533
if denom == 0:
2634
return 0 # TODO?
2735

2836
return table[0][0] / denom
2937

30-
def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.inf, return_pmids = False, top_n_articles = math.inf) -> dict:
38+
def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.inf,
39+
return_pmids = False, top_n_articles = math.inf,
40+
scoring = 'fet') -> dict:
3141
""""""
3242
start_time = time.perf_counter()
3343
result = dict()
@@ -48,8 +58,12 @@ def kinderminer_search(a_term: str, b_term: str, idx: Index, censor_year = math.
4858
n_a_and_b = table[0][0]
4959
n_articles = idx.n_articles(censor_year)
5060

51-
# perform fisher's exact test
52-
pvalue = fisher_exact(table)
61+
# perform statistical test (default fisher's exact test)
62+
if scoring == 'chi-square':
63+
pvalue = chi_square(table)
64+
else: # 'fet'
65+
pvalue = fisher_exact(table)
66+
5367
sort_ratio = get_sort_ratio(table)
5468

5569
run_time = time.perf_counter() - start_time

src/workers/work.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def km_work_all_vs_all(json: dict):
6767

6868
a_terms = json['a_terms']
6969
b_terms = json['b_terms']
70+
scoring = json.get('scoring', 'fet')
7071

7172
if 'c_terms' in json:
7273
# SKiM query
@@ -107,7 +108,7 @@ def km_work_all_vs_all(json: dict):
107108
b_term = li.the_index.get_highest_priority_term(b_term_set, b_term_token_dict)
108109
b_term_set.remove(b_term)
109110

110-
res = km.kinderminer_search(a_term, b_term, li.the_index, censor_year, return_pmids, top_n_articles)
111+
res = km.kinderminer_search(a_term, b_term, li.the_index, censor_year, return_pmids, top_n_articles, scoring)
111112

112113
if res['pvalue'] <= ab_fet_threshold:
113114
ab_results.append(res)
@@ -189,7 +190,7 @@ def km_work_all_vs_all(json: dict):
189190
if b_term == c_term:
190191
continue
191192

192-
bc = km.kinderminer_search(b_term, c_term, li.the_index, censor_year, return_pmids, top_n_articles)
193+
bc = km.kinderminer_search(b_term, c_term, li.the_index, censor_year, return_pmids, top_n_articles, scoring)
193194

194195
abc_result['c_term'] = c_term
195196
abc_result['bc_pvalue'] = bc['pvalue']

0 commit comments

Comments
 (0)