From e43aa14df16c20df6131f38623d8a05f59135cfa Mon Sep 17 00:00:00 2001 From: smallnest Date: Sun, 3 Sep 2023 14:16:39 +0800 Subject: [PATCH] #817 use PreCall and PostCall plugins for RegisterFunction --- server/plugin.go | 16 ++++++++-------- server/server.go | 12 +++++++++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/server/plugin.go b/server/plugin.go index 6e9cef4d..3bc1114d 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -31,7 +31,7 @@ type PluginContainer interface { DoPreHandleRequest(ctx context.Context, req *protocol.Message) error DoPreCall(ctx context.Context, serviceName, methodName string, args interface{}) (interface{}, error) - DoPostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}) (interface{}, error) + DoPostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}, err error) (interface{}, error) DoPreWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error DoPostWriteResponse(context.Context, *protocol.Message, *protocol.Message, error) error @@ -96,7 +96,7 @@ type ( } PostCallPlugin interface { - PostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}) (interface{}, error) + PostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}, err error) (interface{}, error) } // PreWriteResponsePlugin represents . @@ -314,18 +314,18 @@ func (p *pluginContainer) DoPreCall(ctx context.Context, serviceName, methodName } // DoPostCall invokes PostCallPlugin plugin. -func (p *pluginContainer) DoPostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}) (interface{}, error) { - var err error +func (p *pluginContainer) DoPostCall(ctx context.Context, serviceName, methodName string, args, reply interface{}, err error) (interface{}, error) { + var e error for i := range p.plugins { if plugin, ok := p.plugins[i].(PostCallPlugin); ok { - reply, err = plugin.PostCall(ctx, serviceName, methodName, args, reply) - if err != nil { - return reply, err + reply, e = plugin.PostCall(ctx, serviceName, methodName, args, reply, err) + if e != nil { + return reply, e } } } - return reply, err + return reply, e } // DoPreWriteResponse invokes PreWriteResponse plugin. diff --git a/server/server.go b/server/server.go index 2fe78523..2d410199 100644 --- a/server/server.go +++ b/server/server.go @@ -722,9 +722,7 @@ func (s *Server) handleRequest(ctx context.Context, req *protocol.Message) (res err = service.call(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv)) } - if err == nil { - replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv) - } + replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err) // return argc to object pool reflectTypePools.Put(mtype.ArgType, argv) @@ -795,6 +793,12 @@ func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Mes } replyv := reflectTypePools.Get(mtype.ReplyType) + argv, err = s.Plugins.DoPreCall(ctx, serviceName, methodName, argv) + if err != nil { + // return reply to object pool + reflectTypePools.Put(mtype.ReplyType, replyv) + return s.handleError(res, err) + } if mtype.ArgType.Kind() != reflect.Ptr { err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv).Elem(), reflect.ValueOf(replyv)) @@ -802,6 +806,8 @@ func (s *Server) handleRequestForFunction(ctx context.Context, req *protocol.Mes err = service.callForFunction(ctx, mtype, reflect.ValueOf(argv), reflect.ValueOf(replyv)) } + replyv, err = s.Plugins.DoPostCall(ctx, serviceName, methodName, argv, replyv, err) + reflectTypePools.Put(mtype.ArgType, argv) if err != nil {