Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new field to Operation and log client info from connection #468

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,5 @@ snapshot:
docker run --rm --privileged -v $(PWD):/go/tmp \
-v /var/run/docker.sock:/var/run/docker.sock \
-w /go/tmp \
ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip-publish
ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip publish

4 changes: 3 additions & 1 deletion protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,21 @@ type Operation struct {
ServerIP net.IP
SNI string
CertID string
ForwardingSvc int64
CustomFuncName string
JaegerSpan []byte
ReqContext []byte
}

func (o *Operation) String() string {
return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s]",
return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s, Forwarding Service: %v]",
o.Opcode,
o.SKI,
o.Digest,
o.ClientIP,
o.ServerIP,
o.SNI,
o.ForwardingSvc,
)
}

Expand Down
107 changes: 76 additions & 31 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net"
"net/rpc"
"os"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -152,6 +153,12 @@ type Sealer interface {
Unseal(*protocol.Operation) ([]byte, error)
}

// ClientInfo has information on the client of the connection
type ClientInfo struct {
Name string
CertSerial string
}

// handler is associated with a connection and contains bookkeeping
// information used across goroutines. The channel tokens limits the
// concurrency: before reading a request a token is extracted, when
Expand All @@ -166,6 +173,7 @@ type handler struct {
conn net.Conn
timeout time.Duration
closed bool
c *ClientInfo
}

