-
Notifications
You must be signed in to change notification settings - Fork 0
/
sentence_type_classifier.py
51 lines (40 loc) · 1.75 KB
/
sentence_type_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import BertForSequenceClassification, AutoTokenizer
import torch
class SentenceTypeClassifier:
_bert:BertForSequenceClassification = None
_tokenizer:AutoTokenizer = None
_token_file:str = None
_pretrained_classifier:str = None
_device:torch.device = None
def __init__(self, pretrained_classifier, token_file):
self._pretrained_classifier = pretrained_classifier
self._token_file = token_file
@property
def device(self):
if self._device is None:
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return self._device
@property
def token(self):
return open(self._token_file, 'r').read()
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self._pretrained_classifier, token = self.token)
return self._tokenizer
@property
def bert(self):
if self._bert is None:
self._bert = BertForSequenceClassification.from_pretrained(self._pretrained_classifier, token=self.token).to(self.device)
self._bert.eval()
return self._bert
def _get_classification_inputs(self, message):
tokens = self.tokenizer(message, return_tensors='pt', max_length=400, padding='max_length', truncation=True)
input_ids = tokens['input_ids'].to(self.device)
attn_mask = tokens['attention_mask'].to(self.device)
return input_ids, attn_mask
def classify(self, message):
input_ids, attn_mask = self._get_classification_inputs(message)
outputs = self.bert(input_ids=input_ids, attention_mask=attn_mask)
_, preds = torch.max(outputs.logits, dim=1)
return preds