Skip to content

Commit

Permalink
Merge pull request #305 from gatewayd-io/json-binary
Browse files Browse the repository at this point in the history
Implement the new JSON binary protocol
  • Loading branch information
mostafa authored Sep 2, 2023
2 parents 4fb93c2 + 498df7f commit 9fdea60
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 124 deletions.
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/NYTimes/gziphandler v1.1.1
github.com/codingsince1985/checksum v1.3.0
github.com/envoyproxy/protoc-gen-validate v1.0.2
github.com/gatewayd-io/gatewayd-plugin-sdk v0.1.0
github.com/gatewayd-io/gatewayd-plugin-sdk v0.1.1
github.com/getsentry/sentry-go v0.23.0
github.com/go-co-op/gocron v1.33.1
github.com/google/go-cmp v0.5.9
Expand Down Expand Up @@ -80,8 +80,8 @@ require (
golang.org/x/net v0.14.0 // indirect
golang.org/x/oauth2 v0.11.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect
Expand Down
12 changes: 6 additions & 6 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 2 additions & 8 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,8 @@ func (pr *Proxy) getPluginModifiedRequest(result map[string]interface{}) []byte
defer span.End()

// If the hook modified the request, use the modified request.
if modRequest, errMsg, convErr := extractFieldValue(result, "request"); errMsg != "" {
if modRequest, errMsg := extractFieldValue(result, "request"); errMsg != "" {
pr.logger.Error().Str("error", errMsg).Msg("Error in hook")
} else if convErr != nil {
pr.logger.Error().Err(convErr).Msg("Error in data conversion")
span.RecordError(convErr)
} else if modRequest != nil {
return modRequest
}
Expand All @@ -693,11 +690,8 @@ func (pr *Proxy) getPluginModifiedResponse(result map[string]interface{}) ([]byt
defer span.End()

// If the hook returns a response, use it instead of the original response.
if modResponse, errMsg, convErr := extractFieldValue(result, "response"); errMsg != "" {
if modResponse, errMsg := extractFieldValue(result, "response"); errMsg != "" {
pr.logger.Error().Str("error", errMsg).Msg("Error in hook")
} else if convErr != nil {
pr.logger.Error().Err(convErr).Msg("Error in data conversion")
span.RecordError(convErr)
} else if modResponse != nil {
return modResponse, len(modResponse)
}
Expand Down
60 changes: 18 additions & 42 deletions network/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package network

import (
"context"
"encoding/base64"
"errors"
"testing"

Expand All @@ -15,7 +14,6 @@ import (
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/structpb"
)

// TestRunServer tests an entire server run with a single client connection and hooks.
Expand All @@ -42,24 +40,17 @@ func TestRunServer(t *testing.T) {

onTrafficFromClient := func(
ctx context.Context,
params *structpb.Struct,
params *v1.Struct,
opts ...grpc.CallOption,
) (*structpb.Struct, error) {
) (*v1.Struct, error) {
paramsMap := params.AsMap()
if paramsMap["request"] == nil {
errs <- errors.New("request is nil") //nolint:goerr113
}

logger.Info().Msg("Ingress traffic")
// Decode the request.
// The request is []byte, but it is base64-encoded as a string
// via using the structpb.NewStruct function.
if req, ok := paramsMap["request"].(string); ok {
if request, err := base64.StdEncoding.DecodeString(req); err == nil {
assert.Equal(t, CreatePgStartupPacket(), request)
} else {
errs <- err
}
if req, ok := paramsMap["request"].([]byte); ok {
assert.Equal(t, CreatePgStartupPacket(), req)
} else {
errs <- errors.New("request is not a []byte") //nolint:goerr113
}
Expand All @@ -70,24 +61,17 @@ func TestRunServer(t *testing.T) {

onTrafficToServer := func(
ctx context.Context,
params *structpb.Struct,
params *v1.Struct,
opts ...grpc.CallOption,
) (*structpb.Struct, error) {
) (*v1.Struct, error) {
paramsMap := params.AsMap()
if paramsMap["request"] == nil {
errs <- errors.New("request is nil") //nolint:goerr113
}

logger.Info().Msg("Ingress traffic")
// Decode the request.
// The request is []byte, but it is base64-encoded as a string
// via using the structpb.NewStruct function.
if req, ok := paramsMap["request"].(string); ok {
if request, err := base64.StdEncoding.DecodeString(req); err == nil {
assert.Equal(t, CreatePgStartupPacket(), request)
} else {
errs <- err
}
if req, ok := paramsMap["request"].([]byte); ok {
assert.Equal(t, CreatePgStartupPacket(), req)
} else {
errs <- errors.New("request is not a []byte") //nolint:goerr113
}
Expand All @@ -98,23 +82,19 @@ func TestRunServer(t *testing.T) {

onTrafficFromServer := func(
ctx context.Context,
params *structpb.Struct,
params *v1.Struct,
opts ...grpc.CallOption,
) (*structpb.Struct, error) {
) (*v1.Struct, error) {
paramsMap := params.AsMap()
if paramsMap["response"] == nil {
errs <- errors.New("response is nil") //nolint:goerr113
}

logger.Info().Msg("Egress traffic")
if resp, ok := paramsMap["response"].(string); ok {
if response, err := base64.StdEncoding.DecodeString(resp); err == nil {
assert.Equal(t, CreatePostgreSQLPacket('R', []byte{
0x0, 0x0, 0x0, 0xa, 0x53, 0x43, 0x52, 0x41, 0x4d, 0x2d, 0x53, 0x48, 0x41, 0x2d, 0x32, 0x35, 0x36, 0x0, 0x0,
}), response)
} else {
errs <- err
}
if resp, ok := paramsMap["response"].([]byte); ok {
assert.Equal(t, CreatePostgreSQLPacket('R', []byte{
0x0, 0x0, 0x0, 0xa, 0x53, 0x43, 0x52, 0x41, 0x4d, 0x2d, 0x53, 0x48, 0x41, 0x2d, 0x32, 0x35, 0x36, 0x0, 0x0,
}), resp)
} else {
errs <- errors.New("response is not a []byte") //nolint:goerr113
}
Expand All @@ -125,21 +105,17 @@ func TestRunServer(t *testing.T) {

onTrafficToClient := func(
ctx context.Context,
params *structpb.Struct,
params *v1.Struct,
opts ...grpc.CallOption,
) (*structpb.Struct, error) {
) (*v1.Struct, error) {
paramsMap := params.AsMap()
if paramsMap["response"] == nil {
errs <- errors.New("response is nil") //nolint:goerr113
}

logger.Info().Msg("Egress traffic")
if resp, ok := paramsMap["response"].(string); ok {
if response, err := base64.StdEncoding.DecodeString(resp); err == nil {
assert.Equal(t, uint8(0x52), response[0])
} else {
errs <- err
}
if resp, ok := paramsMap["response"].([]byte); ok {
assert.Equal(t, uint8(0x52), resp[0])
} else {
errs <- errors.New("response is not a []byte") //nolint:goerr113
}
Expand Down
16 changes: 5 additions & 11 deletions network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package network

import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -95,25 +94,20 @@ func trafficData(
}

// extractFieldValue extracts the given field name and error message from the result of the hook.
func extractFieldValue(result map[string]interface{}, fieldName string) ([]byte, string, error) {
func extractFieldValue(result map[string]interface{}, fieldName string) ([]byte, string) {
var data []byte
var err string
var conversionErr error

//nolint:nestif
if result != nil {
if fieldValue, ok := result[fieldName].(string); ok {
if base64Decoded, err := base64.StdEncoding.DecodeString(fieldValue); err == nil {
data = base64Decoded
} else {
conversionErr = err
}
if val, ok := result[fieldName].([]byte); ok {
data = val
}

if errMsg, ok := result["error"].(string); ok && errMsg != "" {
err = errMsg
}
}
return data, err, conversionErr
return data, err
}

// IsConnTimedOut returns true if the error is a timeout error.
Expand Down
19 changes: 10 additions & 9 deletions plugin/plugin_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/structpb"
)

type IHook interface {
Expand Down Expand Up @@ -283,11 +282,11 @@ func (reg *Registry) Run(
// Cast custom fields to their primitive types, like time.Duration to float64.
args = CastToPrimitiveTypes(args)

// Create structpb.Struct from args.
var params *structpb.Struct
// Create v1.Struct from args.
var params *v1.Struct
if len(args) == 0 {
params = &structpb.Struct{}
} else if casted, err := structpb.NewStruct(args); err == nil {
params = &v1.Struct{}
} else if casted, err := v1.NewStruct(args); err == nil {
params = casted
} else {
span.RecordError(err)
Expand All @@ -304,11 +303,11 @@ func (reg *Registry) Run(
})

// Run hooks, passing the result of the previous hook to the next one.
returnVal := &structpb.Struct{}
returnVal := &v1.Struct{}
var removeList []sdkPlugin.Priority
// The signature of parameters and args MUST be the same for this to work.
for idx, priority := range priorities {
var result *structpb.Struct
var result *v1.Struct
var err error
if idx == 0 {
result, err = reg.hooks[hookName][priority](inheritedCtx, params, opts...)
Expand Down Expand Up @@ -501,21 +500,23 @@ func (reg *Registry) LoadPlugins(ctx context.Context, plugins []config.Plugin) {
span.AddEvent("Started plugin")

// Load metadata from the plugin.
var metadata *structpb.Struct
var metadata *v1.Struct
pluginV1, err := plugin.Dispense()
if err != nil {
reg.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg(
"Failed to dispense plugin")
plugin.Client.Kill()
continue
}

meta, origErr := pluginV1.GetPluginConfig( //nolint:contextcheck
context.Background(), &structpb.Struct{})
context.Background(), &v1.Struct{})
if err != nil || meta == nil {
reg.Logger.Debug().Str("name", plugin.ID.Name).Err(origErr).Msg(
"Failed to get plugin metadata")
continue
}

metadata = meta

span.AddEvent("Fetched plugin metadata")
Expand Down
Loading

0 comments on commit 9fdea60

Please sign in to comment.