-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
375 lines (316 loc) · 15 KB
/
main.py
File metadata and controls
375 lines (316 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
"""Copilot Hub - FastAPI 主入口。
启动顺序:
1. 初始化 SQLite(建表)
2. 启动共享 httpx.AsyncClient
3. 从 DB 加载账号,初始化 token 管理池(后台刷新循环)
4. 注册所有路由
5. 挂载静态前端
"""
import logging
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import aiosqlite
import aiosqlite as _aiosqlite
import config as _config
from config import PORT, DB_PATH
import database
from database import init_db
from routers.deps import db_dep as _db_dep
from routers import api_keys, auth, dashboard, proxy, admin
from routers import user_auth, avatar
from routers import providers as providers_router
from routers import backup as backup_router
from routers import install as install_router
from services import token_manager, provider_manager, backup_manager
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(name)s %(message)s",
)
logger = logging.getLogger(__name__)
# ── 生命周期 ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting Copilot Hub...")
# 启动前先做一次 startup 备份(使用 backup_manager,记录在 DB 中)
import os as _os
_bname: str | None = None
_bdst: str | None = None
if _os.path.exists(DB_PATH) and _os.path.getsize(DB_PATH) > 0:
# 在 init_db 前用底层 copy 备份(此时 DB 可能还没有 database_backups 表)
import shutil as _shutil, time as _time
_bdir = str(_os.path.join(_os.path.dirname(DB_PATH), "backups"))
_os.makedirs(_bdir, exist_ok=True)
# 避免频繁重启时重复创建启动备份:检查最近的 startup 备份文件时间
_skip_startup_backup = False
_STARTUP_DEDUP_SECS = 300 # 5分钟内重复启动不创建新备份
try:
_existing = sorted(
[f for f in _os.listdir(_bdir) if "_startup.db" in f],
reverse=True,
)
if _existing:
_latest_ts_str = _existing[0].split("_")[2] # copilot_hub_{ts}_startup.db
if _time.time() - int(_latest_ts_str) < _STARTUP_DEDUP_SECS:
_skip_startup_backup = True
logger.info("Skipping startup backup (recent one exists: %s)", _existing[0])
except Exception:
pass
if not _skip_startup_backup:
_bname = f"copilot_hub_{int(_time.time())}_startup.db"
_bdst = _os.path.join(_bdir, _bname)
_shutil.copy2(DB_PATH, _bdst)
logger.info("Startup backup: %s", _bname)
# 初始化数据库(建表,幂等,不会删除现有数据)
await init_db()
logger.info("Database initialized")
# 将 startup 备份记录到 DB,并执行裁剪(避免堆积)
if _bname and _bdst and _os.path.exists(_bdst):
try:
import aiosqlite as _aio
from services.backup_manager import _prune_backups, BACKUP_DIR, DEFAULT_MAX_BACKUPS, get_backup_config
async with _aio.connect(DB_PATH) as _db:
await _db.execute(
"INSERT OR IGNORE INTO database_backups(filename, size_bytes, trigger)"
" VALUES (?, ?, 'startup')",
(_bname, _os.path.getsize(_bdst)),
)
await _db.commit()
# 裁剪超出 max_count 的旧备份
try:
_cfg = await get_backup_config()
await _prune_backups(_cfg["max_count"], _cfg["backup_dir"])
except Exception:
pass
except Exception:
pass
# 从 DB 加载代理配置(覆盖环境变量值)
async with _aiosqlite.connect(DB_PATH) as _db:
async with _db.execute("SELECT value FROM settings WHERE key='https_proxy'") as _cur:
_row = await _cur.fetchone()
if _row and _row[0]:
_config.HTTPS_PROXY = _row[0]
logger.info("Proxy loaded from DB: %s", _config.HTTPS_PROXY)
elif _config.HTTPS_PROXY:
logger.info("Proxy loaded from env: %s", _config.HTTPS_PROXY)
# 创建共享 HTTP 客户端(带代理)
# 注意:代理流量用 HTTP/1.1 而非 HTTP/2,避免流式长请求占满 H2 多路复用连接池
# 导致 token 刷新等短请求 PoolTimeout
_proxy_url = _config.HTTPS_PROXY or None
client = httpx.AsyncClient(
timeout=httpx.Timeout(120.0, connect=15.0),
follow_redirects=True,
http2=False, # 显式关闭 HTTP/2,使用 HTTP/1.1 每请求独占连接,避免池阻塞
proxy=_proxy_url,
limits=httpx.Limits(
max_connections=200,
max_keepalive_connections=50,
keepalive_expiry=30,
),
)
if _proxy_url:
logger.info("HTTP client using proxy: %s", _proxy_url)
# 初始化 token 池(从 DB 加载账号,启动刷新循环)
await token_manager.startup(client)
# 初始化第三方供应商池
await provider_manager.startup(client)
# 将 client 注入到 proxy 路由
proxy.set_http_client(client)
providers_router.set_http_client(client)
# 后台定期刷新模型列表缓存
import asyncio as _asyncio
async def _model_refresh_loop():
await _asyncio.sleep(10) # 等待 token_manager 就绪
while True:
await proxy._refresh_models_cache()
await _asyncio.sleep(proxy._MODELS_TTL)
_asyncio.create_task(_model_refresh_loop(), name="model-refresh-loop")
# 后台定期刷新所有账号的官方 quota(供负载均衡使用)
async def _quota_refresh_loop():
await _asyncio.sleep(15) # 等待 token pool 初始完成
while True:
await token_manager.refresh_quota_for_all()
await _asyncio.sleep(token_manager.QUOTA_REFRESH_INTERVAL)
_asyncio.create_task(_quota_refresh_loop(), name="quota-refresh-loop")
# 后台定期刷新供应商配额
async def _provider_quota_loop():
await _asyncio.sleep(20)
while True:
await provider_manager.refresh_quota_for_all()
await _asyncio.sleep(300)
_asyncio.create_task(_provider_quota_loop(), name="provider-quota-refresh-loop")
# 后台批量写入 usage 记录(每 5s flush)
from routers.proxy import _usage_flush_loop
_asyncio.create_task(_usage_flush_loop(), name="usage-flush-loop")
# 后台定期自动备份数据库
backup_manager.start_backup_loop()
logger.info("Copilot Hub started, listening on :%d", PORT)
yield
# 关闭:取消所有后台任务,关闭 HTTP 客户端
await token_manager.shutdown()
logger.info("Copilot Hub stopped")
# ── FastAPI 应用 ──────────────────────────────────────────────────────────────
app = FastAPI(
title="Copilot Hub",
description="GitHub Copilot 账号共享代理系统",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ── /api 前缀重写中间件(兼容生产环境,Vite dev proxy 会自动剥离 /api)────────
class _StripApiPrefix:
"""纯 ASGI 中间件:将 /api/... 重写为 /...,无需 Vite proxy 也能正常工作"""
def __init__(self, app):
self._app = app
async def __call__(self, scope, receive, send):
if scope.get("type") == "http":
path = scope.get("path", "")
if path.startswith("/api/"):
scope = dict(scope)
scope["path"] = path[4:]
scope["raw_path"] = path[4:].encode()
await self._app(scope, receive, send)
app.add_middleware(_StripApiPrefix)
# ── 注册路由 ──────────────────────────────────────────────────────────────────
app.include_router(user_auth.router)
app.include_router(auth.router)
app.include_router(api_keys.router)
app.include_router(proxy.router)
app.include_router(dashboard.router)
app.include_router(admin.router)
app.include_router(avatar.router)
app.include_router(providers_router.router)
app.include_router(providers_router.hub_router)
app.include_router(backup_router.router)
app.include_router(install_router.router)
@app.get("/hub/models", include_in_schema=False)
async def public_models():
"""Public endpoint for frontend to get available model list (no auth required)."""
return proxy.get_cached_models() # {all, chat, embedding}
@app.get("/hub/models/info", include_in_schema=False)
async def public_models_info(db: aiosqlite.Connection = Depends(_db_dep)):
"""Public endpoint: rich model metadata (context window, vendor, caps) — no auth."""
info = dict(proxy.get_cached_models_info()) # Copilot models
# Merge ext_provider_models
db.row_factory = aiosqlite.Row
async with db.execute(
"""SELECT epm.*, ep.name AS provider_name, ep.provider_type
FROM ext_provider_models epm
JOIN ext_providers ep ON ep.id = epm.provider_id
WHERE epm.enabled=1"""
) as cur:
ext_rows = await cur.fetchall()
import json as _json
for row in ext_rows:
alias = row["model_alias"]
if alias not in info: # don't override Copilot model entries
caps_list = _json.loads(row["capabilities"] or '["chat"]') if isinstance(row["capabilities"], str) else (row["capabilities"] or ["chat"])
info[alias] = {
"id": alias,
"name": row["model_name"] or alias,
"vendor": row["provider_name"],
"type": "chat" if "chat" in caps_list else caps_list[0] if caps_list else "chat",
"preview": False,
"model_picker_enabled": True,
"context_window": row["context_length"],
"max_output_tokens": row["max_output_tokens"],
"max_prompt_tokens": row["max_prompt_tokens"],
"supports_vision": bool(row["supports_vision"]),
"supports_tools": bool(row["supports_tools"]),
"supports_reasoning": bool(row["supports_reasoning"]),
"source": "ext",
"provider_type": row["provider_type"],
"is_public": bool(row["is_public"]),
}
# Copilot models require GitHub account — mark them not public
result = []
for m in info.values():
if m.get("source") != "ext":
m = dict(m, is_public=False)
result.append(m)
return {"models": result}
# ── 前端静态文件 ──────────────────────────────────────────────────────────────
import os as _os
_STATIC_DIR = _os.path.join(_os.path.dirname(__file__), "static")
_VITE_DEV_URL = _os.environ.get("VITE_DEV_URL", "") # e.g. "http://localhost:8002"
_API_PREFIXES = ("v1/", "auth/", "keys/", "dashboard/", "admin/", "avatar/", "hub/")
def _is_api(path: str) -> bool:
return path.startswith(_API_PREFIXES)
if _VITE_DEV_URL or _os.path.exists(_os.path.join(_os.path.dirname(__file__), ".vite_proxy")):
_VITE_DEV_URL = _VITE_DEV_URL or "http://localhost:8002"
from fastapi import Request
from fastapi.responses import Response
import httpx as _httpx
@app.api_route("/{full_path:path}", methods=["GET", "HEAD"], include_in_schema=False)
async def vite_proxy(full_path: str, request: Request):
if _is_api(full_path):
from fastapi import HTTPException
raise HTTPException(status_code=404)
target = f"{_VITE_DEV_URL}/{full_path}"
async with _httpx.AsyncClient() as c:
try:
resp = await c.request(
method=request.method,
url=target,
headers={k: v for k, v in request.headers.items() if k.lower() != "host"},
params=dict(request.query_params),
timeout=5.0,
)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=dict(resp.headers),
media_type=resp.headers.get("content-type"),
)
except Exception:
# Vite dev server not reachable — fall back to static build
index = _os.path.join(_STATIC_DIR, "index.html")
if _os.path.exists(index):
return FileResponse(index)
from fastapi import HTTPException
raise HTTPException(status_code=503, detail="Frontend not available")
elif _os.path.exists(_os.path.join(_STATIC_DIR, "index.html")):
# ── 生产模式:挂载编译后的静态文件 ────────────────────────────────────
# 先挂载 assets 子目录(哈希文件名,可强缓存)
_assets_dir = _os.path.join(_STATIC_DIR, "assets")
if _os.path.exists(_assets_dir):
app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
@app.get("/{full_path:path}", include_in_schema=False)
async def spa_fallback(full_path: str):
if _is_api(full_path):
from fastapi import HTTPException
raise HTTPException(status_code=404)
# 先尝试直接返回静态文件(vite.svg / favicon 等根目录文件)
candidate = _os.path.join(_STATIC_DIR, full_path) if full_path else ""
if candidate and _os.path.isfile(candidate):
return FileResponse(candidate)
# SPA fallback — 所有未命中的路径都返回 index.html(no-cache 防止浏览器缓存旧版本)
from fastapi.responses import Response as _Resp
import os as _os2
_idx = _os.path.join(_STATIC_DIR, "index.html")
_content = open(_idx, "rb").read()
return _Resp(content=_content, media_type="text/html",
headers={"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache", "Expires": "0"})
# ── 本地运行入口 ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvloop as _uvloop
_uvloop.install()
uvicorn.run(
"main:app",
host="0.0.0.0",
port=PORT,
reload=False,
log_level="info",
loop="uvloop",
)