From 0e99bc39b1281f88503cee5e499f8e5c2026469d Mon Sep 17 00:00:00 2001 From: Marina Sakai <118230951+Marina-Sakai@users.noreply.github.com> Date: Sat, 16 Nov 2024 00:28:34 +0800 Subject: [PATCH] feat(thrift): unwrap struct for streaming type descriptor (#81) * feat: unwrap struct for streaming type descriptor * fix: only add IsWithoutWrapping method * fix: not need streaming mode --- go.sum | 2 - thrift/descriptor.go | 20 +++-- thrift/idl.go | 183 ++++++++++++++++++++++++++----------------- thrift/idl_test.go | 35 +++++++++ 4 files changed, 157 insertions(+), 83 deletions(-) diff --git a/go.sum b/go.sum index 9b27430d..8edb759d 100644 --- a/go.sum +++ b/go.sum @@ -186,8 +186,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/thrift/descriptor.go b/thrift/descriptor.go index 11cd4554..4cc4dba9 100644 --- a/thrift/descriptor.go +++ b/thrift/descriptor.go @@ -331,13 +331,14 @@ func (f FieldDescriptor) DefaultValue() *DefaultValue { // FunctionDescriptor idl function descriptor type FunctionDescriptor struct { - oneway bool - hasRequestBase bool - request *TypeDescriptor - response *TypeDescriptor - name string - endpoints []http.Endpoint - annotations []parser.Annotation + oneway bool + hasRequestBase bool + request *TypeDescriptor + response *TypeDescriptor + name string + endpoints []http.Endpoint + annotations []parser.Annotation + isWithoutWrapping bool } // Name returns the name of the function @@ -377,6 +378,11 @@ func (f FunctionDescriptor) Annotations() []parser.Annotation { return f.annotations } +// IsWithoutWrapping returns if the request and response are not wrapped in struct +func (f FunctionDescriptor) IsWithoutWrapping() bool { + return f.isWithoutWrapping +} + // ServiceDescriptor is the runtime descriptor of a service type ServiceDescriptor struct { name string diff --git a/thrift/idl.go b/thrift/idl.go index 4183ca80..62885b4f 100644 --- a/thrift/idl.go +++ b/thrift/idl.go @@ -27,6 +27,8 @@ import ( "time" "unsafe" + "github.com/cloudwego/thriftgo/generator/golang/streaming" + "github.com/cloudwego/dynamicgo/http" "github.com/cloudwego/dynamicgo/internal/json" "github.com/cloudwego/dynamicgo/internal/rt" @@ -371,104 +373,137 @@ func addFunction(ctx context.Context, fn *parser.Function, tree *parser.Thrift, } + st, err := streaming.ParseStreaming(fn) + if err != nil { + return err + } + isStreaming := st.ClientStreaming || st.ServerStreaming + var hasRequestBase bool var req *TypeDescriptor var resp *TypeDescriptor // parse request field if opts.ParseFunctionMode != meta.ParseResponseOnly { - // WARN: only support single argument - reqAst := fn.Arguments[0] - req = &TypeDescriptor{ - typ: STRUCT, - struc: &StructDescriptor{ - baseID: FieldID(math.MaxUint16), - ids: util.FieldIDMap{}, - names: util.FieldNameMap{}, - requires: make(RequiresBitmap, 1), - }, - } - - reqType, err := parseType(ctx, reqAst.Type, tree, structsCache, 0, opts, nextAnns, Request) + req, hasRequestBase, err = parseRequest(ctx, isStreaming, fn, tree, structsCache, nextAnns, opts) if err != nil { return err } - if reqType.Type() == STRUCT { - for _, f := range reqType.Struct().names.All() { - x := (*FieldDescriptor)(f.Val) - if x.isRequestBase { - hasRequestBase = true - break - } - } - } - reqField := &FieldDescriptor{ - name: reqAst.Name, - id: FieldID(reqAst.ID), - typ: reqType, - } - req.Struct().ids.Set(int32(reqAst.ID), unsafe.Pointer(reqField)) - req.Struct().names.Set(reqAst.Name, unsafe.Pointer(reqField)) - req.Struct().names.Build() } // parse response filed if opts.ParseFunctionMode != meta.ParseRequestOnly { - respAst := fn.FunctionType - resp = &TypeDescriptor{ - typ: STRUCT, - struc: &StructDescriptor{ - baseID: FieldID(math.MaxUint16), - ids: util.FieldIDMap{}, - names: util.FieldNameMap{}, - requires: make(RequiresBitmap, 1), - }, - } - respType, err := parseType(ctx, respAst, tree, structsCache, 0, opts, nextAnns, Response) + resp, err = parseResponse(ctx, isStreaming, fn, tree, structsCache, nextAnns, opts) if err != nil { return err } - respField := &FieldDescriptor{ - typ: respType, - } - resp.Struct().ids.Set(0, unsafe.Pointer(respField)) - // response has no name or id - resp.Struct().names.Set("", unsafe.Pointer(respField)) - - // parse exceptions - if len(fn.Throws) > 0 { - // only support single exception - exp := fn.Throws[0] - exceptionType, err := parseType(ctx, exp.Type, tree, structsCache, 0, opts, nextAnns, Exception) - if err != nil { - return err - } - exceptionField := &FieldDescriptor{ - name: exp.Name, - alias: exp.Name, - id: FieldID(exp.ID), - // isException: true, - typ: exceptionType, - } - resp.Struct().ids.Set(int32(exp.ID), unsafe.Pointer(exceptionField)) - resp.Struct().names.Set(exp.Name, unsafe.Pointer(exceptionField)) - } - resp.Struct().names.Build() } fnDsc := &FunctionDescriptor{ - name: fn.Name, - oneway: fn.Oneway, - request: req, - response: resp, - hasRequestBase: hasRequestBase, - endpoints: enpdoints, - annotations: annos, + name: fn.Name, + oneway: fn.Oneway, + request: req, + response: resp, + hasRequestBase: hasRequestBase, + endpoints: enpdoints, + annotations: annos, + isWithoutWrapping: isStreaming, } sDsc.functions[fn.Name] = fnDsc return nil } +func parseRequest(ctx context.Context, isStreaming bool, fn *parser.Function, tree *parser.Thrift, structsCache compilingCache, nextAnns []parser.Annotation, opts Options) (req *TypeDescriptor, hasRequestBase bool, err error) { + // WARN: only support single argument + reqAst := fn.Arguments[0] + reqType, err := parseType(ctx, reqAst.Type, tree, structsCache, 0, opts, nextAnns, Request) + if err != nil { + return nil, hasRequestBase, err + } + if reqType.Type() == STRUCT { + for _, f := range reqType.Struct().names.All() { + x := (*FieldDescriptor)(f.Val) + if x.isRequestBase { + hasRequestBase = true + break + } + } + } + + if isStreaming { + return reqType, hasRequestBase, nil + } + + // wrap with a struct + wrappedTyDsc := &TypeDescriptor{ + typ: STRUCT, + struc: &StructDescriptor{ + baseID: FieldID(math.MaxUint16), + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, + requires: make(RequiresBitmap, 1), + }, + } + reqField := &FieldDescriptor{ + name: reqAst.Name, + id: FieldID(reqAst.ID), + typ: reqType, + } + wrappedTyDsc.Struct().ids.Set(int32(reqAst.ID), unsafe.Pointer(reqField)) + wrappedTyDsc.Struct().names.Set(reqAst.Name, unsafe.Pointer(reqField)) + wrappedTyDsc.Struct().names.Build() + return wrappedTyDsc, hasRequestBase, nil +} + +func parseResponse(ctx context.Context, isStreaming bool, fn *parser.Function, tree *parser.Thrift, structsCache compilingCache, nextAnns []parser.Annotation, opts Options) (resp *TypeDescriptor, err error) { + respAst := fn.FunctionType + respType, err := parseType(ctx, respAst, tree, structsCache, 0, opts, nextAnns, Response) + if err != nil { + return nil, err + } + + if isStreaming { + return respType, nil + } + + wrappedResp := &TypeDescriptor{ + typ: STRUCT, + struc: &StructDescriptor{ + baseID: FieldID(math.MaxUint16), + ids: util.FieldIDMap{}, + names: util.FieldNameMap{}, + requires: make(RequiresBitmap, 1), + }, + } + respField := &FieldDescriptor{ + typ: respType, + } + wrappedResp.Struct().ids.Set(0, unsafe.Pointer(respField)) + // response has no name or id + wrappedResp.Struct().names.Set("", unsafe.Pointer(respField)) + + // parse exceptions + if len(fn.Throws) > 0 { + // only support single exception + exp := fn.Throws[0] + exceptionType, err := parseType(ctx, exp.Type, tree, structsCache, 0, opts, nextAnns, Exception) + if err != nil { + return nil, err + } + exceptionField := &FieldDescriptor{ + name: exp.Name, + alias: exp.Name, + id: FieldID(exp.ID), + // isException: true, + typ: exceptionType, + } + wrappedResp.Struct().ids.Set(int32(exp.ID), unsafe.Pointer(exceptionField)) + wrappedResp.Struct().names.Set(exp.Name, unsafe.Pointer(exceptionField)) + } + wrappedResp.Struct().names.Build() + return wrappedResp, nil +} + // reuse builtin types var builtinTypes = map[string]*TypeDescriptor{ "void": {name: "void", typ: VOID, struc: new(StructDescriptor)}, diff --git a/thrift/idl_test.go b/thrift/idl_test.go index 2d9d90ac..0cc1cf76 100644 --- a/thrift/idl_test.go +++ b/thrift/idl_test.go @@ -402,3 +402,38 @@ func TestNewFunctionDescriptorFromPath(t *testing.T) { require.NotNil(t, p.Functions()["ExampleMethod"]) require.Nil(t, p.Functions()["Ping"]) } + +func TestStreamingFunctionDescriptorFromContent(t *testing.T) { + path := "a/b/main.thrift" + content := ` + namespace go thrift + + struct Request { + 1: required string message, + } + + struct Response { + 1: required string message, + } + + service TestService { + Response Echo (1: Request req) (streaming.mode="bidirectional"), + Response EchoClient (1: Request req) (streaming.mode="client"), + Response EchoServer (1: Request req) (streaming.mode="server"), + Response EchoUnary (1: Request req) (streaming.mode="unary"), // not recommended + Response EchoBizException (1: Request req) (streaming.mode="client"), + + Response EchoPingPong (1: Request req), // KitexThrift, non-streaming + } + ` + dsc, err := NewDescritorFromContent(context.Background(), path, content, nil, false) + require.Nil(t, err) + require.Equal(t, true, dsc.Functions()["Echo"].IsWithoutWrapping()) + require.Equal(t, true, dsc.Functions()["EchoClient"].IsWithoutWrapping()) + require.Equal(t, true, dsc.Functions()["EchoServer"].IsWithoutWrapping()) + require.Equal(t, false, dsc.Functions()["EchoUnary"].IsWithoutWrapping()) + require.Equal(t, true, dsc.Functions()["EchoBizException"].IsWithoutWrapping()) + require.Equal(t, false, dsc.Functions()["EchoPingPong"].IsWithoutWrapping()) + require.Equal(t, "Request", dsc.Functions()["EchoClient"].Request().Struct().Name()) + require.Equal(t, "", dsc.Functions()["EchoUnary"].Request().Struct().Name()) +}