Skip to content

Dense_Retrieval

KwonTaeYang edited this page May 27, 2021 · 1 revision

1️⃣ Passage Tokenizer - Max Length

1. Why

  • 주어진 dataset의 context(passage)의 길이가 매우 길어서 기존 tokenizer를 그대로 사용하게 되면 뒷 부분의 내용이 잘려서 사용되게된다.

  • 따라서 가능한 긴 내용을 사용하기 위해 max length를 1536으로 늘려서 사용하였다.

2. code

model_checkpoint = 'bert-base-multilingual-cased'
p_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_tokenizer.model_max_length = 1536

3. result

Before

After

2️⃣ Training & Infernece

1. Why

  • Passage tokenizer의 max length를 늘려주면서 tokenize된 input의 길이가 길어지게 된다. 이렇게 길어진 input을 model에 그대로 넣어주려하면 error가 나게 된다. 왜냐하면 dense retireval를 학습하기 위해 사용하는 pre-trained model의 position ids가 512로 사전 학습되었기 때문이다. 만약 position ids에 해당하는 embedding layer를 max length에 맞춰서 늘려주면 pre-trained model을 사용할 수 없게 된다.

  • pre-trained model의 max length는 512로 되어 있고, 이 길이에 맞춰서 training과 inference 방법을 바꿔주어야 했다.

  • 따라서 training 시에는 passage의 길이가 512보다 긴 경우 전체 passage 중 random 하게 512만큼 선택하여 사용하였다.

  • 그리고 inference 시에는 passage를 512의 window를 가지고 50% overlap하여 여러 passage로 잘라서 사용하였고, 만약 이 중 하나라도 question과의 similarity가 높게 나오면 찾아낸 것으로 판단하였다.

2. Code

Training

class TrainRetrievalDataset(torch.utils.data.Dataset):
    ...
    def _select_range(self, attention_mask):
        sent_len = len([i for i in attention_mask if i != 0])
        if sent_len <= 512:
            return 1, 511
        else:
            start_idx = random.randint(1, sent_len-511)
            end_idx = start_idx + 510
            return start_idx, end_idx

Inference

class ValidRetrievalDataset(torch.utils.data.Dataset):
    ...
    def _select_range(self, attention_mask):
        sent_len = len([i for i in attention_mask if i != 0])
        if sent_len <= 512:
            return [(1,511)]
        else:
            num = sent_len // 255
            res = sent_len % 255
            if res == 0:
                num -= 1
            ids_list = []
            for n in range(num):
                if res > 0 and n == num-1:
                    end_idx = sent_len-1
                    start_idx = end_idx - 510
                else:
                    start_idx = n*255+1
                    end_idx = start_idx + 510
                ids_list.append((start_idx, end_idx))
            return ids_list

3️⃣ Model

1. Why

  • multilingual bert, facebook/dpr, koelectra, xlm-roberta 중 multilingual bert 선택

  • facebook/dpr : 자소 단위 tokenizing을 사용하고 있어 context의 길이가 매우 길어지기 때문에 다른 모델들과 비교하여 동일한 token 수 대비 적은 정보를 가진다고 판단하였고, context의 앞부분 내용만으로 학습한 결과도 낮은 성능을 보였다.

    Before

    "이순신은 조선 중기의 무신이다."

    After

    ['ᄋ', '# # ᅵ', '# # ᄉ', '# # ᅮ', '# # ᆫ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅳ', '# # ᆫ', 'ᄌ', '# # ᅩ', '# # ᄉ', '# # ᅥ', '# # ᆫ', 'ᄌ', '# # ᅮ', '# # ᆼ', '# # ᄀ', '# # ᅵ', '# # ᄋ', '# # ᅴ', 'ᄆ', '# # ᅮ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅵ', '# # ᄃ', '# # ᅡ', '.']

  • xlm-roberta : question과 context에 해당하는 모델을 각각 사용해야 한다. 하지만 xlm-roberta의 경우엔 모델이 굉장히 크기 때문에 메모리에 2개의 모델을 올릴 수가 없었다. 따라서 1개의 모델로 question과 context에 대한 embedding vector을 만드는 방식으로 성능을 비교하였다.

2. code

retrieval_model.py

from torch import nn
from transformers import AutoModel, AutoConfig

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class Encoder(nn.Module):
    def __init__(self, model_checkpoint):
        super(Encoder, self).__init__()
        self.model_checkpoint = model_checkpoint
        config = AutoConfig.from_pretrained(self.model_checkpoint)
        
        if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
            self.pooler = BertPooler(config)
        config = AutoConfig.from_pretrained(self.model_checkpoint)
        self.model = AutoModel.from_pretrained(self.model_checkpoint, config=config)
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
        if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
            sequence_output = outputs[0]
            pooled_output = self.pooler(sequence_output)
        else:
            pooled_output = outputs[1]
        return pooled_output

