-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
164 lines (148 loc) · 7.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import warnings
from cryptography.utils import CryptographyDeprecationWarning
from logger_config import setup_logger
from endpoint_functions import router
import asyncio
import os
import random
import traceback
import fastapi
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.responses import Response
import uvloop
from uvicorn import Config, Server
from decouple import Config as DecoupleConfig, RepositoryEnv
from database_code import initialize_db
from setup_swiss_army_llama import check_and_setup_swiss_army_llama
from service_functions import (
monitor_new_messages,
generate_or_load_encryption_key_sync,
decrypt_sensitive_data,
get_env_value,
fetch_all_mnid_tickets_details,
establish_ssh_tunnel,
schedule_micro_benchmark_periodically,
list_generic_tickets_in_blockchain_and_parse_and_validate_and_store_them,
generate_supernode_inference_ip_blacklist,
kill_open_ssh_tunnels
)
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
config = DecoupleConfig(RepositoryEnv('.env'))
UVICORN_PORT = config.get("UVICORN_PORT", cast=int)
USE_REMOTE_SWISS_ARMY_LLAMA_IF_AVAILABLE = config.get("USE_REMOTE_SWISS_ARMY_LLAMA_IF_AVAILABLE", default=0, cast=int)
REMOTE_SWISS_ARMY_LLAMA_MAPPED_PORT = config.get("REMOTE_SWISS_ARMY_LLAMA_MAPPED_PORT", cast=int)
SWISS_ARMY_LLAMA_SECURITY_TOKEN = config.get("SWISS_ARMY_LLAMA_SECURITY_TOKEN", cast=str)
os.environ['TZ'] = 'UTC' # Set timezone to UTC for the current session
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = setup_logger()
app = fastapi.FastAPI(
title="Pastel-Supernode-Inference-Layer",
description="Pastel Supernode Inference Layer API",
docs_url="/",
redoc_url="/redoc"
)
class LimitRequestSizeMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_request_size: int):
super().__init__(app)
self.max_request_size = max_request_size
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
request_size = int(request.headers.get('content-length', 0))
if request_size > self.max_request_size:
return Response("Request size exceeds the limit", status_code=413)
return await call_next(request)
app.add_middleware(LimitRequestSizeMiddleware, max_request_size=50 * 1024 * 1024)
app.include_router(router, prefix='', tags=['main'])
# Custom Exception Handling Middleware
@app.middleware("http")
async def custom_exception_handling(request: Request, call_next):
try:
return await call_next(request)
except RequestValidationError as ve:
logger.error(f"Validation error: {ve}")
return JSONResponse(status_code=ve.status_code, content={"detail": ve.error_msg})
except Exception as e:
tb = traceback.format_exc() # Get the full traceback
logger.error(f"Unhandled exception: {e}\n{tb}") # Log the exception with traceback
return JSONResponse(status_code=500, content={"detail": str(e)})
# CORS Middleware
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
expose_headers=["Authorization"]
)
def decrypt_sensitive_fields():
global LOCAL_PASTEL_ID_PASSPHRASE, SWISS_ARMY_LLAMA_SECURITY_TOKEN, OPENAI_API_KEY, CLAUDE3_API_KEY, GROQ_API_KEY, MISTRAL_API_KEY, STABILITY_API_KEY, OPENROUTER_API_KEY, encryption_key
LOCAL_PASTEL_ID_PASSPHRASE = decrypt_sensitive_data(get_env_value("LOCAL_PASTEL_ID_PASSPHRASE"), encryption_key)
SWISS_ARMY_LLAMA_SECURITY_TOKEN = decrypt_sensitive_data(get_env_value("SWISS_ARMY_LLAMA_SECURITY_TOKEN"), encryption_key)
OPENAI_API_KEY = decrypt_sensitive_data(get_env_value("OPENAI_API_KEY"), encryption_key)
CLAUDE3_API_KEY = decrypt_sensitive_data(get_env_value("CLAUDE3_API_KEY"), encryption_key)
GROQ_API_KEY = decrypt_sensitive_data(get_env_value("GROQ_API_KEY"), encryption_key)
MISTRAL_API_KEY = decrypt_sensitive_data(get_env_value("MISTRAL_API_KEY"), encryption_key)
STABILITY_API_KEY = decrypt_sensitive_data(get_env_value("STABILITY_API_KEY"), encryption_key)
OPENROUTER_API_KEY = decrypt_sensitive_data(get_env_value("OPENROUTER_API_KEY"), encryption_key)
async def startup():
global encryption_key # Declare encryption_key as global
try:
db_init_complete = await initialize_db()
logger.info(f"Database initialization complete: {db_init_complete}")
encryption_key = generate_or_load_encryption_key_sync() # Generate or load the encryption key synchronously
decrypt_sensitive_fields() # Now decrypt sensitive fields
asyncio.create_task(monitor_new_messages()) # Create a background task
asyncio.create_task(fetch_all_mnid_tickets_details())
asyncio.create_task(list_generic_tickets_in_blockchain_and_parse_and_validate_and_store_them())
asyncio.create_task(asyncio.to_thread(check_and_setup_swiss_army_llama, SWISS_ARMY_LLAMA_SECURITY_TOKEN)) # Check and setup Swiss Army Llama asynchronously
await generate_supernode_inference_ip_blacklist() # Compile IP blacklist text file
asyncio.create_task(schedule_generate_supernode_inference_ip_blacklist()) # Schedule the task
asyncio.create_task(schedule_micro_benchmark_periodically()) # Schedule the task
except Exception as e:
logger.error(f"Error during startup: {e}")
logger.error(traceback.format_exc())
@app.on_event("startup")
async def startup_event():
await startup()
async def schedule_generate_supernode_inference_ip_blacklist():
while True:
jitter = random.randint(-180, 180) # Jitter of up to 3 minutes (180 seconds)
interval_seconds = 300 + jitter # 300 seconds = 5 minutes
await asyncio.sleep(interval_seconds)
await generate_supernode_inference_ip_blacklist()
async def main():
uvicorn_config = Config(
"main:app",
host="0.0.0.0",
port=UVICORN_PORT,
loop="uvloop",
)
server = Server(uvicorn_config)
await server.serve()
if __name__ == "__main__":
if USE_REMOTE_SWISS_ARMY_LLAMA_IF_AVAILABLE:
kill_open_ssh_tunnels(REMOTE_SWISS_ARMY_LLAMA_MAPPED_PORT)
# Create and run event loop for SSH tunnel setup
async def setup_tunnel():
try:
await establish_ssh_tunnel()
except Exception as e:
logger.error(f"Error establishing SSH tunnel: {e}")
logger.error(traceback.format_exc())
# Run the tunnel setup in the event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(setup_tunnel())
except Exception as e:
logger.error(f"Error in tunnel setup loop: {e}")
logger.error(traceback.format_exc())
finally:
loop.close()
generate_or_load_encryption_key_sync()
config = DecoupleConfig(RepositoryEnv('.env'))
asyncio.run(main())