Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support gzip & zstd compression #599

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mosec"
version = "0.8.9"
version = "0.9.0"
authors = ["Keming <kemingy94@gmail.com>", "Zichen <lkevinzc@gmail.com>"]
edition = "2021"
license = "Apache-2.0"
Expand All @@ -25,3 +25,5 @@ serde = "1.0"
serde_json = "1.0"
utoipa = "5"
utoipa-swagger-ui = { version = "8", features = ["axum"] }
tower = "0.5.1"
tower-http = {version = "0.6.1", features = ["compression-zstd", "decompression-zstd", "compression-gzip", "decompression-gzip"]}
7 changes: 7 additions & 0 deletions mosec/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def build_arguments_parser() -> argparse.ArgumentParser:
"This will omit the worker number for each stage.",
action="store_true",
)

parser.add_argument(
"--compression",
help="Enable Zstd & Gzip compression for the request body",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zstd & Gzip are enabled together? They are two algorithms as I understand?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The compression layer will choose the algorithm according to the request headers.

I enabled both since gzip is widely used (included in Python std) and zstd is currently the best. Users can choose the one that better suits their use cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Can we add an example (similar to the test script) and a pointer in README to make this feature more visible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure what the best use case for compression is in model serving. Numpy vectors to bytes can benefit a little. Images that are using compression (JPEG) should not be applied again. @aseaday do you have suggestions?

action="store_true",
)

return parser


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Rust",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ ruff>=0.7
pre-commit>=2.15.0
httpx[http2]==0.27.2
httpx-sse==0.4.0
zstandard~=0.23
3 changes: 3 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub(crate) struct Config {
pub namespace: String,
// log level: (debug, info, warning, error)
pub log_level: String,
// Zstd & Gzip compression
pub compression: bool,
pub runtimes: Vec<Runtime>,
pub routes: Vec<Route>,
}
Expand All @@ -79,6 +81,7 @@ impl Default for Config {
port: 8000,
namespace: String::from("mosec_service"),
log_level: String::from("info"),
compression: false,
runtimes: vec![Runtime {
max_batch_size: 64,
max_wait_time: 3000,
Expand Down
15 changes: 14 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![forbid(unsafe_code)]

mod apidoc;
mod config;
mod errors;
Expand All @@ -27,6 +29,9 @@ use std::net::SocketAddr;
use axum::routing::{get, post};
use axum::Router;
use tokio::signal::unix::{signal, SignalKind};
use tower::ServiceBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use tracing::{debug, info};
use tracing_subscriber::fmt::time::UtcTime;
use tracing_subscriber::prelude::*;
Expand Down Expand Up @@ -90,12 +95,20 @@ async fn run(conf: &Config) {
}
}

if conf.compression {
router = router.layer(
ServiceBuilder::new()
.layer(RequestDecompressionLayer::new())
.layer(CompressionLayer::new()),
);
}

// wait until each stage has at least one worker alive
barrier.wait().await;
let addr: SocketAddr = format!("{}:{}", conf.address, conf.port).parse().unwrap();
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
info!(?addr, "http service is running");
axum::serve(listener, router.into_make_service())
axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn build_response(status: StatusCode, content: Bytes) -> Response<Body> {
),
),
)]
pub(crate) async fn index(_: Request<Body>) -> Response<Body> {
pub(crate) async fn index() -> Response<Body> {
let task_manager = TaskManager::global();
if task_manager.is_shutdown() {
build_response(
Expand All @@ -79,7 +79,7 @@ pub(crate) async fn index(_: Request<Body>) -> Response<Body> {
(status = StatusCode::OK, description = "Get metrics", body = String),
),
)]
pub(crate) async fn metrics(_: Request<Body>) -> Response<Body> {
pub(crate) async fn metrics() -> Response<Body> {
let mut encoded = String::new();
let registry = REGISTRY.get().unwrap();
encode(&mut encoded, registry).unwrap();
Expand Down
41 changes: 41 additions & 0 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""End-to-end service tests."""

import gzip
import json
import random
import re
import shlex
Expand All @@ -26,6 +28,7 @@
import msgpack # type: ignore
import pytest
from httpx_sse import connect_sse
from zstandard import ZstdCompressor

from mosec.server import GUARD_CHECK_INTERVAL
from tests.utils import wait_for_port_free, wait_for_port_open
Expand Down Expand Up @@ -349,3 +352,41 @@ def test_multi_route_service(mosec_service, http_client):
assert resp.status_code == HTTPStatus.OK, resp
assert resp.headers["content-type"] == "application/msgpack"
assert msgpack.unpackb(resp.content) == {"length": len(data)}


@pytest.mark.parametrize(
"mosec_service, http_client",
[
pytest.param("square_service --compression --debug", "", id="compression"),
],
indirect=["mosec_service", "http_client"],
)
def test_compression_service(mosec_service, http_client):
zstd_compressor = ZstdCompressor()
req = {"x": 2}
expect = {"x": 4}

# test without compression
resp = http_client.post("/inference", json=req)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content

# test with gzip compression
binary = gzip.compress(json.dumps(req).encode())
resp = http_client.post(
"/inference",
content=binary,
headers={"Accept-Encoding": "gzip", "Content-Encoding": "gzip"},
)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content

# test with zstd compression
binary = zstd_compressor.compress(json.dumps(req).encode())
resp = http_client.post(
"/inference",
content=binary,
headers={"Accept-Encoding": "zstd", "Content-Encoding": "zstd"},
)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content