From 9d9de6fa9c4f90316ea0bdf6cc80848c79537b2c Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Tue, 10 Dec 2024 19:13:24 +0800 Subject: [PATCH] openai add br encoding --- drivers/ai-provider/openAI/encoding.go | 49 ++++++++++++++++++++++++++ drivers/ai-provider/openAI/mode.go | 8 +++++ go.mod | 2 +- 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 drivers/ai-provider/openAI/encoding.go diff --git a/drivers/ai-provider/openAI/encoding.go b/drivers/ai-provider/openAI/encoding.go new file mode 100644 index 00000000..53761ef4 --- /dev/null +++ b/drivers/ai-provider/openAI/encoding.go @@ -0,0 +1,49 @@ +package openAI + +import ( + "bytes" + "fmt" + "io" + + "github.com/eolinker/eosc" + + "github.com/andybalholm/brotli" +) + +type IEncoder interface { + ToUTF8([]byte) ([]byte, error) +} + +type EncoderManger struct { + encoders eosc.Untyped[string, IEncoder] +} + +func NewEncoderManger() *EncoderManger { + return &EncoderManger{encoders: eosc.BuildUntyped[string, IEncoder]()} +} + +func (e *EncoderManger) Set(name string, encoder IEncoder) { + e.encoders.Set(name, encoder) +} + +func (e *EncoderManger) ToUTF8(name string, data []byte) ([]byte, error) { + encoder, ok := e.encoders.Get(name) + if !ok { + return nil, fmt.Errorf("encoder %s not found", name) + } + return encoder.ToUTF8(data) +} + +var encoderManger = NewEncoderManger() + +func init() { + encoderManger.Set("br", &Br{}) +} + +type Br struct { +} + +func (b *Br) ToUTF8(data []byte) ([]byte, error) { + reader := brotli.NewReader(bytes.NewReader(data)) + return io.ReadAll(reader) +} diff --git a/drivers/ai-provider/openAI/mode.go b/drivers/ai-provider/openAI/mode.go index 6df0fe4c..4f6ea1f1 100644 --- a/drivers/ai-provider/openAI/mode.go +++ b/drivers/ai-provider/openAI/mode.go @@ -81,6 +81,13 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error { return nil } body := httpContext.Response().GetBody() + encoding := httpContext.Response().Headers().Get("content-encoding") + if encoding != "utf-8" && encoding != "" { + body, err = encoderManger.ToUTF8(encoding, body) + if err != nil { + return err + } + } data := eosc.NewBase[Response]() err = json.Unmarshal(body, data) if err != nil { @@ -102,6 +109,7 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error { if err != nil { return err } + httpContext.Response().SetHeader("content-encoding", "utf-8") httpContext.Response().SetBody(body) return nil } diff --git a/go.mod b/go.mod index a1d66f52..92ec89b5 100644 --- a/go.mod +++ b/go.mod @@ -112,7 +112,7 @@ require ( require ( dubbo.apache.org/dubbo-go/v3 v3.0.2-0.20220519062747-f6405fa79d5c - github.com/andybalholm/brotli v1.0.5 // indirect + github.com/andybalholm/brotli v1.0.5 github.com/apache/dubbo-go-hessian2 v1.11.6 github.com/armon/go-metrics v0.3.9 // indirect github.com/beorn7/perks v1.0.1 // indirect