From f5e660be6646e559ee8f6fefc967933a20235bf4 Mon Sep 17 00:00:00 2001 From: Raphael Jolivet Date: Tue, 19 Apr 2022 11:19:29 +0200 Subject: [PATCH] Adding support for 'deflate' stream compression. --- flask_compress/flask_compress.py | 38 +++++++++++++++++++++----------- tests/test_flask_compress.py | 20 +++++++++++++---- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/flask_compress/flask_compress.py b/flask_compress/flask_compress.py index 8f01287..91c23da 100644 --- a/flask_compress/flask_compress.py +++ b/flask_compress/flask_compress.py @@ -97,7 +97,7 @@ def init_app(self, app): app.config['COMPRESS_MIMETYPES']): app.after_request(self.after_request) - def _choose_compress_algorithm(self, accept_encoding_header): + def _choose_compress_algorithm(self, accept_encoding_header, is_streaming=False): """ Determine which compression algorithm we're going to use based on the client request. The `Accept-Encoding` header may list one or more desired @@ -116,7 +116,7 @@ def _choose_compress_algorithm(self, accept_encoding_header): algos_by_quality = defaultdict(set) # Set of supported algorithms - server_algos_set = set(self.enabled_algorithms) + server_algos_set = set(["deflate"]) if is_streaming else set(self.enabled_algorithms) for part in accept_encoding_header.lower().split(','): part = part.strip() @@ -172,30 +172,33 @@ def after_request(self, response): response.headers['Vary'] = '{}, Accept-Encoding'.format(vary) accept_encoding = request.headers.get('Accept-Encoding', '') - chosen_algorithm = self._choose_compress_algorithm(accept_encoding) + chosen_algorithm = self._choose_compress_algorithm(accept_encoding, response.is_streamed) if (chosen_algorithm is None or response.mimetype not in app.config["COMPRESS_MIMETYPES"] or response.status_code < 200 or response.status_code >= 300 or - response.is_streamed or "Content-Encoding" in response.headers or - (response.content_length is not None and + (not response.is_streamed and response.content_length is not None and response.content_length < app.config["COMPRESS_MIN_SIZE"])): return response response.direct_passthrough = False - if self.cache is not None: - key = self.cache_key(request) - compressed_content = self.cache.get(key) - if compressed_content is None: + if response.is_streamed and chosen_algorithm == "deflate" : + # Handling stream + response.response = self.stream_compress(response.response) + else : + if self.cache is not None: + key = self.cache_key(request) + compressed_content = self.cache.get(key) + if compressed_content is None: + compressed_content = self.compress(app, response, chosen_algorithm) + self.cache.set(key, compressed_content) + else: compressed_content = self.compress(app, response, chosen_algorithm) - self.cache.set(key, compressed_content) - else: - compressed_content = self.compress(app, response, chosen_algorithm) - response.set_data(compressed_content) + response.set_data(compressed_content) response.headers['Content-Encoding'] = chosen_algorithm response.headers['Content-Length'] = response.content_length @@ -236,3 +239,12 @@ def compress(self, app, response, algorithm): quality=app.config['COMPRESS_BR_LEVEL'], lgwin=app.config['COMPRESS_BR_WINDOW'], lgblock=app.config['COMPRESS_BR_BLOCK']) + + + def stream_compress(self, chunks) : + compressor = zlib.compressobj() + for data in chunks: + out = compressor.compress(data.encode()) + if out: + yield out + yield compressor.flush() \ No newline at end of file diff --git a/tests/test_flask_compress.py b/tests/test_flask_compress.py index 84a8d75..c765b06 100644 --- a/tests/test_flask_compress.py +++ b/tests/test_flask_compress.py @@ -4,7 +4,7 @@ from flask import Flask, render_template from flask_compress import Compress - +import zlib class DefaultsTest(unittest.TestCase): def setUp(self): @@ -358,6 +358,9 @@ def setUp(self): 'large.html') self.file_size = os.path.getsize(self.file_path) + with open(self.file_path) as f: + self.file_content = f.read() + Compress(self.app) @self.app.route('/stream/large') @@ -368,15 +371,24 @@ def _stream(): yield line return self.app.response_class(_stream(), mimetype='text/html') - def test_no_compression_stream(self): - """ Tests compression is skipped when response is streamed""" + def decompress(self, data): + return zlib.decompress(data).decode() + + def test_compression_stream(self): + """ Tests content is compressed if algorithm is 'deflate' """ client = self.app.test_client() for algorithm in ('gzip', 'deflate', 'br', ''): headers = [('Accept-Encoding', algorithm)] response = client.get('/stream/large', headers=headers) self.assertEqual(response.status_code, 200) self.assertEqual(response.is_streamed, True) - self.assertEqual(self.file_size, len(response.data)) + + if algorithm == 'deflate' : + self.assertEqual(response.headers["Content-Encoding"], 'deflate') + self.assertLess(len(response.data), self.file_size) + self.assertEqual(self.decompress(response.data), self.file_content) + else: + self.assertEqual(len(response.data), self.file_size) if __name__ == '__main__':