From 86693d8a66d5df9cf904e21e06b05f7630291733 Mon Sep 17 00:00:00 2001
From: lqs <lqs@lqs.me>
Date: Wed, 1 Mar 2023 12:17:06 +0900
Subject: [PATCH] refactor brotliwrapper

---
 brotliwrapper.go | 83 +++++++++++++++++++++++++++++++++++++-----------
 handler.go       |  5 +--
 2 files changed, 68 insertions(+), 20 deletions(-)

diff --git a/brotliwrapper.go b/brotliwrapper.go
index edfb54c..509ee9b 100644
--- a/brotliwrapper.go
+++ b/brotliwrapper.go
@@ -1,38 +1,85 @@
 package grpcmix
 
 import (
-	"errors"
-	"github.com/andybalholm/brotli"
+	"compress/gzip"
 	"io"
 	"net/http"
+	"strings"
+
+	"github.com/andybalholm/brotli"
+	"golang.org/x/net/http/httpguts"
 )
 
 type brotliWrapper struct {
 	http.ResponseWriter
-	request      *http.Request
-	brotliWriter io.WriteCloser
-	isClosed     bool
+	request *http.Request
+	writer  io.Writer
 }
 
-func (w *brotliWrapper) Write(data []byte) (int, error) {
-	if w.isClosed {
-		return 0, errors.New("brotliWrapper is closed")
+const minimumSizeToCompress = 512
+
+type CompressionType int
+
+const (
+	compressionTypeNone CompressionType = iota
+	compressionTypeBrotli
+	compressionTypeGzip
+)
+
+func (w *brotliWrapper) checkCompressionType(data []byte) CompressionType {
+	if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/grpc-web") {
+		return compressionTypeNone
 	}
-	if w.brotliWriter == nil {
-		// create brotli writer on first write, because at this point the response headers are set but not yet sent
-		w.brotliWriter = brotli.HTTPCompressor(w.ResponseWriter, w.request)
+	if len(data) < 5 {
+		// not a grpc-web header
+		return compressionTypeNone
+	}
+	size := uint32(data[1])<<24 | uint32(data[2])<<16 | uint32(data[3])<<8 | uint32(data[4])
+	if size < minimumSizeToCompress {
+		return compressionTypeNone
+	}
+	acceptEncoding := w.request.Header.Values("Accept-Encoding")
+	switch {
+	case httpguts.HeaderValuesContainsToken(acceptEncoding, "br"):
+		return compressionTypeBrotli
+	case httpguts.HeaderValuesContainsToken(acceptEncoding, "gzip"):
+		return compressionTypeGzip
+	default:
+		return compressionTypeNone
 	}
-	return w.brotliWriter.Write(data)
 }
 
-func (w *brotliWrapper) Close() {
-	if w.isClosed {
-		return
+func (w *brotliWrapper) Write(data []byte) (int, error) {
+	if w.writer == nil {
+		compressionType := w.checkCompressionType(data)
+		switch compressionType {
+		case compressionTypeBrotli:
+			w.writer = brotli.NewWriterOptions(w.ResponseWriter, brotli.WriterOptions{
+				Quality: brotli.DefaultCompression,
+				LGWin:   16,
+			})
+			w.Header().Set("Content-Encoding", "br")
+		case compressionTypeGzip:
+			var err error
+			w.writer, err = gzip.NewWriterLevel(w.ResponseWriter, gzip.DefaultCompression)
+			if err != nil {
+				return 0, err
+			}
+			w.Header().Set("Content-Encoding", "gzip")
+		default:
+			w.writer = w.ResponseWriter
+		}
+		if !httpguts.HeaderValuesContainsToken(w.Header().Values("Vary"), "Accept-Encoding") {
+			w.Header().Add("Vary", "Accept-Encoding")
+		}
 	}
-	if w.brotliWriter != nil {
-		_ = w.brotliWriter.Close()
+	return w.writer.Write(data)
+}
+
+func (w *brotliWrapper) Close() {
+	if closer, ok := w.writer.(io.Closer); ok {
+		_ = closer.Close()
 	}
-	w.isClosed = true
 }
 
 func wrapBrotli(writer http.ResponseWriter, request *http.Request) *brotliWrapper {
diff --git a/handler.go b/handler.go
index 5df0ab1..4797652 100644
--- a/handler.go
+++ b/handler.go
@@ -1,12 +1,13 @@
 package grpcmix
 
 import (
+	"net/http"
+	"strings"
+
 	"github.com/improbable-eng/grpc-web/go/grpcweb"
 	"golang.org/x/net/http2"
 	"golang.org/x/net/http2/h2c"
 	"google.golang.org/grpc"
-	"net/http"
-	"strings"
 )
 
 type mixHandler struct {