Skip to content

feat(codec): Unknown Service Handler (#1321) #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/remote/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
Expand All @@ -27,6 +26,7 @@ import (
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/streaming"
"github.com/cloudwego/kitex/pkg/unknownservice/service"
)

// Option is used to pack the inbound and outbound handlers.
Expand Down Expand Up @@ -113,6 +113,8 @@ type ServerOption struct {

GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error

UnknownServiceHandler service.UnknownServiceHandler

Option

// invoking chain with recv/send middlewares for streaming APIs
Expand Down
85 changes: 85 additions & 0 deletions pkg/unknownservice/service/unknown_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package service

import (
"context"

"github.com/cloudwego/kitex/pkg/serviceinfo"
)

const (
// UnknownService name
UnknownService = "$UnknownService" // private as "$"
// UnknownMethod name
UnknownMethod = "$UnknownMethod"
)

type Args struct {
Request []byte
Method string
ServiceName string
}

type Result struct {
Success []byte
Method string
ServiceName string
}

type UnknownServiceHandler interface {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition can put in unknowservice package directly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import cycle not allowed

UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error)
}

// NewServiceInfo creates a new ServiceInfo containing unknown methods
func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo {
methods := map[string]serviceinfo.MethodInfo{
method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false),
}
handlerType := (*UnknownServiceHandler)(nil)

svcInfo := &serviceinfo.ServiceInfo{
ServiceName: service,
HandlerType: handlerType,
Methods: methods,
PayloadCodec: pcType,
Extra: make(map[string]interface{}),
}

return svcInfo
}

func callHandler(ctx context.Context, handler, arg, result interface{}) error {
realArg := arg.(*Args)
realResult := result.(*Result)
realResult.Method = realArg.Method
realResult.ServiceName = realArg.ServiceName
success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request)
if err != nil {
return err
}
realResult.Success = success
return nil
}

func newServiceArgs() interface{} {
return &Args{}
}

func newServiceResult() interface{} {
return &Result{}
}
197 changes: 197 additions & 0 deletions pkg/unknownservice/unknownservice_codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package unknownservice

import (
"context"
"encoding/binary"
"errors"
"fmt"

gthrift "github.com/cloudwego/gopkg/protocol/thrift"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service"
)

// UnknownCodec implements PayloadCodec
type unknownServiceCodec struct {
Codec remote.PayloadCodec
}

// NewUnknownServiceCodec creates the unknown binary codec.
func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec {
return &unknownServiceCodec{code}
}

// Marshal implements the remote.PayloadCodec interface.
func (c unknownServiceCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
ink := msg.RPCInfo().Invocation()
data := msg.Data()

res, ok := data.(*unknownservice.Result)
if !ok {
return c.Codec.Marshal(ctx, msg, out)
}
if msg.MessageType() == remote.Exception {
return c.Codec.Marshal(ctx, msg, out)
}
if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
ink.SetMethodName(res.Method)
ink.SetServiceName(res.ServiceName)
} else {
return errors.New("the interface Invocation doesn't implement InvocationSetter")
}

if res.Success == nil {
sz := gthrift.Binary.MessageBeginLength(msg.RPCInfo().Invocation().MethodName())
if msg.ProtocolInfo().CodecType == serviceinfo.Thrift {
sz += gthrift.Binary.FieldStopLength()
buf, err := out.Malloc(sz)
if err != nil {
return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err))
}
buf = gthrift.Binary.AppendMessageBegin(buf[:0],
msg.RPCInfo().Invocation().MethodName(), gthrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID())
buf = gthrift.Binary.AppendFieldStop(buf)
_ = buf
}

if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
buf, err := out.Malloc(sz)
if err != nil {
return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err))
}
binary.BigEndian.PutUint32(buf, codec.ProtobufV1Magic+uint32(msg.MessageType()))
offset := 4
offset += gthrift.Binary.WriteString(buf[offset:], res.Method)
offset += gthrift.Binary.WriteI32(buf[offset:], msg.RPCInfo().Invocation().SeqID())
_ = buf
}
return nil
}
out.WriteBinary(res.Success)
return nil
}

// Unmarshal implements the remote.PayloadCodec interface.
func (c unknownServiceCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
ink := message.RPCInfo().Invocation()
magicAndMsgType, err := codec.PeekUint32(in)
if err != nil {
return err
}
msgType := magicAndMsgType & codec.FrontMask
if msgType == uint32(remote.Exception) {
return c.Codec.Unmarshal(ctx, message, in)
}
if err = codec.UpdateMsgType(msgType, message); err != nil {
return err
}
service, method, err := readDecode(message, in)
if err != nil {
return err
}
err = codec.SetOrCheckMethodName(method, message)
var te *remote.TransError
if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) {
svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod)
if err != nil {
return err
}

if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
ink.SetMethodName(unknownservice.UnknownMethod)
ink.SetPackageName(svcInfo.GetPackageName())
ink.SetServiceName(unknownservice.UnknownService)
} else {
return errors.New("the interface Invocation doesn't implement InvocationSetter")
}
if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil {
return err
}

data := message.Data()

if data, ok := data.(*unknownservice.Args); ok {
data.Method = method
data.ServiceName = service
buf, err := in.Next(in.ReadableLen())
if err != nil {
return err
}
data.Request = buf
}
return nil
}

return c.Codec.Unmarshal(ctx, message, in)
}

// Name implements the remote.PayloadCodec interface.
func (c unknownServiceCodec) Name() string {
return "unknownServiceCodec"
}

func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) {
code := message.ProtocolInfo().CodecType
if code == serviceinfo.Thrift || code == serviceinfo.Protobuf {
method, size, err := peekMethod(in)
if err != nil {
return "", "", err
}

seqID, err := peekSeqID(in, size)
if err != nil {
return "", "", err
}
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
return "", "", err
}
return message.RPCInfo().Invocation().ServiceName(), method, nil
}
return "", "", nil
}

func peekMethod(in remote.ByteBuffer) (string, int32, error) {
buf, err := in.Peek(8)
if err != nil {
return "", 0, err
}
buf = buf[4:]
size := int32(binary.BigEndian.Uint32(buf))
buf, err = in.Peek(int(size + 8))
if err != nil {
return "", 0, perrors.NewProtocolError(err)
}
buf = buf[8:]
method := string(buf)
return method, size + 8, nil
}

func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) {
buf, err := in.Peek(int(size + 4))
if err != nil {
return 0, perrors.NewProtocolError(err)
}
buf = buf[size:]
seqID := int32(binary.BigEndian.Uint32(buf))
return seqID, nil
}
Loading