From 6113ed3a9140dcdfce18d546c0079d6eec0be6ac Mon Sep 17 00:00:00 2001 From: mrekucci <4932785+mrekucci@users.noreply.github.com> Date: Mon, 12 Feb 2024 09:42:13 -0300 Subject: [PATCH] feat: enable tls on rpc endpoints --- .github/workflows/go-ci.yml | 1 + cmd/main.go | 107 +++++++++++---------- examples/provideremulator/client/client.go | 13 ++- gen-certificates.sh | 66 +++++++++++++ integration-compose.yml | 9 ++ integrationtest/Dockerfile | 29 ++++-- integrationtest/bidder/main.go | 39 ++++---- integrationtest/config/bidder.yaml | 2 + integrationtest/config/bootnode.yaml | 2 + integrationtest/config/provider.yaml | 2 + integrationtest/entrypoint.sh | 6 +- integrationtest/provider/client.go | 12 ++- integrationtest/provider/main.go | 98 ++++++++++--------- integrationtest/real-bidder/main.go | 39 ++++---- pkg/node/node.go | 98 +++++++++++++++---- 15 files changed, 343 insertions(+), 180 deletions(-) create mode 100755 gen-certificates.sh diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml index e88e1963..aff8a6f0 100644 --- a/.github/workflows/go-ci.yml +++ b/.github/workflows/go-ci.yml @@ -27,6 +27,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: version: v1.54.2 + args: --timeout 5m - name: Vet run: go vet ./... diff --git a/cmd/main.go b/cmd/main.go index e6572458..390fc5e7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "path/filepath" + "slices" "strings" "time" @@ -21,9 +22,16 @@ import ( "github.com/urfave/cli/v2/altsrc" ) +// The following const block contains the name of the cli flags, especially +// for reuse purposes. +const ( + serverTLSCertificateFlagName = "server-tls-certificate" + serverTLSPrivateKeyFlagName = "server-tls-private-key" +) + const ( - defaultP2PPort = 13522 defaultP2PAddr = "0.0.0.0" + defaultP2PPort = 13522 defaultHTTPPort = 13523 defaultRPCPort = 13524 @@ -37,19 +45,17 @@ const ( var ( portCheck = func(c *cli.Context, p int) error { if p < 0 || p > 65535 { - return fmt.Errorf("Invalid port number %d, expected 0 <= port <= 65535", p) + return fmt.Errorf("invalid port number %d, expected 0 <= port <= 65535", p) } return nil } stringInCheck = func(flag string, opts []string) func(c *cli.Context, p string) error { return func(c *cli.Context, p string) error { - for _, opt := range opts { - if p == opt { - return nil - } + if !slices.Contains(opts, p) { + return fmt.Errorf("invalid %s option %q, expected one of %s", flag, p, strings.Join(opts, ", ")) } - return fmt.Errorf("Invalid %s option '%s', expected one of %s", flag, p, strings.Join(opts, ", ")) + return nil } } ) @@ -203,6 +209,18 @@ var ( EnvVars: []string{"MEV_COMMIT_NAT_PORT"}, Value: defaultP2PPort, }) + + optionServerTLSCert = altsrc.NewStringFlag(&cli.StringFlag{ + Name: serverTLSCertificateFlagName, + Usage: "Path to the server TLS certificate", + EnvVars: []string{"MEV_COMMIT_SERVER_TLS_CERTIFICATE"}, + }) + + optionServerTLSPrivateKey = altsrc.NewStringFlag(&cli.StringFlag{ + Name: serverTLSPrivateKeyFlagName, + Usage: "Path to the server TLS private key", + EnvVars: []string{"MEV_COMMIT_SERVER_TLS_PRIVATE_KEY"}, + }) ) func main() { @@ -228,6 +246,8 @@ func main() { optionSettlementRPCEndpoint, optionNATAddr, optionNATPort, + optionServerTLSCert, + optionServerTLSPrivateKey, } app := &cli.App{ @@ -240,47 +260,34 @@ func main() { } if err := app.Run(os.Args); err != nil { - fmt.Fprintf(app.Writer, "exited with error: %v\n", err) + fmt.Fprintln(app.Writer, "exited with error:", err) } } func createKeyIfNotExists(c *cli.Context, path string) error { - // check if key already exists if _, err := os.Stat(path); err == nil { - fmt.Fprintf(c.App.Writer, "Using existing private key: %s\n", path) + fmt.Fprintln(c.App.Writer, "using existing private key:", path) return nil } - fmt.Fprintf(c.App.Writer, "Creating new private key: %s\n", path) - - // check if parent directory exists - if _, err := os.Stat(filepath.Dir(path)); os.IsNotExist(err) { - // create parent directory - if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { - return err - } - } - - privKey, err := crypto.GenerateKey() - if err != nil { + fmt.Fprintln(c.App.Writer, "creating new private key:", path) + if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return err } - f, err := os.Create(path) + key, err := crypto.GenerateKey() if err != nil { return err } - defer f.Close() - - if err := crypto.SaveECDSA(path, privKey); err != nil { + if err := crypto.SaveECDSA(path, key); err != nil { return err } - wallet := libp2p.GetEthAddressFromPubKey(&privKey.PublicKey) + addr := libp2p.GetEthAddressFromPubKey(&key.PublicKey) - fmt.Fprintf(c.App.Writer, "Private key saved to file: %s\n", path) - fmt.Fprintf(c.App.Writer, "Wallet address: %s\n", wallet.Hex()) + fmt.Fprintln(c.App.Writer, "private key saved to file:", path) + fmt.Fprintln(c.App.Writer, "wallet address:", addr.Hex()) return nil } @@ -343,6 +350,12 @@ func launchNodeWithConfig(c *cli.Context) error { natAddr = fmt.Sprintf("%s:%d", c.String(optionNATAddr.Name), c.Int(optionNATPort.Name)) } + crtFile := c.String(serverTLSCertificateFlagName) + keyFile := c.String(serverTLSPrivateKeyFlagName) + if (crtFile == "") != (keyFile == "") { + return fmt.Errorf("both -%s and -%s must be provided to enable TLS", serverTLSCertificateFlagName, serverTLSPrivateKeyFlagName) + } + nd, err := node.NewNode(&node.Options{ KeySigner: keysigner, Secret: c.String(optionSecret.Name), @@ -358,13 +371,15 @@ func launchNodeWithConfig(c *cli.Context) error { BidderRegistryContract: c.String(optionBidderRegistryAddr.Name), RPCEndpoint: c.String(optionSettlementRPCEndpoint.Name), NatAddr: natAddr, + TLSCertificateFile: crtFile, + TLSPrivateKeyFile: keyFile, }) if err != nil { return fmt.Errorf("failed starting node: %w", err) } <-c.Done() - fmt.Fprintf(c.App.Writer, "shutting down...\n") + fmt.Fprintln(c.App.Writer, "shutting down...") closed := make(chan struct{}) go func() { @@ -386,31 +401,23 @@ func launchNodeWithConfig(c *cli.Context) error { } func newLogger(lvl, logFmt string, sink io.Writer) (*slog.Logger, error) { + level := new(slog.LevelVar) + if err := level.UnmarshalText([]byte(lvl)); err != nil { + return nil, fmt.Errorf("invalid log level: %w", err) + } + var ( - level = new(slog.LevelVar) // Info by default handler slog.Handler + options = &slog.HandlerOptions{ + AddSource: true, + Level: level, + } ) - - switch lvl { - case "debug": - level.Set(slog.LevelDebug) - case "info": - level.Set(slog.LevelInfo) - case "warn": - level.Set(slog.LevelWarn) - case "error": - level.Set(slog.LevelError) - default: - return nil, fmt.Errorf("invalid log level: %s", lvl) - } - switch logFmt { case "text": - handler = slog.NewTextHandler(sink, &slog.HandlerOptions{AddSource: true, Level: level}) - case "none": - fallthrough - case "json": - handler = slog.NewJSONHandler(sink, &slog.HandlerOptions{AddSource: true, Level: level}) + handler = slog.NewTextHandler(sink, options) + case "json", "none": + handler = slog.NewJSONHandler(sink, options) default: return nil, fmt.Errorf("invalid log format: %s", logFmt) } diff --git a/examples/provideremulator/client/client.go b/examples/provideremulator/client/client.go index b084e616..40f5a458 100644 --- a/examples/provideremulator/client/client.go +++ b/examples/provideremulator/client/client.go @@ -1,5 +1,5 @@ -// package client implements a simple gRPC client which is to be run by the provider -// in their environment to get a stream of bids that are being gossip'd in the +// Package client implements a simple gRPC client which is to be run by the provider +// in their environment to get a stream of bids that are being gossiped in the // mev-commit network. The providers can then decide to accept or reject the bid. // This status is sent back to the mev-commit node to further take action on the // network. The client can be improved by handling connection failures or using @@ -33,18 +33,17 @@ func NewProviderClient( return nil, err } - client := providerapiv1.NewProviderClient(conn) - b := &ProviderClient{ conn: conn, - client: client, + client: providerapiv1.NewProviderClient(conn), logger: logger, senderC: make(chan *providerapiv1.BidResponse), senderClosed: make(chan struct{}), } - b.startSender() - + if err := b.startSender(); err != nil { + return nil, err + } return b, nil } diff --git a/gen-certificates.sh b/gen-certificates.sh new file mode 100755 index 00000000..4c1a3f20 --- /dev/null +++ b/gen-certificates.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env sh + +# This script automates the creation of a Certificate Authority (CA) and a server +# certificate. It generates a private key and self-signed certificate for the CA, +# then creates a server private key and CSR, and signs the CSR with the CA's +# private key to produce a server certificate. The script requires OpenSSL and optionally +# accepts parameters for the CA subject, server subject, and a server extension config file. +# If the server extension config file is not provided, a default file (server-ext.cnf) +# will be generated with basic constraints and key usage. For more information about the +# configuration file, see: https://www.openssl.org/docs/manmaster/man5/x509v3_config.html + +# Parameters: +# 1. CA Subject (optional): The subject details for the CA certificate. Default: "/C=US/O=Server's CA" +# 2. Server Subject (optional): The subject details for the server certificate. Default: "/C=US/O=Server" +# 3. Server Extension Config File (optional): Path to a configuration file with extensions +# for the server certificate. If not provided, a default configuration is used. + +# Generated files: +# - CA private key (ca-key.pem) +# - CA self-signed certificate (ca-cert.pem) +# - Server private key (server-key.pem) +# - Server CSR (server-req.pem) +# - Server certificate (server-cert.pem) + +# Usage: +# Execute this script with up to three optional arguments: +# ./script.sh [CA Subject] [Server Subject] [Server Extension Config File] +# Example: ./script.sh "/C=US/O=My CA" "/C=US/CN=myserver.example.com" "myserver-ext.cnf" + +# Ensure OpenSSL is installed and accessible, prepare the optional server-ext.cnf file +# with necessary server certificate extensions if desired, and execute this script. +# Verify output for CA and server certificate details. + +# Note: Designed for educational or development purposes. Adapt carefully for production use. + + +CA_KEY="ca-key.pem" +CA_CERT="ca-cert.pem" +SERVER_KEY="server-key.pem" +SERVER_REQ="server-req.pem" +SERVER_CERT="server-cert.pem" +CA_SUBJ=${1:-"/C=US/O=Server's CA"} +SERVER_SUBJ=${2:-"/C=US/O=Server"} +SERVER_EXT=${3:-"server-ext.cnf"} + +# Generate a default server-ext.cnf file if not provided. +if [ ! -f "${SERVER_EXT}" ]; then + echo "No server extension conf provided; generating a default configuration:" +cat << EOH > "${SERVER_EXT}" +basicConstraints = CA:FALSE +keyUsage = digitalSignature, keyEncipherment +EOH + cat "${SERVER_EXT}" +fi + +# Generate CA's private key and self-signed certificate. +openssl req -x509 -newkey rsa:4096 -days 365 -nodes -keyout "${CA_KEY}" -out "${CA_CERT}" -subj "${CA_SUBJ}" +echo "CA's self-signed certificate:" +openssl x509 -in "${CA_CERT}" -noout -text + +# Generate server's private key and certificate request (CSR). +openssl req -newkey rsa:4096 -nodes -keyout "${SERVER_KEY}" -out "${SERVER_REQ}" -subj "${SERVER_SUBJ}" +# Use CA's private key to sign server's CSR and generate the server's certificate. +openssl x509 -req -in "${SERVER_REQ}" -days 365 -CA "${CA_CERT}" -CAkey "${CA_KEY}" -CAcreateserial -out "${SERVER_CERT}" -extfile "${SERVER_EXT}" +echo "Server's CA signed certificate:" +openssl x509 -in "${SERVER_CERT}" -noout -text diff --git a/integration-compose.yml b/integration-compose.yml index 5c4c668d..f4cc688e 100644 --- a/integration-compose.yml +++ b/integration-compose.yml @@ -7,6 +7,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bootnode + service_name: bootnode restart: always volumes: - ./integrationtest/keys/bootnode:/key @@ -44,6 +45,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: provider + service_name: provider1 restart: always depends_on: - bootnode @@ -110,6 +112,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: provider + service_name: provider2 restart: always depends_on: - bootnode @@ -174,6 +177,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: provider + service_name: provider3 restart: always depends_on: - bootnode @@ -238,6 +242,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bidder + service_name: bidder1 restart: always depends_on: - bootnode @@ -306,6 +311,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bidder + service_name: bidder2 restart: always depends_on: - bootnode @@ -372,6 +378,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bidder + service_name: bidder3 restart: always depends_on: - bootnode @@ -438,6 +445,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bidder + service_name: bidder4 restart: always depends_on: - bootnode @@ -504,6 +512,7 @@ services: dockerfile: ./integrationtest/Dockerfile args: node_type: bidder + service_name: bidder5 restart: always depends_on: - bootnode diff --git a/integrationtest/Dockerfile b/integrationtest/Dockerfile index d70675d9..02a9703c 100644 --- a/integrationtest/Dockerfile +++ b/integrationtest/Dockerfile @@ -1,23 +1,34 @@ -FROM golang:1.21.1 AS builder +FROM alpine:latest AS cert_builder +# A unique service name. +ARG service_name +ENV SERVICE_NAME=${service_name} -WORKDIR /app -COPY . . +RUN apk --no-cache add openssl -ARG node_type +WORKDIR / +COPY gen-certificates.sh . +RUN chmod +x /gen-certificates.sh +RUN /gen-certificates.sh "/C=US/O=${SERVICE_NAME} CA" "/C=US/O=${SERVICE_NAME}" + +FROM golang:1.21.1 AS mev_commit_builder +WORKDIR / +COPY . . RUN CGO_ENABLED=0 GOOS=linux make build -FROM alpine:latest +FROM alpine:latest +# Type of node: (bootnode|bidder|provider). ARG node_type ENV NODE_TYPE=${node_type} -COPY --from=builder /app/bin/mev-commit /app/mev-commit -COPY --from=builder /app/integrationtest/config/${NODE_TYPE}.yaml /config.yaml -COPY --from=builder /app/integrationtest/entrypoint.sh /entrypoint.sh +COPY --from=cert_builder /server-cert.pem /server-cert.pem +COPY --from=cert_builder /server-key.pem /server-key.pem +COPY --from=mev_commit_builder /bin/mev-commit /app/mev-commit +COPY --from=mev_commit_builder /integrationtest/config/${NODE_TYPE}.yaml /config.yaml +COPY --from=mev_commit_builder /integrationtest/entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh EXPOSE 13522 13523 13524 - ENTRYPOINT ["/entrypoint.sh"] diff --git a/integrationtest/bidder/main.go b/integrationtest/bidder/main.go index 7c44a4e0..d49d88d5 100644 --- a/integrationtest/bidder/main.go +++ b/integrationtest/bidder/main.go @@ -3,6 +3,7 @@ package main import ( "context" cryptorand "crypto/rand" + "crypto/tls" "errors" "flag" "fmt" @@ -22,7 +23,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" ) var ( @@ -75,7 +76,16 @@ func main() { return } - logger := newLogger(*logLevel) + level := new(slog.LevelVar) + if err := level.UnmarshalText([]byte(*logLevel)); err != nil { + level.Set(slog.LevelDebug) + fmt.Printf("Invalid log level: %s; using %q", err, level) + } + + logger := slog.New(slog.NewTextHandler( + os.Stdout, + &slog.HandlerOptions{Level: level}, + )) registry := prometheus.NewRegistry() registry.MustRegister( @@ -107,7 +117,11 @@ func main() { conn, err := grpc.Dial( *serverAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithTransportCredentials(credentials.NewTLS( + // Integration tests take place in a controlled environment, + // thus we do not expect machine-in-the-middle attacks. + &tls.Config{InsecureSkipVerify: true}, + )), ) if err != nil { logger.Error("failed to connect to server", "err", err) @@ -265,22 +279,3 @@ func sendBid( sendBidSuccessDuration.Set(float64(time.Since(start).Milliseconds())) return nil } - -func newLogger(lvl string) *slog.Logger { - var level = new(slog.LevelVar) // debug by default - - switch lvl { - case "debug": - level.Set(slog.LevelDebug) - case "info": - level.Set(slog.LevelInfo) - case "warn": - level.Set(slog.LevelWarn) - case "error": - level.Set(slog.LevelError) - default: - level.Set(slog.LevelDebug) - } - - return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})) -} diff --git a/integrationtest/config/bidder.yaml b/integrationtest/config/bidder.yaml index 11362a5f..70994abf 100644 --- a/integrationtest/config/bidder.yaml +++ b/integrationtest/config/bidder.yaml @@ -6,6 +6,8 @@ rpc-port: 13524 secret: hello log-fmt: json log-level: debug +server-tls-certificate: /server-cert.pem +server-tls-private-key: /server-key.pem bidder-registry-contract: provider-registry-contract: settlement-rpc-endpoint: diff --git a/integrationtest/config/bootnode.yaml b/integrationtest/config/bootnode.yaml index 890c54f0..a04da47b 100644 --- a/integrationtest/config/bootnode.yaml +++ b/integrationtest/config/bootnode.yaml @@ -6,6 +6,8 @@ rpc-port: 13524 secret: hello log-fmt: json log-level: debug +server-tls-certificate: /server-cert.pem +server-tls-private-key: /server-key.pem bidder-registry-contract: provider-registry-contract: settlement-rpc-endpoint: diff --git a/integrationtest/config/provider.yaml b/integrationtest/config/provider.yaml index 962a543e..56fcb6ba 100644 --- a/integrationtest/config/provider.yaml +++ b/integrationtest/config/provider.yaml @@ -6,6 +6,8 @@ rpc-port: 13524 secret: hello log-fmt: json log-level: debug +server-tls-certificate: /server-cert.pem +server-tls-private-key: /server-key.pem preconf-contract: bidder-registry-contract: provider-registry-contract: diff --git a/integrationtest/entrypoint.sh b/integrationtest/entrypoint.sh index 71e06028..b150b55d 100755 --- a/integrationtest/entrypoint.sh +++ b/integrationtest/entrypoint.sh @@ -1,9 +1,9 @@ #!/bin/sh -echo "Node Type: $NODE_TYPE" +echo "Node Type: ${NODE_TYPE}" # If this is not the bootnode, update the bootnodes entry with P2P ID -if [ "$NODE_TYPE" != "bootnode" ]; then +if [ "${NODE_TYPE}" != "bootnode" ]; then # Wait for a few seconds to ensure the bootnode is up and its API is accessible sleep 10 fi @@ -12,7 +12,7 @@ sed -i "s||${BIDDER_REGISTRY}|" /config.yaml sed -i "s||${PROVIDER_REGISTRY}|" /config.yaml sed -i "s||${RPC_URL}|" /config.yaml -if [ "$NODE_TYPE" == "provider" ]; then +if [ "${NODE_TYPE}" == "provider" ]; then sed -i "s||${PRECONF_CONTRACT}|" /config.yaml fi diff --git a/integrationtest/provider/client.go b/integrationtest/provider/client.go index ccfa0d20..19d2dfb2 100644 --- a/integrationtest/provider/client.go +++ b/integrationtest/provider/client.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "errors" "fmt" "log/slog" @@ -9,7 +10,7 @@ import ( providerapiv1 "github.com/primevprotocol/mev-commit/gen/go/rpc/providerapi/v1" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" ) type ProviderClient struct { @@ -24,7 +25,14 @@ func NewProviderClient( serverAddr string, logger *slog.Logger, ) (*ProviderClient, error) { - conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.Dial( + serverAddr, + grpc.WithTransportCredentials(credentials.NewTLS( + // Integration tests take place in a controlled environment, + // thus we do not expect machine-in-the-middle attacks. + &tls.Config{InsecureSkipVerify: true}, + )), + ) if err != nil { return nil, err } diff --git a/integrationtest/provider/main.go b/integrationtest/provider/main.go index d5045e02..3cc4beca 100644 --- a/integrationtest/provider/main.go +++ b/integrationtest/provider/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "errors" "flag" "fmt" @@ -15,16 +16,35 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) +// The following const block contains the name of the cli flags, especially +// for reuse purposes. +const ( + serverAddrFlagName = "server-addr" + logLevelFlagName = "log-level" + httpPortFlagName = "http-port" + errorProbabilityFlagName = "error-probability" +) + var ( serverAddr = flag.String( - "server-addr", + serverAddrFlagName, "localhost:13524", "The server address in the format of host:port", ) - logLevel = flag.String("log-level", "debug", "Verbosity level (debug|info|warn|error)") - httpPort = flag.Int("http-port", 8080, "The port to serve the HTTP metrics endpoint on") - errorProbablity = flag.Int( - "error-probability", 0, "The probability of returning an error when sending a bid response", + logLevel = flag.String( + logLevelFlagName, + "debug", + "Verbosity level (debug|info|warn|error)", + ) + httpPort = flag.Int( + httpPortFlagName, + 8080, + "The port to serve the HTTP metrics endpoint on", + ) + errorProbability = flag.Int( + errorProbabilityFlagName, + 0, + "The probability of returning an error when sending a bid response", ) ) @@ -52,26 +72,34 @@ var ( func main() { flag.Parse() if *serverAddr == "" { - fmt.Println("Please provide a valid server address with the -serverAddr flag") + fmt.Printf("please provide a valid server address with the -%s flag\n", serverAddrFlagName) return } - logger := newLogger(*logLevel) + level := new(slog.LevelVar) + if err := level.UnmarshalText([]byte(*logLevel)); err != nil { + level.Set(slog.LevelDebug) + fmt.Printf("invalid log level: %s; using %q", err, level) + } + + logger := slog.New(slog.NewTextHandler( + os.Stdout, + &slog.HandlerOptions{Level: level}, + )) registry := prometheus.NewRegistry() registry.MustRegister(receivedBids, sentBids) - router := http.NewServeMux() - router.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{})) - - server := &http.Server{ - Addr: fmt.Sprintf(":%d", *httpPort), - Handler: router, - } - go func() { + router := http.NewServeMux() + router.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{})) + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", *httpPort), + Handler: router, + } if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.Error("failed to start server", "err", err) + logger.Error("failed to start server", "error", err) } }() @@ -98,17 +126,21 @@ func main() { for bid := range bidS { receivedBids.Inc() - logger.Info("received new bid", "bid", bidString(bid)) + buf, err := json.Marshal(bid) + if err != nil { + logger.Error("failed to marshal bid", "error", err) + } + logger.Info("received new bid", "bid", string(buf)) status := providerapiv1.BidResponse_STATUS_ACCEPTED - if *errorProbablity > 0 { - if rand.Intn(100) < *errorProbablity { + if *errorProbability > 0 { + if rand.Intn(100) < *errorProbability { logger.Warn("sending error response") status = providerapiv1.BidResponse_STATUS_REJECTED rejectedBids.Inc() } } - err := providerClient.SendBidResponse(context.Background(), &providerapiv1.BidResponse{ + err = providerClient.SendBidResponse(context.Background(), &providerapiv1.BidResponse{ BidDigest: bid.BidDigest, Status: status, }) @@ -120,29 +152,3 @@ func main() { logger.Info("sent bid", "status", status.String()) } } - -func bidString(bid *providerapiv1.Bid) string { - return fmt.Sprintf( - "bid: {txnHashes: %v, block_number: %d, bid_amount: %s, bid_hash: %x}", - bid.TxHashes, bid.BlockNumber, bid.BidAmount, bid.BidDigest, - ) -} - -func newLogger(lvl string) *slog.Logger { - var level = new(slog.LevelVar) // debug by default - - switch lvl { - case "debug": - level.Set(slog.LevelDebug) - case "info": - level.Set(slog.LevelInfo) - case "warn": - level.Set(slog.LevelWarn) - case "error": - level.Set(slog.LevelError) - default: - level.Set(slog.LevelDebug) - } - - return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})) -} diff --git a/integrationtest/real-bidder/main.go b/integrationtest/real-bidder/main.go index 769d33e0..c9e94df8 100644 --- a/integrationtest/real-bidder/main.go +++ b/integrationtest/real-bidder/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "errors" "flag" "fmt" @@ -20,7 +21,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" ) var ( @@ -69,7 +70,16 @@ func main() { return } - logger := newLogger(*logLevel) + level := new(slog.LevelVar) + if err := level.UnmarshalText([]byte(*logLevel)); err != nil { + level.Set(slog.LevelDebug) + fmt.Printf("Invalid log level: %s; using %q", err, level) + } + + logger := slog.New(slog.NewTextHandler( + os.Stdout, + &slog.HandlerOptions{Level: level}, + )) registry := prometheus.NewRegistry() registry.MustRegister(receivedPreconfs, sentBids) @@ -96,7 +106,11 @@ func main() { conn, err := grpc.Dial( *serverAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithTransportCredentials(credentials.NewTLS( + // Integration tests take place in a controlled environment, + // thus we do not expect machine-in-the-middle attacks. + &tls.Config{InsecureSkipVerify: true}, + )), ) if err != nil { logger.Error("failed to connect to server", "err", err) @@ -268,22 +282,3 @@ func sendBid( ).Observe(time.Since(start).Seconds()) return nil } - -func newLogger(lvl string) *slog.Logger { - var level = new(slog.LevelVar) // debug by default - - switch lvl { - case "debug": - level.Set(slog.LevelDebug) - case "info": - level.Set(slog.LevelInfo) - case "warn": - level.Set(slog.LevelWarn) - case "error": - level.Set(slog.LevelError) - default: - level.Set(slog.LevelDebug) - } - - return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})) -} diff --git a/pkg/node/node.go b/pkg/node/node.go index 623e28d5..b83eeecf 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -2,6 +2,7 @@ package node import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -9,6 +10,7 @@ import ( "math/big" "net" "net/http" + "time" "github.com/bufbuild/protovalidate-go" "github.com/ethereum/go-ethereum/common" @@ -33,9 +35,14 @@ import ( "github.com/primevprotocol/mev-commit/pkg/topology" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) +const ( + grpcServerDialTimeout = 5 * time.Second +) + type Options struct { Version string KeySigner keysigner.KeySigner @@ -52,6 +59,8 @@ type Options struct { BidderRegistryContract string RPCEndpoint string NatAddr string + TLSCertificateFile string + TLSPrivateKeyFile string } type Node struct { @@ -137,10 +146,19 @@ func NewNode(opts *Options) (*Node, error) { return nil, errors.Join(err, nd.Close()) } - grpcServer := grpc.NewServer() - preconfSigner := preconfsigner.NewSigner( - opts.KeySigner, - ) + var tlsCredentials credentials.TransportCredentials + if opts.TLSCertificateFile != "" && opts.TLSPrivateKeyFile != "" { + tlsCredentials, err = credentials.NewServerTLSFromFile( + opts.TLSCertificateFile, + opts.TLSPrivateKeyFile, + ) + if err != nil { + return nil, fmt.Errorf("unable to load TLS credentials: %w", err) + } + } + + grpcServer := grpc.NewServer(grpc.Creds(tlsCredentials)) + preconfSigner := preconfsigner.NewSigner(opts.KeySigner) validator, err := protovalidate.New() if err != nil { return nil, errors.Join(err, nd.Close()) @@ -223,35 +241,64 @@ func NewNode(opts *Options) (*Node, error) { // Wait for the server to start <-started - gwMux := runtime.NewServeMux() - bgCtx := context.Background() + // Since we don't know if the server has TLS enabled on its rpc + // endpoint, we try different strategies from most secure to + // least secure. In the future, when only TLS-enabled servers + // are allowed, only the TLS system pool certificate strategy + // should be used. + var grpcConn *grpc.ClientConn + for _, e := range []struct { + strategy string + isSecure bool + credential credentials.TransportCredentials + }{ + {"TLS system pool certificate", true, credentials.NewClientTLSFromCert(nil, "")}, + {"TLS skip verification", false, credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})}, + {"TLS disabled", false, insecure.NewCredentials()}, + } { + ctx, cancel := context.WithTimeout(context.Background(), grpcServerDialTimeout) + opts.Logger.Info("dialing to grpc server", "strategy", e.strategy) + grpcConn, err = grpc.DialContext( + ctx, + opts.RPCAddr, + grpc.WithBlock(), + grpc.WithTransportCredentials(e.credential), + ) + if err != nil { + opts.Logger.Error("failed to dial grpc server", "err", err) + cancel() + continue + } - grpcConn, err := grpc.DialContext( - bgCtx, - opts.RPCAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - opts.Logger.Error("failed to dial grpc server", "err", err) - return nil, errors.Join(err, nd.Close()) + cancel() + if !e.isSecure { + opts.Logger.Warn("established connection with the grpc server has potential security risk") + } + break + } + if grpcConn == nil { + return nil, errors.New("dialing of grpc server failed") } + gatewayMux := runtime.NewServeMux() + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() switch opts.PeerType { case p2p.PeerTypeProvider.String(): - err := providerapiv1.RegisterProviderHandler(bgCtx, gwMux, grpcConn) + err := providerapiv1.RegisterProviderHandler(ctx, gatewayMux, grpcConn) if err != nil { opts.Logger.Error("failed to register provider handler", "err", err) return nil, errors.Join(err, nd.Close()) } case p2p.PeerTypeBidder.String(): - err := bidderapiv1.RegisterBidderHandler(bgCtx, gwMux, grpcConn) + err := bidderapiv1.RegisterBidderHandler(ctx, gatewayMux, grpcConn) if err != nil { opts.Logger.Error("failed to register bidder handler", "err", err) return nil, errors.Join(err, nd.Close()) } } - srv.ChainHandlers("/", gwMux) + srv.ChainHandlers("/", gatewayMux) srv.ChainHandlers( "/health", http.HandlerFunc( @@ -273,7 +320,20 @@ func NewNode(opts *Options) (*Node, error) { } go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + var ( + err error + tlsEnabled = opts.TLSCertificateFile != "" && opts.TLSPrivateKeyFile != "" + ) + opts.Logger.Info("starting to listen", "tls", tlsEnabled) + if tlsEnabled { + err = server.ListenAndServeTLS( + opts.TLSCertificateFile, + opts.TLSPrivateKeyFile, + ) + } else { + err = server.ListenAndServe() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { opts.Logger.Error("failed to start server", "err", err) } }() @@ -293,7 +353,7 @@ func (n *Node) Close() error { type noOpBidProcessor struct{} -// The noOpBidProcesor auto accepts all bids sent +// ProcessBid auto accepts all bids sent. func (noOpBidProcessor) ProcessBid( _ context.Context, _ *preconfsigner.Bid,