facebook/dpr

from transformers import (DPRContextEncoder,
                          DPRContextEncoderTokenizer,
                          DPRQuestionEncoder,
                          DPRQuestionEncoderTokenizer)

p_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
q_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

3. result

Top 1 accuracy

4️⃣ Elastic search + Dense Retrieval

1. why

  • Elastic Search의 성능을 뛰어 넘는 dense retrieval를 만드는 것보다 elastic search의 성능을 더 개선하는 점이 구현 가능성이 높다고 판단하였다.

  • Elastic Search를 통해 score가 가장 높은 20개의 context를 고른 후 dense retrieval를 통해 20개의 context 중 좀 더 유사도가 높은 context를 찾는 방식

    How to use Elastic Search : Elastic Search for Beginners

  • 따라서 dataset의 1개의 question 마다 elastic search에서 찾은 20개의 context들을 pair로 묶어서 dataset을 다시 만들었다.

2. code

mk_retrieval_dataset.py

def mk_new_file(mode, files, top_k, es, index_name):
    if mode == 'test':
        new_files = {'id':[], 'question':[], 'top_k':[]}
        for file in files:
            question_text = file['question']
            
            top_list = elastic_retrieval(es, index_name, question_text, top_k)
            top_list = [text for text, score in top_list]

            new_files['id'].append(file['id'])
            new_files['question'].append(question_text)
            new_files['top_k'].append(top_list)
        return new_files
    
    else:
        new_files = {'context':[], 'id':[], 'question':[], 'top_k':[], 'answer_idx':[], 'answer':[], 'start_idx':[]}
        for file in files:
            start_ids = file["answers"]["answer_start"][0]
            
            before = file["context"][:start_ids]
            after = file["context"][start_ids:]
            
            process_before = preprocess(before)
            process_after = preprocess(after)
            new_context = process_before + process_after
            
            start_idx = start_ids - len(before) + len(process_before)

            question_text = file['question']
            top_list = elastic_retrieval(es, index_name, question_text, top_k)
            top_list = [text for text, score in top_list]
            
            if not new_context in top_list:
                top_list = top_list[:-1] + [new_context]
                answer_idx = top_k-1
            else:
                answer_idx = top_list.index(new_context)

            answer = file['answers']['text'][0]

            new_files['context'].append(new_context)
            new_files['id'].append(file['id'])
            new_files['question'].append(question_text)
            new_files['top_k'].append(top_list)
            new_files['answer_idx'].append(answer_idx)
            new_files['answer'].append(answer)
            new_files['start_idx'].append(start_idx)
        return new_files
def main(args):
    train_file = load_from_disk("../data/train_dataset")["train"]
    validation_file = load_from_disk("../data/train_dataset")["validation"]
    test_file = load_from_disk("../data/test_dataset")["validation"]
    
    es = elastic_setting(args.index_name)

    print('wait...', end='\r')
    new_train_file =  mk_new_file('train', train_file, args.top_k, es, args.index_name)
    print('make train dataset!!')
    save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_train.pkl'), new_train_file)
    
    print('wait...', end='\r')
    new_valid_file =  mk_new_file('valid', validation_file, args.top_k, es, args.index_name)
    print('make validation dataset!!')
    save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_valid.pkl'), new_valid_file)
    
    print('wait...', end='\r')
    new_test_file =  mk_new_file('test', test_file, args.top_k, es, args.index_name)
    print('make test dataset!!')    
    save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_test.pkl'), new_test_file)
    
    print('complete!!')

5️⃣ Add weight

1. why

  • Elastic search의 top 1 accuracy가 70% 이상으로 성능이 매우 좋았다. 그로 인해 20개의 context를 순서대로 list로 만들어주게 되면 대부분의 question이 0번 index의 context를 정답 context로 가지게 되었다.

  • 따라서 inference 시 top 1 context에 해당하는 index가 좀 더 높은 확률을 가질 수 있도록 weight를 추가해주었다.

2. code

class_0 = torch.Tensor([1 if i.item() == 0 else 0 for idx, i in enumerate(top_k_id)])
w = (torch.sum(sim_scores, dim=1)*1/sim_scores.size()[1]).item()
sim_scores -= w*class_0.unsqueeze(0).cuda()

6️⃣ Retrieval Result