Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
SilianZ committed Mar 24, 2024
2 parents 2fbc427 + 4d3b5ff commit cccefdc
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 122 deletions.
58 changes: 34 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,40 @@
## 配置文件

```yml
# 是否不使用 BMCLAPI 分发的证书, 同 CLUSTER_BYOC
byoc: false
# OpenBMCLAPI 的 CLUSTER_ID
cluster_id: ''
# OpenBMCLAPI 的 CLUSTER_SECRET
cluster_secret: ''
# 同步文件时最多打开的连接数量
download_threads: 64
# 超时时间
timeout: 30
# 实际开放的公网主机名, 同 CLUSTER_IP
web_host: ''
# 要监听的本地端口, 同 CLUSTER_PORT
web_port: 8800
# 实际开放的公网端口, 同 CLUSTER_PUBLIC_PORT
web_publicport: 8800
io_buffer: 16777216
max_download: 64
min_rate: 500
min_rate_timestamp: 1000
port: 8800
public_host: ''
public_port: null
server_name: TTB-Network
advanced:
# 新连接读取数据头大小
header_bytes: 4096
# 数据传输缓存大小
io_buffer: 16777216
# 最小读取速率(Bytes)
min_rate: 500
# 最小读取速率时间
min_rate_timestamp: 1000
# 请求缓存大小
request_buffer: 8192
# 超时时间
timeout: 30
cluster:
# 是否不使用 BMCLAPI 分发的证书, 同 CLUSTER_BYOC
byoc: false
# OpenBMCLAPI 的 CLUSTER_ID
id: ''
# 实际开放的公网主机名, 同 CLUSTER_IP
public_host: ''
# 实际开放的公网端口, 同 CLUSTER_PUBLIC_PORT
public_port: 8800
# OpenBMCLAPI 的 CLUSTER_SECRET
secret: ''
download:
# 最高下载线程
threads: 64
web:
# 要监听的本地端口, 同 CLUSTER_PORT
port: 80
# 服务器名字
server_name: TTB-Network
# SSL 端口
ssl_port: 8800
```

# 贡献
Expand Down
2 changes: 1 addition & 1 deletion bmclapi_dashboard/static/js/index.min.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bmclapi_dashboard/static/js/ttb.min.js

Large diffs are not rendered by default.

25 changes: 20 additions & 5 deletions core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def disconnect(self, client: ProxyClient):
return
self._tables.remove(client)

def get_origin_from_ip(self, ip: tuple[str, int]):
# ip is connected client
for target in self._tables:
if target.target.get_sock_address() == ip:
return target.origin.get_address()
return None


ssl_server: Optional[asyncio.Server] = None
server: Optional[asyncio.Server] = None
Expand All @@ -106,7 +113,14 @@ def disconnect(self, client: ProxyClient):


async def _handle_ssl(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
return await _handle_process(Client(reader, writer), True)
return await _handle_process(
Client(
reader,
writer,
peername=proxy.get_origin_from_ip(writer.get_extra_info("peername")),
),
True,
)


async def _handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
Expand All @@ -131,7 +145,8 @@ async def _handle_process(client: Client, ssl: bool = False):
await asyncio.open_connection(
"127.0.0.1", ssl_server.sockets[0].getsockname()[1]
)
)
),
peername=client.get_address(),
)
proxying = True
await proxy.connect(client, target, header)
Expand All @@ -156,7 +171,7 @@ async def check_ports():
ports: list[tuple[asyncio.Server, ssl.SSLContext | None]] = []
for service in (
(server, None),
(ssl_server, client_side_ssl if get_loads() != 0 else None),
(ssl_server, client_side_ssl if get_loaded() else None),
):
if not service[0]:
continue
Expand Down Expand Up @@ -192,20 +207,20 @@ async def check_ports():
async def main():
global ssl_server, server, server_side_ssl, restart
await web.init()
certificate.load_cert(Path(".ssl/cert"), Path(".ssl/key"))
Timer.delay(check_ports, (), 5)
while 1:
try:
server = await asyncio.start_server(_handle, port=PORT)
ssl_server = await asyncio.start_server(
_handle_ssl,
port=0 if SSL_PORT == PORT else SSL_PORT,
ssl=server_side_ssl if get_loads() != 0 else None,
ssl=server_side_ssl if get_loaded() else None,
)
logger.info(f"Listening server on port {PORT}.")
logger.info(
f"Listening server on {ssl_server.sockets[0].getsockname()[1]}."
)
logger.info(f"Loaded {get_loads()} certificates!")
async with server, ssl_server:
await asyncio.gather(server.serve_forever(), ssl_server.serve_forever())
except asyncio.CancelledError:
Expand Down
17 changes: 16 additions & 1 deletion core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class File:
last_hit: float = 0
last_access: float = 0
data: Optional[io.BytesIO] = None
cache: bool = False

def is_url(self):
if not isinstance(self.path, str):
Expand All @@ -60,6 +61,12 @@ def set_data(self, data: io.BytesIO | memoryview | bytes):
self.data = io.BytesIO(zlib.compress(data.getbuffer()))


@dataclass
class StatsCache:
total: int = 0
bytes: int = 0


class Storage(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def get(self, file: str) -> File:
Expand All @@ -79,6 +86,10 @@ async def exists(self, hash: str) -> bool:
async def get_size(self, hash: str) -> int:
raise NotImplementedError

@abc.abstractmethod
async def copy(self, origin: Path, hash: str) -> int:
raise NotImplementedError

@abc.abstractmethod
async def write(self, hash: str, io: io.BytesIO) -> int:
raise NotImplementedError
Expand All @@ -101,6 +112,10 @@ async def get_files_size(self, dir: str) -> int:
async def removes(self, hashs: list[str]) -> int:
raise NotImplementedError

@abc.abstractmethod
async def get_cache_stats(self) -> StatsCache:
raise NotImplementedError


def get_hash(org):
if len(org) == 32:
Expand All @@ -112,7 +127,7 @@ def get_hash(org):
async def get_file_hash(org: str, path: Path):
hash = get_hash(org)
async with aiofiles.open(path, "rb") as r:
while data := await r.read(Config.get("io_buffer")):
while data := await r.read(Config.get("advanced.io_buffer")):
if not data:
break
hash.update(data)
Expand Down
18 changes: 9 additions & 9 deletions core/certificate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
import ssl
import time
Expand All @@ -13,24 +14,25 @@
client_side_ssl = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
client_side_ssl.check_hostname = False

_loads: int = 0
_loaded: bool = False


def load_cert(cert, key):
global server_side_ssl, client_side_ssl, _loads
global server_side_ssl, client_side_ssl, _loaded
if not os.path.exists(cert) or not os.path.exists(key):
return False
try:
server_side_ssl.load_cert_chain(cert, key)
client_side_ssl.load_verify_locations(cert)
_loads += 1
_loaded = True
return True
except:
logger.error(f"Failed to load certificate: {traceback.format_exc()}.")
return False


def get_loads() -> int:
global _loads
return _loads
def get_loaded() -> bool:
return _loaded


def load_text(cert: str, key: str):
Expand All @@ -41,9 +43,7 @@ def load_text(cert: str, key: str):
c.write(cert)
k.write(key)
if load_cert(cert_file, key_file):
logger.info(
f"Loaded certificate from local files! Current certificate: {get_loads()}."
)
logger.info("Loaded certificate from local files!")
core.restart = True
if core.server:
core.server.close()
Expand Down
Loading

0 comments on commit cccefdc

Please sign in to comment.