Skip to content

Commit

Permalink
Merge pull request #238 from yepcord/add-support-for-zstd-ws-compressor
Browse files Browse the repository at this point in the history
Add support for zstd ws compressor
  • Loading branch information
RuslanUC authored Nov 7, 2024
2 parents 7dfe87e + f4e7def commit a7b39f7
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 168 deletions.
6 changes: 3 additions & 3 deletions config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@
},
}

# Redis used for users presences and statuses. If empty, presences will be stored in memory.
# Redis used for users' presences and statuses. If empty, presences will be stored in memory.
REDIS_URL = ""

# How often gateway clients must send keep-alive packets (also, presences expiration time is this variable times 1.25).
# How often gateway clients must send keep-alive packets (also, presence expiration time is this variable times 1.25).
# Default value is 45 seconds, do not set it too big or too small.
GATEWAY_KEEP_ALIVE_DELAY = 45

Expand All @@ -89,7 +89,7 @@

# Settings for external application connections
# For every application, use https://PUBLIC_HOST/connections/SERVICE_NAME/callback as redirect (callback) url,
# for example, if you need to create GitHub app and your yepcord instance (frontend) is running on 127.0.0.1:8888,
# for example, if you need to create GitHub app and your yepcord instance (front-end) is running on 127.0.0.1:8888,
# redirect url will be https://127.0.0.1:8888/connections/github/callback
CONNECTIONS = {
"github": {
Expand Down
414 changes: 269 additions & 145 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ yc-protobuf3-to-dict = "^0.3.0"
s3lite = { version = "^0.1.8", optional = true }
fast-depends = "^2.4.12"
faststream = {extras = ["kafka", "nats", "rabbit", "redis"], version = "^0.5.28"}
zstandard = "^0.23.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand All @@ -84,6 +85,7 @@ ftp = ["aioftp"]
[tool.poetry.group.profiling.dependencies]
viztracer = "^0.16.3"
pyinstrument = "^5.0.0"
tuna = "^0.5.11"

[build-system]
requires = ["poetry-core"]
Expand Down
60 changes: 60 additions & 0 deletions yepcord/gateway/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
YEPCord: Free open source selfhostable fully discord-compatible chat
Copyright (C) 2022-2024 RuslanUC
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
import zlib

import zstandard


class WsCompressor(ABC):
CLSs = {}

@abstractmethod
def __call__(self, data: bytes) -> bytes: ...

@classmethod
def create_compressor(cls, name: str) -> WsCompressor | None:
if name in cls.CLSs:
return cls.CLSs[name]()


class ZlibCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = zlib.compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data) + self._obj.flush(zlib.Z_FULL_FLUSH)


class ZstdCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = zstandard.ZstdCompressor().compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data)# + self._obj.flush(zlib.Z_FULL_FLUSH)


WsCompressor.CLSs["zlib-stream"] = ZlibCompressor
WsCompressor.CLSs["zstd-stream"] = ZstdCompressor
5 changes: 4 additions & 1 deletion yepcord/gateway/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ async def json(self) -> dict:
},
"user_settings": settings.ds_json() if not self.user.is_bot else {},
"user_settings_proto": b64encode(proto.SerializeToString()).decode("utf8") if not self.user.is_bot
else None
else None,
"notification_settings": { # What?
"flags": 0,
}
}
}
if self.user.is_bot:
Expand Down
27 changes: 17 additions & 10 deletions yepcord/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from redis.asyncio import Redis
from tortoise.expressions import Q

from .compression import WsCompressor
from .events import *
from .presences import Presences, Presence
from .utils import require_auth, get_token_type, TokenType, init_redis_pool
Expand All @@ -36,14 +37,18 @@


class GatewayClient:
def __init__(self, ws, gateway: Gateway):
__slots__ = (
"ws", "gateway", "seq", "sid", "id", "user_id", "is_bot", "_connected", "_compressor", "cached_presence",
)

def __init__(self, ws: Websocket, gateway: Gateway):
self.ws = ws
self.gateway = gateway
self.seq = 0
self.sid = hex(Snowflake.makeId())[2:]
self._connected = True

self.z = getattr(ws, "zlib", None)
self._compressor: WsCompressor = getattr(ws, "compressor", None)
self.id = self.user_id = None
self.is_bot = False
self.cached_presence: Optional[Presence] = None
Expand All @@ -59,17 +64,18 @@ def disconnect(self) -> None:
async def send(self, data: dict):
self.seq += 1
data["s"] = self.seq
if self.z:
if self._compressor:
return await self.ws.send(self.compress(data))
await self.ws.send_json(data)
if self.ws is not None:
await self.ws.send_json(data)

async def esend(self, event):
if not self.connected:
return
await self.send(await event.json())

