Skip to content

Commit

Permalink
Merge pull request #74 from Yelp/u/dpopes/msk-auth-support
Browse files Browse the repository at this point in the history
Support MSK IAM Authentication
  • Loading branch information
danielpops authored Nov 29, 2023
2 parents 5e508ed + c062282 commit 427de2d
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 3 deletions.
7 changes: 6 additions & 1 deletion kafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class SimpleClient(object):
# socket timeout.
def __init__(self, hosts, client_id=CLIENT_ID,
timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
correlation_id=0, metrics=None):
correlation_id=0, metrics=None, **kwargs):
# We need one connection to bootstrap
self.client_id = client_id
self.timeout = timeout
Expand All @@ -90,6 +90,10 @@ def __init__(self, hosts, client_id=CLIENT_ID,
self.topics_to_brokers = {} # TopicPartition -> BrokerMetadata
self.topic_partitions = {} # topic -> partition -> leader

# Support arbitrary kwargs to be provided as config to BrokerConnection
# This will allow advanced features like Authentication to work
self.config = kwargs

self.load_metadata_for_topics() # bootstrap with all metadata

##################
Expand All @@ -108,6 +112,7 @@ def _get_conn(self, host, port, afi, node_id='bootstrap'):
metrics=self._metrics_registry,
metric_group_prefix='simple-client',
node_id=node_id,
**self.config,
)

conn = self._conns[host_key]
Expand Down
50 changes: 49 additions & 1 deletion kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.msk import AwsMskIamClient
from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest
from kafka.protocol.commit import OffsetFetchRequest
Expand Down Expand Up @@ -81,6 +82,12 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
Expand Down Expand Up @@ -224,7 +231,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', 'AWS_MSK_IAM')

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -269,6 +276,11 @@ def __init__(self, host, port, afi, **configs):
token_provider = self.config['sasl_oauth_token_provider']
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'

if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'

# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand Down Expand Up @@ -552,6 +564,8 @@ def _handle_sasl_handshake_response(self, future, response):
return self._try_authenticate_gssapi(future)
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
return self._try_authenticate_oauth(future)
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
return self._try_authenticate_aws_msk_iam(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
Expand Down Expand Up @@ -652,6 +666,40 @@ def _try_authenticate_plain(self, future):
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_aws_msk_iam(self, future):
session = BotoSession()
client = AwsMskIamClient(
host=self.host,
boto_session=session,
)

msg = client.first_message()
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)
data = self._recv_bytes_blocking(4)
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
return future.success(True)

def _try_authenticate_gssapi(self, future):
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name
Expand Down
213 changes: 213 additions & 0 deletions kafka/msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import datetime
import hashlib
import hmac
import json
import string

from kafka.errors import IllegalArgumentError
from kafka.vendor.six.moves import urllib


class AwsMskIamClient:
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'

def __init__(self, host, boto_session):
"""
Arguments:
host (str): The hostname of the broker.
boto_session (botocore.BotoSession) the boto session
"""
self.algorithm = 'AWS4-HMAC-SHA256'
self.expires = '900'
self.hashfunc = hashlib.sha256
self.headers = [
('host', host)
]
self.version = '2020_10_22'

self.service = 'kafka-cluster'
self.action = '{}:Connect'.format(self.service)

now = datetime.datetime.utcnow()
self.datestamp = now.strftime('%Y%m%d')
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')

self.host = host
self.boto_session = boto_session

# This will raise if the region can't be determined
# Do this during init instead of waiting for failures downstream
if self.region:
pass

@property
def access_key(self):
return self.boto_session.get_credentials().access_key

@property
def secret_key(self):
return self.boto_session.get_credentials().secret_key

@property
def token(self):
return self.boto_session.get_credentials().token

@property
def region(self):
# Try to get the region information from the broker hostname
for host in self.host.split(','):
if 'amazonaws.com' in host:
return host.split('.')[-3]

