-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
163 lines (136 loc) · 5.22 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
from transformers import pipeline
class ZeroShotModels:
"""
A class to manage ZeroShot models.
Attributes:
ZERO_SHOT (str): Constant for ZeroShot classification.
ATTACK (str): Constant for the "attack" candidate label.
NORMAL (str): Constant for the "normal" candidate label.
candidate_labels (list): List of candidate labels.
models (list): List of model configurations.
Methods:
get_models_by_suffix(suffix): Returns models with the given suffix.
get_models_by_name(model_name): Returns models with the given model name.
get_all_suffixes(): Returns a list of all unique suffixes in the models.
get_all_model_names(): Returns a list of all unique model names in the models.
get_all_models(): Returns all models configurations.
initialise_model(hugging_face_model_name): Initializes and returns a model using the given Hugging Face model name.
classify(model, input_strings): Classifies the input strings using the provided model.
"""
ZERO_SHOT = "zero-shot-classification"
ATTACK = "attack"
NORMAL = "normal"
candidate_labels = [
ATTACK,
NORMAL
]
models = [
{
"model": None,
"suffix": "DeBERTa",
"model_name": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"context_size": 800,
},
{
"model": None,
"suffix": "facebook_zero",
"model_name": "facebook/bart-large-mnli",
"context_size": 800,
},
{
"model_name": "mjwong/e5-large-v2-mnli-anli",
"model": None,
"suffix": "glue-anli",
"context_size": 1600,
},
{
"model_name": "niting3c/llama-2-7b-hf-zero-shot",
"model": None,
"suffix": "llama-2-7b",
"context_size": 3500,
"train": []
},
{
"model_name": "lmsys/vicuna-7b-v1.3",
"model": None,
"suffix": "vicuna",
"context_size": 1600,
},
]
def get_models_by_suffix(self, suffix):
"""
Returns models with the given suffix.
Args:
suffix (str): The suffix to filter models.
Returns:
list: List of models with the matching suffix.
"""
return [model for model in self.models if model["suffix"].lower() == suffix.lower()]
def get_models_by_name(self, model_name):
"""
Returns models with the given model name.
Args:
model_name (str): The model name to filter models.
Returns:
list: List of models with the matching model name.
"""
return [model for model in self.models if model["model_name"].lower() == model_name.lower()]
def get_all_suffixes(self):
"""
Returns a list of all unique suffixes in the models.
Returns:
list: List of all unique suffixes in the models.
"""
return list(set(model["suffix"] for model in self.models))
def get_all_model_names(self):
"""
Returns a list of all unique model names in the models.
Returns:
list: List of all unique model names in the models.
"""
return list(set(model["model_name"] for model in self.models))
def get_all_models(self):
"""
Returns all models configurations.
Returns:
list: List of all models configurations.
"""
return self.models
def initialise_model(self, hugging_face_model_name):
"""
Initializes and returns a model using the given Hugging Face model name.
Args:
hugging_face_model_name (str): The Hugging Face model name.
Returns:
pipeline: The initialized pipeline object for the model or None if there was an error.
"""
try:
print(f"Loading: {hugging_face_model_name}")
print(f"GPU Being Used: {torch.cuda.is_available()}")
return pipeline(task=self.ZERO_SHOT,
model=hugging_face_model_name,
use_auth_token=True,
trust_remote_code=True,
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
except Exception as e:
print(f"Error initializing {hugging_face_model_name}: {e}")
return None
def classify(self, model, input_strings):
"""
Classifies the input strings using the provided model.
Args:
model (pipeline): The initialized pipeline model.
input_strings (str or list): The input string(s) to classify.
Returns:
score object: The classification results
"""
if model is None:
print("Model not initialized")
return {}
try:
return model(input_strings, candidate_labels=self.candidate_labels)
except Exception as e:
print(f"Error generating response from classifier: {e}")
return {}