Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for zstd ws compressor #238

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@
},
}

# Redis used for users presences and statuses. If empty, presences will be stored in memory.
# Redis used for users' presences and statuses. If empty, presences will be stored in memory.
REDIS_URL = ""

# How often gateway clients must send keep-alive packets (also, presences expiration time is this variable times 1.25).
# How often gateway clients must send keep-alive packets (also, presence expiration time is this variable times 1.25).
# Default value is 45 seconds, do not set it too big or too small.
GATEWAY_KEEP_ALIVE_DELAY = 45

Expand All @@ -89,7 +89,7 @@

# Settings for external application connections
# For every application, use https://PUBLIC_HOST/connections/SERVICE_NAME/callback as redirect (callback) url,
# for example, if you need to create GitHub app and your yepcord instance (frontend) is running on 127.0.0.1:8888,
# for example, if you need to create GitHub app and your yepcord instance (front-end) is running on 127.0.0.1:8888,
# redirect url will be https://127.0.0.1:8888/connections/github/callback
CONNECTIONS = {
"github": {
Expand Down
414 changes: 269 additions & 145 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ yc-protobuf3-to-dict = "^0.3.0"
s3lite = { version = "^0.1.8", optional = true }
fast-depends = "^2.4.12"
faststream = {extras = ["kafka", "nats", "rabbit", "redis"], version = "^0.5.28"}
zstandard = "^0.23.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand All @@ -84,6 +85,7 @@ ftp = ["aioftp"]
[tool.poetry.group.profiling.dependencies]
viztracer = "^0.16.3"
pyinstrument = "^5.0.0"
tuna = "^0.5.11"

[build-system]
requires = ["poetry-core"]
Expand Down
60 changes: 60 additions & 0 deletions yepcord/gateway/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
YEPCord: Free open source selfhostable fully discord-compatible chat
Copyright (C) 2022-2024 RuslanUC

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
import zlib

import zstandard


class WsCompressor(ABC):
CLSs = {}

@abstractmethod
def __call__(self, data: bytes) -> bytes: ...

@classmethod
def create_compressor(cls, name: str) -> WsCompressor | None:
if name in cls.CLSs:
return cls.CLSs[name]()


class ZlibCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = zlib.compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data) + self._obj.flush(zlib.Z_FULL_FLUSH)


class ZstdCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = zstandard.ZstdCompressor().compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data)# + self._obj.flush(zlib.Z_FULL_FLUSH)


WsCompressor.CLSs["zlib-stream"] = ZlibCompressor
WsCompressor.CLSs["zstd-stream"] = ZstdCompressor
5 changes: 4 additions & 1 deletion yepcord/gateway/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ async def json(self) -> dict:
},
"user_settings": settings.ds_json() if not self.user.is_bot else {},
"user_settings_proto": b64encode(proto.SerializeToString()).decode("utf8") if not self.user.is_bot
else None
else None,
"notification_settings": { # What?
"flags": 0,
}
}
}
if self.user.is_bot:
Expand Down
27 changes: 17 additions & 10 deletions yepcord/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from redis.asyncio import Redis
from tortoise.expressions import Q

from .compression import WsCompressor
from .events import *
from .presences import Presences, Presence
from .utils import require_auth, get_token_type, TokenType, init_redis_pool
Expand All @@ -36,14 +37,18 @@


class GatewayClient:
def __init__(self, ws, gateway: Gateway):
__slots__ = (
"ws", "gateway", "seq", "sid", "id", "user_id", "is_bot", "_connected", "_compressor", "cached_presence",
)

def __init__(self, ws: Websocket, gateway: Gateway):
self.ws = ws
self.gateway = gateway
self.seq = 0
self.sid = hex(Snowflake.makeId())[2:]
self._connected = True

self.z = getattr(ws, "zlib", None)
self._compressor: WsCompressor = getattr(ws, "compressor", None)
self.id = self.user_id = None
self.is_bot = False
self.cached_presence: Optional[Presence] = None
Expand All @@ -59,17 +64,18 @@ def disconnect(self) -> None:
async def send(self, data: dict):
self.seq += 1
data["s"] = self.seq
if self.z:
if self._compressor:
return await self.ws.send(self.compress(data))
await self.ws.send_json(data)
if self.ws is not None:
await self.ws.send_json(data)

async def esend(self, event):
if not self.connected:
return
await self.send(await event.json())

def compress(self, json: dict):
return self.z(jdumps(json).encode("utf8"))
return self._compressor(jdumps(json).encode("utf8"))

