-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
53 lines (45 loc) · 2.22 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# coding: utf-8
from pathlib import Path
import re
from transformers import BertTokenizer
from tqdm import tqdm
import torch
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
pattern = u'[^\u4e00-\u9fa50-9a-zA-Z]+'
class DataProcessor:
def __init__(self, vocab_path: Path, max_length: int) -> None:
self.tokenizer = BertTokenizer.from_pretrained(vocab_path)
self.max_length = max_length
def readLCQMC(self, path: Path):
data = open(path, 'r', encoding='utf-8')
res = {}
res['encodes'] = []
res['labels'] = []
for line in tqdm(data, desc='Loading data'):
s1, s2, label = line.strip().split('\t')
s1 = re.sub(pattern, '', s1)
s2 = re.sub(pattern, '', s2)
encodes = self.tokenizer.encode_plus(
s1,
text_pair=s2,
max_length=self.max_length,
padding='max_length',
truncation="longest_first",
return_tensors='pt'
)
res['encodes'].append(encodes)
res['labels'].append(int(label))
input_ids = torch.cat([item['input_ids'] for item in res['encodes']])
attention_mask = torch.cat([item['attention_mask'] for item in res['encodes']])
token_type_ids = torch.cat([item['token_type_ids'] for item in res['encodes']])
labels = torch.LongTensor([item for item in res['labels']])
# print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)
return TensorDataset(input_ids, attention_mask, token_type_ids, labels)
class MyDataLoader:
def __init__(self, vocab_path: Path, max_length: int) -> None:
self.data_processor = DataProcessor(vocab_path, max_length)
def load(self, path: Path, batch_size: int):
data = self.data_processor.readLCQMC(path)
sampler = RandomSampler(data)
data_loader = DataLoader(data, sampler=sampler, batch_size=batch_size)
return data_loader