diff --git a/cmd/repeater/repeater.go b/cmd/repeater/repeater.go index c87d125..7ba1621 100644 --- a/cmd/repeater/repeater.go +++ b/cmd/repeater/repeater.go @@ -49,7 +49,7 @@ func main() { log.Fatalf("Error loading mTLS keypair: %v\n", err) os.Exit(1) } - roundtrip.Client = &http.Client{ + roundtrip.MTLSClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -67,11 +67,19 @@ func main() { insecure := os.Getenv("ESCAPE_REPEATER_INSECURE") if insecure == "1" || insecure == "true" { + if mTLScrt != "" && mTLSkey != "" { + logger.Warn("Insecure SSL flag is enabled, so mTLS will not be used.") + roundtrip.MTLSClient = nil + } + logger.Debug("Allowing insecure ssl connections") - if roundtrip.Client.Transport == nil { - roundtrip.Client.Transport = http.DefaultTransport + roundtrip.DefaultClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, } - roundtrip.Client.Transport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } logger.Info("Starting repeater client...") diff --git a/pkg/roundtrip/roudtrip.go b/pkg/roundtrip/roudtrip.go index 4ee5851..b204e55 100644 --- a/pkg/roundtrip/roudtrip.go +++ b/pkg/roundtrip/roudtrip.go @@ -7,7 +7,10 @@ import ( proto "github.com/Escape-Technologies/repeater/proto/repeater/v1" ) -var Client = &http.Client{} +var DefaultClient = &http.Client{} +var MTLSClient *http.Client = nil + +const mTLSHeader = "X-Escape-mTLS" func protoErr(status int, corr int64) *proto.Response { res, err := responseToTransport(&http.Response{ @@ -42,9 +45,23 @@ func HandleRequest(protoReq *proto.Request) *proto.Response { traceroute(protoReq.Url) tls(protoReq.Url) } + client := DefaultClient + mTLS := false + if httpReq.Header.Get(mTLSHeader) != "" { + if MTLSClient != nil { + client = MTLSClient + mTLS = true + } else { + logger.Warn("The current request asked for mTLS but the current configuration does not support it. Falling back to regular TLS.") + } + } - logger.Debug("Sending request (%v)", protoReq.Correlation) - httpRes, err := Client.Do(httpReq) + if mTLS { + logger.Debug("Sending request (%v) with mTLS", protoReq.Correlation) + } else { + logger.Debug("Sending request (%v)", protoReq.Correlation) + } + httpRes, err := client.Do(httpReq) if err != nil { logger.Error("ERROR sending request : %v", err) return protoErr(599, protoReq.Correlation)