async def handle_IDENTIFY(self, data: dict) -> None:
if self.user_id is not None:
Expand Down Expand Up @@ -108,7 +114,7 @@ async def handle_RESUME(self, data: dict, new_client: GatewayClient) -> None:
if (session := await S.from_token(token)) is None or self.user_id.id != session.user.id:
return await self.ws.close(4004)

self.z = new_client.z
self._compressor = new_client._compressor
self.ws = new_client.ws
setattr(self.ws, "_yepcord_client", self)

Expand Down Expand Up @@ -359,7 +365,7 @@ async def send(self, client: GatewayClient, op: int, **data) -> None:
async def sendws(self, ws, op: int, **data) -> None:
r = {"op": op}
r.update(data)
if getattr(ws, "zlib", None):
if getattr(ws, "compressor", None):
return await ws.send(ws.zlib(jdumps(r).encode("utf8")))
await ws.send_json(r)

Expand Down Expand Up @@ -401,9 +407,10 @@ async def process(self, ws: Websocket, data: dict):
kwargs["new_client"] = client
client = _client[0]

func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)
if op in GatewayOp.reversed():
func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)

print("-" * 16)
print(f" Unknown op code: {op}")
Expand Down
11 changes: 5 additions & 6 deletions yepcord/gateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from asyncio import CancelledError, shield, create_task

from quart import Quart, websocket, Websocket
from tortoise.contrib.quart import register_tortoise

from .compression import WsCompressor
from ..yepcord.config import Config
from .utils import ZlibCompressor
from json import loads as jloads
from asyncio import CancelledError
from .gateway import Gateway


Expand Down Expand Up @@ -58,12 +57,12 @@ async def set_cors_headers(response):
async def ws_gateway():
# noinspection PyProtectedMember,PyUnresolvedReferences
ws: Websocket = websocket._get_current_object()
setattr(ws, "zlib", ZlibCompressor() if websocket.args.get("compress") == "zlib-stream" else None)
setattr(ws, "compressor", WsCompressor.create_compressor(websocket.args.get("compress")))
await gw.add_client(ws)
while True:
try:
data = await ws.receive()
await gw.process(ws, jloads(data))
data = await ws.receive_json()
await shield(create_task(gw.process(ws, data)))
except CancelledError:
await gw.disconnect(ws)
raise
Expand Down
1 change: 1 addition & 0 deletions yepcord/rest_api/routes/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def api_auth_locationmetadata():


@other.post("/api/v9/science")
@other.post("/api/v9/metrics/v2")
async def api_science():
return "", 204

Expand Down
2 changes: 1 addition & 1 deletion yepcord/rest_api/routes/users_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def update_protobuf_settings(data: SettingsProtoUpdate, user: User = DepUs
@users_me.get("/settings-proto/2")
async def get_protobuf_frecency_settings(user: User = DepUser):
proto = await FrecencySettings.get_or_none(id=user.id)
return {"settings": proto if proto is not None else ""}
return {"settings": proto.settings if proto is not None else ""}


@users_me.patch("/settings-proto/2", body_cls=SettingsProtoUpdate)
Expand Down
4 changes: 2 additions & 2 deletions yepcord/yepcord/models/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class UserSettings(Model):
mfa: str = fields.CharField(max_length=64, null=True, default=None)
render_spoilers: str = fields.CharField(max_length=16, default="ON_CLICK",
validators=[ChoicesValidator({"ALWAYS", "ON_CLICK", "IF_MODERATOR"})])
dismissed_contents: str = fields.CharField(max_length=64, default="510109000002000080")
dismissed_contents: str = fields.CharField(max_length=128, default="510109000002000080")
status: str = fields.CharField(max_length=32, default="online",
validators=[ChoicesValidator({"online", "idle", "dnd", "offline", "invisible"})])
custom_status: Optional[dict] = fields.JSONField(null=True, default=None)
Expand Down Expand Up @@ -286,7 +286,7 @@ async def update(self, new_proto: PreloadedUserSettings) -> None:
else:
changes["friend_source_flags"] = {"all": False, "mutual_friends": False, "mutual_guilds": False}
if (dismissed_contents := dict_get(proto_d, "user_content.dismissed_contents")) is not None:
changes["dismissed_contents"] = dismissed_contents[:128].hex()
changes["dismissed_contents"] = dismissed_contents[:64].hex()
if guild_folders := dict_get(proto_d, "guild_folders.folders"):
folders = []
for folder in guild_folders:
Expand Down
Loading