func (h *handler) close(err error) {
Expand Down Expand Up @@ -197,6 +205,12 @@ func (h *handler) handle(pkt *protocol.Packet, reqTime time.Time) {
} else {
resp = h.s.unlimitedDo(pkt, h.name)
}

if resp.op.ErrorVal() != protocol.ErrNone {
// log the client certificate information on the connection if the request failed so the caller is apparent
reqID, _ := getOperationRequestID(&pkt.Operation)
log.Errorf("operation from client %s client cert serial: %s errored. sni %s ski %s cert %s request-id %s", h.c.Name, h.c.CertSerial, resp.op.SNI, resp.op.SKI.String(), resp.op.CertID, reqID)
}
logRequestExecDuration(pkt.Operation.Opcode, start, resp.op.ErrorVal())
respPkt := protocol.Packet{
Header: protocol.Header{
Expand Down Expand Up @@ -289,32 +303,61 @@ func makeErrResponse(pkt *protocol.Packet, err protocol.Error) response {
func addOperationRequestID(op *protocol.Operation) string {
reqContext := make(map[string]interface{})
var reqID string
var gen bool

if len(op.ReqContext) > 0 {
if err := json.Unmarshal(op.ReqContext, &reqContext); err == nil {
if v, ok := reqContext["request_id"]; ok {
return v.(string)
} else {
gen = true
}
} else {
log.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)
if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr != nil {
log.Error(fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext))
return reqID
}
}

if v, ok := reqContext["request_id"]; ok {
return v.(string)
}

reqID = uuid.New().String()
reqContext["request_id"] = reqID
b, err := json.Marshal(reqContext)
if err == nil {
op.ReqContext = b
} else {
log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext)
reqID = ""
}
return reqID
}

func getOperationRequestID(op *protocol.Operation) (reqID string, err error) {
reqContext := make(map[string]interface{})
if len(op.ReqContext) == 0 {
return
}
if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr == nil {
if v, ok := reqContext["request_id"]; ok {
return v.(string), nil
}
} else {
err = fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)
log.Error(err)
return
}
return
}

if len(op.ReqContext) == 0 || gen {
reqID = uuid.New().String()
reqContext["request_id"] = reqID
b, err := json.Marshal(reqContext)
if err == nil {
op.ReqContext = b
func getClientInfoFromCerts(certs []*x509.Certificate) *ClientInfo {
cln := []string(nil)
srls := []string(nil)
for _, cert := range certs {
if cert.Subject.CommonName != "" {
cln = append(cln, cert.Subject.CommonName)
} else {
log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext)
reqID = ""
cln = append(cln, cert.DNSNames...)
}
srls = append(srls, cert.SerialNumber.String())
}
return reqID
name := strings.Join(cln, " , ")
serial := strings.Join(srls, " , ")
return &ClientInfo{Name: name, CertSerial: serial}
}

func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
Expand All @@ -328,7 +371,7 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
reqID := addOperationRequestID(&pkt.Operation)
span.SetTag("request_id", reqID)

log.Debugf("connection %s: limited=false opcode=%s id=%d sni=%s ip=%s ski=%v request-id=%s",
log.Debugf("connection %s: limited=false opcode= %s id=%d sni= %s ip= %s ski= %v request-id= %s",
connName,
pkt.Operation.Opcode,
pkt.Header.ID,
Expand Down Expand Up @@ -412,14 +455,14 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {

sig, err := key.Sign(rand.Reader, pkt.Operation.Payload, crypto.Hash(0))
if err != nil {
log.Errorf("Connection: %s: sni=%s ski=%v request-id=%s: Signing error: %v: request-id:%s:", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
log.Errorf("Connection: %s: sni= %s ski= %v request-id= %s: Signing error: %v", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err, reqID)
log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID)
log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand All @@ -430,23 +473,23 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("failed to load key with sni= %s ip= %s ski=%v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

if _, ok := key.Public().(*rsa.PublicKey); !ok {
log.Errorf("Connection %v: sni=%s request-id=%s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
log.Errorf("Connection %v: sni= %s request-id= %s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto)
return makeErrResponse(pkt, protocol.ErrCrypto)
}

if rsaKey, ok := key.(*rsa.PrivateKey); ok {
// Decrypt without removing padding; that's the client's responsibility.
ptxt, err := textbook_rsa.Decrypt(rsaKey, pkt.Operation.Payload)
if err != nil {
log.Errorf("connection %v: sni=%s ip=%s ski=%v request-id=%s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("connection %v: sni= %s ip= %s ski= %v request-id= %s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
return makeRespondResponse(pkt, ptxt)
Expand Down Expand Up @@ -493,10 +536,10 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
key, err := s.keys.Get(ctx, &pkt.Operation)
logKeyLoadDuration(loadStart)
if err != nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err)
return makeErrResponse(pkt, protocol.ErrInternal)
} else if key == nil {
log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound)
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
}

Expand Down Expand Up @@ -526,17 +569,17 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
}
if err != nil {
if attempts > 1 {
log.Debugf("Connection %v sni=%s ip=%s ski=%v request-id=%s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1)
log.Debugf("Connection %v sni= %s ip= %s ski= %v request-id= %s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1)
continue
} else {
tracing.LogError(span, err)
// This indicates that a remote keyserver is being used
var remoteConfigurationErr RemoteConfigurationErr
if errors.As(err, &remoteConfigurationErr) {
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err)
return makeErrResponse(pkt, protocol.ErrRemoteConfiguration)
} else {
log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err)
return makeErrResponse(pkt, protocol.ErrCrypto)
}
}
Expand Down Expand Up @@ -656,6 +699,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) {
}
connState := tconn.ConnectionState()
certmetrics.Observe(certmetrics.CertSourceFromCerts(fmt.Sprintf("listener: %s", l.Addr().String()), connState.PeerCertificates)...)
cl := getClientInfoFromCerts(connState.PeerCertificates)
limited, err := s.config.isLimited(connState)
if err != nil {
log.Errorf("connection %v: could not determine if limited: %v", c.RemoteAddr(), err)
Expand Down Expand Up @@ -692,6 +736,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) {
conn: tconn,
listener: l,
timeout: timeout,
c: cl,
}
err = handler.loop()

Expand Down
1 change: 1 addition & 0 deletions tracing/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func SetOperationSpanTags(span opentracing.Span, op *protocol.Operation) {
"operation.sni": op.SNI,
"operation.certid": op.CertID,
"operation.customfuncname": op.CustomFuncName,
"operation.forwardingsvc": fmt.Sprintf("%d", op.ForwardingSvc),
}
for k, v := range tags {
span.SetTag(k, v)
Expand Down
Loading