Skip to content

Commit

Permalink
add zstd-stream compression to gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanUC committed Nov 5, 2024
1 parent 29ebc3b commit 91d2999
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 12 deletions.
116 changes: 114 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ yc-protobuf3-to-dict = "^0.3.0"
s3lite = "^0.1.4"
fast-depends = ">=2.4.2"
faststream = {extras = ["kafka", "nats", "rabbit", "redis"], version = "^0.5.4"}
zstandard = "^0.23.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.0"
Expand Down
20 changes: 16 additions & 4 deletions yepcord/gateway/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from zlib import compressobj, Z_FULL_FLUSH
import zlib

import zstandard


class WsCompressor(ABC):
Expand All @@ -38,11 +40,21 @@ class ZlibCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = compressobj()
self._obj = zlib.compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data) + self._obj.flush(zlib.Z_FULL_FLUSH)


class ZstdCompressor(WsCompressor):
__slots__ = ("_obj",)

def __init__(self):
self._obj = zstandard.ZstdCompressor().compressobj()

def __call__(self, data: bytes) -> bytes:
return self._obj.compress(data) + self._obj.flush(Z_FULL_FLUSH)
return self._obj.compress(data)# + self._obj.flush(zlib.Z_FULL_FLUSH)


WsCompressor.CLSs["zlib-stream"] = ZlibCompressor
# TODO: add zstd-stream
WsCompressor.CLSs["zstd-stream"] = ZstdCompressor
5 changes: 4 additions & 1 deletion yepcord/gateway/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ async def json(self) -> dict:
},
"user_settings": settings.ds_json() if not self.user.is_bot else {},
"user_settings_proto": b64encode(proto.SerializeToString()).decode("utf8") if not self.user.is_bot
else None
else None,
"notification_settings": { # What?
"flags": 0,
}
}
}
if self.user.is_bot:
Expand Down
7 changes: 4 additions & 3 deletions yepcord/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ async def process(self, ws: Websocket, data: dict):
kwargs["new_client"] = client
client = _client[0]

func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)
if op in GatewayOp.reversed():
func = getattr(client, f"handle_{GatewayOp.reversed()[op]}", None)
if func:
return await func(data.get("d"), **kwargs)

print("-" * 16)
print(f" Unknown op code: {op}")
Expand Down
1 change: 1 addition & 0 deletions yepcord/rest_api/routes/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def api_auth_locationmetadata():


@other.post("/api/v9/science")
@other.post("/api/v9/metrics/v2")
async def api_science():
return "", 204

Expand Down
4 changes: 2 additions & 2 deletions yepcord/yepcord/models/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class UserSettings(Model):
mfa: str = fields.CharField(max_length=64, null=True, default=None)
render_spoilers: str = fields.CharField(max_length=16, default="ON_CLICK",
validators=[ChoicesValidator({"ALWAYS", "ON_CLICK", "IF_MODERATOR"})])
dismissed_contents: str = fields.CharField(max_length=64, default="510109000002000080")
dismissed_contents: str = fields.CharField(max_length=128, default="510109000002000080")
status: str = fields.CharField(max_length=32, default="online",
validators=[ChoicesValidator({"online", "idle", "dnd", "offline", "invisible"})])
custom_status: Optional[dict] = fields.JSONField(null=True, default=None)
Expand Down Expand Up @@ -286,7 +286,7 @@ async def update(self, new_proto: PreloadedUserSettings) -> None:
else:
changes["friend_source_flags"] = {"all": False, "mutual_friends": False, "mutual_guilds": False}
if (dismissed_contents := dict_get(proto_d, "user_content.dismissed_contents")) is not None:
changes["dismissed_contents"] = dismissed_contents[:128].hex()
changes["dismissed_contents"] = dismissed_contents[:64].hex()
if guild_folders := dict_get(proto_d, "guild_folders.folders"):
folders = []
for folder in guild_folders:
Expand Down

0 comments on commit 91d2999

Please sign in to comment.