-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathopenai_api_endpoint.py
135 lines (111 loc) · 4.47 KB
/
openai_api_endpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import time
from pathlib import Path
from fastapi import FastAPI, Request
from just_agents.interfaces.IAgent import build_agent, IAgent
from just_agents.llm_session import LLMSession
from literature.routes import _hybrid_search
from genetics.main import rsid_lookup, gene_lookup, pathway_lookup, disease_lookup, sequencing_info
from clinical_trials.clinical_trails_router import _process_sql, clinical_trails_full_trial
from precious3GPT.p3gpt_tool import get_omics_data, get_enrichment
from precious3GPT.routes import omics_router
from starlette.responses import StreamingResponse
from dotenv import load_dotenv
from fastapi.middleware.cors import CORSMiddleware
from open_genes.tools import db_query
import loguru
import yaml
import mimetypes
import base64
import hashlib
import json
log_path = Path(__file__)
log_path = Path(log_path.parent, "logs", "openai_api_endpoint.log")
loguru.logger.add(log_path.absolute(), rotation="10 MB")
load_dotenv(override=True)
# What is the influence of different alleles in rs10937739 and what is MTOR gene?
app = FastAPI(title="Genetics Genie API endpoint.")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(omics_router)
@app.get("/", description="Defalt message", response_model=str)
def default():
return "This is default page for Genetics Genie API endpoint."
def clean_messages(request: dict):
for message in request["messages"]:
if message["role"] == "user":
content = message["content"]
if type(content) is list:
if len(content) > 0:
if type(content[0]) is dict:
if content[0].get("type", "") == "text":
if type(content[0].get("text", None)) is str:
message["content"] = content[0]["text"]
def remove_system_prompt(request: dict):
if request["messages"][0]["role"] == "system":
request["messages"] = request["messages"][1:]
def get_agent(request):
with open("endpoint_options.yaml") as f:
agent_schema = yaml.full_load(f).get(request["model"])
return build_agent(agent_schema)
def sha256sum(content_str):
hash = hashlib.sha256()
hash.update(content_str.encode('utf-8'))
return hash.hexdigest()
def save_files(file_params: list):
if not file_params:
return
for file in file_params:
file_name = file.get("name")
file_content_base64 = file.get("content")
file_checksum = file.get("checksum")
file_mime = file.get("mime")
if sha256sum(file_content_base64) != file_checksum:
raise Exception("File checksum does not match")
extension = mimetypes.guess_extension(file_mime)
file_content = base64.urlsafe_b64decode(file_content_base64.encode('utf-8'))
full_file_name = file_name + extension
file_path = Path('/tmp', full_file_name)
with open(file_path, "wb") as f:
loguru.logger.debug(f"Saving file {file_path}")
f.write(file_content)
@app.post("/v1/chat/completions")
def chat_completions(request: dict):
try:
loguru.logger.debug(request)
# options = get_options(request)
# session = get_session(options)
agent:IAgent = get_agent(request)
clean_messages(request)
remove_system_prompt(request)
if request["messages"]:
file_params = request.get("metadata", {}).get("file_params", [])
save_files(file_params)
if request.get("stream") and str(request.get("stream")).lower() != "false":
return StreamingResponse(
agent.stream(request["messages"]), media_type="application/x-ndjson"
)
resp_content = agent.query(request["messages"])
else:
resp_content = "Something goes wrong, request did not contain messages!!!"
except Exception as e:
loguru.logger.error(str(e))
resp_content = str(e)
return {
"id": "1",
"object": "chat.completion",
"created": time.time(),
"model": request["model"],
"choices": [{"message": {"role": "assistant", "content": resp_content}}],
}
@app.get("/v1/get_prompt_examples")
def get_prompt_examples():
with open("prompt_examples.json") as f:
return json.load(f)
if __name__ == "__main__":
import uvicorn
uvicorn.run("openai_api_endpoint:app", host="0.0.0.0", port=8088, workers=10)