-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmultiDataset.py
73 lines (59 loc) · 1.89 KB
/
multiDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# coding = utf-8
# -*- coding:utf-8 -*-
import json
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import BertTokenizer, ViTFeatureExtractor
import config
from img import vit
from text import bert, textConfig
config.setup_seed()
tags = {
'positive': 0,
'negative': 1,
'neutral': 2,
'': 3 # 仅占位
}
class MultiDataset(Dataset):
def __init__(self, data: list, tokenizer: BertTokenizer, extractor: ViTFeatureExtractor, maxLen: int):
self.data = data
self.tokenizer = tokenizer
self.extractor = extractor
self.maxLen = maxLen
def __len__(self):
return len(self.data)
def __getitem__(self, item: int):
guid = self.data[item]['guid']
text = self.data[item]['text']
img = self.data[item]['img']
tag = self.data[item]['tag']
tag = torch.tensor(tags[tag], dtype=torch.long)
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
truncation=True,
max_length=self.maxLen,
padding='max_length',
return_token_type_ids=True,
return_attention_mask=True,
return_tensors='pt'
)
img = self.extractor(
images=Image.open(config.raw_data_path + img),
return_tensors='pt'
)
return {
'guid': guid,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'token_type_ids': encoding['token_type_ids'].flatten(),
'img': img,
'tag': tag
}
def getMultiDataset(path: str):
with open(path, 'r', encoding='utf-8') as fs:
data = json.load(fs)
tokenizer = bert.getTokenizer()
extractor = vit.getExtractor()
return MultiDataset(data, tokenizer, extractor, textConfig.max_len)