From 6fca61d52ec5abb06bde141737c4f3748bf59533 Mon Sep 17 00:00:00 2001 From: Michal Budzyn Date: Sat, 15 Feb 2025 23:33:59 +0100 Subject: [PATCH] Allow assume role for AWS_MSK_IAM --- README.md | 2 ++ cmd/kafka-proxy/server.go | 2 ++ config/config.go | 6 +++-- proxy/sasl_aws_msk_iam.go | 56 ++++++++++++++++++++++++++------------- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 48ff168d..1565e904 100644 --- a/README.md +++ b/README.md @@ -191,8 +191,10 @@ You can launch a kafka-proxy container with auth-ldap plugin for trying it out w --proxy-listener-write-buffer-size int Sets the size of the operating system's transmit buffer associated with the connection. If zero, system default is used --proxy-request-buffer-size int Request buffer size pro tcp connection (default 4096) --proxy-response-buffer-size int Response buffer size pro tcp connection (default 4096) + --sasl-aws-identity-lookup Verify AWS authentication identity --sasl-aws-profile string AWS profile --sasl-aws-region string Region for AWS IAM Auth + --sasl-aws-role-arn string AWS Role ARN to assume --sasl-enable Connect using SASL --sasl-jaas-config-file string Location of JAAS config file with SASL username and password --sasl-method string SASL method to use (PLAIN, SCRAM-SHA-256, SCRAM-SHA-512, GSSAPI, AWS_MSK_IAM (default "PLAIN") diff --git a/cmd/kafka-proxy/server.go b/cmd/kafka-proxy/server.go index 452c4f56..4311ef3e 100644 --- a/cmd/kafka-proxy/server.go +++ b/cmd/kafka-proxy/server.go @@ -187,6 +187,8 @@ func initFlags() { // SASL AWS_MSK_IAM Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Region, "sasl-aws-region", "", "Region for AWS IAM Auth") Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.Profile, "sasl-aws-profile", "", "AWS profile") + Server.Flags().StringVar(&c.Kafka.SASL.AWSConfig.RoleArn, "sasl-aws-role-arn", "", "AWS Role ARN to assume") + Server.Flags().BoolVar(&c.Kafka.SASL.AWSConfig.IdentityLookup, "sasl-aws-identity-lookup", false, "Verify AWS authentication identity") // SASL by Proxy plugin Server.Flags().BoolVar(&c.Kafka.SASL.Plugin.Enable, "sasl-plugin-enable", false, "Use plugin for SASL authentication") diff --git a/config/config.go b/config/config.go index f9529c91..fbcb0599 100644 --- a/config/config.go +++ b/config/config.go @@ -48,8 +48,10 @@ type GSSAPIConfig struct { } type AWSConfig struct { - Region string - Profile string + Region string + Profile string + RoleArn string + IdentityLookup bool } type Config struct { diff --git a/proxy/sasl_aws_msk_iam.go b/proxy/sasl_aws_msk_iam.go index f70cdf73..8edb3aa8 100644 --- a/proxy/sasl_aws_msk_iam.go +++ b/proxy/sasl_aws_msk_iam.go @@ -4,15 +4,18 @@ import ( "bytes" "context" "encoding/binary" + "errors" "fmt" "io" "net" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" proxyconfig "github.com/grepplabs/kafka-proxy/config" "github.com/grepplabs/kafka-proxy/proxy/protocol" - "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -44,6 +47,21 @@ func NewAwsMSKIamAuth( if err != nil { return nil, fmt.Errorf("loading aws config: %v", err) } + if awsConfig.RoleArn != "" { + stsClient := sts.NewFromConfig(cfg) + assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, awsConfig.RoleArn) + cfg.Credentials = aws.NewCredentialsCache(assumeRoleProvider) + } + if awsConfig.IdentityLookup { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + output, err := sts.NewFromConfig(cfg).GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return nil, fmt.Errorf("failed to get caller identity: %v", err) + } + logrus.Infof("AWS_MSK_IAM caller identity %v", output.Arn) + } return &AwsMSKIamAuth{ clientID: clientId, signer: newMechanism(cfg), @@ -56,11 +74,11 @@ func NewAwsMSKIamAuth( // sendAndReceiveSASLAuth handles the entire SASL authentication process func (a *AwsMSKIamAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter, brokerString string) error { if err := a.saslHandshake(conn); err != nil { - return errors.Wrap(err, "handshake failed") + return fmt.Errorf("handshake failed: %w", err) } if err := a.saslAuthenticate(conn, brokerString); err != nil { - return errors.Wrap(err, "authenticate failed") + return fmt.Errorf("authenticate failed: %w", err) } return nil @@ -76,21 +94,21 @@ func (a *AwsMSKIamAuth) saslHandshake(conn DeadlineReaderWriter) error { Body: rb, } if err := a.write(conn, req); err != nil { - return errors.Wrap(err, "writing SASL handshake") + return fmt.Errorf("writing SASL handshake: %w", err) } payload, err := a.read(conn) if err != nil { - return errors.Wrap(err, "reading SASL handshake") + return fmt.Errorf("reading SASL handshake: %w", err) } res := &protocol.SaslHandshakeResponseV0orV1{} if err := protocol.Decode(payload, res); err != nil { - return errors.Wrap(err, "parsing SASL handshake response") + return fmt.Errorf("parsing SASL handshake response: %w", err) } - if res.Err != protocol.ErrNoError { - return errors.Wrap(res.Err, "sasl handshake protocol error") + if !errors.Is(res.Err, protocol.ErrNoError) { + return fmt.Errorf("sasl handshake protocol error: %w", res.Err) } logrus.Debugf("Successful IAM SASL handshake. Available mechanisms: %v", res.EnabledMechanisms) return nil @@ -114,21 +132,21 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString Body: saslAuthReqV0, } if err := a.write(conn, req); err != nil { - return errors.Wrap(err, "writing SASL authentication request") + return fmt.Errorf("writing SASL authentication request: %w", err) } payload, err := a.read(conn) if err != nil { - return errors.Wrap(err, "reading SASL authentication response") + return fmt.Errorf("reading SASL authentication response: %w", err) } res := &protocol.SaslAuthenticateResponseV0{} err = protocol.Decode(payload, res) if err != nil { - return errors.Wrap(err, "parsing SASL authentication response") + return fmt.Errorf("parsing SASL authentication response: %w", err) } - if res.Err != protocol.ErrNoError { - return errors.Wrap(res.Err, "sasl authentication protocol error") + if !errors.Is(res.Err, protocol.ErrNoError) { + return fmt.Errorf("sasl authentication protocol error: %w", res.Err) } return nil } @@ -136,37 +154,37 @@ func (a *AwsMSKIamAuth) saslAuthenticate(conn DeadlineReaderWriter, brokerString func (a *AwsMSKIamAuth) write(conn DeadlineReaderWriter, req *protocol.Request) error { reqBuf, err := protocol.Encode(req) if err != nil { - return errors.Wrap(err, "serializing request") + return fmt.Errorf("serializing request: %w", err) } sizeBuf := make([]byte, 4) binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf))) if err := conn.SetWriteDeadline(time.Now().Add(a.writeTimeout)); err != nil { - return errors.Wrap(err, "setting write deadline") + return fmt.Errorf("setting write deadline: %w", err) } if _, err := conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)); err != nil { - return errors.Wrap(err, "writing bytes") + return fmt.Errorf("writing bytes: %w", err) } return nil } func (a *AwsMSKIamAuth) read(conn DeadlineReaderWriter) ([]byte, error) { if err := conn.SetReadDeadline(time.Now().Add(a.readTimeout)); err != nil { - return nil, errors.Wrap(err, "setting read deadline") + return nil, fmt.Errorf("setting read deadline: %w", err) } //wait for the handshake response header := make([]byte, 8) // response header if _, err := io.ReadFull(conn, header); err != nil { - return nil, errors.Wrap(err, "reading header") + return nil, fmt.Errorf("reading header: %w", err) } length := binary.BigEndian.Uint32(header[:4]) payload := make([]byte, length-4) if _, err := io.ReadFull(conn, payload); err != nil { - return nil, errors.Wrap(err, "reading payload") + return nil, fmt.Errorf("reading payload: %w", err) } return payload, nil