Skip to content

Commit

Permalink
Use errors.Is
Browse files Browse the repository at this point in the history
  • Loading branch information
everesio committed Feb 11, 2025
1 parent 9ff9bec commit e4998c0
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions proxy/sasl_by_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"time"

"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -100,7 +100,7 @@ func (b *SASLPlainAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter) e
}
_, err = conn.Write(authBytes)
if err != nil {
return errors.Wrap(err, "Failed to write SASL auth header")
return fmt.Errorf("failed to write SASL auth header: %w", err)
}

err = conn.SetReadDeadline(time.Now().Add(b.readTimeout))
Expand All @@ -116,7 +116,7 @@ func (b *SASLPlainAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter) e
if err == io.EOF {
return fmt.Errorf("SASL/PLAIN auth for user %s failed", b.username)
}
return errors.Wrap(err, "Failed to read response while authenticating with SASL")
return fmt.Errorf("failed to read response while authenticating with SASL: %w", err)
}
return nil
}
Expand All @@ -141,7 +141,7 @@ func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error

_, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil))
if err != nil {
return errors.Wrap(err, "Failed to send SASL handshake")
return fmt.Errorf("failed to send SASL handshake: %w", err)
}

err = conn.SetReadDeadline(time.Now().Add(b.readTimeout))
Expand All @@ -153,23 +153,22 @@ func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error
header := make([]byte, 8) // response header
_, err = io.ReadFull(conn, header)
if err != nil {
return errors.Wrap(err, "Failed to read SASL handshake header")
return fmt.Errorf("failed to read SASL handshake header: %w", err)
}
length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
_, err = io.ReadFull(conn, payload)
if err != nil {
return errors.Wrap(err, "Failed to read SASL handshake payload")
return fmt.Errorf("failed to read SASL handshake payload: %w", err)
}
res := &protocol.SaslHandshakeResponseV0orV1{}
err = protocol.Decode(payload, res)
if err != nil {
return errors.Wrap(err, "Failed to parse SASL handshake")
return fmt.Errorf("failed to parse SASL handshake: %w", err)
}
if res.Err != protocol.ErrNoError {
return errors.Wrap(res.Err, "Invalid SASL Mechanism")
if !errors.Is(res.Err, protocol.ErrNoError) {
return fmt.Errorf("invalid SASL Mechanism: %w", res.Err)
}

logrus.Debugf("Successful SASL handshake. Available mechanisms: %v", res.EnabledMechanisms)
return nil
}
Expand Down Expand Up @@ -231,7 +230,7 @@ func (b *SASLOAuthBearerAuth) sendSaslAuthenticateRequest(token string, conn Dea

_, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil))
if err != nil {
return errors.Wrap(err, "Failed to send SASL auth request")
return fmt.Errorf("failed to send SASL auth request: %w", err)
}

err = conn.SetReadDeadline(time.Now().Add(b.readTimeout))
Expand All @@ -243,22 +242,22 @@ func (b *SASLOAuthBearerAuth) sendSaslAuthenticateRequest(token string, conn Dea
header := make([]byte, 8) // response header
_, err = io.ReadFull(conn, header)
if err != nil {
return errors.Wrap(err, "Failed to read SASL auth header")
return fmt.Errorf("failed to read SASL auth header: %w", err)
}
length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
_, err = io.ReadFull(conn, payload)
if err != nil {
return errors.Wrap(err, "Failed to read SASL auth payload")
return fmt.Errorf("failed to read SASL auth payload: %w", err)
}

res := &protocol.SaslAuthenticateResponseV0{}
err = protocol.Decode(payload, res)
if err != nil {
return errors.Wrap(err, "Failed to parse SASL auth response")
return fmt.Errorf("failed to parse SASL auth response: %w", err)
}
if res.Err != protocol.ErrNoError {
return errors.Wrapf(res.Err, "SASL authentication failed, error message is '%v'", res.ErrMsg)
if !errors.Is(res.Err, protocol.ErrNoError) {
return fmt.Errorf("SASL authentication failed, error message is '%v'", res.ErrMsg)
}
return nil
}

0 comments on commit e4998c0

Please sign in to comment.