# If the region can't be determined from hostname, try the boto session
# This will only have a value if:
# - `AWS_DEFAULT_REGION` environment variable is set
# - `~/.aws/config` region variable is set
region = self.boto_session.get_config_variable('region')
if region:
return region

# Otherwise give up
raise IllegalArgumentError('Could not determine region from broker host(s) or aws configuration')

@property
def _credential(self):
return '{0.access_key}/{0._scope}'.format(self)

@property
def _scope(self):
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)

@property
def _signed_headers(self):
"""
Returns (str):
An alphabetically sorted, semicolon-delimited list of lowercase
request header names.
"""
return ';'.join(sorted(k.lower() for k, _ in self.headers))

@property
def _canonical_headers(self):
"""
Returns (str):
A newline-delited list of header names and values.
Header names are lowercased.
"""
return '\n'.join(map(':'.join, self.headers)) + '\n'

@property
def _canonical_request(self):
"""
Returns (str):
An AWS Signature Version 4 canonical request in the format:
<Method>\n
<Path>\n
<CanonicalQueryString>\n
<CanonicalHeaders>\n
<SignedHeaders>\n
<HashedPayload>
"""
# The hashed_payload is always an empty string for MSK.
hashed_payload = self.hashfunc(b'').hexdigest()
return '\n'.join((
'GET',
'/',
self._canonical_querystring,
self._canonical_headers,
self._signed_headers,
hashed_payload,
))

@property
def _canonical_querystring(self):
"""
Returns (str):
A '&'-separated list of URI-encoded key/value pairs.
"""
params = []
params.append(('Action', self.action))
params.append(('X-Amz-Algorithm', self.algorithm))
params.append(('X-Amz-Credential', self._credential))
params.append(('X-Amz-Date', self.timestamp))
params.append(('X-Amz-Expires', self.expires))
if self.token:
params.append(('X-Amz-Security-Token', self.token))
params.append(('X-Amz-SignedHeaders', self._signed_headers))

return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)

@property
def _signing_key(self):
"""
Returns (bytes):
An AWS Signature V4 signing key generated from the secret_key, date,
region, service, and request type.
"""
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
key = self._hmac(key, self.region)
key = self._hmac(key, self.service)
key = self._hmac(key, 'aws4_request')
return key

@property
def _signing_str(self):
"""
Returns (str):
A string used to sign the AWS Signature V4 payload in the format:
<Algorithm>\n
<Timestamp>\n
<Scope>\n
<CanonicalRequestHash>
"""
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))

def _uriencode(self, msg):
"""
Arguments:
msg (str): A string to URI-encode.
Returns (str):
The URI-encoded version of the provided msg, following the encoding
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
"""
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)

def _hmac(self, key, msg):
"""
Arguments:
key (bytes): A key to use for the HMAC digest.
msg (str): A value to include in the HMAC digest.
Returns (bytes):
An HMAC digest of the given key and msg.
"""
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()

def first_message(self):
"""
Returns (bytes):
An encoded JSON authentication payload that can be sent to the
broker.
"""
signature = hmac.new(
self._signing_key,
self._signing_str.encode('utf-8'),
digestmod=self.hashfunc,
).hexdigest()
msg = {
'version': self.version,
'host': self.host,
'user-agent': 'kafka-python',
'action': self.action,
'x-amz-algorithm': self.algorithm,
'x-amz-credential': self._credential,
'x-amz-date': self.timestamp,
'x-amz-signedheaders': self._signed_headers,
'x-amz-expires': self.expires,
'x-amz-signature': signature,
}
if self.token:
msg['x-amz-security-token'] = self.token

return json.dumps(msg, separators=(',', ':')).encode('utf-8')
2 changes: 1 addition & 1 deletion kafka/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.4.7.post4'
__version__ = '1.4.7.post5'
Loading

0 comments on commit 427de2d

Please sign in to comment.