From 912a07ebc42144362769066014505937ba1fdc16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Barnab=C3=A1s=20Pataki?= Date: Fri, 30 Mar 2018 13:44:55 +0200 Subject: [PATCH] Added ErrMultiRegister on multiple Register calls --- server/rpc_server.go | 10 ++++++++++ server/server.go | 5 ++++- server/simple_server.go | 10 ++++++++++ server/simple_server_test.go | 34 ++++++++++++++++++++++++---------- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/server/rpc_server.go b/server/rpc_server.go index a53f1f3d6..578b2dcc5 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -23,6 +23,9 @@ import ( // RPCServer is an experimental server that serves a gRPC server on one // port and the same endpoints via JSON on another port. type RPCServer struct { + // tracks if the Register function is already called or not + registered bool + cfg *Config // exit chan for graceful shutdown @@ -67,6 +70,13 @@ func NewRPCServer(cfg *Config) *RPCServer { // Register will attempt to register the given RPCService with the server. // If any other types are passed, Register will panic. func (r *RPCServer) Register(svc Service) error { + // check multiple register call error + if r.registered { + return ErrMultiRegister + } + // set registered to true because we called it + r.registered = true + rpcsvc, ok := svc.(RPCService) if !ok { Log.Fatalf("invalid service type for rpc server: %T", svc) diff --git a/server/server.go b/server/server.go index e81a61ae6..c50e0a51f 100644 --- a/server/server.go +++ b/server/server.go @@ -11,9 +11,9 @@ import ( "syscall" "time" - "github.com/sirupsen/logrus" "github.com/go-kit/kit/metrics/provider" "github.com/nu7hatch/gouuid" + "github.com/sirupsen/logrus" "github.com/NYTimes/gizmo/config/metrics" "github.com/NYTimes/gizmo/web" @@ -31,6 +31,9 @@ type Server interface { } var ( + // ErrMultiRegister occurs when a Register method is called multiple times + ErrMultiRegister = errors.New("register method has been called multiple times") + // Name is used for status and logging. Name = "nyt-awesome-go-server" // Log is the global logger for the server. It will take care of logrotate diff --git a/server/simple_server.go b/server/simple_server.go index 673a3b966..34c201920 100644 --- a/server/simple_server.go +++ b/server/simple_server.go @@ -23,6 +23,9 @@ import ( // SimpleServer is a basic http Server implementation for // serving SimpleService, JSONService or MixedService implementations. type SimpleServer struct { + // tracks if the Register function is already called or not + registered bool + cfg *Config // exit chan for graceful shutdown @@ -217,6 +220,13 @@ func (s *SimpleServer) Stop() error { // Register will accept and register SimpleServer, JSONService or MixedService implementations. func (s *SimpleServer) Register(svcI Service) error { + // check multiple register call error + if s.registered { + return ErrMultiRegister + } + // set registered to true because we called it + s.registered = true + s.svc = svcI prefix := svcI.Prefix() // quick fix for backwards compatibility diff --git a/server/simple_server_test.go b/server/simple_server_test.go index f8af5c047..1dc9f2c46 100644 --- a/server/simple_server_test.go +++ b/server/simple_server_test.go @@ -367,20 +367,34 @@ func TestFactory(*testing.T) { } func TestBasicRegistration(t *testing.T) { - s := NewSimpleServer(nil) - services := []Service{ - &benchmarkSimpleService{}, - &benchmarkJSONService{}, - &testMixedService{}, - &benchmarkContextService{}, + tests := []struct { + server *SimpleServer + service Service + }{ + { + NewSimpleServer(nil), + &benchmarkSimpleService{}, + }, + { + NewSimpleServer(nil), + &benchmarkJSONService{}, + }, + { + NewSimpleServer(nil), + &testMixedService{}, + }, + { + NewSimpleServer(nil), + &benchmarkContextService{}, + }, } - for _, svc := range services { - if err := s.Register(svc); err != nil { + for _, test := range tests { + if err := test.server.Register(test.service); err != nil { t.Errorf("Basic registration of services should not encounter an error: %s\n", err) } } - - if err := s.Register(&testInvalidService{}); err == nil { + invServer := NewSimpleServer(nil) + if err := invServer.Register(&testInvalidService{}); err == nil { t.Error("Invalid services should produce an error in service registration") } }