-
Notifications
You must be signed in to change notification settings - Fork 0
/
bm25_retriever.py
executable file
·119 lines (95 loc) · 4.24 KB
/
bm25_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
import json
import hydra
import hydra.utils as hu
import numpy as np
from tqdm import tqdm
import multiprocessing
from transformers import set_seed
from rank_bm25 import BM25Okapi
from omegaconf import DictConfig
from nltk.tokenize import word_tokenize
from src.dataset_readers.dataset_wrappers import get_dataset_wrapper
logger = logging.getLogger(__name__)
class BM25Finder:
def __init__(self, cfg: DictConfig) -> None:
self.output_file = cfg.output_file
self.is_train = cfg.dataset_split == "train"
self.num_candidates = cfg.num_candidates
self.num_ice = cfg.num_ice
#assert 1==0
self.index_dataset = hu.instantiate(cfg.index_reader).dataset_wrapper
print("cfg.task_name:{}".format(cfg.task_name))
#assert 1==0
self.dataset_wrapper = get_dataset_wrapper(cfg.task_name, dataset_path=cfg.index_reader.dataset_path, dataset_split=cfg.dataset_split, ds_size=cfg.ds_size)
logger.info("started creating the corpus")
self.index_corpus = [word_tokenize(i) for i in self.index_dataset.get_corpus(cfg.index_reader.field)]
self.bm25 = BM25Okapi(self.index_corpus)
self.tokenized_queries = [word_tokenize(i) for i in self.dataset_wrapper.get_corpus(cfg.query_field)]
logger.info("finished creating the corpus")
def knn_search(tokenized_query, is_train, idx, num_candidates=1, num_ice=1):
bm25 = bm25_global
scores = bm25.get_scores(tokenized_query)
near_ids = list(np.argsort(scores)[::-1][:max(num_candidates, num_ice)])
near_ids = near_ids[1:] if is_train else near_ids
near_ids = [int(a) for a in near_ids]
return near_ids[:num_ice], [[i] for i in near_ids[:num_candidates]], idx
def search(tokenized_query, is_train, idx, num_candidates, num_ice):
"""for BM25, we simply random select subsets"""
if num_ice == 1 or num_candidates == 1:
return knn_search(tokenized_query, is_train, idx,
num_candidates=num_candidates, num_ice=num_ice)
candidates = knn_search(tokenized_query, is_train, idx, num_ice=100)[0]
# add topk as one of the candidates
ctxs_candidates = [candidates[:num_ice]]
while len(ctxs_candidates) < num_candidates:
# ordered by sim score
samples_ids = np.random.choice(len(candidates), num_ice, replace=False)
samples_ids = sorted(samples_ids)
candidate = [candidates[i] for i in samples_ids]
if candidate not in ctxs_candidates:
ctxs_candidates.append(candidate)
return ctxs_candidates[0], ctxs_candidates, idx
def _search(args):
return search(*args)
def find(cfg):
global bm25_global
knn_finder = BM25Finder(cfg)
bm25_global = knn_finder.bm25
def set_global_object(bm25):
global bm25_global
bm25_global = bm25
pool = multiprocessing.Pool(processes=16, initializer=set_global_object, initargs=(knn_finder.bm25,))
print("here===================1")
cntx_pre = [[tokenized_query, knn_finder.is_train, idx, knn_finder.num_candidates, knn_finder.num_ice]
for idx, tokenized_query in enumerate(knn_finder.tokenized_queries)]
# for idx, tokenized_query in enumerate(knn_finder.tokenized_queries):
# print("tokenized_query:{}".format(tokenized_query))
#assert 1==0
print("here===================2")
data_list = list(knn_finder.dataset_wrapper.dataset)
print("len of data_list:{}".format(len(data_list)))
print("data_list:{}".format(data_list[0]))
#assert 1==0
cntx_post = []
with tqdm(total=len(cntx_pre)) as pbar:
for i, res in enumerate(pool.imap_unordered(_search, cntx_pre)):
pbar.update()
cntx_post.append(res)
for ctxs, ctxs_candidates, idx in cntx_post:
data_list[idx]['ctxs'] = ctxs
data_list[idx]['ctxs_candidates'] = ctxs_candidates
#data_list = data_list[:20]
print("=====cfg.output_file:{}".format(cfg.output_file))
#assert 1==0
with open(cfg.output_file, "w") as f:
json.dump(data_list, f)
return data_list
@hydra.main(config_path="configs", config_name="bm25_retriever")
def main(cfg):
logger.info(cfg)
set_seed(43)
print("=====cfg:{}".format(cfg))
find(cfg)
if __name__ == "__main__":
main()