def compress(self, json: dict):
return self.z(jdumps(json).encode("utf8"))
return self._compressor(jdumps(json).encode("utf8"))

async def handle_IDENTIFY(self, data: dict) -> None:
if self.user_id is not None:
Expand Down Expand Up @@ -108,7 +114,7 @@ async def handle_RESUME(self, data: dict, new_client: GatewayClient) -> None:
if (session := await S.from_token(token)) is None or self.user_id.id != session.user.id:
return await self.ws.close(4004)

self.z = new_client.z
self._compressor = new_client._compressor
self.ws = new_client.ws
setattr(self.ws, "_yepcord_client", self)

Expand Down Expand Up @@ -359,7 +365,7 @@ async def send(self, client: GatewayClient, op: int, **data) -> None:
async def sendws(self, ws, op: int, **data) -> None:
r = {"op": op}
r.update(data)
if getattr(ws, "zlib", None):
if getattr(ws, "compressor", None):
return await ws.send(ws.zlib(jdumps(r).encode("utf8")))
await ws.send_json(r)

Expand Down Expand Up @@ -401,9 +407,10 @@ async def process(self, ws: Websocket, data: dict):
kwargs["new_client"] = client
client = _client[0]

func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)
if op in GatewayOp.reversed():
func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)

print("-" * 16)
print(f" Unknown op code: {op}")
Expand Down
11 changes: 5 additions & 6 deletions yepcord/gateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from asyncio import CancelledError, shield, create_task

from quart import Quart, websocket, Websocket
from tortoise.contrib.quart import register_tortoise

from .compression import WsCompressor
from ..yepcord.config import Config
from .utils import ZlibCompressor
from json import loads as jloads
from asyncio import CancelledError
from .gateway import Gateway


Expand Down Expand Up @@ -58,12 +57,12 @@ async def set_cors_headers(response):
async def ws_gateway():
# noinspection PyProtectedMember,PyUnresolvedReferences
ws: Websocket = websocket._get_current_object()
setattr(ws, "zlib", ZlibCompressor() if websocket.args.get("compress") == "zlib-stream" else None)
setattr(ws, "compressor", WsCompressor.create_compressor(websocket.args.get("compress")))
await gw.add_client(ws)
while True:
try:
data = await ws.receive()
await gw.process(ws, jloads(data))
data = await ws.receive_json()
await shield(create_task(gw.process(ws, data)))
except CancelledError:
await gw.disconnect(ws)
raise
Expand Down
1 change: 1 addition & 0 deletions yepcord/rest_api/routes/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def api_auth_locationmetadata():


@other.post("/api/v9/science")
@other.post("/api/v9/metrics/v2")
async def api_science():
return "", 204

Expand Down
2 changes: 1 addition & 1 deletion yepcord/rest_api/routes/users_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def update_protobuf_settings(data: SettingsProtoUpdate, user: User = DepUs
@users_me.get("/settings-proto/2")
async def get_protobuf_frecency_settings(user: User = DepUser):
proto = await FrecencySettings.get_or_none(id=user.id)
return {"settings": proto if proto is not None else ""}
return {"settings": proto.settings if proto is not None else ""}


@users_me.patch("/settings-proto/2", body_cls=SettingsProtoUpdate)
Expand Down
4 changes: 2 additions & 2 deletions yepcord/yepcord/models/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class UserSettings(Model):
mfa: str = fields.CharField(max_length=64, null=True, default=None)
render_spoilers: str = fields.CharField(max_length=16, default="ON_CLICK",
validators=[ChoicesValidator({"ALWAYS", "ON_CLICK", "IF_MODERATOR"})])
dismissed_contents: str = fields.CharField(max_length=64, default="510109000002000080")
dismissed_contents: str = fields.CharField(max_length=128, default="510109000002000080")
status: str = fields.CharField(max_length=32, default="online",
validators=[ChoicesValidator({"online", "idle", "dnd", "offline", "invisible"})])
custom_status: Optional[dict] = fields.JSONField(null=True, default=None)
Expand Down Expand Up @@ -286,7 +286,7 @@ async def update(self, new_proto: PreloadedUserSettings) -> None:
else:
changes["friend_source_flags"] = {"all": False, "mutual_friends": False, "mutual_guilds": False}
if (dismissed_contents := dict_get(proto_d, "user_content.dismissed_contents")) is not None:
changes["dismissed_contents"] = dismissed_contents[:128].hex()
changes["dismissed_contents"] = dismissed_contents[:64].hex()
if guild_folders := dict_get(proto_d, "guild_folders.folders"):
folders = []
for folder in guild_folders:
Expand Down

0 comments on commit a7b39f7

Please sign in to comment.