-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Artsiom Koltun <artsiom.koltun@intel.com>
- Loading branch information
1 parent
325cfd9
commit 34c1dfa
Showing
3 changed files
with
348 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// Copyright (C) 2023 Intel Corporation | ||
|
||
// Package server implements the server | ||
package server | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"log" | ||
"net" | ||
"os" | ||
"os/signal" | ||
"syscall" | ||
"time" | ||
|
||
"google.golang.org/grpc" | ||
) | ||
|
||
// GRPCServerWrapper wraps gRPC server to provide graceful shutdown capabilities | ||
type GRPCServerWrapper struct { | ||
waitSignal chan os.Signal | ||
signalsToWait []os.Signal | ||
|
||
timeout time.Duration | ||
|
||
server *grpc.Server | ||
listener net.Listener | ||
waitServeComplete chan error | ||
serve func(*grpc.Server, net.Listener) error | ||
} | ||
|
||
func defaultServe(s *grpc.Server, l net.Listener) error { return s.Serve(l) } | ||
|
||
// NewGRPCServerWrapper creates a new instance of GRPCServerWrapper | ||
func NewGRPCServerWrapper( | ||
timeout time.Duration, server *grpc.Server, listener net.Listener, | ||
) *GRPCServerWrapper { | ||
if timeout == 0 { | ||
log.Panicf("timeout cannot be zero") | ||
} | ||
|
||
if server == nil { | ||
log.Panicf("grpc server cannot be nil") | ||
} | ||
|
||
if listener == nil { | ||
log.Panic("listener cannot be nil") | ||
} | ||
|
||
return &GRPCServerWrapper{ | ||
waitSignal: make(chan os.Signal, 1), | ||
signalsToWait: []os.Signal{syscall.SIGINT, syscall.SIGTERM}, | ||
timeout: timeout, | ||
server: server, | ||
listener: listener, | ||
waitServeComplete: make(chan error, 1), | ||
serve: defaultServe, | ||
} | ||
} | ||
|
||
// RunAsync runs gRPC server | ||
func (s *GRPCServerWrapper) RunAsync() { | ||
go func() { | ||
log.Printf("Server listening at %v", s.listener.Addr()) | ||
s.waitServeComplete <- s.serve(s.server, s.listener) | ||
}() | ||
} | ||
|
||
// Wait waits for a signal and handles graceful completion | ||
func (s *GRPCServerWrapper) Wait() error { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
go func() { | ||
signal.Notify(s.waitSignal, s.signalsToWait...) | ||
select { | ||
case sig := <-s.waitSignal: | ||
log.Printf("Got signal: %v", sig) | ||
log.Printf("Start graceful shutdown with timeout: %v", s.timeout) | ||
time.AfterFunc(s.timeout, func() { cancel() }) | ||
s.stopServer(ctx) | ||
case <-ctx.Done(): | ||
log.Println("Stop listening for a signal") | ||
} | ||
}() | ||
|
||
select { | ||
case err := <-s.waitServeComplete: | ||
return err | ||
case <-ctx.Done(): | ||
return errors.New("server stop timeout elapsed") | ||
} | ||
} | ||
|
||
func (s *GRPCServerWrapper) stopServer(ctx context.Context) { | ||
log.Println("Stop server") | ||
|
||
stopped := make(chan struct{}, 1) | ||
go func() { | ||
s.server.GracefulStop() | ||
close(stopped) | ||
}() | ||
|
||
select { | ||
case <-ctx.Done(): | ||
log.Println("Server stop context done") | ||
s.server.Stop() | ||
case <-stopped: | ||
log.Println("GracefulStop completed") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// Copyright (C) 2023 Intel Corporation | ||
|
||
// Package server implements the server | ||
package server | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"log" | ||
"net" | ||
"os" | ||
"sync" | ||
"syscall" | ||
"testing" | ||
"time" | ||
|
||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials/insecure" | ||
"google.golang.org/grpc/test/bufconn" | ||
|
||
pb "github.com/opiproject/opi-api/storage/v1alpha1/gen/go" | ||
) | ||
|
||
const timeout = 10 * time.Millisecond | ||
|
||
type TestServer struct { | ||
pb.MiddleendEncryptionServiceServer | ||
wait time.Duration | ||
startedHandlingCall sync.WaitGroup | ||
} | ||
|
||
func (b *TestServer) CreateEncryptedVolume(_ context.Context, _ *pb.CreateEncryptedVolumeRequest) (*pb.EncryptedVolume, error) { | ||
b.startedHandlingCall.Done() | ||
time.Sleep(b.wait) | ||
return &pb.EncryptedVolume{}, nil | ||
} | ||
|
||
type testEnv struct { | ||
testServer *TestServer | ||
client pb.MiddleendEncryptionServiceClient | ||
conn *grpc.ClientConn | ||
ln net.Listener | ||
grpcServer *grpc.Server | ||
} | ||
|
||
func (e *testEnv) Close() { | ||
CloseGrpcConnection(e.conn) | ||
CloseListener(e.ln) | ||
} | ||
|
||
func createTestEnvironment(callTime time.Duration) *testEnv { | ||
env := &testEnv{} | ||
env.testServer = &TestServer{ | ||
pb.UnimplementedMiddleendEncryptionServiceServer{}, | ||
callTime, | ||
sync.WaitGroup{}, | ||
} | ||
env.grpcServer = grpc.NewServer() | ||
listener := bufconn.Listen(1024 * 1024) | ||
env.ln = listener | ||
pb.RegisterMiddleendEncryptionServiceServer(env.grpcServer, env.testServer) | ||
|
||
ctx := context.Background() | ||
conn, err := grpc.DialContext(ctx, | ||
"", | ||
grpc.WithTransportCredentials(insecure.NewCredentials()), | ||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { | ||
return listener.Dial() | ||
})) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
env.conn = conn | ||
env.client = pb.NewMiddleendEncryptionServiceClient(env.conn) | ||
|
||
return env | ||
} | ||
|
||
func TestGRPCWrapperWait(t *testing.T) { | ||
tests := map[string]struct { | ||
callTime time.Duration | ||
wantErr bool | ||
serve func(*grpc.Server, net.Listener) error | ||
}{ | ||
"server stop timeout": { | ||
callTime: timeout * 2, | ||
wantErr: true, | ||
}, | ||
"successful server stop": { | ||
callTime: timeout / 2, | ||
wantErr: false, | ||
}, | ||
} | ||
for testName, tt := range tests { | ||
t.Run(testName, func(t *testing.T) { | ||
testEnv := createTestEnvironment(tt.callTime) | ||
defer testEnv.Close() | ||
testEnv.testServer.startedHandlingCall.Add(1) | ||
|
||
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) | ||
// use rare signal in order not to catch a real interrupt | ||
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} | ||
serverWrapper.RunAsync() | ||
|
||
var ( | ||
clientResponse *pb.EncryptedVolume | ||
clientErr error | ||
) | ||
clientDone := sync.WaitGroup{} | ||
clientDone.Add(1) | ||
go func() { | ||
clientResponse, clientErr = testEnv.client.CreateEncryptedVolume( | ||
context.Background(), &pb.CreateEncryptedVolumeRequest{}) | ||
clientDone.Done() | ||
}() | ||
testEnv.testServer.startedHandlingCall.Wait() | ||
|
||
serverWrapper.waitSignal <- os.Interrupt | ||
waitErr := serverWrapper.Wait() | ||
|
||
if (waitErr != nil) != tt.wantErr { | ||
t.Errorf("Expected elapsed: %v. received: %v", tt.wantErr, waitErr) | ||
} | ||
clientDone.Wait() | ||
if (clientErr != nil) != tt.wantErr { | ||
t.Errorf("Expected error %v, received: %v", tt.wantErr, clientErr) | ||
} | ||
if (clientResponse == nil) != tt.wantErr { | ||
t.Errorf("Expected not nil response %v, received: %v", tt.wantErr, clientResponse) | ||
} | ||
}) | ||
} | ||
|
||
t.Run("failed serve", func(t *testing.T) { | ||
testEnv := createTestEnvironment(timeout) | ||
defer testEnv.Close() | ||
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) | ||
// use rare signal in order not to catch a real interrupt | ||
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} | ||
stubErr := errors.New("some serve error") | ||
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { return stubErr } | ||
serverWrapper.RunAsync() | ||
|
||
waitErr := serverWrapper.Wait() | ||
|
||
if waitErr != stubErr { | ||
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr) | ||
} | ||
}) | ||
|
||
t.Run("failed serve after signal received", func(t *testing.T) { | ||
testEnv := createTestEnvironment(timeout) | ||
defer testEnv.Close() | ||
serverWrapper := NewGRPCServerWrapper(timeout, testEnv.grpcServer, testEnv.ln) | ||
// use rare signal in order not to catch a real interrupt | ||
serverWrapper.signalsToWait = []os.Signal{syscall.SIGILL} | ||
stubErr := errors.New("some serve error") | ||
wg := sync.WaitGroup{} | ||
wg.Add(1) | ||
serverWrapper.serve = func(s *grpc.Server, l net.Listener) error { | ||
wg.Wait() | ||
return stubErr | ||
} | ||
serverWrapper.RunAsync() | ||
go func() { | ||
serverWrapper.waitSignal <- os.Interrupt | ||
time.Sleep(timeout / 5) | ||
wg.Done() | ||
}() | ||
|
||
waitErr := serverWrapper.Wait() | ||
|
||
if waitErr != stubErr { | ||
t.Errorf("Expected error: %v, received: %v", stubErr, waitErr) | ||
} | ||
}) | ||
} | ||
|
||
func TestNewGRPCWrapper(t *testing.T) { | ||
tests := map[string]struct { | ||
timeout time.Duration | ||
grpcServer *grpc.Server | ||
listener net.Listener | ||
wantPanic bool | ||
}{ | ||
"zero timeout": { | ||
timeout: 0, | ||
grpcServer: grpc.NewServer(), | ||
listener: bufconn.Listen(32), | ||
wantPanic: true, | ||
}, | ||
"nil grpc server": { | ||
timeout: timeout, | ||
grpcServer: nil, | ||
listener: bufconn.Listen(32), | ||
wantPanic: true, | ||
}, | ||
"nil listener": { | ||
timeout: timeout, | ||
grpcServer: grpc.NewServer(), | ||
listener: nil, | ||
wantPanic: true, | ||
}, | ||
"successful wrapper creation": { | ||
timeout: timeout, | ||
grpcServer: grpc.NewServer(), | ||
listener: bufconn.Listen(32), | ||
wantPanic: false, | ||
}, | ||
} | ||
for testName, tt := range tests { | ||
t.Run(testName, func(t *testing.T) { | ||
defer func() { | ||
r := recover() | ||
if (r != nil) != tt.wantPanic { | ||
t.Errorf("GRPCServerWrapper.Run() recover = %v, wantPanic = %v", r, tt.wantPanic) | ||
} | ||
}() | ||
|
||
wrapper := NewGRPCServerWrapper(tt.timeout, tt.grpcServer, tt.listener) | ||
if !tt.wantPanic && wrapper == nil { | ||
t.Error("Expect not nil wrapper") | ||
} | ||
}) | ||
} | ||
} |