diff --git a/.covrc b/.covrc deleted file mode 100644 index 43c5fd7af..000000000 --- a/.covrc +++ /dev/null @@ -1,3 +0,0 @@ -[run] -omit = - kafka/vendor/* diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index df0e4e489..c5bd66218 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -111,7 +111,7 @@ jobs: needs: - build-sdist runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 15 strategy: fail-fast: false matrix: diff --git a/CHANGES.md b/CHANGES.md index ccec6b5c3..ba40007f9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,8 @@ +# 2.0.3 (under development) + +Consumer +* KIP-345: Implement static membership support + # 2.0.2 (Sep 29, 2020) Consumer diff --git a/Makefile b/Makefile index 9d7d89f4d..399b69653 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,7 @@ test-local: build-integration cov-local: build-integration KAFKA_VERSION=$(KAFKA_VERSION) SCALA_VERSION=$(SCALA_VERSION) pytest \ --pylint --pylint-rcfile=pylint.rc --pylint-error-types=EF --cov=kafka \ - --cov-config=.covrc --cov-report html $(FLAGS) kafka test + --cov-report html $(FLAGS) kafka test @echo "open file://`pwd`/htmlcov/index.html" # Check the readme for syntax errors, which can lead to invalid formatting on diff --git a/README.rst b/README.rst index b7acfc8a2..ce82c6d3b 100644 --- a/README.rst +++ b/README.rst @@ -64,8 +64,12 @@ that expose basic message attributes: topic, partition, offset, key, and value: .. code-block:: python + # join a consumer group for dynamic partition assignment and offset commits from kafka import KafkaConsumer - consumer = KafkaConsumer('my_favorite_topic') + consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group') + # or as a static member with a fixed group member name + # consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group', + # group_instance_id='consumer-1', leave_group_on_close=False) for msg in consumer: print (msg) diff --git a/benchmarks/consumer_performance.py b/benchmarks/consumer_performance.py index 9e3b6a919..19231f24a 100755 --- a/benchmarks/consumer_performance.py +++ b/benchmarks/consumer_performance.py @@ -10,8 +10,6 @@ import threading import traceback -from kafka.vendor.six.moves import range - from kafka import KafkaConsumer, KafkaProducer from test.fixtures import KafkaFixture, ZookeeperFixture diff --git a/benchmarks/producer_performance.py b/benchmarks/producer_performance.py index c0de6fd23..d000955d3 100755 --- a/benchmarks/producer_performance.py +++ b/benchmarks/producer_performance.py @@ -9,8 +9,6 @@ import threading import traceback -from kafka.vendor.six.moves import range - from kafka import KafkaProducer from test.fixtures import KafkaFixture, ZookeeperFixture diff --git a/benchmarks/varint_speed.py b/benchmarks/varint_speed.py index fd63d0ac1..83ca1c6e0 100644 --- a/benchmarks/varint_speed.py +++ b/benchmarks/varint_speed.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -from __future__ import print_function import pyperf -from kafka.vendor import six test_data = [ @@ -67,6 +65,10 @@ BENCH_VALUES_DEC = list(map(bytearray, BENCH_VALUES_DEC)) +def int2byte(i): + return bytes((i),) + + def _assert_valid_enc(enc_func): for encoded, decoded in test_data: assert enc_func(decoded) == encoded, decoded @@ -116,7 +118,7 @@ def encode_varint_1(num): _assert_valid_enc(encode_varint_1) -def encode_varint_2(value, int2byte=six.int2byte): +def encode_varint_2(value, int2byte=int2byte): value = (value << 1) ^ (value >> 63) bits = value & 0x7f @@ -151,7 +153,7 @@ def encode_varint_3(value, buf): assert res == encoded -def encode_varint_4(value, int2byte=six.int2byte): +def encode_varint_4(value, int2byte=int2byte): value = (value << 1) ^ (value >> 63) if value <= 0x7f: # 1 byte @@ -301,22 +303,13 @@ def size_of_varint_2(value): _assert_valid_size(size_of_varint_2) -if six.PY3: - def _read_byte(memview, pos): - """ Read a byte from memoryview as an integer - - Raises: - IndexError: if position is out of bounds - """ - return memview[pos] -else: - def _read_byte(memview, pos): - """ Read a byte from memoryview as an integer +def _read_byte(memview, pos): + """ Read a byte from memoryview as an integer Raises: IndexError: if position is out of bounds """ - return ord(memview[pos]) + return memview[pos] def decode_varint_1(buffer, pos=0): diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d3cb6512..67013247b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,13 @@ Changelog ========= +2.2.0 +#################### + +Consumer +-------- +* KIP-345: Implement static membership support + 2.0.2 (Sep 29, 2020) #################### diff --git a/docs/usage.rst b/docs/usage.rst index 047bbad77..dbc8813f0 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -47,6 +47,18 @@ KafkaConsumer group_id='my-group', bootstrap_servers='my.server.com') + # Use multiple static consumers w/ 2.3.0 kafka brokers + consumer1 = KafkaConsumer('my-topic', + group_id='my-group', + group_instance_id='process-1', + leave_group_on_close=False, + bootstrap_servers='my.server.com') + consumer2 = KafkaConsumer('my-topic', + group_id='my-group', + group_instance_id='process-2', + leave_group_on_close=False, + bootstrap_servers='my.server.com') + There are many configuration options for the consumer class. See :class:`~kafka.KafkaConsumer` API documentation for more details. diff --git a/kafka/__init__.py b/kafka/__init__.py index d5e30affa..685593ce3 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - __title__ = 'kafka' from kafka.version import __version__ __author__ = 'Dana Powers' diff --git a/kafka/admin/__init__.py b/kafka/admin/__init__.py index c240fc6d0..c67fb9e6a 100644 --- a/kafka/admin/__init__.py +++ b/kafka/admin/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.admin.config_resource import ConfigResource, ConfigResourceType from kafka.admin.client import KafkaAdminClient from kafka.admin.acl_resource import (ACL, ACLFilter, ResourcePattern, ResourcePatternFilter, ACLOperation, diff --git a/kafka/admin/acl_resource.py b/kafka/admin/acl_resource.py index fd997a10a..4bf36baaa 100644 --- a/kafka/admin/acl_resource.py +++ b/kafka/admin/acl_resource.py @@ -1,12 +1,6 @@ -from __future__ import absolute_import -from kafka.errors import IllegalArgumentError +from enum import IntEnum -# enum in stdlib as of py3.4 -try: - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum +from kafka.errors import IllegalArgumentError class ResourceType(IntEnum): @@ -69,7 +63,7 @@ class ACLResourcePatternType(IntEnum): PREFIXED = 4 -class ACLFilter(object): +class ACLFilter: """Represents a filter to use with describing and deleting ACLs The difference between this class and the ACL class is mainly that @@ -161,7 +155,7 @@ def __init__( permission_type, resource_pattern ): - super(ACL, self).__init__(principal, host, operation, permission_type, resource_pattern) + super().__init__(principal, host, operation, permission_type, resource_pattern) self.validate() def validate(self): @@ -173,7 +167,7 @@ def validate(self): raise IllegalArgumentError("resource_pattern must be a ResourcePattern object") -class ResourcePatternFilter(object): +class ResourcePatternFilter: def __init__( self, resource_type, @@ -232,7 +226,7 @@ def __init__( resource_name, pattern_type=ACLResourcePatternType.LITERAL ): - super(ResourcePattern, self).__init__(resource_type, resource_name, pattern_type) + super().__init__(resource_type, resource_name, pattern_type) self.validate() def validate(self): @@ -240,5 +234,5 @@ def validate(self): raise IllegalArgumentError("resource_type cannot be ANY") if self.pattern_type in [ACLResourcePatternType.ANY, ACLResourcePatternType.MATCH]: raise IllegalArgumentError( - "pattern_type cannot be {} on a concrete ResourcePattern".format(self.pattern_type.name) + f"pattern_type cannot be {self.pattern_type.name} on a concrete ResourcePattern" ) diff --git a/kafka/admin/client.py b/kafka/admin/client.py index 8eb7504a7..5b01f8fe6 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -1,12 +1,9 @@ -from __future__ import absolute_import - from collections import defaultdict import copy import logging import socket from . import ConfigResourceType -from kafka.vendor import six from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \ ACLResourcePatternType @@ -20,7 +17,7 @@ from kafka.protocol.admin import ( CreateTopicsRequest, DeleteTopicsRequest, DescribeConfigsRequest, AlterConfigsRequest, CreatePartitionsRequest, ListGroupsRequest, DescribeGroupsRequest, DescribeAclsRequest, CreateAclsRequest, DeleteAclsRequest, - DeleteGroupsRequest + DeleteGroupsRequest, DescribeLogDirsRequest ) from kafka.protocol.commit import GroupCoordinatorRequest, OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest @@ -32,7 +29,7 @@ log = logging.getLogger(__name__) -class KafkaAdminClient(object): +class KafkaAdminClient: """A class for administering the Kafka cluster. Warning: @@ -194,7 +191,7 @@ def __init__(self, **configs): log.debug("Starting KafkaAdminClient with configuration: %s", configs) extra_configs = set(configs).difference(self.DEFAULT_CONFIG) if extra_configs: - raise KafkaConfigurationError("Unrecognized configs: {}".format(extra_configs)) + raise KafkaConfigurationError(f"Unrecognized configs: {extra_configs}") self.config = copy.copy(self.DEFAULT_CONFIG) self.config.update(configs) @@ -874,7 +871,7 @@ def describe_configs(self, config_resources, include_synonyms=False): )) else: raise NotImplementedError( - "Support for DescribeConfigs v{} has not yet been added to KafkaAdminClient.".format(version)) + f"Support for DescribeConfigs v{version} has not yet been added to KafkaAdminClient.") self._wait_for_futures(futures) return [f.value for f in futures] @@ -1197,7 +1194,7 @@ def _list_consumer_group_offsets_send_request(self, group_id, topics_partitions_dict = defaultdict(set) for topic, partition in partitions: topics_partitions_dict[topic].add(partition) - topics_partitions = list(six.iteritems(topics_partitions_dict)) + topics_partitions = list(topics_partitions_dict.items()) request = OffsetFetchRequest[version](group_id, topics_partitions) else: raise NotImplementedError( @@ -1345,3 +1342,19 @@ def _wait_for_futures(self, futures): if future.failed(): raise future.exception # pylint: disable-msg=raising-bad-type + + def describe_log_dirs(self): + """Send a DescribeLogDirsRequest request to a broker. + + :return: A message future + """ + version = self._matching_api_version(DescribeLogDirsRequest) + if version <= 1: + request = DescribeLogDirsRequest[version]() + future = self._send_request_to_node(self._client.least_loaded_node(), request) + self._wait_for_futures([future]) + else: + raise NotImplementedError( + "Support for DescribeLogDirsRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return future.value diff --git a/kafka/admin/config_resource.py b/kafka/admin/config_resource.py index e3294c9c4..55a2818ea 100644 --- a/kafka/admin/config_resource.py +++ b/kafka/admin/config_resource.py @@ -1,11 +1,4 @@ -from __future__ import absolute_import - -# enum in stdlib as of py3.4 -try: - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum +from enum import IntEnum class ConfigResourceType(IntEnum): @@ -15,7 +8,7 @@ class ConfigResourceType(IntEnum): TOPIC = 2 -class ConfigResource(object): +class ConfigResource: """A class for specifying config resources. Arguments: resource_type (ConfigResourceType): the type of kafka resource diff --git a/kafka/admin/new_partitions.py b/kafka/admin/new_partitions.py index 429b2e190..613fb861e 100644 --- a/kafka/admin/new_partitions.py +++ b/kafka/admin/new_partitions.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import - - -class NewPartitions(object): +class NewPartitions: """A class for new partition creation on existing topics. Note that the length of new_assignments, if specified, must be the difference between the new total number of partitions and the existing number of partitions. Arguments: diff --git a/kafka/admin/new_topic.py b/kafka/admin/new_topic.py index 645ac383a..a50c3a374 100644 --- a/kafka/admin/new_topic.py +++ b/kafka/admin/new_topic.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - from kafka.errors import IllegalArgumentError -class NewTopic(object): +class NewTopic: """ A class for new topic creation Arguments: name (string): name of the topic diff --git a/kafka/client_async.py b/kafka/client_async.py index 530a1f441..b395dc5da 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -1,23 +1,13 @@ -from __future__ import absolute_import, division - import collections import copy import logging import random +import selectors import socket import threading import time import weakref -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor import selectors34 as selectors - -from kafka.vendor import six - from kafka.cluster import ClusterMetadata from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts, get_ip_port_afi from kafka import errors as Errors @@ -27,19 +17,12 @@ from kafka.metrics.stats.rate import TimeUnit from kafka.protocol.metadata import MetadataRequest from kafka.util import Dict, WeakMethod -# Although this looks unused, it actually monkey-patches socket.socketpair() -# and should be left in as long as we're using socket.socketpair() in this file -from kafka.vendor import socketpair from kafka.version import __version__ -if six.PY2: - ConnectionError = None - - log = logging.getLogger('kafka.client') -class KafkaClient(object): +class KafkaClient: """ A network client for asynchronous request/response network I/O. @@ -154,6 +137,8 @@ class KafkaClient(object): sasl mechanism handshake. Default: one of bootstrap servers sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider instance. (See kafka.oauth.abstract). Default: None + raise_upon_socket_err_during_wakeup (bool): If set to True, raise an exception + upon socket error during wakeup(). Default: False """ DEFAULT_CONFIG = { @@ -192,7 +177,8 @@ class KafkaClient(object): 'sasl_plain_password': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None + 'sasl_oauth_token_provider': None, + 'raise_upon_socket_err_during_wakeup': False } def __init__(self, **configs): @@ -243,6 +229,8 @@ def __init__(self, **configs): check_timeout = self.config['api_version_auto_timeout_ms'] / 1000 self.config['api_version'] = self.check_version(timeout=check_timeout) + self._raise_upon_socket_err_during_wakeup = self.config['raise_upon_socket_err_during_wakeup'] + def _can_bootstrap(self): effective_failures = self._bootstrap_fails // self._num_bootstrap_hosts backoff_factor = 2 ** effective_failures @@ -369,7 +357,7 @@ def _maybe_connect(self, node_id): if conn is None: broker = self.cluster.broker_metadata(node_id) - assert broker, 'Broker id %s not in current metadata' % (node_id,) + assert broker, 'Broker id {} not in current metadata'.format(node_id) log.debug("Initiating connection to node %s at %s:%s", node_id, broker.host, broker.port) @@ -681,7 +669,7 @@ def _poll(self, timeout): unexpected_data = key.fileobj.recv(1) if unexpected_data: # anything other than a 0-byte read means protocol issues log.warning('Protocol out of sync on %r, closing', conn) - except socket.error: + except OSError: pass conn.close(Errors.KafkaConnectionError('Socket EVENT_READ without in-flight-requests')) continue @@ -696,7 +684,7 @@ def _poll(self, timeout): if conn not in processed and conn.connected() and conn._sock.pending(): self._pending_completion.extend(conn.recv()) - for conn in six.itervalues(self._conns): + for conn in self._conns.values(): if conn.requests_timed_out(): log.warning('%s timed out after %s ms. Closing connection.', conn, conn.config['request_timeout_ms']) @@ -936,15 +924,17 @@ def wakeup(self): except socket.timeout: log.warning('Timeout to send to wakeup socket!') raise Errors.KafkaTimeoutError() - except socket.error: + except OSError as e: log.warning('Unable to send to wakeup socket!') + if self._raise_upon_socket_err_during_wakeup: + raise e def _clear_wake_fd(self): # reading from wake socket should only happen in a single thread while True: try: self._wake_r.recv(1024) - except socket.error: + except OSError: break def _maybe_close_oldest_connection(self): @@ -974,7 +964,7 @@ def bootstrap_connected(self): OrderedDict = dict -class IdleConnectionManager(object): +class IdleConnectionManager: def __init__(self, connections_max_idle_ms): if connections_max_idle_ms > 0: self.connections_max_idle = connections_max_idle_ms / 1000 @@ -1036,7 +1026,7 @@ def poll_expired_connection(self): return None -class KafkaClientMetrics(object): +class KafkaClientMetrics: def __init__(self, metrics, metric_group_prefix, conns): self.metrics = metrics self.metric_group_name = metric_group_prefix + '-metrics' diff --git a/kafka/cluster.py b/kafka/cluster.py index 438baf29d..ee1fe79a0 100644 --- a/kafka/cluster.py +++ b/kafka/cluster.py @@ -1,13 +1,9 @@ -from __future__ import absolute_import - import collections import copy import logging import threading import time -from kafka.vendor import six - from kafka import errors as Errors from kafka.conn import collect_hosts from kafka.future import Future @@ -16,7 +12,7 @@ log = logging.getLogger(__name__) -class ClusterMetadata(object): +class ClusterMetadata: """ A class to manage kafka cluster metadata. @@ -128,9 +124,9 @@ def available_partitions_for_topic(self, topic): """ if topic not in self._partitions: return None - return set([partition for partition, metadata - in six.iteritems(self._partitions[topic]) - if metadata.leader != -1]) + return {partition for partition, metadata + in self._partitions[topic].items() + if metadata.leader != -1} def leader_for_partition(self, partition): """Return node_id of leader, -1 unavailable, None if unknown.""" @@ -361,7 +357,7 @@ def add_group_coordinator(self, group, response): # Use a coordinator-specific node id so that group requests # get a dedicated connection - node_id = 'coordinator-{}'.format(response.coordinator_id) + node_id = f'coordinator-{response.coordinator_id}' coordinator = BrokerMetadata( node_id, response.host, diff --git a/kafka/codec.py b/kafka/codec.py index c740a181c..3a9982c19 100644 --- a/kafka/codec.py +++ b/kafka/codec.py @@ -1,13 +1,8 @@ -from __future__ import absolute_import - import gzip import io import platform import struct -from kafka.vendor import six -from kafka.vendor.six.moves import range - _XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1) _XERIAL_V1_FORMAT = 'bccccccBii' ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024 @@ -149,10 +144,6 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): # buffer... likely a python-snappy bug, so just use a slice copy chunker = lambda payload, i, size: payload[i:size+i] - elif six.PY2: - # Sliced buffer avoids additional copies - # pylint: disable-msg=undefined-variable - chunker = lambda payload, i, size: buffer(payload, i, size) else: # snappy.compress does not like raw memoryviews, so we have to convert # tobytes, which is a copy... oh well. it's the thought that counts. diff --git a/kafka/conn.py b/kafka/conn.py index 80f17009c..5a73ba429 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -1,46 +1,31 @@ -from __future__ import absolute_import, division - import copy import errno -import io import logging +import selectors from random import shuffle, uniform -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor import selectors34 as selectors - import socket -import struct import threading import time -from kafka.vendor import six - +from kafka import sasl import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.oauth.abstract import AbstractTokenProvider -from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest +from kafka.protocol.admin import ( + DescribeAclsRequest_v2, + DescribeClientQuotasRequest, + SaslHandShakeRequest, +) from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.offset import OffsetRequest from kafka.protocol.produce import ProduceRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.fetch import FetchRequest from kafka.protocol.parser import KafkaProtocol -from kafka.protocol.types import Int32, Int8 -from kafka.scram import ScramClient from kafka.version import __version__ -if six.PY2: - ConnectionError = socket.error - TimeoutError = socket.error - BlockingIOError = Exception - log = logging.getLogger(__name__) DEFAULT_KAFKA_PORT = 9092 @@ -83,6 +68,12 @@ class SSLWantWriteError(Exception): gssapi = None GSSError = None +# needed for AWS_MSK_IAM authentication: +try: + from botocore.session import Session as BotoSession +except ImportError: + # no botocore available, will disable AWS_MSK_IAM mechanism + BotoSession = None AFI_NAMES = { socket.AF_UNSPEC: "unspecified", @@ -91,7 +82,7 @@ class SSLWantWriteError(Exception): } -class ConnectionStates(object): +class ConnectionStates: DISCONNECTING = '' DISCONNECTED = '' CONNECTING = '' @@ -100,7 +91,7 @@ class ConnectionStates(object): AUTHENTICATING = '' -class BrokerConnection(object): +class BrokerConnection: """Initialize a Kafka broker connection Keyword Arguments: @@ -227,7 +218,6 @@ class BrokerConnection(object): 'sasl_oauth_token_provider': None } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") def __init__(self, host, port, afi, **configs): self.host = host @@ -256,26 +246,19 @@ def __init__(self, host, port, afi, **configs): assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, ( 'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS)) + if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): assert ssl_available, "Python wasn't built with SSL support" + if self.config['sasl_mechanism'] == 'AWS_MSK_IAM': + assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' + assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL' + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( - 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) - if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): - assert self.config['sasl_plain_username'] is not None, ( - 'sasl_plain_username required for PLAIN or SCRAM sasl' - ) - assert self.config['sasl_plain_password'] is not None, ( - 'sasl_plain_password required for PLAIN or SCRAM sasl' - ) - if self.config['sasl_mechanism'] == 'GSSAPI': - assert gssapi is not None, 'GSSAPI lib not available' - assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' - if self.config['sasl_mechanism'] == 'OAUTHBEARER': - token_provider = self.config['sasl_oauth_token_provider'] - assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' - assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' + assert self.config['sasl_mechanism'] in sasl.MECHANISMS, ( + 'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys())) + ) + sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self) # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to @@ -386,7 +369,7 @@ def connect(self): ret = None try: ret = self._sock.connect_ex(self._sock_addr) - except socket.error as err: + except OSError as err: ret = err.errno # Connection succeeded @@ -418,7 +401,7 @@ def connect(self): log.error('Connect attempt to %s returned error %s.' ' Disconnecting.', self, ret) errstr = errno.errorcode.get(ret, 'UNKNOWN') - self.close(Errors.KafkaConnectionError('{} {}'.format(ret, errstr))) + self.close(Errors.KafkaConnectionError(f'{ret} {errstr}')) return self.state # Needs retry @@ -553,19 +536,9 @@ def _handle_sasl_handshake_response(self, future, response): Errors.UnsupportedSaslMechanismError( 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' % (self.config['sasl_mechanism'], response.enabled_mechanisms))) - elif self.config['sasl_mechanism'] == 'PLAIN': - return self._try_authenticate_plain(future) - elif self.config['sasl_mechanism'] == 'GSSAPI': - return self._try_authenticate_gssapi(future) - elif self.config['sasl_mechanism'] == 'OAUTHBEARER': - return self._try_authenticate_oauth(future) - elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): - return self._try_authenticate_scram(future) - else: - return future.failure( - Errors.UnsupportedSaslMechanismError( - 'kafka-python does not support SASL mechanism %s' % - self.config['sasl_mechanism'])) + + try_authenticate = sasl.MECHANISMS[self.config['sasl_mechanism']].try_authenticate + return try_authenticate(self, future) def _send_bytes(self, data): """Send some data via non-blocking IO @@ -584,12 +557,9 @@ def _send_bytes(self, data): except (SSLWantReadError, SSLWantWriteError): break except (ConnectionError, TimeoutError) as e: - if six.PY2 and e.errno == errno.EWOULDBLOCK: - break raise except BlockingIOError: - if six.PY3: - break + break raise return total_sent @@ -619,225 +589,6 @@ def _recv_bytes_blocking(self, n): finally: self._sock.settimeout(0.0) - def _try_authenticate_plain(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Sending username and password in the clear', self) - - data = b'' - # Send PLAIN credentials per RFC-4616 - msg = bytes('\0'.join([self.config['sasl_plain_username'], - self.config['sasl_plain_username'], - self.config['sasl_plain_password']]).encode('utf-8')) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) - return future.success(True) - - def _try_authenticate_scram(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Exchanging credentials in the clear', self) - - scram_client = ScramClient( - self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] - ) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - client_first = scram_client.first_message().encode('utf-8') - size = Int32.encode(len(client_first)) - self._send_bytes_blocking(size + client_first) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_first_message(server_first) - - client_final = scram_client.final_message().encode('utf-8') - size = Int32.encode(len(client_final)) - self._send_bytes_blocking(size + client_final) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_final_message(server_final) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info( - '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] - ) - return future.success(True) - - def _try_authenticate_gssapi(self, future): - kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host - auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name - gssapi_name = gssapi.Name( - auth_id, - name_type=gssapi.NameType.hostbased_service - ).canonicalize(gssapi.MechType.kerberos) - log.debug('%s: GSSAPI name: %s', self, gssapi_name) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - try: - # Exchange tokens until authentication either succeeds or fails - client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') - received_token = None - while not client_ctx.complete: - # calculate an output token from kafka token (or None if first iteration) - output_token = client_ctx.step(received_token) - - # pass output token to kafka, or send empty response if the security - # context is complete (output token is None in that case) - if output_token is None: - self._send_bytes_blocking(Int32.encode(0)) - else: - msg = output_token - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - # The connection is closed on failure. - header = self._recv_bytes_blocking(4) - (token_size,) = struct.unpack('>i', header) - received_token = self._recv_bytes_blocking(token_size) - - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # unwraps message containing supported protection levels and msg size - msg = client_ctx.unwrap(received_token).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] - # add authorization identity to the response, GSS-wrap and send it - msg = client_ctx.wrap(msg + auth_id.encode(), False).message - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - except Exception as e: - err = e - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) - return future.success(True) - - def _try_authenticate_oauth(self, future): - data = b'' - - msg = bytes(self._build_oauth_client_request().encode("utf-8")) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - # Send SASL OAuthBearer request with OAuth token - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated via OAuth', self) - return future.success(True) - - def _build_oauth_client_request(self): - token_provider = self.config['sasl_oauth_token_provider'] - return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) - - def _token_extensions(self): - """ - Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER - initial request. - """ - token_provider = self.config['sasl_oauth_token_provider'] - - # Only run if the #extensions() method is implemented by the clients Token Provider class - # Builds up a string separated by \x01 via a dict of key value pairs - if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: - msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) - return "\x01" + msg - else: - return "" - def blacked_out(self): """ Return true if we are disconnected from the given node and can't @@ -916,7 +667,7 @@ def close(self, error=None): with self._lock: if self.state is ConnectionStates.DISCONNECTED: return - log.info('%s: Closing connection. %s', self, error or '') + log.log(logging.ERROR if error else logging.INFO, '%s: Closing connection. %s', self, error or '') self._update_reconnect_backoff() self._sasl_auth_future = None self._protocol = KafkaProtocol( @@ -1003,7 +754,7 @@ def send_pending_requests(self): except (ConnectionError, TimeoutError) as e: log.exception("Error sending request data to %s", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + error = Errors.KafkaConnectionError("{}: {}".format(self, e)) self.close(error=error) return False @@ -1036,7 +787,7 @@ def send_pending_requests_v2(self): except (ConnectionError, TimeoutError, Exception) as e: log.exception("Error sending request data to %s", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + error = Errors.KafkaConnectionError("{}: {}".format(self, e)) self.close(error=error) return False @@ -1102,15 +853,12 @@ def _recv(self): except (SSLWantReadError, SSLWantWriteError): break except (ConnectionError, TimeoutError) as e: - if six.PY2 and e.errno == errno.EWOULDBLOCK: - break log.exception('%s: Error receiving network data' ' closing socket', self) err = Errors.KafkaConnectionError(e) break except BlockingIOError: - if six.PY3: - break + break # For PY2 this is a catchall and should be re-raised raise @@ -1145,10 +893,10 @@ def requests_timed_out(self): def _handle_api_version_response(self, response): error_type = Errors.for_code(response.error_code) assert error_type is Errors.NoError, "API version check failed" - self._api_versions = dict([ - (api_key, (min_version, max_version)) + self._api_versions = { + api_key: (min_version, max_version) for api_key, min_version, max_version in response.api_versions - ]) + } return self._api_versions def get_api_versions(self): @@ -1286,9 +1034,6 @@ def reset_override_configs(): elif (isinstance(f.exception, Errors.CorrelationIdError) and version == (0, 10)): pass - elif six.PY2: - assert isinstance(f.exception.args[0], socket.error) - assert f.exception.args[0].errno in (32, 54, 104) else: assert isinstance(f.exception.args[0], ConnectionError) log.info("Broker is not v%s -- it did not recognize %s", @@ -1306,7 +1051,7 @@ def __str__(self): AFI_NAMES[self._sock_afi], self._sock_addr) -class BrokerConnectionMetrics(object): +class BrokerConnectionMetrics: def __init__(self, metrics, metric_group_prefix, node_id): self.metrics = metrics @@ -1361,7 +1106,7 @@ def __init__(self, metrics, metric_group_prefix, node_id): # if one sensor of the metrics has been registered for the connection, # then all other sensors should have been registered; and vice versa - node_str = 'node-{0}'.format(node_id) + node_str = f'node-{node_id}' node_sensor = metrics.get_sensor(node_str + '.bytes-sent') if not node_sensor: metric_group_name = metric_group_prefix + '-node-metrics.' + node_str @@ -1428,7 +1173,7 @@ def _address_family(address): try: socket.inet_pton(af, address) return af - except (ValueError, AttributeError, socket.error): + except (ValueError, AttributeError, OSError): continue return socket.AF_UNSPEC @@ -1472,7 +1217,7 @@ def get_ip_port_afi(host_and_port_str): log.warning('socket.inet_pton not available on this platform.' ' consider `pip install win_inet_pton`') pass - except (ValueError, socket.error): + except (ValueError, OSError): # it's a host:port pair pass host, port = host_and_port_str.rsplit(':', 1) @@ -1488,7 +1233,7 @@ def collect_hosts(hosts, randomize=True): randomize the returned list. """ - if isinstance(hosts, six.string_types): + if isinstance(hosts, str): hosts = hosts.strip().split(',') result = [] diff --git a/kafka/consumer/__init__.py b/kafka/consumer/__init__.py index e09bcc1b8..5341d5648 100644 --- a/kafka/consumer/__init__.py +++ b/kafka/consumer/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.consumer.group import KafkaConsumer __all__ = [ diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 20a5ecac5..6d67c6e7a 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import collections import copy import logging @@ -7,8 +5,6 @@ import sys import time -from kafka.vendor import six - import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate @@ -45,7 +41,7 @@ class RecordTooLargeError(Errors.KafkaError): pass -class Fetcher(six.Iterator): +class Fetcher: DEFAULT_CONFIG = { 'key_deserializer': None, 'value_deserializer': None, @@ -120,7 +116,7 @@ def send_fetches(self): List of Futures: each future resolves to a FetchResponse """ futures = [] - for node_id, request in six.iteritems(self._create_fetch_requests()): + for node_id, request in self._create_fetch_requests().items(): if self._client.ready(node_id): log.debug("Sending FetchRequest to node %s", node_id) future = self._client.send(node_id, request, wakeup=False) @@ -209,7 +205,7 @@ def end_offsets(self, partitions, timeout_ms): partitions, OffsetResetStrategy.LATEST, timeout_ms) def beginning_or_end_offset(self, partitions, timestamp, timeout_ms): - timestamps = dict([(tp, timestamp) for tp in partitions]) + timestamps = {tp: timestamp for tp in partitions} offsets = self._retrieve_offsets(timestamps, timeout_ms) for tp in timestamps: offsets[tp] = offsets[tp][0] @@ -244,7 +240,7 @@ def _reset_offset(self, partition): if self._subscriptions.is_assigned(partition): self._subscriptions.seek(partition, offset) else: - log.debug("Could not find offset for partition %s since it is probably deleted" % (partition,)) + log.debug(f"Could not find offset for partition {partition} since it is probably deleted") def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): """Fetch offset for each partition passed in ``timestamps`` map. @@ -296,7 +292,7 @@ def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): log.debug("Stale metadata was raised, and we now have an updated metadata. Rechecking partition existence") unknown_partition = future.exception.args[0] # TopicPartition from StaleMetadata if self._client.cluster.leader_for_partition(unknown_partition) is None: - log.debug("Removed partition %s from offsets retrieval" % (unknown_partition, )) + log.debug(f"Removed partition {unknown_partition} from offsets retrieval") timestamps.pop(unknown_partition) else: time.sleep(self.config['retry_backoff_ms'] / 1000.0) @@ -305,7 +301,7 @@ def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): remaining_ms = timeout_ms - elapsed_ms raise Errors.KafkaTimeoutError( - "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) + f"Failed to get offsets by timestamps in {timeout_ms} ms") def fetched_records(self, max_records=None, update_offsets=True): """Returns previously fetched records and updates consumed offsets. @@ -529,7 +525,7 @@ def _send_offset_requests(self, timestamps): Future: resolves to a mapping of retrieved offsets """ timestamps_by_node = collections.defaultdict(dict) - for partition, timestamp in six.iteritems(timestamps): + for partition, timestamp in timestamps.items(): node_id = self._client.cluster.leader_for_partition(partition) if node_id is None: self._client.add_topic(partition.topic) @@ -561,7 +557,7 @@ def on_fail(err): if not list_offsets_future.is_done: list_offsets_future.failure(err) - for node_id, timestamps in six.iteritems(timestamps_by_node): + for node_id, timestamps in timestamps_by_node.items(): _f = self._send_offset_request(node_id, timestamps) _f.add_callback(on_success) _f.add_errback(on_fail) @@ -569,7 +565,7 @@ def on_fail(err): def _send_offset_request(self, node_id, timestamps): by_topic = collections.defaultdict(list) - for tp, timestamp in six.iteritems(timestamps): + for tp, timestamp in timestamps.items(): if self.config['api_version'] >= (0, 10, 1): data = (tp.partition, timestamp) else: @@ -577,9 +573,9 @@ def _send_offset_request(self, node_id, timestamps): by_topic[tp.topic].append(data) if self.config['api_version'] >= (0, 10, 1): - request = OffsetRequest[1](-1, list(six.iteritems(by_topic))) + request = OffsetRequest[1](-1, list(by_topic.items())) else: - request = OffsetRequest[0](-1, list(six.iteritems(by_topic))) + request = OffsetRequest[0](-1, list(by_topic.items())) # Client returns a future that only fails on network issues # so create a separate future and attach a callback to update it @@ -720,7 +716,7 @@ def _create_fetch_requests(self): else: version = 0 requests = {} - for node_id, partition_data in six.iteritems(fetchable): + for node_id, partition_data in fetchable.items(): if version < 3: requests[node_id] = FetchRequest[version]( -1, # replica_id @@ -762,9 +758,9 @@ def _handle_fetch_response(self, request, send_time, response): partition, offset = partition_data[:2] fetch_offsets[TopicPartition(topic, partition)] = offset - partitions = set([TopicPartition(topic, partition_data[0]) + partitions = {TopicPartition(topic, partition_data[0]) for topic, partitions in response.topics - for partition_data in partitions]) + for partition_data in partitions} metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) # randomized ordering should improve balance for short-lived consumers @@ -873,7 +869,7 @@ def _parse_fetched_data(self, completed_fetch): return parsed_records - class PartitionRecords(object): + class PartitionRecords: def __init__(self, fetch_offset, tp, messages): self.fetch_offset = fetch_offset self.topic_partition = tp @@ -917,7 +913,7 @@ def take(self, n=None): return res -class FetchResponseMetricAggregator(object): +class FetchResponseMetricAggregator: """ Since we parse the message data for each partition from each fetch response lazily, fetch-level metrics need to be aggregated as the messages @@ -946,10 +942,10 @@ def record(self, partition, num_bytes, num_records): self.sensors.records_fetched.record(self.total_records) -class FetchManagerMetrics(object): +class FetchManagerMetrics: def __init__(self, metrics, prefix): self.metrics = metrics - self.group_name = '%s-fetch-manager-metrics' % (prefix,) + self.group_name = f'{prefix}-fetch-manager-metrics' self.bytes_fetched = metrics.sensor('bytes-fetched') self.bytes_fetched.add(metrics.metric_name('fetch-size-avg', self.group_name, @@ -993,15 +989,15 @@ def record_topic_fetch_metrics(self, topic, num_bytes, num_records): bytes_fetched = self.metrics.sensor(name) bytes_fetched.add(self.metrics.metric_name('fetch-size-avg', self.group_name, - 'The average number of bytes fetched per request for topic %s' % (topic,), + f'The average number of bytes fetched per request for topic {topic}', metric_tags), Avg()) bytes_fetched.add(self.metrics.metric_name('fetch-size-max', self.group_name, - 'The maximum number of bytes fetched per request for topic %s' % (topic,), + f'The maximum number of bytes fetched per request for topic {topic}', metric_tags), Max()) bytes_fetched.add(self.metrics.metric_name('bytes-consumed-rate', self.group_name, - 'The average number of bytes consumed per second for topic %s' % (topic,), + f'The average number of bytes consumed per second for topic {topic}', metric_tags), Rate()) bytes_fetched.record(num_bytes) @@ -1014,10 +1010,10 @@ def record_topic_fetch_metrics(self, topic, num_bytes, num_records): records_fetched = self.metrics.sensor(name) records_fetched.add(self.metrics.metric_name('records-per-request-avg', self.group_name, - 'The average number of records in each request for topic %s' % (topic,), + f'The average number of records in each request for topic {topic}', metric_tags), Avg()) records_fetched.add(self.metrics.metric_name('records-consumed-rate', self.group_name, - 'The average number of records consumed per second for topic %s' % (topic,), + f'The average number of records consumed per second for topic {topic}', metric_tags), Rate()) records_fetched.record(num_records) diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index a1d1dfa37..fc04c4bd6 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, division - import copy import logging import socket @@ -7,8 +5,6 @@ from kafka.errors import KafkaConfigurationError, UnsupportedVersionError -from kafka.vendor import six - from kafka.client_async import KafkaClient, selectors from kafka.consumer.fetcher import Fetcher from kafka.consumer.subscription_state import SubscriptionState @@ -23,7 +19,7 @@ log = logging.getLogger(__name__) -class KafkaConsumer(six.Iterator): +class KafkaConsumer: """Consume records from a Kafka cluster. The consumer will transparently handle the failure of servers in the Kafka @@ -56,6 +52,12 @@ class KafkaConsumer(six.Iterator): committing offsets. If None, auto-partition assignment (via group coordinator) and offset commits are disabled. Default: None + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. key_deserializer (callable): Any callable that takes a raw message key and returns a deserialized key. value_deserializer (callable): Any callable that takes a @@ -245,6 +247,7 @@ class KafkaConsumer(six.Iterator): sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider instance. (See kafka.oauth.abstract). Default: None kafka_client (callable): Custom class / callable for creating KafkaClient instances + coordinator (callable): Custom class / callable for creating ConsumerCoordinator instances Note: Configuration parameters are described in more detail at @@ -254,6 +257,8 @@ class KafkaConsumer(six.Iterator): 'bootstrap_servers': 'localhost', 'client_id': 'kafka-python-' + __version__, 'group_id': None, + 'group_instance_id': '', + 'leave_group_on_close': None, 'key_deserializer': None, 'value_deserializer': None, 'fetch_max_wait_ms': 500, @@ -308,6 +313,7 @@ class KafkaConsumer(six.Iterator): 'sasl_oauth_token_provider': None, 'legacy_iterator': False, # enable to revert to < 1.4.7 iterator 'kafka_client': KafkaClient, + 'coordinator': ConsumerCoordinator, } DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000 @@ -315,7 +321,7 @@ def __init__(self, *topics, **configs): # Only check for extra config keys in top-level class extra_configs = set(configs).difference(self.DEFAULT_CONFIG) if extra_configs: - raise KafkaConfigurationError("Unrecognized configs: %s" % (extra_configs,)) + raise KafkaConfigurationError(f"Unrecognized configs: {extra_configs}") self.config = copy.copy(self.DEFAULT_CONFIG) self.config.update(configs) @@ -383,7 +389,7 @@ def __init__(self, *topics, **configs): self._subscription = SubscriptionState(self.config['auto_offset_reset']) self._fetcher = Fetcher( self._client, self._subscription, self._metrics, **self.config) - self._coordinator = ConsumerCoordinator( + self._coordinator = self.config['coordinator']( self._client, self._subscription, self._metrics, assignors=self.config['partition_assignment_strategy'], **self.config) @@ -968,7 +974,7 @@ def metrics(self, raw=False): return self._metrics.metrics.copy() metrics = {} - for k, v in six.iteritems(self._metrics.metrics.copy()): + for k, v in self._metrics.metrics.copy().items(): if k.group not in metrics: metrics[k.group] = {} if k.name not in metrics[k.group]: @@ -1013,7 +1019,7 @@ def offsets_for_times(self, timestamps): raise UnsupportedVersionError( "offsets_for_times API not supported for cluster version {}" .format(self.config['api_version'])) - for tp, ts in six.iteritems(timestamps): + for tp, ts in timestamps.items(): timestamps[tp] = int(ts) if ts < 0: raise ValueError( @@ -1118,7 +1124,7 @@ def _update_fetch_positions(self, partitions): def _message_generator_v2(self): timeout_ms = 1000 * (self._consumer_timeout - time.time()) record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False) - for tp, records in six.iteritems(record_map): + for tp, records in record_map.items(): # Generators are stateful, and it is possible that the tp / records # here may become stale during iteration -- i.e., we seek to a # different offset, pause consumption, or lose assignment. diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index 08842d133..bb78cd2a2 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -1,11 +1,7 @@ -from __future__ import absolute_import - import abc import logging import re -from kafka.vendor import six - from kafka.errors import IllegalStateError from kafka.protocol.offset import OffsetResetStrategy from kafka.structs import OffsetAndMetadata @@ -13,7 +9,7 @@ log = logging.getLogger(__name__) -class SubscriptionState(object): +class SubscriptionState: """ A class for tracking the topics, partitions, and offsets for the consumer. A partition is "assigned" either directly with assign_from_user() (manual @@ -130,16 +126,16 @@ def _ensure_valid_topic_name(self, topic): # https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java if topic is None: raise TypeError('All topics must not be None') - if not isinstance(topic, six.string_types): + if not isinstance(topic, str): raise TypeError('All topics must be strings') if len(topic) == 0: raise ValueError('All topics must be non-empty strings') if topic == '.' or topic == '..': raise ValueError('Topic name cannot be "." or ".."') if len(topic) > self._MAX_NAME_LENGTH: - raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(self._MAX_NAME_LENGTH, topic)) + raise ValueError(f'Topic name is illegal, it can\'t be longer than {self._MAX_NAME_LENGTH} characters, topic: "{topic}"') if not self._TOPIC_LEGAL_CHARS.match(topic): - raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic)) + raise ValueError(f'Topic name "{topic}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"') def change_subscription(self, topics): """Change the topic subscription. @@ -157,7 +153,7 @@ def change_subscription(self, topics): if self._user_assignment: raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - if isinstance(topics, six.string_types): + if isinstance(topics, str): topics = [topics] if self.subscription == set(topics): @@ -247,7 +243,7 @@ def assign_from_subscribed(self, assignments): for tp in assignments: if tp.topic not in self.subscription: - raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,)) + raise ValueError(f"Assigned partition {tp} for non-subscribed topic.") # after rebalancing, we always reinitialize the assignment state self.assignment.clear() @@ -299,13 +295,13 @@ def assigned_partitions(self): def paused_partitions(self): """Return current set of paused TopicPartitions.""" - return set(partition for partition in self.assignment - if self.is_paused(partition)) + return {partition for partition in self.assignment + if self.is_paused(partition)} def fetchable_partitions(self): """Return set of TopicPartitions that should be Fetched.""" fetchable = set() - for partition, state in six.iteritems(self.assignment): + for partition, state in self.assignment.items(): if state.is_fetchable(): fetchable.add(partition) return fetchable @@ -317,7 +313,7 @@ def partitions_auto_assigned(self): def all_consumed_offsets(self): """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}""" all_consumed = {} - for partition, state in six.iteritems(self.assignment): + for partition, state in self.assignment.items(): if state.has_valid_position: all_consumed[partition] = OffsetAndMetadata(state.position, '') return all_consumed @@ -348,7 +344,7 @@ def has_all_fetch_positions(self): def missing_fetch_positions(self): missing = set() - for partition, state in six.iteritems(self.assignment): + for partition, state in self.assignment.items(): if not state.has_valid_position: missing.add(partition) return missing @@ -372,7 +368,7 @@ def _add_assigned_partition(self, partition): self.assignment[partition] = TopicPartitionState() -class TopicPartitionState(object): +class TopicPartitionState: def __init__(self): self.committed = None # last committed OffsetAndMetadata self.has_valid_position = False # whether we have valid position @@ -420,7 +416,7 @@ def is_fetchable(self): return not self.paused and self.has_valid_position -class ConsumerRebalanceListener(object): +class ConsumerRebalanceListener: """ A callback interface that the user can implement to trigger custom actions when the set of partitions assigned to the consumer changes. diff --git a/kafka/coordinator/assignors/range.py b/kafka/coordinator/assignors/range.py index 299e39c48..ae64f55df 100644 --- a/kafka/coordinator/assignors/range.py +++ b/kafka/coordinator/assignors/range.py @@ -1,10 +1,6 @@ -from __future__ import absolute_import - import collections import logging -from kafka.vendor import six - from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment @@ -34,14 +30,14 @@ class RangePartitionAssignor(AbstractPartitionAssignor): @classmethod def assign(cls, cluster, member_metadata): consumers_per_topic = collections.defaultdict(list) - for member, metadata in six.iteritems(member_metadata): + for member, metadata in member_metadata.items(): for topic in metadata.subscription: consumers_per_topic[topic].append(member) # construct {member_id: {topic: [partition, ...]}} assignment = collections.defaultdict(dict) - for topic, consumers_for_topic in six.iteritems(consumers_per_topic): + for topic, consumers_for_topic in consumers_per_topic.items(): partitions = cluster.partitions_for_topic(topic) if partitions is None: log.warning('No partition metadata for topic %s', topic) diff --git a/kafka/coordinator/assignors/roundrobin.py b/kafka/coordinator/assignors/roundrobin.py index 2d24a5c8b..d3292dd36 100644 --- a/kafka/coordinator/assignors/roundrobin.py +++ b/kafka/coordinator/assignors/roundrobin.py @@ -1,11 +1,7 @@ -from __future__ import absolute_import - import collections import itertools import logging -from kafka.vendor import six - from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment from kafka.structs import TopicPartition @@ -51,7 +47,7 @@ class RoundRobinPartitionAssignor(AbstractPartitionAssignor): @classmethod def assign(cls, cluster, member_metadata): all_topics = set() - for metadata in six.itervalues(member_metadata): + for metadata in member_metadata.values(): all_topics.update(metadata.subscription) all_topic_partitions = [] diff --git a/kafka/coordinator/assignors/sticky/partition_movements.py b/kafka/coordinator/assignors/sticky/partition_movements.py index 8851e4cda..78f2eb22c 100644 --- a/kafka/coordinator/assignors/sticky/partition_movements.py +++ b/kafka/coordinator/assignors/sticky/partition_movements.py @@ -2,8 +2,6 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.vendor import six - log = logging.getLogger(__name__) @@ -74,7 +72,7 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer): return next(iter(self.partition_movements_by_topic[partition.topic][reverse_pair])) def are_sticky(self): - for topic, movements in six.iteritems(self.partition_movements_by_topic): + for topic, movements in self.partition_movements_by_topic.items(): movement_pairs = set(movements.keys()) if self._has_cycles(movement_pairs): log.error( diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/kafka/coordinator/assignors/sticky/sticky_assignor.py index dce714f1a..033642425 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -11,7 +11,6 @@ from kafka.protocol.struct import Struct from kafka.protocol.types import String, Array, Int32 from kafka.structs import TopicPartition -from kafka.vendor import six log = logging.getLogger(__name__) @@ -110,7 +109,7 @@ def balance(self): # narrow down the reassignment scope to only those partitions that can actually be reassigned fixed_partitions = set() - for partition in six.iterkeys(self.partition_to_all_potential_consumers): + for partition in self.partition_to_all_potential_consumers.keys(): if not self._can_partition_participate_in_reassignment(partition): fixed_partitions.add(partition) for fixed_partition in fixed_partitions: @@ -119,7 +118,7 @@ def balance(self): # narrow down the reassignment scope to only those consumers that are subject to reassignment fixed_assignments = {} - for consumer in six.iterkeys(self.consumer_to_all_potential_partitions): + for consumer in self.consumer_to_all_potential_partitions.keys(): if not self._can_consumer_participate_in_reassignment(consumer): self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer) fixed_assignments[consumer] = self.current_assignment[consumer] @@ -148,7 +147,7 @@ def balance(self): self.current_partition_consumer.update(prebalance_partition_consumers) # add the fixed assignments (those that could not change) back - for consumer, partitions in six.iteritems(fixed_assignments): + for consumer, partitions in fixed_assignments.items(): self.current_assignment[consumer] = partitions self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) @@ -156,8 +155,8 @@ def get_final_assignment(self, member_id): assignment = defaultdict(list) for topic_partition in self.current_assignment[member_id]: assignment[topic_partition.topic].append(topic_partition.partition) - assignment = {k: sorted(v) for k, v in six.iteritems(assignment)} - return six.viewitems(assignment) + assignment = {k: sorted(v) for k, v in assignment.items()} + return assignment.items() def _initialize(self, cluster): self._init_current_assignments(self.members) @@ -170,7 +169,7 @@ def _initialize(self, cluster): for p in partitions: partition = TopicPartition(topic=topic, partition=p) self.partition_to_all_potential_consumers[partition] = [] - for consumer_id, member_metadata in six.iteritems(self.members): + for consumer_id, member_metadata in self.members.items(): self.consumer_to_all_potential_partitions[consumer_id] = [] for topic in member_metadata.subscription: if cluster.partitions_for_topic(topic) is None: @@ -190,7 +189,7 @@ def _init_current_assignments(self, members): # for each partition we create a map of its consumers by generation sorted_partition_consumers_by_generation = {} - for consumer, member_metadata in six.iteritems(members): + for consumer, member_metadata in members.items(): for partitions in member_metadata.partitions: if partitions in sorted_partition_consumers_by_generation: consumers = sorted_partition_consumers_by_generation[partitions] @@ -209,7 +208,7 @@ def _init_current_assignments(self, members): # previous_assignment holds the prior ConsumerGenerationPair (before current) of each partition # current and previous consumers are the last two consumers of each partition in the above sorted map - for partitions, consumers in six.iteritems(sorted_partition_consumers_by_generation): + for partitions, consumers in sorted_partition_consumers_by_generation.items(): generations = sorted(consumers.keys(), reverse=True) self.current_assignment[consumers[generations[0]]].append(partitions) # now update previous assignment if any @@ -220,7 +219,7 @@ def _init_current_assignments(self, members): self.is_fresh_assignment = len(self.current_assignment) == 0 - for consumer_id, partitions in six.iteritems(self.current_assignment): + for consumer_id, partitions in self.current_assignment.items(): for partition in partitions: self.current_partition_consumer[partition] = consumer_id @@ -230,14 +229,14 @@ def _are_subscriptions_identical(self): true, if both potential consumers of partitions and potential partitions that consumers can consume are the same """ - if not has_identical_list_elements(list(six.itervalues(self.partition_to_all_potential_consumers))): + if not has_identical_list_elements(list(self.partition_to_all_potential_consumers.values())): return False - return has_identical_list_elements(list(six.itervalues(self.consumer_to_all_potential_partitions))) + return has_identical_list_elements(list(self.consumer_to_all_potential_partitions.values())) def _populate_sorted_partitions(self): # set of topic partitions with their respective potential consumers - all_partitions = set((tp, tuple(consumers)) - for tp, consumers in six.iteritems(self.partition_to_all_potential_consumers)) + all_partitions = {(tp, tuple(consumers)) + for tp, consumers in self.partition_to_all_potential_consumers.items()} partitions_sorted_by_num_of_potential_consumers = sorted(all_partitions, key=partitions_comparator_key) self.sorted_partitions = [] @@ -246,7 +245,7 @@ def _populate_sorted_partitions(self): # then we just need to simply list partitions in a round robin fashion (from consumers with # most assigned partitions to those with least) assignments = deepcopy(self.current_assignment) - for consumer_id, partitions in six.iteritems(assignments): + for consumer_id, partitions in assignments.items(): to_remove = [] for partition in partitions: if partition not in self.partition_to_all_potential_consumers: @@ -255,7 +254,7 @@ def _populate_sorted_partitions(self): partitions.remove(partition) sorted_consumers = SortedSet( - iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(assignments)], + iterable=[(consumer, tuple(partitions)) for consumer, partitions in assignments.items()], key=subscriptions_comparator_key, ) # at this point, sorted_consumers contains an ascending-sorted list of consumers based on @@ -267,7 +266,7 @@ def _populate_sorted_partitions(self): remaining_partitions = assignments[consumer] # from partitions that had a different consumer before, # keep only those that are assigned to this consumer now - previous_partitions = set(six.iterkeys(self.previous_assignment)).intersection(set(remaining_partitions)) + previous_partitions = set(self.previous_assignment.keys()).intersection(set(remaining_partitions)) if previous_partitions: # if there is a partition of this consumer that was assigned to another consumer before # mark it as good options for reassignment @@ -292,7 +291,7 @@ def _populate_partitions_to_reassign(self): self.unassigned_partitions = deepcopy(self.sorted_partitions) assignments_to_remove = [] - for consumer_id, partitions in six.iteritems(self.current_assignment): + for consumer_id, partitions in self.current_assignment.items(): if consumer_id not in self.members: # if a consumer that existed before (and had some partition assignments) is now removed, # remove it from current_assignment @@ -325,7 +324,7 @@ def _populate_partitions_to_reassign(self): def _initialize_current_subscriptions(self): self.sorted_current_subscriptions = SortedSet( - iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(self.current_assignment)], + iterable=[(consumer, tuple(partitions)) for consumer, partitions in self.current_assignment.items()], key=subscriptions_comparator_key, ) @@ -352,7 +351,7 @@ def _is_balanced(self): # create a mapping from partitions to the consumer assigned to them all_assigned_partitions = {} - for consumer_id, consumer_partitions in six.iteritems(self.current_assignment): + for consumer_id, consumer_partitions in self.current_assignment.items(): for partition in consumer_partitions: if partition in all_assigned_partitions: log.error("{} is assigned to more than one consumer.".format(partition)) @@ -491,7 +490,7 @@ def _get_balance_score(assignment): """ score = 0 consumer_to_assignment = {} - for consumer_id, partitions in six.iteritems(assignment): + for consumer_id, partitions in assignment.items(): consumer_to_assignment[consumer_id] = len(partitions) consumers_to_explore = set(consumer_to_assignment.keys()) @@ -593,7 +592,7 @@ def assign(cls, cluster, members): dict: {member_id: MemberAssignment} """ members_metadata = {} - for consumer, member_metadata in six.iteritems(members): + for consumer, member_metadata in members.items(): members_metadata[consumer] = cls.parse_member_metadata(member_metadata) executor = StickyAssignmentExecutor(cluster, members_metadata) @@ -660,7 +659,7 @@ def _metadata(cls, topics, member_assignment_partitions, generation=-1): partitions_by_topic = defaultdict(list) for topic_partition in member_assignment_partitions: partitions_by_topic[topic_partition.topic].append(topic_partition.partition) - data = StickyAssignorUserDataV1(six.viewitems(partitions_by_topic), generation) + data = StickyAssignorUserDataV1(partitions_by_topic.items(), generation) user_data = data.encode() return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data) diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index e71984108..d5ec4c720 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, division - import abc import copy import logging @@ -7,8 +5,6 @@ import time import weakref -from kafka.vendor import six - from kafka.coordinator.heartbeat import Heartbeat from kafka import errors as Errors from kafka.future import Future @@ -21,13 +17,13 @@ log = logging.getLogger('kafka.coordinator') -class MemberState(object): +class MemberState: UNJOINED = '' # the client is not part of a group REBALANCING = '' # the client has begun rebalancing STABLE = '' # the client has joined and is sending heartbeats -class Generation(object): +class Generation: def __init__(self, generation_id, member_id, protocol): self.generation_id = generation_id self.member_id = member_id @@ -43,7 +39,7 @@ class UnjoinedGroupException(Errors.KafkaError): retriable = True -class BaseCoordinator(object): +class BaseCoordinator: """ BaseCoordinator implements group management for a single group member by interacting with a designated Kafka broker (the coordinator). Group @@ -82,6 +78,8 @@ class BaseCoordinator(object): DEFAULT_CONFIG = { 'group_id': 'kafka-python-default-group', + 'group_instance_id': '', + 'leave_group_on_close': None, 'session_timeout_ms': 10000, 'heartbeat_interval_ms': 3000, 'max_poll_interval_ms': 300000, @@ -96,6 +94,12 @@ def __init__(self, client, metrics, **configs): group_id (str): name of the consumer group to join for dynamic partition assignment (if enabled), and to use for fetching and committing offsets. Default: 'kafka-python-default-group' + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. session_timeout_ms (int): The timeout used to detect failures when using Kafka's group management facilities. Default: 30000 heartbeat_interval_ms (int): The expected time in milliseconds @@ -121,6 +125,11 @@ def __init__(self, client, metrics, **configs): "different values for max_poll_interval_ms " "and session_timeout_ms") + if self.config['group_instance_id'] and self.config['api_version'] < (2, 3, 0): + raise Errors.KafkaConfigurationError( + 'Broker version %s does not support static membership' % (self.config['api_version'],), + ) + self._client = client self.group_id = self.config['group_id'] self.heartbeat = Heartbeat(**self.config) @@ -455,30 +464,48 @@ def _send_join_group_request(self): if self.config['api_version'] < (0, 9): raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers') elif (0, 9) <= self.config['api_version'] < (0, 10, 1): - request = JoinGroupRequest[0]( + version = 0 + args = ( self.group_id, self.config['session_timeout_ms'], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0): - request = JoinGroupRequest[1]( + version = 1 + args = ( + self.group_id, + self.config['session_timeout_ms'], + self.config['max_poll_interval_ms'], + self._generation.member_id, + self.protocol_type(), + member_metadata, + ) + elif self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 5 + args = ( self.group_id, self.config['session_timeout_ms'], self.config['max_poll_interval_ms'], self._generation.member_id, + self.config['group_instance_id'], self.protocol_type(), - member_metadata) + member_metadata, + ) else: - request = JoinGroupRequest[2]( + version = 2 + args = ( self.group_id, self.config['session_timeout_ms'], self.config['max_poll_interval_ms'], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) # create the request for the coordinator + request = JoinGroupRequest[version](*args) log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id) future = Future() _f = self._client.send(self.coordinator_id, request) @@ -562,12 +589,25 @@ def _handle_join_group_response(self, future, send_time, response): def _on_join_follower(self): # send follower's sync group with an empty assignment - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = SyncGroupRequest[version]( - self.group_id, - self._generation.generation_id, - self._generation.member_id, - {}) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + {}, + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + {}, + ) + + request = SyncGroupRequest[version](*args) log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", self.group_id, self.coordinator_id, request) return self._send_sync_group_request(request) @@ -590,15 +630,30 @@ def _on_join_leader(self, response): except Exception as e: return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = SyncGroupRequest[version]( - self.group_id, - self._generation.generation_id, - self._generation.member_id, - [(member_id, - assignment if isinstance(assignment, bytes) else assignment.encode()) - for member_id, assignment in six.iteritems(group_assignment)]) + group_assignment = [ + (member_id, assignment if isinstance(assignment, bytes) else assignment.encode()) + for member_id, assignment in group_assignment.items() + ] + + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + group_assignment, + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + group_assignment, + ) + request = SyncGroupRequest[version](*args) log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", self.group_id, self.coordinator_id, request) return self._send_sync_group_request(request) @@ -764,15 +819,22 @@ def close(self): def maybe_leave_group(self): """Leave the current group and reset local generation/memberId.""" with self._client._lock, self._lock: - if (not self.coordinator_unknown() + if ( + not self.coordinator_unknown() and self.state is not MemberState.UNJOINED - and self._generation is not Generation.NO_GENERATION): - + and self._generation is not Generation.NO_GENERATION + and self._leave_group_on_close() + ): # this is a minimal effort attempt to leave the group. we do not # attempt any resending if the request fails or times out. log.info('Leaving consumer group (%s).', self.group_id) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = LeaveGroupRequest[version](self.group_id, self._generation.member_id) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 3 + args = (self.group_id, [(self._generation.member_id, self.config['group_instance_id'])]) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = self.group_id, self._generation.member_id + request = LeaveGroupRequest[version](*args) future = self._client.send(self.coordinator_id, request) future.add_callback(self._handle_leave_group_response) future.add_errback(log.error, "LeaveGroup request failed: %s") @@ -799,10 +861,23 @@ def _send_heartbeat_request(self): e = Errors.NodeNotReadyError(self.coordinator_id) return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = HeartbeatRequest[version](self.group_id, - self._generation.generation_id, - self._generation.member_id) + if self.config['api_version'] >= (2, 3, 0) and self.config['group_instance_id']: + version = 2 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + self.config['group_instance_id'], + ) + else: + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + args = ( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + ) + + request = HeartbeatRequest[version](*args) log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member future = Future() _f = self._client.send(self.coordinator_id, request) @@ -849,8 +924,11 @@ def _handle_heartbeat_response(self, future, send_time, response): log.error("Heartbeat failed: Unhandled error: %s", error) future.failure(error) + def _leave_group_on_close(self): + return self.config['leave_group_on_close'] is None or self.config['leave_group_on_close'] + -class GroupCoordinatorMetrics(object): +class GroupCoordinatorMetrics: def __init__(self, heartbeat, metrics, prefix, tags=None): self.heartbeat = heartbeat self.metrics = metrics @@ -903,7 +981,7 @@ def __init__(self, heartbeat, metrics, prefix, tags=None): class HeartbeatThread(threading.Thread): def __init__(self, coordinator): - super(HeartbeatThread, self).__init__() + super().__init__() self.name = coordinator.group_id + '-heartbeat' self.coordinator = coordinator self.enabled = False diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index 971f5e802..cf82b69fe 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -1,13 +1,9 @@ -from __future__ import absolute_import, division - import collections import copy import functools import logging import time -from kafka.vendor import six - from kafka.coordinator.base import BaseCoordinator, Generation from kafka.coordinator.assignors.range import RangePartitionAssignor from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor @@ -29,6 +25,8 @@ class ConsumerCoordinator(BaseCoordinator): """This class manages the coordination process with the consumer coordinator.""" DEFAULT_CONFIG = { 'group_id': 'kafka-python-default-group', + 'group_instance_id': '', + 'leave_group_on_close': None, 'enable_auto_commit': True, 'auto_commit_interval_ms': 5000, 'default_offset_commit_callback': None, @@ -49,6 +47,12 @@ def __init__(self, client, subscription, metrics, **configs): group_id (str): name of the consumer group to join for dynamic partition assignment (if enabled), and to use for fetching and committing offsets. Default: 'kafka-python-default-group' + group_instance_id (str): the unique identifier to distinguish + each client instance. If set and leave_group_on_close is + False consumer group rebalancing won't be triggered until + sessiont_timeout_ms is met. Requires 2.3.0+. + leave_group_on_close (bool or None): whether to leave a consumer + group or not on consumer shutdown. enable_auto_commit (bool): If true the consumer's offset will be periodically committed in the background. Default: True. auto_commit_interval_ms (int): milliseconds between automatic @@ -78,7 +82,7 @@ def __init__(self, client, subscription, metrics, **configs): True the only way to receive records from an internal topic is subscribing to it. Requires 0.10+. Default: True """ - super(ConsumerCoordinator, self).__init__(client, metrics, **configs) + super().__init__(client, metrics, **configs) self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -129,7 +133,7 @@ def __init__(self, client, subscription, metrics, **configs): def __del__(self): if hasattr(self, '_cluster') and self._cluster: self._cluster.remove_listener(WeakMethod(self._handle_metadata_update)) - super(ConsumerCoordinator, self).__del__() + super().__del__() def protocol_type(self): return ConsumerProtocol.PROTOCOL_TYPE @@ -218,7 +222,7 @@ def _on_join_complete(self, generation, member_id, protocol, self._assignment_snapshot = None assignor = self._lookup_assignor(protocol) - assert assignor, 'Coordinator selected invalid assignment protocol: %s' % (protocol,) + assert assignor, f'Coordinator selected invalid assignment protocol: {protocol}' assignment = ConsumerProtocol.ASSIGNMENT.decode(member_assignment_bytes) @@ -305,13 +309,18 @@ def time_to_next_poll(self): def _perform_assignment(self, leader_id, assignment_strategy, members): assignor = self._lookup_assignor(assignment_strategy) - assert assignor, 'Invalid assignment protocol: %s' % (assignment_strategy,) + assert assignor, f'Invalid assignment protocol: {assignment_strategy}' member_metadata = {} all_subscribed_topics = set() - for member_id, metadata_bytes in members: + + for member in members: + if len(member) == 3: + member_id, group_instance_id, metadata_bytes = member + else: + member_id, metadata_bytes = member metadata = ConsumerProtocol.METADATA.decode(metadata_bytes) member_metadata[member_id] = metadata - all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member + all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member # the leader will begin watching for changes to any of the topics # the group is interested in, which ensures that all metadata changes @@ -336,7 +345,7 @@ def _perform_assignment(self, leader_id, assignment_strategy, members): log.debug("Finished assignment for group %s: %s", self.group_id, assignments) group_assignment = {} - for member_id, assignment in six.iteritems(assignments): + for member_id, assignment in assignments.items(): group_assignment[member_id] = assignment return group_assignment @@ -381,13 +390,13 @@ def need_rejoin(self): and self._joined_subscription != self._subscription.subscription): return True - return super(ConsumerCoordinator, self).need_rejoin() + return super().need_rejoin() def refresh_committed_offsets_if_needed(self): """Fetch committed offsets for assigned partitions.""" if self._subscription.needs_fetch_committed_offsets: offsets = self.fetch_committed_offsets(self._subscription.assigned_partitions()) - for partition, offset in six.iteritems(offsets): + for partition, offset in offsets.items(): # verify assignment is still active if self._subscription.is_assigned(partition): self._subscription.assignment[partition].committed = offset @@ -433,7 +442,7 @@ def close(self, autocommit=True): if autocommit: self._maybe_auto_commit_offsets_sync() finally: - super(ConsumerCoordinator, self).close() + super().close() def _invoke_completed_offset_commit_callbacks(self): while self.completed_offset_commits: @@ -568,7 +577,7 @@ def _send_offset_commit_request(self, offsets): # create the offset commit request offset_data = collections.defaultdict(dict) - for tp, offset in six.iteritems(offsets): + for tp, offset in offsets.items(): offset_data[tp.topic][tp.partition] = offset if self._subscription.partitions_auto_assigned(): @@ -593,8 +602,8 @@ def _send_offset_commit_request(self, offsets): partition, offset.offset, offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + ) for partition, offset in partitions.items()] + ) for topic, partitions in offset_data.items()] ) elif self.config['api_version'] >= (0, 8, 2): request = OffsetCommitRequest[1]( @@ -605,8 +614,8 @@ def _send_offset_commit_request(self, offsets): offset.offset, -1, offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + ) for partition, offset in partitions.items()] + ) for topic, partitions in offset_data.items()] ) elif self.config['api_version'] >= (0, 8, 1): request = OffsetCommitRequest[0]( @@ -616,8 +625,8 @@ def _send_offset_commit_request(self, offsets): partition, offset.offset, offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + ) for partition, offset in partitions.items()] + ) for topic, partitions in offset_data.items()] ) log.debug("Sending offset-commit request with %s for group %s to %s", @@ -809,10 +818,10 @@ def _maybe_auto_commit_offsets_async(self): self._commit_offsets_async_on_complete) -class ConsumerCoordinatorMetrics(object): +class ConsumerCoordinatorMetrics: def __init__(self, metrics, metric_group_prefix, subscription): self.metrics = metrics - self.metric_group_name = '%s-coordinator-metrics' % (metric_group_prefix,) + self.metric_group_name = f'{metric_group_prefix}-coordinator-metrics' self.commit_latency = metrics.sensor('commit-latency') self.commit_latency.add(metrics.metric_name( diff --git a/kafka/coordinator/heartbeat.py b/kafka/coordinator/heartbeat.py index 2f5930b63..b12159cdd 100644 --- a/kafka/coordinator/heartbeat.py +++ b/kafka/coordinator/heartbeat.py @@ -1,10 +1,8 @@ -from __future__ import absolute_import, division - import copy import time -class Heartbeat(object): +class Heartbeat: DEFAULT_CONFIG = { 'group_id': None, 'heartbeat_interval_ms': 3000, diff --git a/kafka/coordinator/protocol.py b/kafka/coordinator/protocol.py index 56a390159..97efc1c84 100644 --- a/kafka/coordinator/protocol.py +++ b/kafka/coordinator/protocol.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.struct import Struct from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String from kafka.structs import TopicPartition @@ -26,7 +24,7 @@ def partitions(self): for partition in partitions] -class ConsumerProtocol(object): +class ConsumerProtocol: PROTOCOL_TYPE = 'consumer' ASSIGNMENT_STRATEGIES = ('range', 'roundrobin') METADATA = ConsumerProtocolMemberMetadata diff --git a/kafka/errors.py b/kafka/errors.py index b33cf51e2..cb3ff285f 100644 --- a/kafka/errors.py +++ b/kafka/errors.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import inspect import sys @@ -12,8 +10,8 @@ class KafkaError(RuntimeError): def __str__(self): if not self.args: return self.__class__.__name__ - return '{0}: {1}'.format(self.__class__.__name__, - super(KafkaError, self).__str__()) + return '{}: {}'.format(self.__class__.__name__, + super().__str__()) class IllegalStateError(KafkaError): @@ -68,7 +66,7 @@ class IncompatibleBrokerVersion(KafkaError): class CommitFailedError(KafkaError): def __init__(self, *args, **kwargs): - super(CommitFailedError, self).__init__( + super().__init__( """Commit cannot be completed since the group has already rebalanced and assigned the partitions to another member. This means that the time between subsequent calls to poll() @@ -96,9 +94,9 @@ class BrokerResponseError(KafkaError): def __str__(self): """Add errno to standard KafkaError str""" - return '[Error {0}] {1}'.format( + return '[Error {}] {}'.format( self.errno, - super(BrokerResponseError, self).__str__()) + super().__str__()) class NoError(BrokerResponseError): @@ -471,7 +469,7 @@ class KafkaTimeoutError(KafkaError): class FailedPayloadsError(KafkaError): def __init__(self, payload, *args): - super(FailedPayloadsError, self).__init__(*args) + super().__init__(*args) self.payload = payload @@ -498,7 +496,7 @@ class QuotaViolationError(KafkaError): class AsyncProducerQueueFull(KafkaError): def __init__(self, failed_msgs, *args): - super(AsyncProducerQueueFull, self).__init__(*args) + super().__init__(*args) self.failed_msgs = failed_msgs @@ -508,7 +506,7 @@ def _iter_broker_errors(): yield obj -kafka_errors = dict([(x.errno, x) for x in _iter_broker_errors()]) +kafka_errors = {x.errno: x for x in _iter_broker_errors()} def for_code(error_code): diff --git a/kafka/future.py b/kafka/future.py index d0f3c6658..e9f534611 100644 --- a/kafka/future.py +++ b/kafka/future.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import - import functools import logging log = logging.getLogger(__name__) -class Future(object): +class Future: error_on_callbacks = False # and errbacks def __init__(self): diff --git a/kafka/metrics/__init__.py b/kafka/metrics/__init__.py index 2a62d6334..22427e967 100644 --- a/kafka/metrics/__init__.py +++ b/kafka/metrics/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.metrics.compound_stat import NamedMeasurable from kafka.metrics.dict_reporter import DictReporter from kafka.metrics.kafka_metric import KafkaMetric diff --git a/kafka/metrics/compound_stat.py b/kafka/metrics/compound_stat.py index ac92480dc..260714256 100644 --- a/kafka/metrics/compound_stat.py +++ b/kafka/metrics/compound_stat.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc from kafka.metrics.stat import AbstractStat @@ -20,7 +18,7 @@ def stats(self): raise NotImplementedError -class NamedMeasurable(object): +class NamedMeasurable: def __init__(self, metric_name, measurable_stat): self._name = metric_name self._stat = measurable_stat diff --git a/kafka/metrics/dict_reporter.py b/kafka/metrics/dict_reporter.py index 0b98fe1e4..bc8088ca9 100644 --- a/kafka/metrics/dict_reporter.py +++ b/kafka/metrics/dict_reporter.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import logging import threading @@ -29,10 +27,10 @@ def snapshot(self): } } """ - return dict((category, dict((name, metric.value()) - for name, metric in list(metrics.items()))) + return {category: {name: metric.value() + for name, metric in list(metrics.items())} for category, metrics in - list(self._store.items())) + list(self._store.items())} def init(self, metrics): for metric in metrics: @@ -71,7 +69,7 @@ def get_category(self, metric): prefix = None, group = 'bar', tags = None returns: 'bar' """ - tags = ','.join('%s=%s' % (k, v) for k, v in + tags = ','.join(f'{k}={v}' for k, v in sorted(metric.metric_name.tags.items())) return '.'.join(x for x in [self._prefix, metric.metric_name.group, tags] if x) diff --git a/kafka/metrics/kafka_metric.py b/kafka/metrics/kafka_metric.py index 9fb8d89f1..40d74952a 100644 --- a/kafka/metrics/kafka_metric.py +++ b/kafka/metrics/kafka_metric.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import time -class KafkaMetric(object): +class KafkaMetric: # NOTE java constructor takes a lock instance def __init__(self, metric_name, measurable, config): if not metric_name: diff --git a/kafka/metrics/measurable.py b/kafka/metrics/measurable.py index b06d4d789..fd5be1205 100644 --- a/kafka/metrics/measurable.py +++ b/kafka/metrics/measurable.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -class AbstractMeasurable(object): +class AbstractMeasurable: """A measurable quantity that can be registered as a metric""" @abc.abstractmethod def measure(self, config, now): diff --git a/kafka/metrics/measurable_stat.py b/kafka/metrics/measurable_stat.py index 4487adf6e..dba887d2b 100644 --- a/kafka/metrics/measurable_stat.py +++ b/kafka/metrics/measurable_stat.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc from kafka.metrics.measurable import AbstractMeasurable diff --git a/kafka/metrics/metric_config.py b/kafka/metrics/metric_config.py index 2e55abfcb..39a5b4168 100644 --- a/kafka/metrics/metric_config.py +++ b/kafka/metrics/metric_config.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import sys -class MetricConfig(object): +class MetricConfig: """Configuration values for metrics""" def __init__(self, quota=None, samples=2, event_window=sys.maxsize, time_window_ms=30 * 1000, tags=None): diff --git a/kafka/metrics/metric_name.py b/kafka/metrics/metric_name.py index b5acd1662..5739c64f0 100644 --- a/kafka/metrics/metric_name.py +++ b/kafka/metrics/metric_name.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import copy -class MetricName(object): +class MetricName: """ This class encapsulates a metric's name, logical group and its related attributes (tags). @@ -102,5 +100,5 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return 'MetricName(name=%s, group=%s, description=%s, tags=%s)' % ( + return 'MetricName(name={}, group={}, description={}, tags={})'.format( self.name, self.group, self.description, self.tags) diff --git a/kafka/metrics/metrics.py b/kafka/metrics/metrics.py index 2c53488ff..67d609d08 100644 --- a/kafka/metrics/metrics.py +++ b/kafka/metrics/metrics.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import logging import sys import time @@ -11,7 +9,7 @@ logger = logging.getLogger(__name__) -class Metrics(object): +class Metrics: """ A registry of sensors and metrics. @@ -230,7 +228,7 @@ def register_metric(self, metric): for reporter in self._reporters: reporter.metric_change(metric) - class ExpireSensorTask(object): + class ExpireSensorTask: """ This iterates over every Sensor and triggers a remove_sensor if it has expired. Package private for testing diff --git a/kafka/metrics/metrics_reporter.py b/kafka/metrics/metrics_reporter.py index d8bd12b3b..3f0fe189b 100644 --- a/kafka/metrics/metrics_reporter.py +++ b/kafka/metrics/metrics_reporter.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -class AbstractMetricsReporter(object): +class AbstractMetricsReporter: """ An abstract class to allow things to listen as new metrics are created so they can be reported. diff --git a/kafka/metrics/quota.py b/kafka/metrics/quota.py index 4d1b0d6cb..5ec5d13d1 100644 --- a/kafka/metrics/quota.py +++ b/kafka/metrics/quota.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import - - -class Quota(object): +class Quota: """An upper or lower bound for metrics""" def __init__(self, bound, is_upper): self._bound = bound diff --git a/kafka/metrics/stat.py b/kafka/metrics/stat.py index 9fd2f01ec..daf01935d 100644 --- a/kafka/metrics/stat.py +++ b/kafka/metrics/stat.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -class AbstractStat(object): +class AbstractStat: """ An AbstractStat is a quantity such as average, max, etc that is computed off the stream of updates to a sensor diff --git a/kafka/oauth/__init__.py b/kafka/oauth/__init__.py index 8c8349564..f4d780892 100644 --- a/kafka/oauth/__init__.py +++ b/kafka/oauth/__init__.py @@ -1,3 +1 @@ -from __future__ import absolute_import - from kafka.oauth.abstract import AbstractTokenProvider diff --git a/kafka/oauth/abstract.py b/kafka/oauth/abstract.py index 8d89ff51d..fa5f4eb99 100644 --- a/kafka/oauth/abstract.py +++ b/kafka/oauth/abstract.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc # This statement is compatible with both Python 2.7 & 3+ diff --git a/kafka/partitioner/__init__.py b/kafka/partitioner/__init__.py index 21a3bbb66..eed1dca69 100644 --- a/kafka/partitioner/__init__.py +++ b/kafka/partitioner/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.partitioner.default import DefaultPartitioner, murmur2 diff --git a/kafka/partitioner/default.py b/kafka/partitioner/default.py index d0914c682..a33b850cc 100644 --- a/kafka/partitioner/default.py +++ b/kafka/partitioner/default.py @@ -1,11 +1,7 @@ -from __future__ import absolute_import - import random -from kafka.vendor import six - -class DefaultPartitioner(object): +class DefaultPartitioner: """Default partitioner. Hashes key to partition using murmur2 hashing (from java client) @@ -43,11 +39,6 @@ def murmur2(data): Returns: MurmurHash2 of data """ - # Python2 bytes is really a str, causing the bitwise operations below to fail - # so convert to bytearray. - if six.PY2: - data = bytearray(bytes(data)) - length = len(data) seed = 0x9747b28c # 'm' and 'r' are mixing constants generated offline. diff --git a/kafka/producer/__init__.py b/kafka/producer/__init__.py index 576c772a0..869dbb3dc 100644 --- a/kafka/producer/__init__.py +++ b/kafka/producer/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.producer.kafka import KafkaProducer __all__ = [ diff --git a/kafka/producer/buffer.py b/kafka/producer/buffer.py index 100801700..0e5b57b93 100644 --- a/kafka/producer/buffer.py +++ b/kafka/producer/buffer.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, division - import collections import io import threading @@ -10,7 +8,7 @@ import kafka.errors as Errors -class SimpleBufferPool(object): +class SimpleBufferPool: """A simple pool of BytesIO objects with a weak memory ceiling.""" def __init__(self, memory, poolable_size, metrics=None, metric_group_prefix='producer-metrics'): """Create a new buffer pool. diff --git a/kafka/producer/future.py b/kafka/producer/future.py index 07fa4adb4..72a4d3985 100644 --- a/kafka/producer/future.py +++ b/kafka/producer/future.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import collections import threading @@ -9,17 +7,17 @@ class FutureProduceResult(Future): def __init__(self, topic_partition): - super(FutureProduceResult, self).__init__() + super().__init__() self.topic_partition = topic_partition self._latch = threading.Event() def success(self, value): - ret = super(FutureProduceResult, self).success(value) + ret = super().success(value) self._latch.set() return ret def failure(self, error): - ret = super(FutureProduceResult, self).failure(error) + ret = super().failure(error) self._latch.set() return ret @@ -30,7 +28,7 @@ def wait(self, timeout=None): class FutureRecordMetadata(Future): def __init__(self, produce_future, relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size): - super(FutureRecordMetadata, self).__init__() + super().__init__() self._produce_future = produce_future # packing args as a tuple is a minor speed optimization self.args = (relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size) @@ -59,7 +57,7 @@ def _produce_success(self, offset_and_timestamp): def get(self, timeout=None): if not self.is_done and not self._produce_future.wait(timeout): raise Errors.KafkaTimeoutError( - "Timeout after waiting for %s secs." % (timeout,)) + f"Timeout after waiting for {timeout} secs.") assert self.is_done if self.failed(): raise self.exception # pylint: disable-msg=raising-bad-type diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index dd1cc508c..f58221372 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import atexit import copy import logging @@ -8,8 +6,6 @@ import time import weakref -from kafka.vendor import six - import kafka.errors as Errors from kafka.client_async import KafkaClient, selectors from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd @@ -28,7 +24,7 @@ PRODUCER_CLIENT_ID_SEQUENCE = AtomicInteger() -class KafkaProducer(object): +class KafkaProducer: """A Kafka client that publishes records to the Kafka cluster. The producer is thread safe and sharing a single producer instance across @@ -353,7 +349,7 @@ def __init__(self, **configs): self.config[key] = configs.pop(key) # Only check for extra config keys in top-level class - assert not configs, 'Unrecognized configs: %s' % (configs,) + assert not configs, f'Unrecognized configs: {configs}' if self.config['client_id'] is None: self.config['client_id'] = 'kafka-python-producer-%s' % \ @@ -398,10 +394,10 @@ def __init__(self, **configs): # Check compression_type for library support ct = self.config['compression_type'] if ct not in self._COMPRESSORS: - raise ValueError("Not supported codec: {}".format(ct)) + raise ValueError(f"Not supported codec: {ct}") else: checker, compression_attrs = self._COMPRESSORS[ct] - assert checker(), "Libraries for {} compression codec not found".format(ct) + assert checker(), f"Libraries for {ct} compression codec not found" self.config['compression_attrs'] = compression_attrs message_version = self._max_usable_produce_magic() @@ -453,7 +449,7 @@ def _unregister_cleanup(self): def __del__(self): # Disable logger during destruction to avoid touching dangling references - class NullLogger(object): + class NullLogger: def __getattr__(self, name): return lambda *args: None @@ -703,7 +699,7 @@ def _wait_on_metadata(self, topic, max_wait): elapsed = time.time() - begin if not metadata_event.is_set(): raise Errors.KafkaTimeoutError( - "Failed to update metadata after %.1f secs." % (max_wait,)) + f"Failed to update metadata after {max_wait:.1f} secs.") elif topic in self._metadata.unauthorized_topics: raise Errors.TopicAuthorizationFailedError(topic) else: @@ -743,7 +739,7 @@ def metrics(self, raw=False): return self._metrics.metrics.copy() metrics = {} - for k, v in six.iteritems(self._metrics.metrics.copy()): + for k, v in self._metrics.metrics.copy().items(): if k.group not in metrics: metrics[k.group] = {} if k.name not in metrics[k.group]: diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index a2aa0e8ec..fc4bfeb30 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import collections import copy import logging @@ -16,7 +14,7 @@ log = logging.getLogger(__name__) -class AtomicInteger(object): +class AtomicInteger: def __init__(self, val=0): self._lock = threading.Lock() self._val = val @@ -35,7 +33,7 @@ def get(self): return self._val -class ProducerBatch(object): +class ProducerBatch: def __init__(self, tp, records, buffer): self.max_record_size = 0 now = time.time() @@ -110,7 +108,7 @@ def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full) if error: self.records.close() self.done(-1, None, Errors.KafkaTimeoutError( - "Batch for %s containing %s record(s) expired: %s" % ( + "Batch for {} containing {} record(s) expired: {}".format( self.topic_partition, self.records.next_offset(), error))) return True return False @@ -129,7 +127,7 @@ def __str__(self): self.topic_partition, self.records.next_offset()) -class RecordAccumulator(object): +class RecordAccumulator: """ This class maintains a dequeue per TopicPartition that accumulates messages into MessageSets to be sent to the server. @@ -570,7 +568,7 @@ def close(self): self._closed = True -class IncompleteProducerBatches(object): +class IncompleteProducerBatches: """A threadsafe helper class to hold ProducerBatches that haven't been ack'd yet""" def __init__(self): diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 35688d3f1..34f049486 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -1,13 +1,9 @@ -from __future__ import absolute_import, division - import collections import copy import logging import threading import time -from kafka.vendor import six - from kafka import errors as Errors from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate @@ -35,7 +31,7 @@ class Sender(threading.Thread): } def __init__(self, client, metadata, accumulator, metrics, **configs): - super(Sender, self).__init__() + super().__init__() self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: if key in configs: @@ -118,7 +114,7 @@ def run_once(self): if self.config['guarantee_message_order']: # Mute all the partitions drained - for batch_list in six.itervalues(batches_by_node): + for batch_list in batches_by_node.values(): for batch in batch_list: self._accumulator.muted.add(batch.topic_partition) @@ -142,7 +138,7 @@ def run_once(self): log.debug("Created %d produce requests: %s", len(requests), requests) # trace poll_timeout_ms = 0 - for node_id, request in six.iteritems(requests): + for node_id, request in requests.items(): batches = batches_by_node[node_id] log.debug('Sending Produce Request: %r', request) (self._client.send(node_id, request, wakeup=False) @@ -190,8 +186,8 @@ def _handle_produce_response(self, node_id, send_time, batches, response): # if we have a response, parse it log.debug('Parsing produce response: %r', response) if response: - batches_by_partition = dict([(batch.topic_partition, batch) - for batch in batches]) + batches_by_partition = {batch.topic_partition: batch + for batch in batches} for topic, partitions in response.topics: for partition_info in partitions: @@ -281,7 +277,7 @@ def _create_produce_requests(self, collated): dict: {node_id: ProduceRequest} (version depends on api_version) """ requests = {} - for node_id, batches in six.iteritems(collated): + for node_id, batches in collated.items(): requests[node_id] = self._produce_request( node_id, self.config['acks'], self.config['request_timeout_ms'], batches) @@ -324,7 +320,7 @@ def _produce_request(self, node_id, acks, timeout, batches): timeout=timeout, topics=[(topic, list(partition_info.items())) for topic, partition_info - in six.iteritems(produce_records_by_partition)], + in produce_records_by_partition.items()], **kwargs ) @@ -336,7 +332,7 @@ def bootstrap_connected(self): return self._client.bootstrap_connected() -class SenderMetrics(object): +class SenderMetrics: def __init__(self, metrics, client, metadata): self.metrics = metrics @@ -434,7 +430,7 @@ def add_metric(self, metric_name, measurable, group_name='producer-metrics', def maybe_register_topic_metrics(self, topic): def sensor_name(name): - return 'topic.{0}.{1}'.format(topic, name) + return f'topic.{topic}.{name}' # if one sensor of the metrics has been registered for the topic, # then all other sensors should have been registered; and vice versa diff --git a/kafka/protocol/__init__.py b/kafka/protocol/__init__.py index 025447f99..ff9c68306 100644 --- a/kafka/protocol/__init__.py +++ b/kafka/protocol/__init__.py @@ -1,6 +1,3 @@ -from __future__ import absolute_import - - API_KEYS = { 0: 'Produce', 1: 'Fetch', diff --git a/kafka/protocol/abstract.py b/kafka/protocol/abstract.py index 2de65c4bb..10eed5649 100644 --- a/kafka/protocol/abstract.py +++ b/kafka/protocol/abstract.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -class AbstractType(object): +class AbstractType: __metaclass__ = abc.ABCMeta @abc.abstractmethod diff --git a/kafka/protocol/admin.py b/kafka/protocol/admin.py index 0bb1a7acc..bc717fc6b 100644 --- a/kafka/protocol/admin.py +++ b/kafka/protocol/admin.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64, CompactString, CompactArray, TaggedFields @@ -790,6 +788,48 @@ class DescribeConfigsRequest_v2(Request): ] +class DescribeLogDirsResponse_v0(Response): + API_KEY = 35 + API_VERSION = 0 + FLEXIBLE_VERSION = True + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('log_dirs', Array( + ('error_code', Int16), + ('log_dir', String('utf-8')), + ('topics', Array( + ('name', String('utf-8')), + ('partitions', Array( + ('partition_index', Int32), + ('partition_size', Int64), + ('offset_lag', Int64), + ('is_future_key', Boolean) + )) + )) + )) + ) + + +class DescribeLogDirsRequest_v0(Request): + API_KEY = 35 + API_VERSION = 0 + RESPONSE_TYPE = DescribeLogDirsResponse_v0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Int32) + )) + ) + + +DescribeLogDirsResponse = [ + DescribeLogDirsResponse_v0, +] +DescribeLogDirsRequest = [ + DescribeLogDirsRequest_v0, +] + + class SaslAuthenticateResponse_v0(Response): API_KEY = 36 API_VERSION = 0 diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index f12cb972b..24cf61a62 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc from kafka.protocol.struct import Struct @@ -15,7 +13,7 @@ class RequestHeader(Struct): ) def __init__(self, request, correlation_id=0, client_id='kafka-python'): - super(RequestHeader, self).__init__( + super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -31,7 +29,7 @@ class RequestHeaderV2(Struct): ) def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=None): - super(RequestHeaderV2, self).__init__( + super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} ) diff --git a/kafka/protocol/commit.py b/kafka/protocol/commit.py index 31fc23707..0fb2b646f 100644 --- a/kafka/protocol/commit.py +++ b/kafka/protocol/commit.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String diff --git a/kafka/protocol/fetch.py b/kafka/protocol/fetch.py index f367848ce..3b742b357 100644 --- a/kafka/protocol/fetch.py +++ b/kafka/protocol/fetch.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes diff --git a/kafka/protocol/frame.py b/kafka/protocol/frame.py index 7b4a32bcf..10ebf7c9b 100644 --- a/kafka/protocol/frame.py +++ b/kafka/protocol/frame.py @@ -1,6 +1,6 @@ class KafkaBytes(bytearray): def __init__(self, size): - super(KafkaBytes, self).__init__(size) + super().__init__(size) self._idx = 0 def read(self, nbytes=None): diff --git a/kafka/protocol/group.py b/kafka/protocol/group.py index bcb96553b..9e698c21f 100644 --- a/kafka/protocol/group.py +++ b/kafka/protocol/group.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.struct import Struct from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String @@ -42,6 +40,23 @@ class JoinGroupResponse_v2(Response): ) +class JoinGroupResponse_v5(Response): + API_KEY = 11 + API_VERSION = 5 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('generation_id', Int32), + ('group_protocol', String('utf-8')), + ('leader_id', String('utf-8')), + ('member_id', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('member_metadata', Bytes))), + ) + + class JoinGroupRequest_v0(Request): API_KEY = 11 API_VERSION = 0 @@ -83,11 +98,30 @@ class JoinGroupRequest_v2(Request): UNKNOWN_MEMBER_ID = '' +class JoinGroupRequest_v5(Request): + API_KEY = 11 + API_VERSION = 5 + RESPONSE_TYPE = JoinGroupResponse_v5 + SCHEMA = Schema( + ('group', String('utf-8')), + ('session_timeout', Int32), + ('rebalance_timeout', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('protocol_type', String('utf-8')), + ('group_protocols', Array( + ('protocol_name', String('utf-8')), + ('protocol_metadata', Bytes))), + ) + UNKNOWN_MEMBER_ID = '' + + + JoinGroupRequest = [ - JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2 + JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2, None, None, JoinGroupRequest_v5, ] JoinGroupResponse = [ - JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2 + JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2, None, None, JoinGroupResponse_v5, ] @@ -118,6 +152,16 @@ class SyncGroupResponse_v1(Response): ) +class SyncGroupResponse_v3(Response): + API_KEY = 14 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('member_assignment', Bytes) + ) + + class SyncGroupRequest_v0(Request): API_KEY = 14 API_VERSION = 0 @@ -139,8 +183,23 @@ class SyncGroupRequest_v1(Request): SCHEMA = SyncGroupRequest_v0.SCHEMA -SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] -SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] +class SyncGroupRequest_v3(Request): + API_KEY = 14 + API_VERSION = 3 + RESPONSE_TYPE = SyncGroupResponse_v3 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ('group_assignment', Array( + ('member_id', String('utf-8')), + ('member_metadata', Bytes))), + ) + + +SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1, None, SyncGroupRequest_v3] +SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1, None, SyncGroupResponse_v3] class MemberAssignment(Struct): @@ -188,8 +247,29 @@ class HeartbeatRequest_v1(Request): SCHEMA = HeartbeatRequest_v0.SCHEMA -HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] -HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] +class HeartbeatResponse_v2(Response): + API_KEY = 12 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + +class HeartbeatRequest_v2(Request): + API_KEY = 12 + API_VERSION = 2 + RESPONSE_TYPE = HeartbeatResponse_v2 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')), + ) + + +HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1, HeartbeatRequest_v2] +HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1, HeartbeatResponse_v2] class LeaveGroupResponse_v0(Response): @@ -209,6 +289,15 @@ class LeaveGroupResponse_v1(Response): ) +class LeaveGroupResponse_v3(Response): + API_KEY = 13 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + class LeaveGroupRequest_v0(Request): API_KEY = 13 API_VERSION = 0 @@ -226,5 +315,17 @@ class LeaveGroupRequest_v1(Request): SCHEMA = LeaveGroupRequest_v0.SCHEMA -LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] -LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] +class LeaveGroupRequest_v3(Request): + API_KEY = 13 + API_VERSION = 3 + RESPONSE_TYPE = LeaveGroupResponse_v3 + SCHEMA = Schema( + ('group', String('utf-8')), + ('member_identity_list', Array( + ('member_id', String('utf-8')), + ('group_instance_id', String('utf-8')))), + ) + + +LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1, None, LeaveGroupRequest_v3] +LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1, None, LeaveGroupResponse_v3] diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py index 4c5c031b8..b07d4eb0e 100644 --- a/kafka/protocol/message.py +++ b/kafka/protocol/message.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import io import time @@ -78,7 +76,7 @@ def _encode_self(self, recalc_crc=True): elif version == 0: fields = (self.crc, self.magic, self.attributes, self.key, self.value) else: - raise ValueError('Unrecognized message version: %s' % (version,)) + raise ValueError(f'Unrecognized message version: {version}') message = Message.SCHEMAS[version].encode(fields) if not recalc_crc: return message @@ -94,7 +92,7 @@ def decode(cls, data): data = io.BytesIO(data) # Partial decode required to determine message version base_fields = cls.SCHEMAS[0].fields[0:3] - crc, magic, attributes = [field.decode(data) for field in base_fields] + crc, magic, attributes = (field.decode(data) for field in base_fields) remaining = cls.SCHEMAS[magic].fields[3:] fields = [field.decode(data) for field in remaining] if magic == 1: @@ -147,7 +145,7 @@ def __hash__(self): class PartialMessage(bytes): def __repr__(self): - return 'PartialMessage(%s)' % (self,) + return f'PartialMessage({self})' class MessageSet(AbstractType): diff --git a/kafka/protocol/metadata.py b/kafka/protocol/metadata.py index 414e5b84a..041444100 100644 --- a/kafka/protocol/metadata.py +++ b/kafka/protocol/metadata.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Boolean, Int16, Int32, Schema, String diff --git a/kafka/protocol/offset.py b/kafka/protocol/offset.py index 1ed382b0d..34b00a40b 100644 --- a/kafka/protocol/offset.py +++ b/kafka/protocol/offset.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String UNKNOWN_OFFSET = -1 -class OffsetResetStrategy(object): +class OffsetResetStrategy: LATEST = -1 EARLIEST = -2 NONE = 0 diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index a9e767220..a667105ad 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import collections import logging @@ -12,7 +10,7 @@ log = logging.getLogger(__name__) -class KafkaProtocol(object): +class KafkaProtocol: """Manage the kafka network protocol Use an instance of KafkaProtocol to manage bytes send/recv'd diff --git a/kafka/protocol/pickle.py b/kafka/protocol/pickle.py index d6e5fa74f..cd73b3add 100644 --- a/kafka/protocol/pickle.py +++ b/kafka/protocol/pickle.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - try: import copyreg # pylint: disable=import-error except ImportError: diff --git a/kafka/protocol/produce.py b/kafka/protocol/produce.py index 9b3f6bf55..b62430c22 100644 --- a/kafka/protocol/produce.py +++ b/kafka/protocol/produce.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from kafka.protocol.api import Request, Response from kafka.protocol.types import Int16, Int32, Int64, String, Array, Schema, Bytes diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index e9da6e6c1..eb08ac8ef 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - from io import BytesIO from kafka.protocol.abstract import AbstractType @@ -57,7 +55,7 @@ def get_item(self, name): def __repr__(self): key_vals = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): - key_vals.append('%s=%s' % (name, field.repr(self.__dict__[name]))) + key_vals.append(f'{name}={field.repr(self.__dict__[name])}') return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' def __hash__(self): diff --git a/kafka/protocol/types.py b/kafka/protocol/types.py index 0e3685d73..0118af11b 100644 --- a/kafka/protocol/types.py +++ b/kafka/protocol/types.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import struct from struct import error @@ -175,7 +173,7 @@ def repr(self, value): field_val = getattr(value, self.names[i]) except AttributeError: field_val = value[i] - key_vals.append('%s=%s' % (self.names[i], self.fields[i].repr(field_val))) + key_vals.append(f'{self.names[i]}={self.fields[i].repr(field_val)}') return '(' + ', '.join(key_vals) + ')' except Exception: return repr(value) @@ -223,7 +221,7 @@ def decode(cls, data): value |= (b & 0x7f) << i i += 7 if i > 28: - raise ValueError('Invalid value {}'.format(value)) + raise ValueError(f'Invalid value {value}') value |= b << i return value @@ -263,7 +261,7 @@ def decode(cls, data): value |= (b & 0x7f) << i i += 7 if i > 63: - raise ValueError('Invalid value {}'.format(value)) + raise ValueError(f'Invalid value {value}') value |= b << i return (value >> 1) ^ -(value & 1) @@ -309,7 +307,7 @@ def decode(cls, data): for i in range(num_fields): tag = UnsignedVarInt32.decode(data) if tag <= prev_tag: - raise ValueError('Invalid or out-of-order tag {}'.format(tag)) + raise ValueError(f'Invalid or out-of-order tag {tag}') prev_tag = tag size = UnsignedVarInt32.decode(data) val = data.read(size) @@ -321,8 +319,8 @@ def encode(cls, value): ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), 'Value {} is not a byte array'.format(v) - assert isinstance(k, int) and k > 0, 'Key {} is not a positive integer'.format(k) + assert isinstance(v, bytes), f'Value {v} is not a byte array' + assert isinstance(k, int) and k > 0, f'Key {k} is not a positive integer' ret += UnsignedVarInt32.encode(k) ret += v return ret diff --git a/kafka/record/abc.py b/kafka/record/abc.py index 8509e23e5..f45176051 100644 --- a/kafka/record/abc.py +++ b/kafka/record/abc.py @@ -1,8 +1,7 @@ -from __future__ import absolute_import import abc -class ABCRecord(object): +class ABCRecord: __metaclass__ = abc.ABCMeta __slots__ = () @@ -44,7 +43,7 @@ def headers(self): """ -class ABCRecordBatchBuilder(object): +class ABCRecordBatchBuilder: __metaclass__ = abc.ABCMeta __slots__ = () @@ -84,7 +83,7 @@ def build(self): """ -class ABCRecordBatch(object): +class ABCRecordBatch: """ For v2 encapsulates a RecordBatch, for v0/v1 a single (maybe compressed) message. """ @@ -98,7 +97,7 @@ def __iter__(self): """ -class ABCRecords(object): +class ABCRecords: __metaclass__ = abc.ABCMeta __slots__ = () diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index a098c42a9..5045f31ee 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -68,7 +68,7 @@ import kafka.codec as codecs -class DefaultRecordBase(object): +class DefaultRecordBase: __slots__ = () @@ -116,7 +116,7 @@ def _assert_has_codec(self, compression_type): checker, name = codecs.has_zstd, "zstd" if not checker(): raise UnsupportedCodecError( - "Libraries for {} compression codec not found".format(name)) + f"Libraries for {name} compression codec not found") class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): @@ -247,7 +247,7 @@ def _read_msg( h_key_len, pos = decode_varint(buffer, pos) if h_key_len < 0: raise CorruptRecordException( - "Invalid negative header key size {}".format(h_key_len)) + f"Invalid negative header key size {h_key_len}") h_key = buffer[pos: pos + h_key_len].decode("utf-8") pos += h_key_len @@ -287,7 +287,7 @@ def __next__(self): msg = self._read_msg() except (ValueError, IndexError) as err: raise CorruptRecordException( - "Found invalid record structure: {!r}".format(err)) + f"Found invalid record structure: {err!r}") else: self._next_record_index += 1 return msg @@ -421,10 +421,10 @@ def append(self, offset, timestamp, key, value, headers, raise TypeError(timestamp) if not (key is None or get_type(key) in byte_like): raise TypeError( - "Not supported type for key: {}".format(type(key))) + f"Not supported type for key: {type(key)}") if not (value is None or get_type(value) in byte_like): raise TypeError( - "Not supported type for value: {}".format(type(value))) + f"Not supported type for value: {type(value)}") # We will always add the first message, so those will be set if self._first_timestamp is None: @@ -598,7 +598,7 @@ def estimate_size_in_bytes(cls, key, value, headers): ) -class DefaultRecordMetadata(object): +class DefaultRecordMetadata: __slots__ = ("_size", "_timestamp", "_offset") diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py index 2f8523fcb..9ab8873ca 100644 --- a/kafka/record/legacy_records.py +++ b/kafka/record/legacy_records.py @@ -55,7 +55,7 @@ from kafka.errors import CorruptRecordException, UnsupportedCodecError -class LegacyRecordBase(object): +class LegacyRecordBase: __slots__ = () @@ -124,7 +124,7 @@ def _assert_has_codec(self, compression_type): checker, name = codecs.has_lz4, "lz4" if not checker(): raise UnsupportedCodecError( - "Libraries for {} compression codec not found".format(name)) + f"Libraries for {name} compression codec not found") class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): @@ -367,11 +367,11 @@ def append(self, offset, timestamp, key, value, headers=None): if not (key is None or isinstance(key, (bytes, bytearray, memoryview))): raise TypeError( - "Not supported type for key: {}".format(type(key))) + f"Not supported type for key: {type(key)}") if not (value is None or isinstance(value, (bytes, bytearray, memoryview))): raise TypeError( - "Not supported type for value: {}".format(type(value))) + f"Not supported type for value: {type(value)}") # Check if we have room for another message pos = len(self._buffer) @@ -514,7 +514,7 @@ def estimate_size_in_bytes(cls, magic, compression_type, key, value): return cls.LOG_OVERHEAD + cls.record_size(magic, key, value) -class LegacyRecordMetadata(object): +class LegacyRecordMetadata: __slots__ = ("_crc", "_size", "_timestamp", "_offset") diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index fc2ef2d6b..7a604887c 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -18,7 +18,6 @@ # # So we can iterate over batches just by knowing offsets of Length. Magic is # used to construct the correct class for Batch itself. -from __future__ import division import struct @@ -110,7 +109,7 @@ def next_batch(self, _min_slice=MIN_SLICE, return DefaultRecordBatch(next_slice) -class MemoryRecordsBuilder(object): +class MemoryRecordsBuilder: __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", "_bytes_written") diff --git a/kafka/sasl/__init__.py b/kafka/sasl/__init__.py new file mode 100644 index 000000000..337c90949 --- /dev/null +++ b/kafka/sasl/__init__.py @@ -0,0 +1,54 @@ +import logging + +from kafka.sasl import gssapi, oauthbearer, plain, scram, msk + +log = logging.getLogger(__name__) + +MECHANISMS = { + 'GSSAPI': gssapi, + 'OAUTHBEARER': oauthbearer, + 'PLAIN': plain, + 'SCRAM-SHA-256': scram, + 'SCRAM-SHA-512': scram, + 'AWS_MSK_IAM': msk, +} + + +def register_mechanism(key, module): + """ + Registers a custom SASL mechanism that can be used via sasl_mechanism={key}. + + Example: + import kakfa.sasl + from kafka import KafkaProducer + from mymodule import custom_sasl + kafka.sasl.register_mechanism('CUSTOM_SASL', custom_sasl) + + producer = KafkaProducer(sasl_mechanism='CUSTOM_SASL') + + Arguments: + key (str): The name of the mechanism returned by the broker and used + in the sasl_mechanism config value. + module (module): A module that implements the following methods... + + def validate_config(conn: BrokerConnection): -> None: + # Raises an AssertionError for missing or invalid conifg values. + + def try_authenticate(conn: BrokerConncetion, future: -> Future): + # Executes authentication routine and returns a resolved Future. + + Raises: + AssertionError: The registered module does not define a required method. + """ + assert callable(getattr(module, 'validate_config', None)), ( + 'Custom SASL mechanism {} must implement method #validate_config()' + .format(key) + ) + assert callable(getattr(module, 'try_authenticate', None)), ( + 'Custom SASL mechanism {} must implement method #try_authenticate()' + .format(key) + ) + if key in MECHANISMS: + log.warning(f'Overriding existing SASL mechanism {key}') + + MECHANISMS[key] = module diff --git a/kafka/sasl/gssapi.py b/kafka/sasl/gssapi.py new file mode 100644 index 000000000..3daf7e148 --- /dev/null +++ b/kafka/sasl/gssapi.py @@ -0,0 +1,100 @@ +import io +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int8, Int32 + +try: + import gssapi + from gssapi.raw.misc import GSSError +except ImportError: + gssapi = None + GSSError = None + +log = logging.getLogger(__name__) + +SASL_QOP_AUTH = 1 + + +def validate_config(conn): + assert gssapi is not None, ( + 'gssapi library required when sasl_mechanism=GSSAPI' + ) + assert conn.config['sasl_kerberos_service_name'] is not None, ( + 'sasl_kerberos_service_name required when sasl_mechanism=GSSAPI' + ) + + +def try_authenticate(conn, future): + kerberos_damin_name = conn.config['sasl_kerberos_domain_name'] or conn.host + auth_id = conn.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name + gssapi_name = gssapi.Name( + auth_id, + name_type=gssapi.NameType.hostbased_service + ).canonicalize(gssapi.MechType.kerberos) + log.debug('%s: GSSAPI name: %s', conn, gssapi_name) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + conn._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = conn._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = conn._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality + # security layers, so we simply set QoP to 'auth' only (first octet). + # We reuse the max message size proposed by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message + size = Int32.encode(len(msg)) + conn._send_bytes_blocking(size + msg) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + except Exception as e: + err = e + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info('%s: Authenticated as %s via GSSAPI', conn, gssapi_name) + return future.success(True) diff --git a/kafka/sasl/msk.py b/kafka/sasl/msk.py new file mode 100644 index 000000000..6d1bb74fb --- /dev/null +++ b/kafka/sasl/msk.py @@ -0,0 +1,230 @@ +import datetime +import hashlib +import hmac +import json +import string +import struct +import logging +import urllib + +from kafka.protocol.types import Int32 +import kafka.errors as Errors + +from botocore.session import Session as BotoSession # importing it in advance is not an option apparently... + + +def try_authenticate(self, future): + + session = BotoSession() + credentials = session.get_credentials().get_frozen_credentials() + client = AwsMskIamClient( + host=self.host, + access_key=credentials.access_key, + secret_key=credentials.secret_key, + region=session.get_config_variable('region'), + token=credentials.token, + ) + + msg = client.first_message() + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + data = self._recv_bytes_blocking(4) + data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1]) + except (ConnectionError, TimeoutError) as e: + logging.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError(f"{self}: {e}") + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + logging.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8')) + return future.success(True) + + +class AwsMskIamClient: + UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' + + def __init__(self, host, access_key, secret_key, region, token=None): + """ + Arguments: + host (str): The hostname of the broker. + access_key (str): An AWS_ACCESS_KEY_ID. + secret_key (str): An AWS_SECRET_ACCESS_KEY. + region (str): An AWS_REGION. + token (Optional[str]): An AWS_SESSION_TOKEN if using temporary + credentials. + """ + self.algorithm = 'AWS4-HMAC-SHA256' + self.expires = '900' + self.hashfunc = hashlib.sha256 + self.headers = [ + ('host', host) + ] + self.version = '2020_10_22' + + self.service = 'kafka-cluster' + self.action = f'{self.service}:Connect' + + now = datetime.datetime.utcnow() + self.datestamp = now.strftime('%Y%m%d') + self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') + + self.host = host + self.access_key = access_key + self.secret_key = secret_key + self.region = region + self.token = token + + @property + def _credential(self): + return '{0.access_key}/{0._scope}'.format(self) + + @property + def _scope(self): + return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) + + @property + def _signed_headers(self): + """ + Returns (str): + An alphabetically sorted, semicolon-delimited list of lowercase + request header names. + """ + return ';'.join(sorted(k.lower() for k, _ in self.headers)) + + @property + def _canonical_headers(self): + """ + Returns (str): + A newline-delited list of header names and values. + Header names are lowercased. + """ + return '\n'.join(map(':'.join, self.headers)) + '\n' + + @property + def _canonical_request(self): + """ + Returns (str): + An AWS Signature Version 4 canonical request in the format: + \n + \n + \n + \n + \n + + """ + # The hashed_payload is always an empty string for MSK. + hashed_payload = self.hashfunc(b'').hexdigest() + return '\n'.join(( + 'GET', + '/', + self._canonical_querystring, + self._canonical_headers, + self._signed_headers, + hashed_payload, + )) + + @property + def _canonical_querystring(self): + """ + Returns (str): + A '&'-separated list of URI-encoded key/value pairs. + """ + params = [] + params.append(('Action', self.action)) + params.append(('X-Amz-Algorithm', self.algorithm)) + params.append(('X-Amz-Credential', self._credential)) + params.append(('X-Amz-Date', self.timestamp)) + params.append(('X-Amz-Expires', self.expires)) + if self.token: + params.append(('X-Amz-Security-Token', self.token)) + params.append(('X-Amz-SignedHeaders', self._signed_headers)) + + return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) + + @property + def _signing_key(self): + """ + Returns (bytes): + An AWS Signature V4 signing key generated from the secret_key, date, + region, service, and request type. + """ + key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) + key = self._hmac(key, self.region) + key = self._hmac(key, self.service) + key = self._hmac(key, 'aws4_request') + return key + + @property + def _signing_str(self): + """ + Returns (str): + A string used to sign the AWS Signature V4 payload in the format: + \n + \n + \n + + """ + canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() + return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) + + def _uriencode(self, msg): + """ + Arguments: + msg (str): A string to URI-encode. + + Returns (str): + The URI-encoded version of the provided msg, following the encoding + rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode + """ + return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) + + def _hmac(self, key, msg): + """ + Arguments: + key (bytes): A key to use for the HMAC digest. + msg (str): A value to include in the HMAC digest. + Returns (bytes): + An HMAC digest of the given key and msg. + """ + return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() + + def first_message(self): + """ + Returns (bytes): + An encoded JSON authentication payload that can be sent to the + broker. + """ + signature = hmac.new( + self._signing_key, + self._signing_str.encode('utf-8'), + digestmod=self.hashfunc, + ).hexdigest() + msg = { + 'version': self.version, + 'host': self.host, + 'user-agent': 'kafka-python', + 'action': self.action, + 'x-amz-algorithm': self.algorithm, + 'x-amz-credential': self._credential, + 'x-amz-date': self.timestamp, + 'x-amz-signedheaders': self._signed_headers, + 'x-amz-expires': self.expires, + 'x-amz-signature': signature, + } + if self.token: + msg['x-amz-security-token'] = self.token + + return json.dumps(msg, separators=(',', ':')).encode('utf-8') diff --git a/kafka/sasl/oauthbearer.py b/kafka/sasl/oauthbearer.py new file mode 100644 index 000000000..5d2488baf --- /dev/null +++ b/kafka/sasl/oauthbearer.py @@ -0,0 +1,80 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + token_provider = conn.config.get('sasl_oauth_token_provider') + assert token_provider is not None, ( + 'sasl_oauth_token_provider required when sasl_mechanism=OAUTHBEARER' + ) + assert callable(getattr(token_provider, 'token', None)), ( + 'sasl_oauth_token_provider must implement method #token()' + ) + + +def try_authenticate(conn, future): + data = b'' + + msg = bytes(_build_oauth_client_request(conn).encode("utf-8")) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', conn) + return future.success(True) + + +def _build_oauth_client_request(conn): + token_provider = conn.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format( + token_provider.token(), + _token_extensions(conn), + ) + + +def _token_extensions(conn): + """ + Return a string representation of the OPTIONAL key-value pairs that can be + sent with an OAUTHBEARER initial request. + """ + token_provider = conn.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if (callable(getattr(token_provider, "extensions", None)) + and len(token_provider.extensions()) > 0): + msg = "\x01".join([f"{k}={v}" for k, v in token_provider.extensions().items()]) + return "\x01" + msg + else: + return "" diff --git a/kafka/sasl/plain.py b/kafka/sasl/plain.py new file mode 100644 index 000000000..625a43f08 --- /dev/null +++ b/kafka/sasl/plain.py @@ -0,0 +1,58 @@ +import logging + +import kafka.errors as Errors +from kafka.protocol.types import Int32 + +log = logging.getLogger(__name__) + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=PLAIN' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=PLAIN' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', conn) + + data = b'' + # Send PLAIN credentials per RFC-4616 + msg = bytes('\0'.join([conn.config['sasl_plain_username'], + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password']]).encode('utf-8')) + size = Int32.encode(len(msg)) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + conn._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = conn._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated as %s via PLAIN', conn, conn.config['sasl_plain_username']) + return future.success(True) diff --git a/kafka/sasl/scram.py b/kafka/sasl/scram.py new file mode 100644 index 000000000..4f3e60126 --- /dev/null +++ b/kafka/sasl/scram.py @@ -0,0 +1,68 @@ +import logging +import struct + +import kafka.errors as Errors +from kafka.protocol.types import Int32 +from kafka.scram import ScramClient + +log = logging.getLogger() + + +def validate_config(conn): + assert conn.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required when sasl_mechanism=SCRAM-*' + ) + assert conn.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required when sasl_mechanism=SCRAM-*' + ) + + +def try_authenticate(conn, future): + if conn.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Exchanging credentials in the clear', conn) + + scram_client = ScramClient( + conn.config['sasl_plain_username'], + conn.config['sasl_plain_password'], + conn.config['sasl_mechanism'], + ) + + err = None + close = False + with conn._lock: + if not conn._can_send_recv(): + err = Errors.NodeNotReadyError(str(conn)) + close = False + else: + try: + client_first = scram_client.first_message().encode('utf-8') + size = Int32.encode(len(client_first)) + conn._send_bytes_blocking(size + client_first) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_first = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_first_message(server_first) + + client_final = scram_client.final_message().encode('utf-8') + size = Int32.encode(len(client_final)) + conn._send_bytes_blocking(size + client_final) + + (data_len,) = struct.unpack('>i', conn._recv_bytes_blocking(4)) + server_final = conn._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_final_message(server_final) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", conn) + err = Errors.KafkaConnectionError(f"{conn}: {e}") + close = True + + if err is not None: + if close: + conn.close(error=err) + return future.failure(err) + + log.info( + '%s: Authenticated as %s via %s', + conn, conn.config['sasl_plain_username'], conn.config['sasl_mechanism'] + ) + return future.success(True) diff --git a/kafka/scram.py b/kafka/scram.py index 7f003750c..74f4716bd 100644 --- a/kafka/scram.py +++ b/kafka/scram.py @@ -1,19 +1,11 @@ -from __future__ import absolute_import - import base64 import hashlib import hmac import uuid -from kafka.vendor import six - -if six.PY2: - def xor_bytes(left, right): - return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) -else: - def xor_bytes(left, right): - return bytes(lb ^ rb for lb, rb in zip(left, right)) +def xor_bytes(left, right): + return bytes(lb ^ rb for lb, rb in zip(left, right)) class ScramClient: @@ -38,7 +30,7 @@ def __init__(self, user, password, mechanism): self.server_signature = None def first_message(self): - client_first_bare = 'n={},r={}'.format(self.user, self.nonce) + client_first_bare = f'n={self.user},r={self.nonce}' self.auth_message += client_first_bare return 'n,,' + client_first_bare diff --git a/kafka/serializer/__init__.py b/kafka/serializer/__init__.py index 90cd93ab2..168277519 100644 --- a/kafka/serializer/__init__.py +++ b/kafka/serializer/__init__.py @@ -1,3 +1 @@ -from __future__ import absolute_import - from kafka.serializer.abstract import Serializer, Deserializer diff --git a/kafka/serializer/abstract.py b/kafka/serializer/abstract.py index 18ad8d69c..529662b07 100644 --- a/kafka/serializer/abstract.py +++ b/kafka/serializer/abstract.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -class Serializer(object): +class Serializer: __meta__ = abc.ABCMeta def __init__(self, **config): @@ -17,7 +15,7 @@ def close(self): pass -class Deserializer(object): +class Deserializer: __meta__ = abc.ABCMeta def __init__(self, **config): diff --git a/kafka/structs.py b/kafka/structs.py index bcb023670..4f11259aa 100644 --- a/kafka/structs.py +++ b/kafka/structs.py @@ -1,5 +1,4 @@ """ Other useful structs """ -from __future__ import absolute_import from collections import namedtuple diff --git a/kafka/util.py b/kafka/util.py index e31d99305..0c9c5ea62 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -1,28 +1,21 @@ -from __future__ import absolute_import - import binascii import weakref -from kafka.vendor import six - -if six.PY3: - MAX_INT = 2 ** 31 - TO_SIGNED = 2 ** 32 +MAX_INT = 2 ** 31 +TO_SIGNED = 2 ** 32 - def crc32(data): - crc = binascii.crc32(data) - # py2 and py3 behave a little differently - # CRC is encoded as a signed int in kafka protocol - # so we'll convert the py3 unsigned result to signed - if crc >= MAX_INT: - crc -= TO_SIGNED - return crc -else: - from binascii import crc32 +def crc32(data): + crc = binascii.crc32(data) + # py2 and py3 behave a little differently + # CRC is encoded as a signed int in kafka protocol + # so we'll convert the py3 unsigned result to signed + if crc >= MAX_INT: + crc -= TO_SIGNED + return crc -class WeakMethod(object): +class WeakMethod: """ Callable that weakly references a method and the object it is bound to. It is based on https://stackoverflow.com/a/24287465. diff --git a/kafka/vendor/__init__.py b/kafka/vendor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/kafka/vendor/enum34.py b/kafka/vendor/enum34.py deleted file mode 100644 index 5f64bd2d8..000000000 --- a/kafka/vendor/enum34.py +++ /dev/null @@ -1,841 +0,0 @@ -# pylint: skip-file -# vendored from: -# https://bitbucket.org/stoneleaf/enum34/src/58c4cd7174ca35f164304c8a6f0a4d47b779c2a7/enum/__init__.py?at=1.1.6 - -"""Python Enumerations""" - -import sys as _sys - -__all__ = ['Enum', 'IntEnum', 'unique'] - -version = 1, 1, 6 - -pyver = float('%s.%s' % _sys.version_info[:2]) - -try: - any -except NameError: - def any(iterable): - for element in iterable: - if element: - return True - return False - -try: - from collections import OrderedDict -except ImportError: - OrderedDict = None - -try: - basestring -except NameError: - # In Python 2 basestring is the ancestor of both str and unicode - # in Python 3 it's just str, but was missing in 3.1 - basestring = str - -try: - unicode -except NameError: - # In Python 3 unicode no longer exists (it's just str) - unicode = str - -class _RouteClassAttributeToGetattr(object): - """Route attribute access on a class to __getattr__. - - This is a descriptor, used to define attributes that act differently when - accessed through an instance and through a class. Instance access remains - normal, but access to an attribute through a class will be routed to the - class's __getattr__ method; this is done by raising AttributeError. - - """ - def __init__(self, fget=None): - self.fget = fget - - def __get__(self, instance, ownerclass=None): - if instance is None: - raise AttributeError() - return self.fget(instance) - - def __set__(self, instance, value): - raise AttributeError("can't set attribute") - - def __delete__(self, instance): - raise AttributeError("can't delete attribute") - - -def _is_descriptor(obj): - """Returns True if obj is a descriptor, False otherwise.""" - return ( - hasattr(obj, '__get__') or - hasattr(obj, '__set__') or - hasattr(obj, '__delete__')) - - -def _is_dunder(name): - """Returns True if a __dunder__ name, False otherwise.""" - return (name[:2] == name[-2:] == '__' and - name[2:3] != '_' and - name[-3:-2] != '_' and - len(name) > 4) - - -def _is_sunder(name): - """Returns True if a _sunder_ name, False otherwise.""" - return (name[0] == name[-1] == '_' and - name[1:2] != '_' and - name[-2:-1] != '_' and - len(name) > 2) - - -def _make_class_unpicklable(cls): - """Make the given class un-picklable.""" - def _break_on_call_reduce(self, protocol=None): - raise TypeError('%r cannot be pickled' % self) - cls.__reduce_ex__ = _break_on_call_reduce - cls.__module__ = '' - - -class _EnumDict(dict): - """Track enum member order and ensure member names are not reused. - - EnumMeta will use the names found in self._member_names as the - enumeration member names. - - """ - def __init__(self): - super(_EnumDict, self).__init__() - self._member_names = [] - - def __setitem__(self, key, value): - """Changes anything not dundered or not a descriptor. - - If a descriptor is added with the same name as an enum member, the name - is removed from _member_names (this may leave a hole in the numerical - sequence of values). - - If an enum member name is used twice, an error is raised; duplicate - values are not checked for. - - Single underscore (sunder) names are reserved. - - Note: in 3.x __order__ is simply discarded as a not necessary piece - leftover from 2.x - - """ - if pyver >= 3.0 and key in ('_order_', '__order__'): - return - elif key == '__order__': - key = '_order_' - if _is_sunder(key): - if key != '_order_': - raise ValueError('_names_ are reserved for future Enum use') - elif _is_dunder(key): - pass - elif key in self._member_names: - # descriptor overwriting an enum? - raise TypeError('Attempted to reuse key: %r' % key) - elif not _is_descriptor(value): - if key in self: - # enum overwriting a descriptor? - raise TypeError('Key already defined as: %r' % self[key]) - self._member_names.append(key) - super(_EnumDict, self).__setitem__(key, value) - - -# Dummy value for Enum as EnumMeta explicity checks for it, but of course until -# EnumMeta finishes running the first time the Enum class doesn't exist. This -# is also why there are checks in EnumMeta like `if Enum is not None` -Enum = None - - -class EnumMeta(type): - """Metaclass for Enum""" - @classmethod - def __prepare__(metacls, cls, bases): - return _EnumDict() - - def __new__(metacls, cls, bases, classdict): - # an Enum class is final once enumeration items have been defined; it - # cannot be mixed with other types (int, float, etc.) if it has an - # inherited __new__ unless a new __new__ is defined (or the resulting - # class will fail). - if type(classdict) is dict: - original_dict = classdict - classdict = _EnumDict() - for k, v in original_dict.items(): - classdict[k] = v - - member_type, first_enum = metacls._get_mixins_(bases) - __new__, save_new, use_args = metacls._find_new_(classdict, member_type, - first_enum) - # save enum items into separate mapping so they don't get baked into - # the new class - members = dict((k, classdict[k]) for k in classdict._member_names) - for name in classdict._member_names: - del classdict[name] - - # py2 support for definition order - _order_ = classdict.get('_order_') - if _order_ is None: - if pyver < 3.0: - try: - _order_ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] - except TypeError: - _order_ = [name for name in sorted(members.keys())] - else: - _order_ = classdict._member_names - else: - del classdict['_order_'] - if pyver < 3.0: - _order_ = _order_.replace(',', ' ').split() - aliases = [name for name in members if name not in _order_] - _order_ += aliases - - # check for illegal enum names (any others?) - invalid_names = set(members) & set(['mro']) - if invalid_names: - raise ValueError('Invalid enum member name(s): %s' % ( - ', '.join(invalid_names), )) - - # save attributes from super classes so we know if we can take - # the shortcut of storing members in the class dict - base_attributes = set([a for b in bases for a in b.__dict__]) - # create our new Enum type - enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) - enum_class._member_names_ = [] # names in random order - if OrderedDict is not None: - enum_class._member_map_ = OrderedDict() - else: - enum_class._member_map_ = {} # name->value map - enum_class._member_type_ = member_type - - # Reverse value->name map for hashable values. - enum_class._value2member_map_ = {} - - # instantiate them, checking for duplicates as we go - # we instantiate first instead of checking for duplicates first in case - # a custom __new__ is doing something funky with the values -- such as - # auto-numbering ;) - if __new__ is None: - __new__ = enum_class.__new__ - for member_name in _order_: - value = members[member_name] - if not isinstance(value, tuple): - args = (value, ) - else: - args = value - if member_type is tuple: # special case for tuple enums - args = (args, ) # wrap it one more time - if not use_args or not args: - enum_member = __new__(enum_class) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = value - else: - enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = member_type(*args) - value = enum_member._value_ - enum_member._name_ = member_name - enum_member.__objclass__ = enum_class - enum_member.__init__(*args) - # If another member with the same value was already defined, the - # new member becomes an alias to the existing one. - for name, canonical_member in enum_class._member_map_.items(): - if canonical_member.value == enum_member._value_: - enum_member = canonical_member - break - else: - # Aliases don't appear in member names (only in __members__). - enum_class._member_names_.append(member_name) - # performance boost for any member that would not shadow - # a DynamicClassAttribute (aka _RouteClassAttributeToGetattr) - if member_name not in base_attributes: - setattr(enum_class, member_name, enum_member) - # now add to _member_map_ - enum_class._member_map_[member_name] = enum_member - try: - # This may fail if value is not hashable. We can't add the value - # to the map, and by-value lookups for this value will be - # linear. - enum_class._value2member_map_[value] = enum_member - except TypeError: - pass - - - # If a custom type is mixed into the Enum, and it does not know how - # to pickle itself, pickle.dumps will succeed but pickle.loads will - # fail. Rather than have the error show up later and possibly far - # from the source, sabotage the pickle protocol for this class so - # that pickle.dumps also fails. - # - # However, if the new class implements its own __reduce_ex__, do not - # sabotage -- it's on them to make sure it works correctly. We use - # __reduce_ex__ instead of any of the others as it is preferred by - # pickle over __reduce__, and it handles all pickle protocols. - unpicklable = False - if '__reduce_ex__' not in classdict: - if member_type is not object: - methods = ('__getnewargs_ex__', '__getnewargs__', - '__reduce_ex__', '__reduce__') - if not any(m in member_type.__dict__ for m in methods): - _make_class_unpicklable(enum_class) - unpicklable = True - - - # double check that repr and friends are not the mixin's or various - # things break (such as pickle) - for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): - class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) - enum_method = getattr(first_enum, name, None) - if name not in classdict and class_method is not enum_method: - if name == '__reduce_ex__' and unpicklable: - continue - setattr(enum_class, name, enum_method) - - # method resolution and int's are not playing nice - # Python's less than 2.6 use __cmp__ - - if pyver < 2.6: - - if issubclass(enum_class, int): - setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) - - elif pyver < 3.0: - - if issubclass(enum_class, int): - for method in ( - '__le__', - '__lt__', - '__gt__', - '__ge__', - '__eq__', - '__ne__', - '__hash__', - ): - setattr(enum_class, method, getattr(int, method)) - - # replace any other __new__ with our own (as long as Enum is not None, - # anyway) -- again, this is to support pickle - if Enum is not None: - # if the user defined their own __new__, save it before it gets - # clobbered in case they subclass later - if save_new: - setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) - setattr(enum_class, '__new__', Enum.__dict__['__new__']) - return enum_class - - def __bool__(cls): - """ - classes/types should always be True. - """ - return True - - def __call__(cls, value, names=None, module=None, type=None, start=1): - """Either returns an existing member, or creates a new enum class. - - This method is used both when an enum class is given a value to match - to an enumeration member (i.e. Color(3)) and for the functional API - (i.e. Color = Enum('Color', names='red green blue')). - - When used for the functional API: `module`, if set, will be stored in - the new class' __module__ attribute; `type`, if set, will be mixed in - as the first base class. - - Note: if `module` is not set this routine will attempt to discover the - calling module by walking the frame stack; if this is unsuccessful - the resulting class will not be pickleable. - - """ - if names is None: # simple value lookup - return cls.__new__(cls, value) - # otherwise, functional API: we're creating a new Enum type - return cls._create_(value, names, module=module, type=type, start=start) - - def __contains__(cls, member): - return isinstance(member, cls) and member.name in cls._member_map_ - - def __delattr__(cls, attr): - # nicer error message when someone tries to delete an attribute - # (see issue19025). - if attr in cls._member_map_: - raise AttributeError( - "%s: cannot delete Enum member." % cls.__name__) - super(EnumMeta, cls).__delattr__(attr) - - def __dir__(self): - return (['__class__', '__doc__', '__members__', '__module__'] + - self._member_names_) - - @property - def __members__(cls): - """Returns a mapping of member name->value. - - This mapping lists all enum members, including aliases. Note that this - is a copy of the internal mapping. - - """ - return cls._member_map_.copy() - - def __getattr__(cls, name): - """Return the enum member matching `name` - - We use __getattr__ instead of descriptors or inserting into the enum - class' __dict__ in order to support `name` and `value` being both - properties for enum members (which live in the class' __dict__) and - enum members themselves. - - """ - if _is_dunder(name): - raise AttributeError(name) - try: - return cls._member_map_[name] - except KeyError: - raise AttributeError(name) - - def __getitem__(cls, name): - return cls._member_map_[name] - - def __iter__(cls): - return (cls._member_map_[name] for name in cls._member_names_) - - def __reversed__(cls): - return (cls._member_map_[name] for name in reversed(cls._member_names_)) - - def __len__(cls): - return len(cls._member_names_) - - __nonzero__ = __bool__ - - def __repr__(cls): - return "" % cls.__name__ - - def __setattr__(cls, name, value): - """Block attempts to reassign Enum members. - - A simple assignment to the class namespace only changes one of the - several possible ways to get an Enum member from the Enum class, - resulting in an inconsistent Enumeration. - - """ - member_map = cls.__dict__.get('_member_map_', {}) - if name in member_map: - raise AttributeError('Cannot reassign members.') - super(EnumMeta, cls).__setattr__(name, value) - - def _create_(cls, class_name, names=None, module=None, type=None, start=1): - """Convenience method to create a new Enum class. - - `names` can be: - - * A string containing member names, separated either with spaces or - commas. Values are auto-numbered from 1. - * An iterable of member names. Values are auto-numbered from 1. - * An iterable of (member name, value) pairs. - * A mapping of member name -> value. - - """ - if pyver < 3.0: - # if class_name is unicode, attempt a conversion to ASCII - if isinstance(class_name, unicode): - try: - class_name = class_name.encode('ascii') - except UnicodeEncodeError: - raise TypeError('%r is not representable in ASCII' % class_name) - metacls = cls.__class__ - if type is None: - bases = (cls, ) - else: - bases = (type, cls) - classdict = metacls.__prepare__(class_name, bases) - _order_ = [] - - # special processing needed for names? - if isinstance(names, basestring): - names = names.replace(',', ' ').split() - if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): - names = [(e, i+start) for (i, e) in enumerate(names)] - - # Here, names is either an iterable of (name, value) or a mapping. - item = None # in case names is empty - for item in names: - if isinstance(item, basestring): - member_name, member_value = item, names[item] - else: - member_name, member_value = item - classdict[member_name] = member_value - _order_.append(member_name) - # only set _order_ in classdict if name/value was not from a mapping - if not isinstance(item, basestring): - classdict['_order_'] = ' '.join(_order_) - enum_class = metacls.__new__(metacls, class_name, bases, classdict) - - # TODO: replace the frame hack if a blessed way to know the calling - # module is ever developed - if module is None: - try: - module = _sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError): - pass - if module is None: - _make_class_unpicklable(enum_class) - else: - enum_class.__module__ = module - - return enum_class - - @staticmethod - def _get_mixins_(bases): - """Returns the type for creating enum members, and the first inherited - enum class. - - bases: the tuple of bases that was given to __new__ - - """ - if not bases or Enum is None: - return object, Enum - - - # double check that we are not subclassing a class with existing - # enumeration members; while we're at it, see if any other data - # type has been mixed in so we can use the correct __new__ - member_type = first_enum = None - for base in bases: - if (base is not Enum and - issubclass(base, Enum) and - base._member_names_): - raise TypeError("Cannot extend enumerations") - # base is now the last base in bases - if not issubclass(base, Enum): - raise TypeError("new enumerations must be created as " - "`ClassName([mixin_type,] enum_type)`") - - # get correct mix-in type (either mix-in type of Enum subclass, or - # first base if last base is Enum) - if not issubclass(bases[0], Enum): - member_type = bases[0] # first data type - first_enum = bases[-1] # enum type - else: - for base in bases[0].__mro__: - # most common: (IntEnum, int, Enum, object) - # possible: (, , - # , , - # ) - if issubclass(base, Enum): - if first_enum is None: - first_enum = base - else: - if member_type is None: - member_type = base - - return member_type, first_enum - - if pyver < 3.0: - @staticmethod - def _find_new_(classdict, member_type, first_enum): - """Returns the __new__ to be used for creating the enum members. - - classdict: the class dictionary given to __new__ - member_type: the data type whose __new__ will be used by default - first_enum: enumeration to check for an overriding __new__ - - """ - # now find the correct __new__, checking to see of one was defined - # by the user; also check earlier enum classes in case a __new__ was - # saved as __member_new__ - __new__ = classdict.get('__new__', None) - if __new__: - return None, True, True # __new__, save_new, use_args - - N__new__ = getattr(None, '__new__') - O__new__ = getattr(object, '__new__') - if Enum is None: - E__new__ = N__new__ - else: - E__new__ = Enum.__dict__['__new__'] - # check all possibles for __member_new__ before falling back to - # __new__ - for method in ('__member_new__', '__new__'): - for possible in (member_type, first_enum): - try: - target = possible.__dict__[method] - except (AttributeError, KeyError): - target = getattr(possible, method, None) - if target not in [ - None, - N__new__, - O__new__, - E__new__, - ]: - if method == '__member_new__': - classdict['__new__'] = target - return None, False, True - if isinstance(target, staticmethod): - target = target.__get__(member_type) - __new__ = target - break - if __new__ is not None: - break - else: - __new__ = object.__new__ - - # if a non-object.__new__ is used then whatever value/tuple was - # assigned to the enum member name will be passed to __new__ and to the - # new enum member's __init__ - if __new__ is object.__new__: - use_args = False - else: - use_args = True - - return __new__, False, use_args - else: - @staticmethod - def _find_new_(classdict, member_type, first_enum): - """Returns the __new__ to be used for creating the enum members. - - classdict: the class dictionary given to __new__ - member_type: the data type whose __new__ will be used by default - first_enum: enumeration to check for an overriding __new__ - - """ - # now find the correct __new__, checking to see of one was defined - # by the user; also check earlier enum classes in case a __new__ was - # saved as __member_new__ - __new__ = classdict.get('__new__', None) - - # should __new__ be saved as __member_new__ later? - save_new = __new__ is not None - - if __new__ is None: - # check all possibles for __member_new__ before falling back to - # __new__ - for method in ('__member_new__', '__new__'): - for possible in (member_type, first_enum): - target = getattr(possible, method, None) - if target not in ( - None, - None.__new__, - object.__new__, - Enum.__new__, - ): - __new__ = target - break - if __new__ is not None: - break - else: - __new__ = object.__new__ - - # if a non-object.__new__ is used then whatever value/tuple was - # assigned to the enum member name will be passed to __new__ and to the - # new enum member's __init__ - if __new__ is object.__new__: - use_args = False - else: - use_args = True - - return __new__, save_new, use_args - - -######################################################## -# In order to support Python 2 and 3 with a single -# codebase we have to create the Enum methods separately -# and then use the `type(name, bases, dict)` method to -# create the class. -######################################################## -temp_enum_dict = {} -temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" - -def __new__(cls, value): - # all enum instances are actually created during class construction - # without calling this method; this method is called by the metaclass' - # __call__ (i.e. Color(3) ), and by pickle - if type(value) is cls: - # For lookups like Color(Color.red) - value = value.value - #return value - # by-value search for a matching enum member - # see if it's in the reverse mapping (for hashable values) - try: - if value in cls._value2member_map_: - return cls._value2member_map_[value] - except TypeError: - # not there, now do long search -- O(n) behavior - for member in cls._member_map_.values(): - if member.value == value: - return member - raise ValueError("%s is not a valid %s" % (value, cls.__name__)) -temp_enum_dict['__new__'] = __new__ -del __new__ - -def __repr__(self): - return "<%s.%s: %r>" % ( - self.__class__.__name__, self._name_, self._value_) -temp_enum_dict['__repr__'] = __repr__ -del __repr__ - -def __str__(self): - return "%s.%s" % (self.__class__.__name__, self._name_) -temp_enum_dict['__str__'] = __str__ -del __str__ - -if pyver >= 3.0: - def __dir__(self): - added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != '_' and m not in self._member_map_ - ] - return (['__class__', '__doc__', '__module__', ] + added_behavior) - temp_enum_dict['__dir__'] = __dir__ - del __dir__ - -def __format__(self, format_spec): - # mixed-in Enums should use the mixed-in type's __format__, otherwise - # we can get strange results with the Enum name showing up instead of - # the value - - # pure Enum branch - if self._member_type_ is object: - cls = str - val = str(self) - # mix-in branch - else: - cls = self._member_type_ - val = self.value - return cls.__format__(val, format_spec) -temp_enum_dict['__format__'] = __format__ -del __format__ - - -#################################### -# Python's less than 2.6 use __cmp__ - -if pyver < 2.6: - - def __cmp__(self, other): - if type(other) is self.__class__: - if self is other: - return 0 - return -1 - return NotImplemented - raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__cmp__'] = __cmp__ - del __cmp__ - -else: - - def __le__(self, other): - raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__le__'] = __le__ - del __le__ - - def __lt__(self, other): - raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__lt__'] = __lt__ - del __lt__ - - def __ge__(self, other): - raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__ge__'] = __ge__ - del __ge__ - - def __gt__(self, other): - raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__gt__'] = __gt__ - del __gt__ - - -def __eq__(self, other): - if type(other) is self.__class__: - return self is other - return NotImplemented -temp_enum_dict['__eq__'] = __eq__ -del __eq__ - -def __ne__(self, other): - if type(other) is self.__class__: - return self is not other - return NotImplemented -temp_enum_dict['__ne__'] = __ne__ -del __ne__ - -def __hash__(self): - return hash(self._name_) -temp_enum_dict['__hash__'] = __hash__ -del __hash__ - -def __reduce_ex__(self, proto): - return self.__class__, (self._value_, ) -temp_enum_dict['__reduce_ex__'] = __reduce_ex__ -del __reduce_ex__ - -# _RouteClassAttributeToGetattr is used to provide access to the `name` -# and `value` properties of enum members while keeping some measure of -# protection from modification, while still allowing for an enumeration -# to have members named `name` and `value`. This works because enumeration -# members are not set directly on the enum class -- __getattr__ is -# used to look them up. - -@_RouteClassAttributeToGetattr -def name(self): - return self._name_ -temp_enum_dict['name'] = name -del name - -@_RouteClassAttributeToGetattr -def value(self): - return self._value_ -temp_enum_dict['value'] = value -del value - -@classmethod -def _convert(cls, name, module, filter, source=None): - """ - Create a new Enum subclass that replaces a collection of global constants - """ - # convert all constants from source (or module) that pass filter() to - # a new Enum called name, and export the enum and its members back to - # module; - # also, replace the __reduce_ex__ method so unpickling works in - # previous Python versions - module_globals = vars(_sys.modules[module]) - if source: - source = vars(source) - else: - source = module_globals - members = dict((name, value) for name, value in source.items() if filter(name)) - cls = cls(name, members, module=module) - cls.__reduce_ex__ = _reduce_ex_by_name - module_globals.update(cls.__members__) - module_globals[name] = cls - return cls -temp_enum_dict['_convert'] = _convert -del _convert - -Enum = EnumMeta('Enum', (object, ), temp_enum_dict) -del temp_enum_dict - -# Enum has now been created -########################### - -class IntEnum(int, Enum): - """Enum where members are also (and must be) ints""" - -def _reduce_ex_by_name(self, proto): - return self.name - -def unique(enumeration): - """Class decorator that ensures only unique members exist in an enumeration.""" - duplicates = [] - for name, member in enumeration.__members__.items(): - if name != member.name: - duplicates.append((name, member.name)) - if duplicates: - duplicate_names = ', '.join( - ["%s -> %s" % (alias, name) for (alias, name) in duplicates] - ) - raise ValueError('duplicate names found in %r: %s' % - (enumeration, duplicate_names) - ) - return enumeration diff --git a/kafka/vendor/selectors34.py b/kafka/vendor/selectors34.py deleted file mode 100644 index 787490340..000000000 --- a/kafka/vendor/selectors34.py +++ /dev/null @@ -1,641 +0,0 @@ -# pylint: skip-file -# vendored from https://github.com/berkerpeksag/selectors34 -# at commit ff61b82168d2cc9c4922ae08e2a8bf94aab61ea2 (unreleased, ~1.2) -# -# Original author: Charles-Francois Natali (c.f.natali[at]gmail.com) -# Maintainer: Berker Peksag (berker.peksag[at]gmail.com) -# Also see https://pypi.python.org/pypi/selectors34 -"""Selectors module. - -This module allows high-level and efficient I/O multiplexing, built upon the -`select` module primitives. - -The following code adapted from trollius.selectors. -""" -from __future__ import absolute_import - -from abc import ABCMeta, abstractmethod -from collections import namedtuple -try: - from collections.abc import Mapping -except ImportError: - from collections import Mapping -from errno import EINTR -import math -import select -import sys - -from kafka.vendor import six - - -def _wrap_error(exc, mapping, key): - if key not in mapping: - return - new_err_cls = mapping[key] - new_err = new_err_cls(*exc.args) - - # raise a new exception with the original traceback - if hasattr(exc, '__traceback__'): - traceback = exc.__traceback__ - else: - traceback = sys.exc_info()[2] - six.reraise(new_err_cls, new_err, traceback) - - -# generic events, that must be mapped to implementation-specific ones -EVENT_READ = (1 << 0) -EVENT_WRITE = (1 << 1) - - -def _fileobj_to_fd(fileobj): - """Return a file descriptor from a file object. - - Parameters: - fileobj -- file object or file descriptor - - Returns: - corresponding file descriptor - - Raises: - ValueError if the object is invalid - """ - if isinstance(fileobj, six.integer_types): - fd = fileobj - else: - try: - fd = int(fileobj.fileno()) - except (AttributeError, TypeError, ValueError): - raise ValueError("Invalid file object: " - "{0!r}".format(fileobj)) - if fd < 0: - raise ValueError("Invalid file descriptor: {0}".format(fd)) - return fd - - -SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) -"""Object used to associate a file object to its backing file descriptor, -selected event mask and attached data.""" - - -class _SelectorMapping(Mapping): - """Mapping of file objects to selector keys.""" - - def __init__(self, selector): - self._selector = selector - - def __len__(self): - return len(self._selector._fd_to_key) - - def __getitem__(self, fileobj): - try: - fd = self._selector._fileobj_lookup(fileobj) - return self._selector._fd_to_key[fd] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - - def __iter__(self): - return iter(self._selector._fd_to_key) - -# Using six.add_metaclass() decorator instead of six.with_metaclass() because -# the latter leaks temporary_class to garbage with gc disabled -@six.add_metaclass(ABCMeta) -class BaseSelector(object): - """Selector abstract base class. - - A selector supports registering file objects to be monitored for specific - I/O events. - - A file object is a file descriptor or any object with a `fileno()` method. - An arbitrary object can be attached to the file object, which can be used - for example to store context information, a callback, etc. - - A selector can use various implementations (select(), poll(), epoll()...) - depending on the platform. The default `Selector` class uses the most - efficient implementation on the current platform. - """ - - @abstractmethod - def register(self, fileobj, events, data=None): - """Register a file object. - - Parameters: - fileobj -- file object or file descriptor - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) - data -- attached data - - Returns: - SelectorKey instance - - Raises: - ValueError if events is invalid - KeyError if fileobj is already registered - OSError if fileobj is closed or otherwise is unacceptable to - the underlying system call (if a system call is made) - - Note: - OSError may or may not be raised - """ - raise NotImplementedError - - @abstractmethod - def unregister(self, fileobj): - """Unregister a file object. - - Parameters: - fileobj -- file object or file descriptor - - Returns: - SelectorKey instance - - Raises: - KeyError if fileobj is not registered - - Note: - If fileobj is registered but has since been closed this does - *not* raise OSError (even if the wrapped syscall does) - """ - raise NotImplementedError - - def modify(self, fileobj, events, data=None): - """Change a registered file object monitored events or attached data. - - Parameters: - fileobj -- file object or file descriptor - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) - data -- attached data - - Returns: - SelectorKey instance - - Raises: - Anything that unregister() or register() raises - """ - self.unregister(fileobj) - return self.register(fileobj, events, data) - - @abstractmethod - def select(self, timeout=None): - """Perform the actual selection, until some monitored file objects are - ready or a timeout expires. - - Parameters: - timeout -- if timeout > 0, this specifies the maximum wait time, in - seconds - if timeout <= 0, the select() call won't block, and will - report the currently ready file objects - if timeout is None, select() will block until a monitored - file object becomes ready - - Returns: - list of (key, events) for ready file objects - `events` is a bitwise mask of EVENT_READ|EVENT_WRITE - """ - raise NotImplementedError - - def close(self): - """Close the selector. - - This must be called to make sure that any underlying resource is freed. - """ - pass - - def get_key(self, fileobj): - """Return the key associated to a registered file object. - - Returns: - SelectorKey for this file object - """ - mapping = self.get_map() - if mapping is None: - raise RuntimeError('Selector is closed') - try: - return mapping[fileobj] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - - @abstractmethod - def get_map(self): - """Return a mapping of file objects to selector keys.""" - raise NotImplementedError - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - -class _BaseSelectorImpl(BaseSelector): - """Base selector implementation.""" - - def __init__(self): - # this maps file descriptors to keys - self._fd_to_key = {} - # read-only mapping returned by get_map() - self._map = _SelectorMapping(self) - - def _fileobj_lookup(self, fileobj): - """Return a file descriptor from a file object. - - This wraps _fileobj_to_fd() to do an exhaustive search in case - the object is invalid but we still have it in our map. This - is used by unregister() so we can unregister an object that - was previously registered even if it is closed. It is also - used by _SelectorMapping. - """ - try: - return _fileobj_to_fd(fileobj) - except ValueError: - # Do an exhaustive search. - for key in self._fd_to_key.values(): - if key.fileobj is fileobj: - return key.fd - # Raise ValueError after all. - raise - - def register(self, fileobj, events, data=None): - if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {0!r}".format(events)) - - key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) - - if key.fd in self._fd_to_key: - raise KeyError("{0!r} (FD {1}) is already registered" - .format(fileobj, key.fd)) - - self._fd_to_key[key.fd] = key - return key - - def unregister(self, fileobj): - try: - key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - return key - - def modify(self, fileobj, events, data=None): - # TODO: Subclasses can probably optimize this even further. - try: - key = self._fd_to_key[self._fileobj_lookup(fileobj)] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - if events != key.events: - self.unregister(fileobj) - key = self.register(fileobj, events, data) - elif data != key.data: - # Use a shortcut to update the data. - key = key._replace(data=data) - self._fd_to_key[key.fd] = key - return key - - def close(self): - self._fd_to_key.clear() - self._map = None - - def get_map(self): - return self._map - - def _key_from_fd(self, fd): - """Return the key associated to a given file descriptor. - - Parameters: - fd -- file descriptor - - Returns: - corresponding key, or None if not found - """ - try: - return self._fd_to_key[fd] - except KeyError: - return None - - -class SelectSelector(_BaseSelectorImpl): - """Select-based selector.""" - - def __init__(self): - super(SelectSelector, self).__init__() - self._readers = set() - self._writers = set() - - def register(self, fileobj, events, data=None): - key = super(SelectSelector, self).register(fileobj, events, data) - if events & EVENT_READ: - self._readers.add(key.fd) - if events & EVENT_WRITE: - self._writers.add(key.fd) - return key - - def unregister(self, fileobj): - key = super(SelectSelector, self).unregister(fileobj) - self._readers.discard(key.fd) - self._writers.discard(key.fd) - return key - - if sys.platform == 'win32': - def _select(self, r, w, _, timeout=None): - r, w, x = select.select(r, w, w, timeout) - return r, w + x, [] - else: - _select = staticmethod(select.select) - - def select(self, timeout=None): - timeout = None if timeout is None else max(timeout, 0) - ready = [] - try: - r, w, _ = self._select(self._readers, self._writers, [], timeout) - except select.error as exc: - if exc.args[0] == EINTR: - return ready - else: - raise - r = set(r) - w = set(w) - for fd in r | w: - events = 0 - if fd in r: - events |= EVENT_READ - if fd in w: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - -if hasattr(select, 'poll'): - - class PollSelector(_BaseSelectorImpl): - """Poll-based selector.""" - - def __init__(self): - super(PollSelector, self).__init__() - self._poll = select.poll() - - def register(self, fileobj, events, data=None): - key = super(PollSelector, self).register(fileobj, events, data) - poll_events = 0 - if events & EVENT_READ: - poll_events |= select.POLLIN - if events & EVENT_WRITE: - poll_events |= select.POLLOUT - self._poll.register(key.fd, poll_events) - return key - - def unregister(self, fileobj): - key = super(PollSelector, self).unregister(fileobj) - self._poll.unregister(key.fd) - return key - - def select(self, timeout=None): - if timeout is None: - timeout = None - elif timeout <= 0: - timeout = 0 - else: - # poll() has a resolution of 1 millisecond, round away from - # zero to wait *at least* timeout seconds. - timeout = int(math.ceil(timeout * 1e3)) - ready = [] - try: - fd_event_list = self._poll.poll(timeout) - except select.error as exc: - if exc.args[0] == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.POLLIN: - events |= EVENT_WRITE - if event & ~select.POLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - -if hasattr(select, 'epoll'): - - class EpollSelector(_BaseSelectorImpl): - """Epoll-based selector.""" - - def __init__(self): - super(EpollSelector, self).__init__() - self._epoll = select.epoll() - - def fileno(self): - return self._epoll.fileno() - - def register(self, fileobj, events, data=None): - key = super(EpollSelector, self).register(fileobj, events, data) - epoll_events = 0 - if events & EVENT_READ: - epoll_events |= select.EPOLLIN - if events & EVENT_WRITE: - epoll_events |= select.EPOLLOUT - self._epoll.register(key.fd, epoll_events) - return key - - def unregister(self, fileobj): - key = super(EpollSelector, self).unregister(fileobj) - try: - self._epoll.unregister(key.fd) - except IOError: - # This can happen if the FD was closed since it - # was registered. - pass - return key - - def select(self, timeout=None): - if timeout is None: - timeout = -1 - elif timeout <= 0: - timeout = 0 - else: - # epoll_wait() has a resolution of 1 millisecond, round away - # from zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) * 1e-3 - - # epoll_wait() expects `maxevents` to be greater than zero; - # we want to make sure that `select()` can be called when no - # FD is registered. - max_ev = max(len(self._fd_to_key), 1) - - ready = [] - try: - fd_event_list = self._epoll.poll(timeout, max_ev) - except IOError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.EPOLLIN: - events |= EVENT_WRITE - if event & ~select.EPOLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._epoll.close() - super(EpollSelector, self).close() - - -if hasattr(select, 'devpoll'): - - class DevpollSelector(_BaseSelectorImpl): - """Solaris /dev/poll selector.""" - - def __init__(self): - super(DevpollSelector, self).__init__() - self._devpoll = select.devpoll() - - def fileno(self): - return self._devpoll.fileno() - - def register(self, fileobj, events, data=None): - key = super(DevpollSelector, self).register(fileobj, events, data) - poll_events = 0 - if events & EVENT_READ: - poll_events |= select.POLLIN - if events & EVENT_WRITE: - poll_events |= select.POLLOUT - self._devpoll.register(key.fd, poll_events) - return key - - def unregister(self, fileobj): - key = super(DevpollSelector, self).unregister(fileobj) - self._devpoll.unregister(key.fd) - return key - - def select(self, timeout=None): - if timeout is None: - timeout = None - elif timeout <= 0: - timeout = 0 - else: - # devpoll() has a resolution of 1 millisecond, round away from - # zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) - ready = [] - try: - fd_event_list = self._devpoll.poll(timeout) - except OSError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.POLLIN: - events |= EVENT_WRITE - if event & ~select.POLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._devpoll.close() - super(DevpollSelector, self).close() - - -if hasattr(select, 'kqueue'): - - class KqueueSelector(_BaseSelectorImpl): - """Kqueue-based selector.""" - - def __init__(self): - super(KqueueSelector, self).__init__() - self._kqueue = select.kqueue() - - def fileno(self): - return self._kqueue.fileno() - - def register(self, fileobj, events, data=None): - key = super(KqueueSelector, self).register(fileobj, events, data) - if events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, - select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - if events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, - select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return key - - def unregister(self, fileobj): - key = super(KqueueSelector, self).unregister(fileobj) - if key.events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, - select.KQ_EV_DELETE) - try: - self._kqueue.control([kev], 0, 0) - except OSError: - # This can happen if the FD was closed since it - # was registered. - pass - if key.events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, - select.KQ_EV_DELETE) - try: - self._kqueue.control([kev], 0, 0) - except OSError: - # See comment above. - pass - return key - - def select(self, timeout=None): - timeout = None if timeout is None else max(timeout, 0) - max_ev = len(self._fd_to_key) - ready = [] - try: - kev_list = self._kqueue.control(None, max_ev, timeout) - except OSError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for kev in kev_list: - fd = kev.ident - flag = kev.filter - events = 0 - if flag == select.KQ_FILTER_READ: - events |= EVENT_READ - if flag == select.KQ_FILTER_WRITE: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._kqueue.close() - super(KqueueSelector, self).close() - - -# Choose the best implementation, roughly: -# epoll|kqueue|devpoll > poll > select. -# select() also can't accept a FD > FD_SETSIZE (usually around 1024) -if 'KqueueSelector' in globals(): - DefaultSelector = KqueueSelector -elif 'EpollSelector' in globals(): - DefaultSelector = EpollSelector -elif 'DevpollSelector' in globals(): - DefaultSelector = DevpollSelector -elif 'PollSelector' in globals(): - DefaultSelector = PollSelector -else: - DefaultSelector = SelectSelector diff --git a/kafka/vendor/six.py b/kafka/vendor/six.py deleted file mode 100644 index 319821353..000000000 --- a/kafka/vendor/six.py +++ /dev/null @@ -1,1004 +0,0 @@ -# pylint: skip-file - -# Copyright (c) 2010-2020 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Utilities for writing code that runs on Python 2 and 3""" - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.16.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - - # Don't del it here, cause with gc disabled this "leaks" to garbage. - # Note: This is a kafka-python customization, details at: - # https://github.com/dpkp/kafka-python/pull/979#discussion_r100403389 - # del X - -if PY34: - from importlib.util import spec_from_loader -else: - spec_from_loader = None - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def find_spec(self, fullname, path, target=None): - if fullname in self.known_modules: - return spec_from_loader(fullname, self) - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - get_source = get_code # same as get_code - - def create_module(self, spec): - return self.load_module(spec.name) - - def exec_module(self, module): - pass - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("getoutput", "commands", "subprocess"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections", "IterableUserDict", "UserDict"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [ - MovedModule("winreg", "_winreg"), - ] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("splitvalue", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), - MovedAttribute("parse_http_list", "urllib2", "urllib.request"), - MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes - -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] - -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - def advance_iterator(it): - return it.next() -next = advance_iterator - - -try: - callable = callable -except NameError: - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") - - -if PY3: - def b(s): - return s.encode("latin-1") - - def u(s): - return s - unichr = chr - import struct - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - StringIO = io.StringIO - BytesIO = io.BytesIO - del io - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" - _assertNotRegex = "assertNotRegex" -else: - def b(s): - return s - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -def assertNotRegex(self, *args, **kwargs): - return getattr(self, _assertNotRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - try: - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - finally: - value = None - tb = None - -else: - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_("""def reraise(tp, value, tb=None): - try: - raise tp, value, tb - finally: - tb = None -""") - - -if sys.version_info[:2] > (3,): - exec_("""def raise_from(value, from_value): - try: - raise value from from_value - finally: - value = None -""") -else: - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - # This does exactly the same what the :func:`py3:functools.update_wrapper` - # function does on Python versions after 3.2. It sets the ``__wrapped__`` - # attribute on ``wrapper`` object and it doesn't raise an error if any of - # the attributes mentioned in ``assigned`` and ``updated`` are missing on - # ``wrapped`` object. - def _update_wrapper(wrapper, wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - for attr in assigned: - try: - value = getattr(wrapped, attr) - except AttributeError: - continue - else: - setattr(wrapper, attr, value) - for attr in updated: - getattr(wrapper, attr).update(getattr(wrapped, attr, {})) - wrapper.__wrapped__ = wrapped - return wrapper - _update_wrapper.__doc__ = functools.update_wrapper.__doc__ - - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - return functools.partial(_update_wrapper, wrapped=wrapped, - assigned=assigned, updated=updated) - wraps.__doc__ = functools.wraps.__doc__ - -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - - def __new__(cls, name, this_bases, d): - if sys.version_info[:2] >= (3, 7): - # This version introduced PEP 560 that requires a bit - # of extra care (we mimic what is done by __build_class__). - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d['__orig_bases__'] = bases - else: - resolved_bases = bases - return meta(name, resolved_bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - if hasattr(cls, '__qualname__'): - orig_vars['__qualname__'] = cls.__qualname__ - return metaclass(cls.__name__, cls.__bases__, orig_vars) - return wrapper - - -def ensure_binary(s, encoding='utf-8', errors='strict'): - """Coerce **s** to six.binary_type. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` - """ - if isinstance(s, binary_type): - return s - if isinstance(s, text_type): - return s.encode(encoding, errors) - raise TypeError("not expecting type '%s'" % type(s)) - - -def ensure_str(s, encoding='utf-8', errors='strict'): - """Coerce *s* to `str`. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - # Optimization: Fast return for the common case. - if type(s) is str: - return s - if PY2 and isinstance(s, text_type): - return s.encode(encoding, errors) - elif PY3 and isinstance(s, binary_type): - return s.decode(encoding, errors) - elif not isinstance(s, (text_type, binary_type)): - raise TypeError("not expecting type '%s'" % type(s)) - return s - - -def ensure_text(s, encoding='utf-8', errors='strict'): - """Coerce *s* to six.text_type. - - For Python 2: - - `unicode` -> `unicode` - - `str` -> `unicode` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if isinstance(s, binary_type): - return s.decode(encoding, errors) - elif isinstance(s, text_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def python_2_unicode_compatible(klass): - """ - A class decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) diff --git a/kafka/vendor/socketpair.py b/kafka/vendor/socketpair.py deleted file mode 100644 index b55e629ee..000000000 --- a/kafka/vendor/socketpair.py +++ /dev/null @@ -1,58 +0,0 @@ -# pylint: skip-file -# vendored from https://github.com/mhils/backports.socketpair -from __future__ import absolute_import - -import sys -import socket -import errno - -_LOCALHOST = '127.0.0.1' -_LOCALHOST_V6 = '::1' - -if not hasattr(socket, "socketpair"): - # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - if family == socket.AF_INET: - host = _LOCALHOST - elif family == socket.AF_INET6: - host = _LOCALHOST_V6 - else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with - # setblocking(False) that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(min(socket.SOMAXCONN, 128)) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - try: - csock.setblocking(False) - if sys.version_info >= (3, 0): - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - else: - try: - csock.connect((addr, port)) - except socket.error as e: - if e.errno != errno.WSAEWOULDBLOCK: - raise - csock.setblocking(True) - ssock, _ = lsock.accept() - except Exception: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) - - socket.socketpair = socketpair diff --git a/kafka/version.py b/kafka/version.py index 8a26a1868..ee5ea98ce 100644 --- a/kafka/version.py +++ b/kafka/version.py @@ -1,9 +1,6 @@ import sys -if sys.version_info < (3, 8): - from importlib_metadata import version -else: - from importlib.metadata import version +from importlib.metadata import version __version__ = version("kafka-python-ng") diff --git a/pylint.rc b/pylint.rc index 851275bcc..d22e523ec 100644 --- a/pylint.rc +++ b/pylint.rc @@ -1,6 +1,5 @@ [TYPECHECK] ignored-classes=SyncManager,_socketobject -ignored-modules=kafka.vendor.six.moves generated-members=py.* [MESSAGES CONTROL] diff --git a/requirements-dev.txt b/requirements-dev.txt index 1fa933da2..de29cad63 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,6 @@ crc32c docker-py flake8 lz4 -mock py pylint pytest @@ -15,3 +14,4 @@ Sphinx sphinx-rtd-theme tox xxhash +botocore diff --git a/test/conftest.py b/test/conftest.py index 3fa0262fd..824c0fa76 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -119,7 +119,7 @@ def factory(**kafka_admin_client_params): @pytest.fixture def topic(kafka_broker, request): """Return a topic fixture""" - topic_name = '%s_%s' % (request.node.name, random_string(10)) + topic_name = f'{request.node.originalname}_{random_string(10)}' kafka_broker.create_topics([topic_name]) return topic_name diff --git a/test/fixtures.py b/test/fixtures.py index d9c072b86..4ed515da3 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -7,11 +7,11 @@ import socket import subprocess import time +import urllib import uuid +from urllib.parse import urlparse import py -from kafka.vendor.six.moves import urllib, range -from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 from kafka import errors, KafkaAdminClient, KafkaClient, KafkaConsumer, KafkaProducer from kafka.errors import InvalidReplicationFactorError diff --git a/test/record/test_default_records.py b/test/record/test_default_records.py index c3a7b02c8..a3f69da6d 100644 --- a/test/record/test_default_records.py +++ b/test/record/test_default_records.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals import pytest -from mock import patch +from unittest.mock import patch import kafka.codec from kafka.record.default_records import ( DefaultRecordBatch, DefaultRecordBatchBuilder @@ -197,7 +195,7 @@ def test_unavailable_codec(magic, compression_type, name, checker_name): magic=2, compression_type=compression_type, is_transactional=0, producer_id=-1, producer_epoch=-1, base_sequence=-1, batch_size=1024) - error_msg = "Libraries for {} compression codec not found".format(name) + error_msg = f"Libraries for {name} compression codec not found" with pytest.raises(UnsupportedCodecError, match=error_msg): builder.append(0, timestamp=None, key=None, value=b"M", headers=[]) builder.build() diff --git a/test/record/test_legacy_records.py b/test/record/test_legacy_records.py index 43970f7c9..f16c29809 100644 --- a/test/record/test_legacy_records.py +++ b/test/record/test_legacy_records.py @@ -1,6 +1,5 @@ -from __future__ import unicode_literals import pytest -from mock import patch +from unittest.mock import patch from kafka.record.legacy_records import ( LegacyRecordBatch, LegacyRecordBatchBuilder ) @@ -186,7 +185,7 @@ def test_unavailable_codec(magic, compression_type, name, checker_name): # Check that builder raises error builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=1024) - error_msg = "Libraries for {} compression codec not found".format(name) + error_msg = f"Libraries for {name} compression codec not found" with pytest.raises(UnsupportedCodecError, match=error_msg): builder.append(0, timestamp=None, key=None, value=b"M") builder.build() diff --git a/test/record/test_records.py b/test/record/test_records.py index 9f72234ae..adde3ba06 100644 --- a/test/record/test_records.py +++ b/test/record/test_records.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals import pytest from kafka.record import MemoryRecords, MemoryRecordsBuilder from kafka.errors import CorruptRecordException diff --git a/test/test_assignors.py b/test/test_assignors.py index 858ef426d..937afa86b 100644 --- a/test/test_assignors.py +++ b/test/test_assignors.py @@ -1,5 +1,4 @@ # pylint: skip-file -from __future__ import absolute_import from collections import defaultdict from random import randint, sample @@ -11,7 +10,6 @@ from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor, StickyAssignorUserDataV1 from kafka.coordinator.protocol import ConsumerProtocolMemberAssignment, ConsumerProtocolMemberMetadata -from kafka.vendor import six @pytest.fixture(autouse=True) @@ -110,7 +108,7 @@ def test_sticky_assignor1(mocker): del subscriptions['C1'] member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -153,7 +151,7 @@ def test_sticky_assignor2(mocker): 'C2': {'t0', 't1', 't2'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -166,7 +164,7 @@ def test_sticky_assignor2(mocker): del subscriptions['C0'] member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -325,7 +323,7 @@ def test_sticky_add_remove_consumer_one_topic(mocker): 'C2': {'t'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata( topics, assignment[member].partitions() if member in assignment else [] ) @@ -337,7 +335,7 @@ def test_sticky_add_remove_consumer_one_topic(mocker): 'C2': {'t'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -366,7 +364,7 @@ def test_sticky_add_remove_topic_two_consumers(mocker): 'C2': {'t1', 't2'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -381,7 +379,7 @@ def test_sticky_add_remove_topic_two_consumers(mocker): 'C2': {'t2'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -393,9 +391,9 @@ def test_sticky_add_remove_topic_two_consumers(mocker): def test_sticky_reassignment_after_one_consumer_leaves(mocker): - partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 20)]) + partitions = {'t{}'.format(i): set(range(i)) for i in range(1, 20)} cluster = create_cluster( - mocker, topics=set(['t{}'.format(i) for i in range(1, 20)]), topic_partitions_lambda=lambda t: partitions[t] + mocker, topics={'t{}'.format(i) for i in range(1, 20)}, topic_partitions_lambda=lambda t: partitions[t] ) subscriptions = {} @@ -412,7 +410,7 @@ def test_sticky_reassignment_after_one_consumer_leaves(mocker): del subscriptions['C10'] member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -434,7 +432,7 @@ def test_sticky_reassignment_after_one_consumer_added(mocker): subscriptions['C10'] = {'t'} member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata( topics, assignment[member].partitions() if member in assignment else [] ) @@ -444,14 +442,14 @@ def test_sticky_reassignment_after_one_consumer_added(mocker): def test_sticky_same_subscriptions(mocker): - partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 15)]) + partitions = {'t{}'.format(i): set(range(i)) for i in range(1, 15)} cluster = create_cluster( - mocker, topics=set(['t{}'.format(i) for i in range(1, 15)]), topic_partitions_lambda=lambda t: partitions[t] + mocker, topics={'t{}'.format(i) for i in range(1, 15)}, topic_partitions_lambda=lambda t: partitions[t] ) subscriptions = defaultdict(set) for i in range(1, 9): - for j in range(1, len(six.viewkeys(partitions)) + 1): + for j in range(1, len(partitions.keys()) + 1): subscriptions['C{}'.format(i)].add('t{}'.format(j)) member_metadata = make_member_metadata(subscriptions) @@ -461,7 +459,7 @@ def test_sticky_same_subscriptions(mocker): del subscriptions['C5'] member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) @@ -472,8 +470,8 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): n_topics = 40 n_consumers = 200 - all_topics = set(['t{}'.format(i) for i in range(1, n_topics + 1)]) - partitions = dict([(t, set(range(1, randint(0, 10) + 1))) for t in all_topics]) + all_topics = {'t{}'.format(i) for i in range(1, n_topics + 1)} + partitions = {t: set(range(1, randint(0, 10) + 1)) for t in all_topics} cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) subscriptions = defaultdict(set) @@ -487,7 +485,7 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): verify_validity_and_balance(subscriptions, assignment) member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) for i in range(50): @@ -516,7 +514,7 @@ def test_new_subscription(mocker): subscriptions['C0'].add('t1') member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -539,7 +537,7 @@ def test_move_existing_assignments(mocker): } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, member_assignments[member]) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -559,7 +557,7 @@ def test_stickiness(mocker): assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) partitions_assigned = {} - for consumer, consumer_assignment in six.iteritems(assignment): + for consumer, consumer_assignment in assignment.items(): assert ( len(consumer_assignment.partitions()) <= 1 ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) @@ -569,14 +567,14 @@ def test_stickiness(mocker): # removing the potential group leader del subscriptions['C1'] member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() - for consumer, consumer_assignment in six.iteritems(assignment): + for consumer, consumer_assignment in assignment.items(): assert ( len(consumer_assignment.partitions()) <= 1 ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) @@ -624,7 +622,7 @@ def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): 'C': {}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) cluster = create_cluster(mocker, topics={}, topics_partitions={}) @@ -643,7 +641,7 @@ def test_conflicting_previous_assignments(mocker): 'C2': {'t'}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): # assume both C1 and C2 have partition 1 assigned to them in generation 1 member_metadata[member] = StickyPartitionAssignor._metadata(topics, [TopicPartition('t', 0), TopicPartition('t', 0)], 1) @@ -656,7 +654,7 @@ def test_conflicting_previous_assignments(mocker): ) def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_number, n_topics, n_consumers): all_topics = sorted(['t{}'.format(i) for i in range(1, n_topics + 1)]) - partitions = dict([(t, set(range(1, i + 1))) for i, t in enumerate(all_topics)]) + partitions = {t: set(range(1, i + 1)) for i, t in enumerate(all_topics)} cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) subscriptions = defaultdict(set) @@ -675,7 +673,7 @@ def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_nu subscriptions['C{}'.format(i)].update(topics_sample) member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -777,7 +775,7 @@ def test_assignment_with_conflicting_previous_generations(mocker, execution_numb 'C3': 2, } member_metadata = {} - for member in six.iterkeys(member_assignments): + for member in member_assignments.keys(): member_metadata[member] = StickyPartitionAssignor._metadata({'t'}, member_assignments[member], member_generations[member]) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -787,7 +785,7 @@ def test_assignment_with_conflicting_previous_generations(mocker, execution_numb def make_member_metadata(subscriptions): member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) return member_metadata @@ -812,9 +810,9 @@ def verify_validity_and_balance(subscriptions, assignment): :param subscriptions topic subscriptions of each consumer :param assignment: given assignment for balance check """ - assert six.viewkeys(subscriptions) == six.viewkeys(assignment) + assert subscriptions.keys() == assignment.keys() - consumers = sorted(six.viewkeys(assignment)) + consumers = sorted(assignment.keys()) for i in range(len(consumers)): consumer = consumers[i] partitions = assignment[consumer].partitions() @@ -845,7 +843,7 @@ def verify_validity_and_balance(subscriptions, assignment): assignments_by_topic = group_partitions_by_topic(partitions) other_assignments_by_topic = group_partitions_by_topic(other_partitions) if len(partitions) > len(other_partitions): - for topic in six.iterkeys(assignments_by_topic): + for topic in assignments_by_topic.keys(): assert topic not in other_assignments_by_topic, ( 'Error: Some partitions can be moved from {} ({} partitions) ' 'to {} ({} partitions) ' @@ -854,7 +852,7 @@ def verify_validity_and_balance(subscriptions, assignment): 'Assignments: {}'.format(consumer, len(partitions), other_consumer, len(other_partitions), subscriptions, assignment) ) if len(other_partitions) > len(partitions): - for topic in six.iterkeys(other_assignments_by_topic): + for topic in other_assignments_by_topic.keys(): assert topic not in assignments_by_topic, ( 'Error: Some partitions can be moved from {} ({} partitions) ' 'to {} ({} partitions) ' diff --git a/test/test_client_async.py b/test/test_client_async.py index 66b227aa9..84d52807e 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -1,13 +1,5 @@ -from __future__ import absolute_import, division - -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - import kafka.vendor.selectors34 as selectors - import socket +import selectors import time import pytest diff --git a/test/test_codec.py b/test/test_codec.py index e05707451..91bfd01ab 100644 --- a/test/test_codec.py +++ b/test/test_codec.py @@ -4,7 +4,6 @@ import struct import pytest -from kafka.vendor.six.moves import range from kafka.codec import ( has_snappy, has_lz4, has_zstd, diff --git a/test/test_conn.py b/test/test_conn.py index 966f7b34d..d595fac3a 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -2,9 +2,9 @@ from __future__ import absolute_import from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET +from unittest import mock import socket -import mock import pytest from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts diff --git a/test/test_consumer.py b/test/test_consumer.py index 436fe55c0..0c6110517 100644 --- a/test/test_consumer.py +++ b/test/test_consumer.py @@ -24,3 +24,8 @@ def test_subscription_copy(self): assert sub == set(['foo']) sub.add('fizz') assert consumer.subscription() == set(['foo']) + + def test_version_for_static_membership(self): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(2, 3, 0), group_instance_id='test') + with pytest.raises(KafkaConfigurationError): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(2, 2, 0), group_instance_id='test') diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py index 4904ffeea..ed6863fa2 100644 --- a/test/test_consumer_group.py +++ b/test/test_consumer_group.py @@ -5,7 +5,6 @@ import time import pytest -from kafka.vendor import six from kafka.conn import ConnectionStates from kafka.consumer.group import KafkaConsumer @@ -62,7 +61,7 @@ def consumer_thread(i): group_id=group_id, heartbeat_interval_ms=500) while not stop[i].is_set(): - for tp, records in six.itervalues(consumers[i].poll(100)): + for tp, records in consumers[i].poll(100).values(): messages[i][tp].extend(records) consumers[i].close() consumers[i] = None @@ -93,8 +92,8 @@ def consumer_thread(i): logging.info('All consumers have assignment... checking for stable group') # Verify all consumers are in the same generation # then log state and break while loop - generations = set([consumer._coordinator._generation.generation_id - for consumer in list(consumers.values())]) + generations = {consumer._coordinator._generation.generation_id + for consumer in list(consumers.values())} # New generation assignment is not complete until # coordinator.rejoining = False @@ -120,9 +119,9 @@ def consumer_thread(i): assert set.isdisjoint(consumers[c].assignment(), group_assignment) group_assignment.update(consumers[c].assignment()) - assert group_assignment == set([ + assert group_assignment == { TopicPartition(topic, partition) - for partition in range(num_partitions)]) + for partition in range(num_partitions)} logging.info('Assignment looks good!') finally: @@ -143,7 +142,7 @@ def test_paused(kafka_broker, topic): assert set() == consumer.paused() consumer.pause(topics[0]) - assert set([topics[0]]) == consumer.paused() + assert {topics[0]} == consumer.paused() consumer.resume(topics[0]) assert set() == consumer.paused() @@ -181,3 +180,23 @@ def test_heartbeat_thread(kafka_broker, topic): consumer.poll(timeout_ms=100) assert consumer._coordinator.heartbeat.last_poll > last_poll consumer.close() + + +@pytest.mark.skipif(env_kafka_version() < (2, 3, 0), reason="Requires KAFKA_VERSION >= 2.3.0") +@pytest.mark.parametrize('leave, result', [ + (False, True), + (True, False), +]) +def test_kafka_consumer_rebalance_for_static_members(kafka_consumer_factory, leave, result): + GROUP_ID = random_string(10) + + consumer1 = kafka_consumer_factory(group_id=GROUP_ID, group_instance_id=GROUP_ID, leave_group_on_close=leave) + consumer1.poll() + generation1 = consumer1._coordinator.generation().generation_id + consumer1.close() + + consumer2 = kafka_consumer_factory(group_id=GROUP_ID, group_instance_id=GROUP_ID, leave_group_on_close=leave) + consumer2.poll() + generation2 = consumer2._coordinator.generation().generation_id + consumer2.close() + assert (generation1 == generation2) is result diff --git a/test/test_consumer_integration.py b/test/test_consumer_integration.py index 90b7ed203..d3165cd63 100644 --- a/test/test_consumer_integration.py +++ b/test/test_consumer_integration.py @@ -1,9 +1,8 @@ import logging import time +from unittest.mock import patch -from mock import patch import pytest -from kafka.vendor.six.moves import range import kafka.codec from kafka.errors import UnsupportedCodecError, UnsupportedVersionError @@ -99,6 +98,7 @@ def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): assert t.interval >= (TIMEOUT_MS / 1000.0) +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1") def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messages): GROUP_ID = random_string(10) @@ -143,6 +143,7 @@ def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messa assert_message_count(output_msgs1 + output_msgs2, 200) +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_messages): send_messages(range(100, 200), partition=0) @@ -162,6 +163,7 @@ def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_mes assert seen_partitions == {TopicPartition(topic, 0), TopicPartition(topic, 1)} +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages): # We send to only 1 partition so we don't have parallel requests to 2 @@ -188,6 +190,7 @@ def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages) assert_message_count(fetched_msgs, 10) +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer): late_time = int(time.time()) * 1000 @@ -237,6 +240,7 @@ def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer): assert offsets == {tp: late_msg.offset + 1} +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_producer, topic): tp0 = TopicPartition(topic, 0) @@ -275,6 +279,7 @@ def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_pro } +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1") def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic): consumer = kafka_consumer @@ -284,6 +289,7 @@ def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic): consumer.offsets_for_times({tp: int(time.time())}) +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") def test_kafka_consumer_offsets_for_times_errors(kafka_consumer_factory, topic): consumer = kafka_consumer_factory(fetch_max_wait_ms=200, diff --git a/test/test_msk.py b/test/test_msk.py new file mode 100644 index 000000000..05c84ad16 --- /dev/null +++ b/test/test_msk.py @@ -0,0 +1,65 @@ +import datetime +import json +from unittest import mock + +from kafka.sasl.msk import AwsMskIamClient + + +def client_factory(token=None): + + now = datetime.datetime.utcfromtimestamp(1629321911) + with mock.patch('kafka.sasl.msk.datetime') as mock_dt: + mock_dt.datetime.utcnow = mock.Mock(return_value=now) + return AwsMskIamClient( + host='localhost', + access_key='XXXXXXXXXXXXXXXXXXXX', + secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX', + region='us-east-1', + token=token, + ) + + +def test_aws_msk_iam_client_permanent_credentials(): + client = client_factory(token=None) + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b', + } + assert actual == expected + + +def test_aws_msk_iam_client_temporary_credentials(): + client = client_factory(token='XXXXX') + msg = client.first_message() + assert msg + assert isinstance(msg, bytes) + actual = json.loads(msg) + + expected = { + 'version': '2020_10_22', + 'host': 'localhost', + 'user-agent': 'kafka-python', + 'action': 'kafka-cluster:Connect', + 'x-amz-algorithm': 'AWS4-HMAC-SHA256', + 'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request', + 'x-amz-date': '20210818T212511Z', + 'x-amz-signedheaders': 'host', + 'x-amz-expires': '900', + 'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05', + 'x-amz-security-token': 'XXXXX', + } + assert actual == expected + diff --git a/tox.ini b/tox.ini index d9b1e36d4..a574dc136 100644 --- a/tox.ini +++ b/tox.ini @@ -28,8 +28,9 @@ deps = lz4 xxhash crc32c + botocore commands = - pytest {posargs:--pylint --pylint-rcfile=pylint.rc --pylint-error-types=EF --cov=kafka --cov-config=.covrc} + pytest {posargs:--pylint --pylint-rcfile=pylint.rc --pylint-error-types=EF --cov=kafka} setenv = CRC32C_SW_MODE = auto PROJECT_ROOT = {toxinidir} @@ -38,7 +39,7 @@ passenv = KAFKA_VERSION [testenv:pypy] # pylint is super slow on pypy... -commands = pytest {posargs:--cov=kafka --cov-config=.covrc} +commands = pytest {posargs:--cov=kafka} [testenv:docs] deps =