-
Notifications
You must be signed in to change notification settings - Fork 4
Dense_Retrieval
-
주어진 dataset의 context(passage)의 길이가 매우 길어서 기존 tokenizer를 그대로 사용하게 되면 뒷 부분의 내용이 잘려서 사용되게된다.
-
따라서 가능한 긴 내용을 사용하기 위해 max length를 1536으로 늘려서 사용하였다.
model_checkpoint = 'bert-base-multilingual-cased'
p_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_tokenizer.model_max_length = 1536
Before
After
-
Passage tokenizer의 max length를 늘려주면서 tokenize된 input의 길이가 길어지게 된다. 이렇게 길어진 input을 model에 그대로 넣어주려하면 error가 나게 된다. 왜냐하면 dense retireval를 학습하기 위해 사용하는 pre-trained model의 position ids가 512로 사전 학습되었기 때문이다. 만약 position ids에 해당하는 embedding layer를 max length에 맞춰서 늘려주면 pre-trained model을 사용할 수 없게 된다.
-
pre-trained model의 max length는 512로 되어 있고, 이 길이에 맞춰서 training과 inference 방법을 바꿔주어야 했다.
-
따라서 training 시에는 passage의 길이가 512보다 긴 경우 전체 passage 중 random 하게 512만큼 선택하여 사용하였다.
-
그리고 inference 시에는 passage를 512의 window를 가지고 50% overlap하여 여러 passage로 잘라서 사용하였고, 만약 이 중 하나라도 question과의 similarity가 높게 나오면 찾아낸 것으로 판단하였다.
Training
class TrainRetrievalDataset(torch.utils.data.Dataset):
...
def _select_range(self, attention_mask):
sent_len = len([i for i in attention_mask if i != 0])
if sent_len <= 512:
return 1, 511
else:
start_idx = random.randint(1, sent_len-511)
end_idx = start_idx + 510
return start_idx, end_idx
Inference
class ValidRetrievalDataset(torch.utils.data.Dataset):
...
def _select_range(self, attention_mask):
sent_len = len([i for i in attention_mask if i != 0])
if sent_len <= 512:
return [(1,511)]
else:
num = sent_len // 255
res = sent_len % 255
if res == 0:
num -= 1
ids_list = []
for n in range(num):
if res > 0 and n == num-1:
end_idx = sent_len-1
start_idx = end_idx - 510
else:
start_idx = n*255+1
end_idx = start_idx + 510
ids_list.append((start_idx, end_idx))
return ids_list
-
multilingual bert, facebook/dpr, koelectra, xlm-roberta 중 multilingual bert 선택
-
facebook/dpr : 자소 단위 tokenizing을 사용하고 있어 context의 길이가 매우 길어지기 때문에 다른 모델들과 비교하여 동일한 token 수 대비 적은 정보를 가진다고 판단하였고, context의 앞부분 내용만으로 학습한 결과도 낮은 성능을 보였다.
Before
"이순신은 조선 중기의 무신이다."
After
['ᄋ', '# # ᅵ', '# # ᄉ', '# # ᅮ', '# # ᆫ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅳ', '# # ᆫ', 'ᄌ', '# # ᅩ', '# # ᄉ', '# # ᅥ', '# # ᆫ', 'ᄌ', '# # ᅮ', '# # ᆼ', '# # ᄀ', '# # ᅵ', '# # ᄋ', '# # ᅴ', 'ᄆ', '# # ᅮ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅵ', '# # ᄃ', '# # ᅡ', '.']
-
xlm-roberta : question과 context에 해당하는 모델을 각각 사용해야 한다. 하지만 xlm-roberta의 경우엔 모델이 굉장히 크기 때문에 메모리에 2개의 모델을 올릴 수가 없었다. 따라서 1개의 모델로 question과 context에 대한 embedding vector을 만드는 방식으로 성능을 비교하였다.
retrieval_model.py
from torch import nn
from transformers import AutoModel, AutoConfig
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class Encoder(nn.Module):
def __init__(self, model_checkpoint):
super(Encoder, self).__init__()
self.model_checkpoint = model_checkpoint
config = AutoConfig.from_pretrained(self.model_checkpoint)
if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
self.pooler = BertPooler(config)
config = AutoConfig.from_pretrained(self.model_checkpoint)
self.model = AutoModel.from_pretrained(self.model_checkpoint, config=config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
sequence_output = outputs[0]
pooled_output = self.pooler(sequence_output)
else:
pooled_output = outputs[1]
return pooled_output
facebook/dpr
from transformers import (DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer)
p_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
q_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
Top 1 accuracy
-
Elastic Search의 성능을 뛰어 넘는 dense retrieval를 만드는 것보다 elastic search의 성능을 더 개선하는 점이 구현 가능성이 높다고 판단하였다.
-
Elastic Search를 통해 score가 가장 높은 20개의 context를 고른 후 dense retrieval를 통해 20개의 context 중 좀 더 유사도가 높은 context를 찾는 방식
How to use Elastic Search : Elastic Search for Beginners
-
따라서 dataset의 1개의 question 마다 elastic search에서 찾은 20개의 context들을 pair로 묶어서 dataset을 다시 만들었다.
mk_retrieval_dataset.py
def mk_new_file(mode, files, top_k, es, index_name):
if mode == 'test':
new_files = {'id':[], 'question':[], 'top_k':[]}
for file in files:
question_text = file['question']
top_list = elastic_retrieval(es, index_name, question_text, top_k)
top_list = [text for text, score in top_list]
new_files['id'].append(file['id'])
new_files['question'].append(question_text)
new_files['top_k'].append(top_list)
return new_files
else:
new_files = {'context':[], 'id':[], 'question':[], 'top_k':[], 'answer_idx':[], 'answer':[], 'start_idx':[]}
for file in files:
start_ids = file["answers"]["answer_start"][0]
before = file["context"][:start_ids]
after = file["context"][start_ids:]
process_before = preprocess(before)
process_after = preprocess(after)
new_context = process_before + process_after
start_idx = start_ids - len(before) + len(process_before)
question_text = file['question']
top_list = elastic_retrieval(es, index_name, question_text, top_k)
top_list = [text for text, score in top_list]
if not new_context in top_list:
top_list = top_list[:-1] + [new_context]
answer_idx = top_k-1
else:
answer_idx = top_list.index(new_context)
answer = file['answers']['text'][0]
new_files['context'].append(new_context)
new_files['id'].append(file['id'])
new_files['question'].append(question_text)
new_files['top_k'].append(top_list)
new_files['answer_idx'].append(answer_idx)
new_files['answer'].append(answer)
new_files['start_idx'].append(start_idx)
return new_files
def main(args):
train_file = load_from_disk("../data/train_dataset")["train"]
validation_file = load_from_disk("../data/train_dataset")["validation"]
test_file = load_from_disk("../data/test_dataset")["validation"]
es = elastic_setting(args.index_name)
print('wait...', end='\r')
new_train_file = mk_new_file('train', train_file, args.top_k, es, args.index_name)
print('make train dataset!!')
save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_train.pkl'), new_train_file)
print('wait...', end='\r')
new_valid_file = mk_new_file('valid', validation_file, args.top_k, es, args.index_name)
print('make validation dataset!!')
save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_valid.pkl'), new_valid_file)
print('wait...', end='\r')
new_test_file = mk_new_file('test', test_file, args.top_k, es, args.index_name)
print('make test dataset!!')
save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_test.pkl'), new_test_file)
print('complete!!')
-
Elastic search의 top 1 accuracy가 70% 이상으로 성능이 매우 좋았다. 그로 인해 20개의 context를 순서대로 list로 만들어주게 되면 대부분의 question이 0번 index의 context를 정답 context로 가지게 되었다.
-
따라서 inference 시 top 1 context에 해당하는 index가 좀 더 높은 확률을 가질 수 있도록 weight를 추가해주었다.
class_0 = torch.Tensor([1 if i.item() == 0 else 0 for idx, i in enumerate(top_k_id)])
w = (torch.sum(sim_scores, dim=1)*1/sim_scores.size()[1]).item()
sim_scores -= w*class_0.unsqueeze(0).cuda()