-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictors.py
61 lines (46 loc) · 2.14 KB
/
predictors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from abc import ABC, abstractmethod
from typing import List, Dict, Callable
from liquid import Template
from vllm import LLM, SamplingParams
import torch
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
import utils
import tasks
import os
import gc
######
# GlobalVars
######
file_path = os.path.abspath(os.path.dirname(__file__))
class VLLMPredictor(ABC):
def __init__(self, opt):
self.opt = opt
self.llm = None
self.llm_syn = None
def eval_multiple(self, prompt, n=1):
if self.llm_syn is None:
# This vLLM function resets the global variables, which enables initializing models
destroy_model_parallel()
# If you face CUDA OOM Error, then delete all the left over queued operations
del self.llm
self.llm = None
torch.cuda.synchronize()
gc.collect()
self.llm_syn = LLM(model=self.opt["model"] if self.opt["paraphraser"] is None else self.opt["paraphraser"], download_dir=file_path+"/../cache", gpu_memory_utilization=.8, dtype="half", seed=420)
outputs = self.llm_syn.generate([prompt for i in range(n)], SamplingParams(temperature=0.8, max_tokens=1024))
return [output.outputs[0].text for output in outputs]
def inference(self, ex, prompt):
if self.llm is None:
# This vLLM function resets the global variables, which enables initializing models
destroy_model_parallel()
# If you face CUDA OOM Error, then delete all the left over queued operations
del self.llm_syn
self.llm_syn = None
torch.cuda.synchronize()
gc.collect()
self.llm = LLM(model=self.opt["model"], download_dir=file_path+"/../cache", gpu_memory_utilization=.8, dtype="half", seed=420)
prompt = Template(prompt).render(text=ex['text'])
res = self.llm.generate([prompt], SamplingParams(temperature=self.opt["temperature"], max_tokens=self.opt["max_tokens"]))
response = res[0].outputs[0].text
pred = 1 if response.strip().upper().startswith('REFUTES') else 0
return pred