-
Notifications
You must be signed in to change notification settings - Fork 4
/
LLM.py
187 lines (164 loc) · 7.21 KB
/
LLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
""" This file contains the code for calling all LLM APIs. """
import os
from functools import partial
import tiktoken
# from schema import TooLongPromptError, LLMError
enc = tiktoken.get_encoding("cl100k_base")
try:
from helm.common.authentication import Authentication
from helm.common.request import Request, RequestResult
from helm.proxy.accounts import Account
from helm.proxy.services.remote_service import RemoteService
# setup CRFM API
auth = Authentication(api_key=open("crfm_api_key.txt").read().strip())
service = RemoteService("https://crfm-models.stanford.edu")
account: Account = service.get_account(auth)
except Exception as e:
print(e)
print("Could not load CRFM API key crfm_api_key.txt.")
try:
import anthropic
#setup anthropic API key
anthropic_client = anthropic.Anthropic(api_key=open("claude_api_key.txt").read().strip())
except Exception as e:
print(e)
print("Could not load anthropic API key claude_api_key.txt.")
try:
import openai
from openai import OpenAI
organization, api_key = open("openai_api_key.txt").read().strip().split(":")
os.environ["OPENAI_API_KEY"] = api_key
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
except Exception as e:
print(e)
print("Could not load OpenAI API key openai_api_key.txt.")
def log_to_file(log_file, prompt, completion, model, max_tokens_to_sample):
""" Log the prompt and completion to a file."""
with open(log_file, "a") as f:
f.write("\n===================prompt=====================\n")
f.write(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}")
num_prompt_tokens = len(enc.encode(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}"))
f.write(f"\n==================={model} response ({max_tokens_to_sample})=====================\n")
f.write(completion)
num_sample_tokens = len(enc.encode(completion))
f.write("\n===================tokens=====================\n")
f.write(f"Number of prompt tokens: {num_prompt_tokens}\n")
f.write(f"Number of sampled tokens: {num_sample_tokens}\n")
f.write("\n\n")
def complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT], model="claude-v1", max_tokens_to_sample = 2000, temperature=0.5, log_file=None, **kwargs):
""" Call the Claude API to complete a prompt."""
ai_prompt = anthropic.AI_PROMPT
if "ai_prompt" in kwargs is not None:
ai_prompt = kwargs["ai_prompt"]
del kwargs["ai_prompt"]
# model = "claude-2"
if model.startswith("claude-3"):
messages = [
{'role': 'user', 'content': f"{anthropic.HUMAN_PROMPT} {prompt}"}
]
rsp = anthropic_client.messages.create(
model=model,
messages=messages,
max_tokens=max_tokens_to_sample
)
completion = rsp.content[0].text
if log_file is not None:
log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
return completion
try:
rsp = anthropic_client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {ai_prompt}",
stop_sequences=stop_sequences,
model=model,
temperature=temperature,
max_tokens_to_sample=max_tokens_to_sample,
**kwargs
)
except anthropic.APIStatusError as e:
print(e)
exit()
raise TooLongPromptError()
except Exception as e:
exit()
raise LLMError(e)
completion = rsp.completion
if log_file is not None:
log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
return completion
def get_embedding_crfm(text, model="openai/gpt-4-0314"):
request = Request(model="openai/text-similarity-ada-001", prompt=text, embedding=True)
request_result: RequestResult = service.make_request(auth, request)
return request_result.embedding
def complete_text_crfm(prompt=None, stop_sequences = None, model="openai/gpt-4-0314", max_tokens_to_sample=2000, temperature = 0.5, log_file=None, messages = None, **kwargs):
random = log_file
if messages:
request = Request(
prompt=prompt,
messages=messages,
model=model,
stop_sequences=stop_sequences,
temperature = temperature,
max_tokens = max_tokens_to_sample,
random = random
)
else:
print("model", model)
print("max_tokens", max_tokens_to_sample)
request = Request(
prompt=prompt,
model=model,
stop_sequences=stop_sequences,
temperature = temperature,
max_tokens = max_tokens_to_sample,
random = random
)
try:
request_result: RequestResult = service.make_request(auth, request)
except Exception as e:
# probably too long prompt
print(e)
exit()
# raise TooLongPromptError()
if request_result.success == False:
print(request.error)
# raise LLMError(request.error)
completion = request_result.completions[0].text
if log_file is not None:
log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
return completion
def complete_text_openai(prompt, stop_sequences=[], model="gpt-3.5-turbo", max_tokens_to_sample=2000, temperature=0.5, log_file=None, **kwargs):
""" Call the OpenAI API to complete a prompt."""
raw_request = {
"model": model,
# "temperature": temperature,
# "max_completion_tokens": max_tokens_to_sample,
# "stop": stop_sequences or None, # API doesn't like empty list
**kwargs
}
if model.startswith("gpt-3.5") or model.startswith("gpt-4") or model.startswith("o1"):
# Requires openai==1.42.0
messages = [{"role": "user", "content": prompt}]
response = client.chat.completions.create(**{"messages": messages,**raw_request})
completion = response.choices[0].message.content
else:
response = client.completions.create(**{"prompt": prompt,**raw_request})
completion = response.choices[0].text
if log_file is not None:
log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
return completion
def complete_text(prompt, log_file, model, **kwargs):
""" Complete text using the specified model with appropriate API. """
if model.startswith("claude"):
# use anthropic API
completion = complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT, "Observation:"], log_file=log_file, model=model, **kwargs)
elif "/" in model:
# use CRFM API since this specifies organization like "openai/..."
completion = complete_text_crfm(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
else:
# use OpenAI API
completion = complete_text_openai(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
return completion
# specify fast models for summarization etc
FAST_MODEL = "claude-v1"
def complete_text_fast(prompt, **kwargs):
return complete_text(prompt = prompt, model = FAST_MODEL, temperature =0.01, **kwargs)