Skip to content

Commit

Permalink
Allow assume role for AWS_MSK_IAM
Browse files Browse the repository at this point in the history
  • Loading branch information
everesio committed Feb 15, 2025
1 parent f5e4d2d commit 6fca61d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,10 @@ You can launch a kafka-proxy container with auth-ldap plugin for trying it out w
--proxy-listener-write-buffer-size int Sets the size of the operating system's transmit buffer associated with the connection. If zero, system default is used
--proxy-request-buffer-size int Request buffer size pro tcp connection (default 4096)
--proxy-response-buffer-size int Response buffer size pro tcp connection (default 4096)
--sasl-aws-identity-lookup Verify AWS authentication identity
--sasl-aws-profile string AWS profile
--sasl-aws-region string Region for AWS IAM Auth
--sasl-aws-role-arn string AWS Role ARN to assume
--sasl-enable Connect using SASL
--sasl-jaas-config-file string Location of JAAS config file with SASL username and password
--sasl-method string SASL method to use (PLAIN, SCRAM-SHA-256, SCRAM-SHA-512, GSSAPI, AWS_MSK_IAM (default "PLAIN")
Expand Down
2 changes: 2 additions & 0 deletions cmd/kafka-proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ func initFlags() {
// SASL AWS_MSK_IAM
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Region, "sasl-aws-region", "", "Region for AWS IAM Auth")
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Profile, "sasl-aws-profile", "", "AWS profile")
Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.RoleArn, "sasl-aws-role-arn", "", "AWS Role ARN to assume")
Server.Flags().BoolVar(&c.Kafka.SASL.AWSConfig.IdentityLookup, "sasl-aws-identity-lookup", false, "Verify AWS authentication identity")

// SASL by Proxy plugin
Server.Flags().BoolVar(&c.Kafka.SASL.Plugin.Enable, "sasl-plugin-enable", false, "Use plugin for SASL authentication")
Expand Down
6 changes: 4 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ type GSSAPIConfig struct {
}

type AWSConfig struct {
Region string
Profile string
Region string
Profile string
RoleArn string
IdentityLookup bool
}

type Config struct {
Expand Down
56 changes: 37 additions & 19 deletions proxy/sasl_aws_msk_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
proxyconfig "github.com/grepplabs/kafka-proxy/config"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -44,6 +47,21 @@ func NewAwsMSKIamAuth(
if err != nil {
return nil, fmt.Errorf("loading aws config: %v", err)
}
if awsConfig.RoleArn != "" {
stsClient := sts.NewFromConfig(cfg)
assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, awsConfig.RoleArn)
cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider)
}
if awsConfig.IdentityLookup {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
output, err := sts.NewFromConfig(cfg).GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, fmt.Errorf("failed to get caller identity: %v", err)
}
logrus.Infof("AWS_MSK_IAM caller identity %v", output.Arn)
}
return &AwsMSKIamAuth{
clientID: clientId,
signer: newMechanism(cfg),
Expand All @@ -56,11 +74,11 @@ func NewAwsMSKIamAuth(
// sendAndReceiveSASLAuth handles the entire SASL authentication process
func (a *AwsMSKIamAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter, brokerString string) error {
if err := a.saslHandshake(conn); err != nil {
return errors.Wrap(err, "handshake failed")
return fmt.Errorf("handshake failed: %w", err)
}

if err := a.saslAuthenticate(conn, brokerString); err != nil {
return errors.Wrap(err, "authenticate failed")
return fmt.Errorf("authenticate failed: %w", err)
}

return nil
Expand All @@ -76,21 +94,21 @@ func (a *AwsMSKIamAuth) saslHandshake(conn DeadlineReaderWriter) error {
Body: rb,
}
if err := a.write(conn, req); err != nil {
return errors.Wrap(err, "writing SASL handshake")
return fmt.Errorf("writing SASL handshake: %w", err)
}

payload, err := a.read(conn)
if err != nil {
return errors.Wrap(err, "reading SASL handshake")
return fmt.Errorf("reading SASL handshake: %w", err)
}

res := &protocol.SaslHandshakeResponseV0orV1{}
if err := protocol.Decode(payload, res); err != nil {
return errors.Wrap(err, "parsing SASL handshake response")
return fmt.Errorf("parsing SASL handshake response: %w", err)
}

if res.Err != protocol.ErrNoError {
return errors.Wrap(res.Err, "sasl handshake protocol error")
if !errors.Is(res.Err, protocol.ErrNoError) {
return fmt.Errorf("sasl handshake protocol error: %w", res.Err)
}
logrus.Debugf("Successful IAM SASL handshake. Available mechanisms: %v", res.EnabledMechanisms)
return nil
Expand All @@ -114,59 +132,59 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString
Body: saslAuthReqV0,
}
if err := a.write(conn, req); err != nil {
return errors.Wrap(err, "writing SASL authentication request")
return fmt.Errorf("writing SASL authentication request: %w", err)
}

payload, err := a.read(conn)
if err != nil {
return errors.Wrap(err, "reading SASL authentication response")
return fmt.Errorf("reading SASL authentication response: %w", err)
}

res := &protocol.SaslAuthenticateResponseV0{}
err = protocol.Decode(payload, res)
if err != nil {
return errors.Wrap(err, "parsing SASL authentication response")
return fmt.Errorf("parsing SASL authentication response: %w", err)
}
if res.Err != protocol.ErrNoError {
return errors.Wrap(res.Err, "sasl authentication protocol error")
if !errors.Is(res.Err, protocol.ErrNoError) {
return fmt.Errorf("sasl authentication protocol error: %w", res.Err)
}
return nil
}

func (a *AwsMSKIamAuth) write(conn DeadlineReaderWriter, req *protocol.Request) error {
reqBuf, err := protocol.Encode(req)
if err != nil {
return errors.Wrap(err, "serializing request")
return fmt.Errorf("serializing request: %w", err)
}

sizeBuf := make([]byte, 4)
binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf)))

if err := conn.SetWriteDeadline(time.Now().Add(a.writeTimeout)); err != nil {
return errors.Wrap(err, "setting write deadline")
return fmt.Errorf("setting write deadline: %w", err)
}

if _, err := conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)); err != nil {
return errors.Wrap(err, "writing bytes")
return fmt.Errorf("writing bytes: %w", err)
}
return nil
}

func (a *AwsMSKIamAuth) read(conn DeadlineReaderWriter) ([]byte, error) {
if err := conn.SetReadDeadline(time.Now().Add(a.readTimeout)); err != nil {
return nil, errors.Wrap(err, "setting read deadline")
return nil, fmt.Errorf("setting read deadline: %w", err)
}

//wait for the handshake response
header := make([]byte, 8) // response header
if _, err := io.ReadFull(conn, header); err != nil {
return nil, errors.Wrap(err, "reading header")
return nil, fmt.Errorf("reading header: %w", err)
}

length := binary.BigEndian.Uint32(header[:4])
payload := make([]byte, length-4)
if _, err := io.ReadFull(conn, payload); err != nil {
return nil, errors.Wrap(err, "reading payload")
return nil, fmt.Errorf("reading payload: %w", err)
}

return payload, nil
Expand Down

0 comments on commit 6fca61d

Please sign in to comment.