diff --git a/venv/Lib/site-packages/bson/__init__.py b/venv/Lib/site-packages/bson/__init__.py new file mode 100644 index 00000000..a7c9ddc5 --- /dev/null +++ b/venv/Lib/site-packages/bson/__init__.py @@ -0,0 +1,1464 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BSON (Binary JSON) encoding and decoding. + +The mapping from Python types to BSON types is as follows: + +======================================= ============= =================== +Python Type BSON Type Supported Direction +======================================= ============= =================== +None null both +bool boolean both +int [#int]_ int32 / int64 py -> bson +`bson.int64.Int64` int64 both +float number (real) both +str string both +list array both +dict / `SON` object both +datetime.datetime [#dt]_ [#dt2]_ date both +`bson.regex.Regex` regex both +compiled re [#re]_ regex py -> bson +`bson.binary.Binary` binary both +`bson.objectid.ObjectId` oid both +`bson.dbref.DBRef` dbref both +None undefined bson -> py +`bson.code.Code` code both +str symbol bson -> py +bytes [#bytes]_ binary both +======================================= ============= =================== + +.. [#int] A Python int will be saved as a BSON int32 or BSON int64 depending + on its size. A BSON int32 will always decode to a Python int. A BSON + int64 will always decode to a :class:`~bson.int64.Int64`. +.. [#dt] datetime.datetime instances will be rounded to the nearest + millisecond when saved +.. [#dt2] all datetime.datetime instances are treated as *naive*. clients + should always use UTC. +.. [#re] :class:`~bson.regex.Regex` instances and regular expression + objects from ``re.compile()`` are both saved as BSON regular expressions. + BSON regular expressions are decoded as :class:`~bson.regex.Regex` + instances. +.. [#bytes] The bytes type is encoded as BSON binary with + subtype 0. It will be decoded back to bytes. +""" +from __future__ import annotations + +import datetime +import itertools +import os +import re +import struct +import sys +import uuid +from codecs import utf_8_decode as _utf_8_decode +from codecs import utf_8_encode as _utf_8_encode +from collections import abc as _abc +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + Generator, + Iterator, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from bson.binary import ( + ALL_UUID_SUBTYPES, + CSHARP_LEGACY, + JAVA_LEGACY, + OLD_UUID_SUBTYPE, + STANDARD, + UUID_SUBTYPE, + Binary, + UuidRepresentation, +) +from bson.code import Code +from bson.codec_options import ( + DEFAULT_CODEC_OPTIONS, + CodecOptions, + DatetimeConversion, + _raw_document_class, +) +from bson.datetime_ms import ( + EPOCH_AWARE, + EPOCH_NAIVE, + DatetimeMS, + _datetime_to_millis, + _millis_to_datetime, +) +from bson.dbref import DBRef +from bson.decimal128 import Decimal128 +from bson.errors import InvalidBSON, InvalidDocument, InvalidStringData +from bson.int64 import Int64 +from bson.max_key import MaxKey +from bson.min_key import MinKey +from bson.objectid import ObjectId +from bson.regex import Regex +from bson.son import RE_TYPE, SON +from bson.timestamp import Timestamp +from bson.tz_util import utc + +# Import some modules for type-checking only. +if TYPE_CHECKING: + from bson.raw_bson import RawBSONDocument + from bson.typings import _DocumentType, _ReadableBuffer + +try: + from bson import _cbson # type: ignore[attr-defined] + + _USE_C = True +except ImportError: + _USE_C = False + +__all__ = [ + "ALL_UUID_SUBTYPES", + "CSHARP_LEGACY", + "JAVA_LEGACY", + "OLD_UUID_SUBTYPE", + "STANDARD", + "UUID_SUBTYPE", + "Binary", + "UuidRepresentation", + "Code", + "DEFAULT_CODEC_OPTIONS", + "CodecOptions", + "DBRef", + "Decimal128", + "InvalidBSON", + "InvalidDocument", + "InvalidStringData", + "Int64", + "MaxKey", + "MinKey", + "ObjectId", + "Regex", + "RE_TYPE", + "SON", + "Timestamp", + "utc", + "EPOCH_AWARE", + "EPOCH_NAIVE", + "BSONNUM", + "BSONSTR", + "BSONOBJ", + "BSONARR", + "BSONBIN", + "BSONUND", + "BSONOID", + "BSONBOO", + "BSONDAT", + "BSONNUL", + "BSONRGX", + "BSONREF", + "BSONCOD", + "BSONSYM", + "BSONCWS", + "BSONINT", + "BSONTIM", + "BSONLON", + "BSONDEC", + "BSONMIN", + "BSONMAX", + "get_data_and_view", + "gen_list_name", + "encode", + "decode", + "decode_all", + "decode_iter", + "decode_file_iter", + "is_valid", + "BSON", + "has_c", + "DatetimeConversion", + "DatetimeMS", +] + +BSONNUM = b"\x01" # Floating point +BSONSTR = b"\x02" # UTF-8 string +BSONOBJ = b"\x03" # Embedded document +BSONARR = b"\x04" # Array +BSONBIN = b"\x05" # Binary +BSONUND = b"\x06" # Undefined +BSONOID = b"\x07" # ObjectId +BSONBOO = b"\x08" # Boolean +BSONDAT = b"\x09" # UTC Datetime +BSONNUL = b"\x0A" # Null +BSONRGX = b"\x0B" # Regex +BSONREF = b"\x0C" # DBRef +BSONCOD = b"\x0D" # Javascript code +BSONSYM = b"\x0E" # Symbol +BSONCWS = b"\x0F" # Javascript code with scope +BSONINT = b"\x10" # 32bit int +BSONTIM = b"\x11" # Timestamp +BSONLON = b"\x12" # 64bit int +BSONDEC = b"\x13" # Decimal128 +BSONMIN = b"\xFF" # Min key +BSONMAX = b"\x7F" # Max key + + +_UNPACK_FLOAT_FROM = struct.Struct(" Tuple[Any, memoryview]: + if isinstance(data, (bytes, bytearray)): + return data, memoryview(data) + view = memoryview(data) + return view.tobytes(), view + + +def _raise_unknown_type(element_type: int, element_name: str) -> NoReturn: + """Unknown type helper.""" + raise InvalidBSON( + "Detected unknown BSON type {!r} for fieldname '{}'. Are " + "you using the latest driver version?".format(chr(element_type).encode(), element_name) + ) + + +def _get_int( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[int, int]: + """Decode a BSON int32 to python int.""" + return _UNPACK_INT_FROM(data, position)[0], position + 4 + + +def _get_c_string(data: Any, view: Any, position: int, opts: CodecOptions[Any]) -> Tuple[str, int]: + """Decode a BSON 'C' string to python str.""" + end = data.index(b"\x00", position) + return _utf_8_decode(view[position:end], opts.unicode_decode_error_handler, True)[0], end + 1 + + +def _get_float( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[float, int]: + """Decode a BSON double to python float.""" + return _UNPACK_FLOAT_FROM(data, position)[0], position + 8 + + +def _get_string( + data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy: Any +) -> Tuple[str, int]: + """Decode a BSON string to python str.""" + length = _UNPACK_INT_FROM(data, position)[0] + position += 4 + if length < 1 or obj_end - position < length: + raise InvalidBSON("invalid string length") + end = position + length - 1 + if data[end] != 0: + raise InvalidBSON("invalid end of string") + return _utf_8_decode(view[position:end], opts.unicode_decode_error_handler, True)[0], end + 1 + + +def _get_object_size(data: Any, position: int, obj_end: int) -> Tuple[int, int]: + """Validate and return a BSON document's size.""" + try: + obj_size = _UNPACK_INT_FROM(data, position)[0] + except struct.error as exc: + raise InvalidBSON(str(exc)) from None + end = position + obj_size - 1 + if data[end] != 0: + raise InvalidBSON("bad eoo") + if end >= obj_end: + raise InvalidBSON("invalid object length") + # If this is the top-level document, validate the total size too. + if position == 0 and obj_size != obj_end: + raise InvalidBSON("invalid object length") + return obj_size, end + + +def _get_object( + data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy: Any +) -> Tuple[Any, int]: + """Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef.""" + obj_size, end = _get_object_size(data, position, obj_end) + if _raw_document_class(opts.document_class): + return (opts.document_class(data[position : end + 1], opts), position + obj_size) + + obj = _elements_to_dict(data, view, position + 4, end, opts) + + position += obj_size + # If DBRef validation fails, return a normal doc. + if ( + isinstance(obj.get("$ref"), str) + and "$id" in obj + and isinstance(obj.get("$db"), (str, type(None))) + ): + return (DBRef(obj.pop("$ref"), obj.pop("$id", None), obj.pop("$db", None), obj), position) + return obj, position + + +def _get_array( + data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], element_name: str +) -> Tuple[Any, int]: + """Decode a BSON array to python list.""" + size = _UNPACK_INT_FROM(data, position)[0] + end = position + size - 1 + if data[end] != 0: + raise InvalidBSON("bad eoo") + + position += 4 + end -= 1 + result: list[Any] = [] + + # Avoid doing global and attribute lookups in the loop. + append = result.append + index = data.index + getter = _ELEMENT_GETTER + decoder_map = opts.type_registry._decoder_map + + while position < end: + element_type = data[position] + # Just skip the keys. + position = index(b"\x00", position) + 1 + try: + value, position = getter[element_type]( + data, view, position, obj_end, opts, element_name + ) + except KeyError: + _raise_unknown_type(element_type, element_name) + + if decoder_map: + custom_decoder = decoder_map.get(type(value)) + if custom_decoder is not None: + value = custom_decoder(value) + + append(value) + + if position != end + 1: + raise InvalidBSON("bad array length") + return result, position + 1 + + +def _get_binary( + data: Any, _view: Any, position: int, obj_end: int, opts: CodecOptions[Any], dummy1: Any +) -> Tuple[Union[Binary, uuid.UUID], int]: + """Decode a BSON binary to bson.binary.Binary or python UUID.""" + length, subtype = _UNPACK_LENGTH_SUBTYPE_FROM(data, position) + position += 5 + if subtype == 2: + length2 = _UNPACK_INT_FROM(data, position)[0] + position += 4 + if length2 != length - 4: + raise InvalidBSON("invalid binary (st 2) - lengths don't match!") + length = length2 + end = position + length + if length < 0 or end > obj_end: + raise InvalidBSON("bad binary object length") + + # Convert UUID subtypes to native UUIDs. + if subtype in ALL_UUID_SUBTYPES: + uuid_rep = opts.uuid_representation + binary_value = Binary(data[position:end], subtype) + if ( + (uuid_rep == UuidRepresentation.UNSPECIFIED) + or (subtype == UUID_SUBTYPE and uuid_rep != STANDARD) + or (subtype == OLD_UUID_SUBTYPE and uuid_rep == STANDARD) + ): + return binary_value, end + return binary_value.as_uuid(uuid_rep), end + + # Decode subtype 0 to 'bytes'. + if subtype == 0: + value = data[position:end] + else: + value = Binary(data[position:end], subtype) + + return value, end + + +def _get_oid( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[ObjectId, int]: + """Decode a BSON ObjectId to bson.objectid.ObjectId.""" + end = position + 12 + return ObjectId(data[position:end]), end + + +def _get_boolean( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[bool, int]: + """Decode a BSON true/false to python True/False.""" + end = position + 1 + boolean_byte = data[position:end] + if boolean_byte == b"\x00": + return False, end + elif boolean_byte == b"\x01": + return True, end + raise InvalidBSON("invalid boolean value: %r" % boolean_byte) + + +def _get_date( + data: Any, _view: Any, position: int, dummy0: int, opts: CodecOptions[Any], dummy1: Any +) -> Tuple[Union[datetime.datetime, DatetimeMS], int]: + """Decode a BSON datetime to python datetime.datetime.""" + return _millis_to_datetime(_UNPACK_LONG_FROM(data, position)[0], opts), position + 8 + + +def _get_code( + data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], element_name: str +) -> Tuple[Code, int]: + """Decode a BSON code to bson.code.Code.""" + code, position = _get_string(data, view, position, obj_end, opts, element_name) + return Code(code), position + + +def _get_code_w_scope( + data: Any, view: Any, position: int, _obj_end: int, opts: CodecOptions[Any], element_name: str +) -> Tuple[Code, int]: + """Decode a BSON code_w_scope to bson.code.Code.""" + code_end = position + _UNPACK_INT_FROM(data, position)[0] + code, position = _get_string(data, view, position + 4, code_end, opts, element_name) + scope, position = _get_object(data, view, position, code_end, opts, element_name) + if position != code_end: + raise InvalidBSON("scope outside of javascript code boundaries") + return Code(code, scope), position + + +def _get_regex( + data: Any, view: Any, position: int, dummy0: Any, opts: CodecOptions[Any], dummy1: Any +) -> Tuple[Regex[Any], int]: + """Decode a BSON regex to bson.regex.Regex or a python pattern object.""" + pattern, position = _get_c_string(data, view, position, opts) + bson_flags, position = _get_c_string(data, view, position, opts) + bson_re = Regex(pattern, bson_flags) + return bson_re, position + + +def _get_ref( + data: Any, view: Any, position: int, obj_end: int, opts: CodecOptions[Any], element_name: str +) -> Tuple[DBRef, int]: + """Decode (deprecated) BSON DBPointer to bson.dbref.DBRef.""" + collection, position = _get_string(data, view, position, obj_end, opts, element_name) + oid, position = _get_oid(data, view, position, obj_end, opts, element_name) + return DBRef(collection, oid), position + + +def _get_timestamp( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[Timestamp, int]: + """Decode a BSON timestamp to bson.timestamp.Timestamp.""" + inc, timestamp = _UNPACK_TIMESTAMP_FROM(data, position) + return Timestamp(timestamp, inc), position + 8 + + +def _get_int64( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[Int64, int]: + """Decode a BSON int64 to bson.int64.Int64.""" + return Int64(_UNPACK_LONG_FROM(data, position)[0]), position + 8 + + +def _get_decimal128( + data: Any, _view: Any, position: int, dummy0: Any, dummy1: Any, dummy2: Any +) -> Tuple[Decimal128, int]: + """Decode a BSON decimal128 to bson.decimal128.Decimal128.""" + end = position + 16 + return Decimal128.from_bid(data[position:end]), end + + +# Each decoder function's signature is: +# - data: bytes +# - view: memoryview that references `data` +# - position: int, beginning of object in 'data' to decode +# - obj_end: int, end of object to decode in 'data' if variable-length type +# - opts: a CodecOptions +_ELEMENT_GETTER: dict[int, Callable[..., Tuple[Any, int]]] = { + ord(BSONNUM): _get_float, + ord(BSONSTR): _get_string, + ord(BSONOBJ): _get_object, + ord(BSONARR): _get_array, + ord(BSONBIN): _get_binary, + ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005 # Deprecated undefined + ord(BSONOID): _get_oid, + ord(BSONBOO): _get_boolean, + ord(BSONDAT): _get_date, + ord(BSONNUL): lambda u, v, w, x, y, z: (None, w), # noqa: ARG005 + ord(BSONRGX): _get_regex, + ord(BSONREF): _get_ref, # Deprecated DBPointer + ord(BSONCOD): _get_code, + ord(BSONSYM): _get_string, # Deprecated symbol + ord(BSONCWS): _get_code_w_scope, + ord(BSONINT): _get_int, + ord(BSONTIM): _get_timestamp, + ord(BSONLON): _get_int64, + ord(BSONDEC): _get_decimal128, + ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), # noqa: ARG005 + ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w), # noqa: ARG005 +} + + +if _USE_C: + + def _element_to_dict( + data: Any, + view: Any, # noqa: ARG001 + position: int, + obj_end: int, + opts: CodecOptions[Any], + raw_array: bool = False, + ) -> Tuple[str, Any, int]: + return cast( + "Tuple[str, Any, int]", + _cbson._element_to_dict(data, position, obj_end, opts, raw_array), + ) + +else: + + def _element_to_dict( + data: Any, + view: Any, + position: int, + obj_end: int, + opts: CodecOptions[Any], + raw_array: bool = False, + ) -> Tuple[str, Any, int]: + """Decode a single key, value pair.""" + element_type = data[position] + position += 1 + element_name, position = _get_c_string(data, view, position, opts) + if raw_array and element_type == ord(BSONARR): + _, end = _get_object_size(data, position, len(data)) + return element_name, view[position : end + 1], end + 1 + try: + value, position = _ELEMENT_GETTER[element_type]( + data, view, position, obj_end, opts, element_name + ) + except KeyError: + _raise_unknown_type(element_type, element_name) + + if opts.type_registry._decoder_map: + custom_decoder = opts.type_registry._decoder_map.get(type(value)) + if custom_decoder is not None: + value = custom_decoder(value) + + return element_name, value, position + + +_T = TypeVar("_T", bound=MutableMapping[str, Any]) + + +def _raw_to_dict( + data: Any, + position: int, + obj_end: int, + opts: CodecOptions[RawBSONDocument], + result: _T, + raw_array: bool = False, +) -> _T: + data, view = get_data_and_view(data) + return cast( + _T, _elements_to_dict(data, view, position, obj_end, opts, result, raw_array=raw_array) + ) + + +def _elements_to_dict( + data: Any, + view: Any, + position: int, + obj_end: int, + opts: CodecOptions[Any], + result: Any = None, + raw_array: bool = False, +) -> Any: + """Decode a BSON document into result.""" + if result is None: + result = opts.document_class() + end = obj_end - 1 + while position < end: + key, value, position = _element_to_dict( + data, view, position, obj_end, opts, raw_array=raw_array + ) + result[key] = value + if position != obj_end: + raise InvalidBSON("bad object or element length") + return result + + +def _bson_to_dict(data: Any, opts: CodecOptions[_DocumentType]) -> _DocumentType: + """Decode a BSON string to document_class.""" + data, view = get_data_and_view(data) + try: + if _raw_document_class(opts.document_class): + return opts.document_class(data, opts) # type:ignore[call-arg] + _, end = _get_object_size(data, 0, len(data)) + return cast("_DocumentType", _elements_to_dict(data, view, 4, end, opts)) + except InvalidBSON: + raise + except Exception: + # Change exception type to InvalidBSON but preserve traceback. + _, exc_value, exc_tb = sys.exc_info() + raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None + + +if _USE_C: + _bson_to_dict = _cbson._bson_to_dict + + +_PACK_FLOAT = struct.Struct(" Generator[bytes, None, None]: + """Generate "keys" for encoded lists in the sequence + b"0\x00", b"1\x00", b"2\x00", ... + + The first 1000 keys are returned from a pre-built cache. All + subsequent keys are generated on the fly. + """ + yield from _LIST_NAMES + + counter = itertools.count(1000) + while True: + yield (str(next(counter)) + "\x00").encode("utf8") + + +def _make_c_string_check(string: Union[str, bytes]) -> bytes: + """Make a 'C' string, checking for embedded NUL characters.""" + if isinstance(string, bytes): + if b"\x00" in string: + raise InvalidDocument("BSON keys / regex patterns must not contain a NUL character") + try: + _utf_8_decode(string, None, True) + return string + b"\x00" + except UnicodeError: + raise InvalidStringData( + "strings in documents must be valid UTF-8: %r" % string + ) from None + else: + if "\x00" in string: + raise InvalidDocument("BSON keys / regex patterns must not contain a NUL character") + return _utf_8_encode(string)[0] + b"\x00" + + +def _make_c_string(string: Union[str, bytes]) -> bytes: + """Make a 'C' string.""" + if isinstance(string, bytes): + try: + _utf_8_decode(string, None, True) + return string + b"\x00" + except UnicodeError: + raise InvalidStringData( + "strings in documents must be valid UTF-8: %r" % string + ) from None + else: + return _utf_8_encode(string)[0] + b"\x00" + + +def _make_name(string: str) -> bytes: + """Make a 'C' string suitable for a BSON key.""" + if "\x00" in string: + raise InvalidDocument("BSON keys must not contain a NUL character") + return _utf_8_encode(string)[0] + b"\x00" + + +def _encode_float(name: bytes, value: float, dummy0: Any, dummy1: Any) -> bytes: + """Encode a float.""" + return b"\x01" + name + _PACK_FLOAT(value) + + +def _encode_bytes(name: bytes, value: bytes, dummy0: Any, dummy1: Any) -> bytes: + """Encode a python bytes.""" + # Python3 special case. Store 'bytes' as BSON binary subtype 0. + return b"\x05" + name + _PACK_INT(len(value)) + b"\x00" + value + + +def _encode_mapping(name: bytes, value: Any, check_keys: bool, opts: CodecOptions[Any]) -> bytes: + """Encode a mapping type.""" + if _raw_document_class(value): + return b"\x03" + name + cast(bytes, value.raw) + data = b"".join([_element_to_bson(key, val, check_keys, opts) for key, val in value.items()]) + return b"\x03" + name + _PACK_INT(len(data) + 5) + data + b"\x00" + + +def _encode_dbref(name: bytes, value: DBRef, check_keys: bool, opts: CodecOptions[Any]) -> bytes: + """Encode bson.dbref.DBRef.""" + buf = bytearray(b"\x03" + name + b"\x00\x00\x00\x00") + begin = len(buf) - 4 + + buf += _name_value_to_bson(b"$ref\x00", value.collection, check_keys, opts) + buf += _name_value_to_bson(b"$id\x00", value.id, check_keys, opts) + if value.database is not None: + buf += _name_value_to_bson(b"$db\x00", value.database, check_keys, opts) + for key, val in value._DBRef__kwargs.items(): + buf += _element_to_bson(key, val, check_keys, opts) + + buf += b"\x00" + buf[begin : begin + 4] = _PACK_INT(len(buf) - begin) + return bytes(buf) + + +def _encode_list( + name: bytes, value: Sequence[Any], check_keys: bool, opts: CodecOptions[Any] +) -> bytes: + """Encode a list/tuple.""" + lname = gen_list_name() + data = b"".join([_name_value_to_bson(next(lname), item, check_keys, opts) for item in value]) + return b"\x04" + name + _PACK_INT(len(data) + 5) + data + b"\x00" + + +def _encode_text(name: bytes, value: str, dummy0: Any, dummy1: Any) -> bytes: + """Encode a python str.""" + bvalue = _utf_8_encode(value)[0] + return b"\x02" + name + _PACK_INT(len(bvalue) + 1) + bvalue + b"\x00" + + +def _encode_binary(name: bytes, value: Binary, dummy0: Any, dummy1: Any) -> bytes: + """Encode bson.binary.Binary.""" + subtype = value.subtype + if subtype == 2: + value = _PACK_INT(len(value)) + value # type: ignore + return b"\x05" + name + _PACK_LENGTH_SUBTYPE(len(value), subtype) + value + + +def _encode_uuid(name: bytes, value: uuid.UUID, dummy: Any, opts: CodecOptions[Any]) -> bytes: + """Encode uuid.UUID.""" + uuid_representation = opts.uuid_representation + binval = Binary.from_uuid(value, uuid_representation=uuid_representation) + return _encode_binary(name, binval, dummy, opts) + + +def _encode_objectid(name: bytes, value: ObjectId, dummy: Any, dummy1: Any) -> bytes: + """Encode bson.objectid.ObjectId.""" + return b"\x07" + name + value.binary + + +def _encode_bool(name: bytes, value: bool, dummy0: Any, dummy1: Any) -> bytes: + """Encode a python boolean (True/False).""" + return b"\x08" + name + (value and b"\x01" or b"\x00") + + +def _encode_datetime(name: bytes, value: datetime.datetime, dummy0: Any, dummy1: Any) -> bytes: + """Encode datetime.datetime.""" + millis = _datetime_to_millis(value) + return b"\x09" + name + _PACK_LONG(millis) + + +def _encode_datetime_ms(name: bytes, value: DatetimeMS, dummy0: Any, dummy1: Any) -> bytes: + """Encode datetime.datetime.""" + millis = int(value) + return b"\x09" + name + _PACK_LONG(millis) + + +def _encode_none(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes: + """Encode python None.""" + return b"\x0A" + name + + +def _encode_regex(name: bytes, value: Regex[Any], dummy0: Any, dummy1: Any) -> bytes: + """Encode a python regex or bson.regex.Regex.""" + flags = value.flags + # Python 3 common case + if flags == re.UNICODE: + return b"\x0B" + name + _make_c_string_check(value.pattern) + b"u\x00" + elif flags == 0: + return b"\x0B" + name + _make_c_string_check(value.pattern) + b"\x00" + else: + sflags = b"" + if flags & re.IGNORECASE: + sflags += b"i" + if flags & re.LOCALE: + sflags += b"l" + if flags & re.MULTILINE: + sflags += b"m" + if flags & re.DOTALL: + sflags += b"s" + if flags & re.UNICODE: + sflags += b"u" + if flags & re.VERBOSE: + sflags += b"x" + sflags += b"\x00" + return b"\x0B" + name + _make_c_string_check(value.pattern) + sflags + + +def _encode_code(name: bytes, value: Code, dummy: Any, opts: CodecOptions[Any]) -> bytes: + """Encode bson.code.Code.""" + cstring = _make_c_string(value) + cstrlen = len(cstring) + if value.scope is None: + return b"\x0D" + name + _PACK_INT(cstrlen) + cstring + scope = _dict_to_bson(value.scope, False, opts, False) + full_length = _PACK_INT(8 + cstrlen + len(scope)) + return b"\x0F" + name + full_length + _PACK_INT(cstrlen) + cstring + scope + + +def _encode_int(name: bytes, value: int, dummy0: Any, dummy1: Any) -> bytes: + """Encode a python int.""" + if -2147483648 <= value <= 2147483647: + return b"\x10" + name + _PACK_INT(value) + else: + try: + return b"\x12" + name + _PACK_LONG(value) + except struct.error: + raise OverflowError("BSON can only handle up to 8-byte ints") from None + + +def _encode_timestamp(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes: + """Encode bson.timestamp.Timestamp.""" + return b"\x11" + name + _PACK_TIMESTAMP(value.inc, value.time) + + +def _encode_long(name: bytes, value: Any, dummy0: Any, dummy1: Any) -> bytes: + """Encode a bson.int64.Int64.""" + try: + return b"\x12" + name + _PACK_LONG(value) + except struct.error: + raise OverflowError("BSON can only handle up to 8-byte ints") from None + + +def _encode_decimal128(name: bytes, value: Decimal128, dummy0: Any, dummy1: Any) -> bytes: + """Encode bson.decimal128.Decimal128.""" + return b"\x13" + name + value.bid + + +def _encode_minkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes: + """Encode bson.min_key.MinKey.""" + return b"\xFF" + name + + +def _encode_maxkey(name: bytes, dummy0: Any, dummy1: Any, dummy2: Any) -> bytes: + """Encode bson.max_key.MaxKey.""" + return b"\x7F" + name + + +# Each encoder function's signature is: +# - name: utf-8 bytes +# - value: a Python data type, e.g. a Python int for _encode_int +# - check_keys: bool, whether to check for invalid names +# - opts: a CodecOptions +_ENCODERS = { + bool: _encode_bool, + bytes: _encode_bytes, + datetime.datetime: _encode_datetime, + DatetimeMS: _encode_datetime_ms, + dict: _encode_mapping, + float: _encode_float, + int: _encode_int, + list: _encode_list, + str: _encode_text, + tuple: _encode_list, + type(None): _encode_none, + uuid.UUID: _encode_uuid, + Binary: _encode_binary, + Int64: _encode_long, + Code: _encode_code, + DBRef: _encode_dbref, + MaxKey: _encode_maxkey, + MinKey: _encode_minkey, + ObjectId: _encode_objectid, + Regex: _encode_regex, + RE_TYPE: _encode_regex, + SON: _encode_mapping, + Timestamp: _encode_timestamp, + Decimal128: _encode_decimal128, + # Special case. This will never be looked up directly. + _abc.Mapping: _encode_mapping, +} + +# Map each _type_marker to its encoder for faster lookup. +_MARKERS = {} +for _typ in _ENCODERS: + if hasattr(_typ, "_type_marker"): + _MARKERS[_typ._type_marker] = _ENCODERS[_typ] + + +_BUILT_IN_TYPES = tuple(t for t in _ENCODERS) + + +def _name_value_to_bson( + name: bytes, + value: Any, + check_keys: bool, + opts: CodecOptions[Any], + in_custom_call: bool = False, + in_fallback_call: bool = False, +) -> bytes: + """Encode a single name, value pair.""" + + was_integer_overflow = False + + # First see if the type is already cached. KeyError will only ever + # happen once per subtype. + try: + return _ENCODERS[type(value)](name, value, check_keys, opts) # type: ignore + except KeyError: + pass + except OverflowError: + if not isinstance(value, int): + raise + + # Give the fallback_encoder a chance + was_integer_overflow = True + + # Second, fall back to trying _type_marker. This has to be done + # before the loop below since users could subclass one of our + # custom types that subclasses a python built-in (e.g. Binary) + marker = getattr(value, "_type_marker", None) + if isinstance(marker, int) and marker in _MARKERS: + func = _MARKERS[marker] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(value)] = func + return func(name, value, check_keys, opts) # type: ignore + + # Third, check if a type encoder is registered for this type. + # Note that subtypes of registered custom types are not auto-encoded. + if not in_custom_call and opts.type_registry._encoder_map: + custom_encoder = opts.type_registry._encoder_map.get(type(value)) + if custom_encoder is not None: + return _name_value_to_bson( + name, custom_encoder(value), check_keys, opts, in_custom_call=True + ) + + # Fourth, test each base type. This will only happen once for + # a subtype of a supported base type. Unlike in the C-extensions, this + # is done after trying the custom type encoder because checking for each + # subtype is expensive. + for base in _BUILT_IN_TYPES: + if not was_integer_overflow and isinstance(value, base): + func = _ENCODERS[base] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(value)] = func + return func(name, value, check_keys, opts) # type: ignore + + # As a last resort, try using the fallback encoder, if the user has + # provided one. + fallback_encoder = opts.type_registry._fallback_encoder + if not in_fallback_call and fallback_encoder is not None: + return _name_value_to_bson( + name, fallback_encoder(value), check_keys, opts, in_fallback_call=True + ) + + if was_integer_overflow: + raise OverflowError("BSON can only handle up to 8-byte ints") + raise InvalidDocument(f"cannot encode object: {value!r}, of type: {type(value)!r}") + + +def _element_to_bson(key: Any, value: Any, check_keys: bool, opts: CodecOptions[Any]) -> bytes: + """Encode a single key, value pair.""" + if not isinstance(key, str): + raise InvalidDocument(f"documents must have only string keys, key was {key!r}") + if check_keys: + if key.startswith("$"): + raise InvalidDocument(f"key {key!r} must not start with '$'") + if "." in key: + raise InvalidDocument(f"key {key!r} must not contain '.'") + + name = _make_name(key) + return _name_value_to_bson(name, value, check_keys, opts) + + +def _dict_to_bson( + doc: Any, check_keys: bool, opts: CodecOptions[Any], top_level: bool = True +) -> bytes: + """Encode a document to BSON.""" + if _raw_document_class(doc): + return cast(bytes, doc.raw) + try: + elements = [] + if top_level and "_id" in doc: + elements.append(_name_value_to_bson(b"_id\x00", doc["_id"], check_keys, opts)) + for key, value in doc.items(): + if not top_level or key != "_id": + elements.append(_element_to_bson(key, value, check_keys, opts)) + except AttributeError: + raise TypeError(f"encoder expected a mapping type but got: {doc!r}") from None + + encoded = b"".join(elements) + return _PACK_INT(len(encoded) + 5) + encoded + b"\x00" + + +if _USE_C: + _dict_to_bson = _cbson._dict_to_bson + + +_CODEC_OPTIONS_TYPE_ERROR = TypeError("codec_options must be an instance of CodecOptions") + + +def encode( + document: Mapping[str, Any], + check_keys: bool = False, + codec_options: CodecOptions[Any] = DEFAULT_CODEC_OPTIONS, +) -> bytes: + """Encode a document to BSON. + + A document can be any mapping type (like :class:`dict`). + + Raises :class:`TypeError` if `document` is not a mapping type, + or contains keys that are not instances of :class:`str`. Raises + :class:`~bson.errors.InvalidDocument` if `document` cannot be + converted to :class:`BSON`. + + :param document: mapping type representing a document + :param check_keys: check if keys start with '$' or + contain '.', raising :class:`~bson.errors.InvalidDocument` in + either case + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionadded:: 3.9 + """ + if not isinstance(codec_options, CodecOptions): + raise _CODEC_OPTIONS_TYPE_ERROR + + return _dict_to_bson(document, check_keys, codec_options) + + +@overload +def decode(data: _ReadableBuffer, codec_options: None = None) -> dict[str, Any]: + ... + + +@overload +def decode(data: _ReadableBuffer, codec_options: CodecOptions[_DocumentType]) -> _DocumentType: + ... + + +def decode( + data: _ReadableBuffer, codec_options: Optional[CodecOptions[_DocumentType]] = None +) -> Union[dict[str, Any], _DocumentType]: + """Decode BSON to a document. + + By default, returns a BSON document represented as a Python + :class:`dict`. To use a different :class:`MutableMapping` class, + configure a :class:`~bson.codec_options.CodecOptions`:: + + >>> import collections # From Python standard library. + >>> import bson + >>> from bson.codec_options import CodecOptions + >>> data = bson.encode({'a': 1}) + >>> decoded_doc = bson.decode(data) + + >>> options = CodecOptions(document_class=collections.OrderedDict) + >>> decoded_doc = bson.decode(data, codec_options=options) + >>> type(decoded_doc) + + + :param data: the BSON to decode. Any bytes-like object that implements + the buffer protocol. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionadded:: 3.9 + """ + opts: CodecOptions[Any] = codec_options or DEFAULT_CODEC_OPTIONS + if not isinstance(opts, CodecOptions): + raise _CODEC_OPTIONS_TYPE_ERROR + + return cast("Union[dict[str, Any], _DocumentType]", _bson_to_dict(data, opts)) + + +def _decode_all(data: _ReadableBuffer, opts: CodecOptions[_DocumentType]) -> list[_DocumentType]: + """Decode a BSON data to multiple documents.""" + data, view = get_data_and_view(data) + data_len = len(data) + docs: list[_DocumentType] = [] + position = 0 + end = data_len - 1 + use_raw = _raw_document_class(opts.document_class) + try: + while position < end: + obj_size = _UNPACK_INT_FROM(data, position)[0] + if data_len - position < obj_size: + raise InvalidBSON("invalid object size") + obj_end = position + obj_size - 1 + if data[obj_end] != 0: + raise InvalidBSON("bad eoo") + if use_raw: + docs.append(opts.document_class(data[position : obj_end + 1], opts)) # type: ignore + else: + docs.append(_elements_to_dict(data, view, position + 4, obj_end, opts)) + position += obj_size + return docs + except InvalidBSON: + raise + except Exception: + # Change exception type to InvalidBSON but preserve traceback. + _, exc_value, exc_tb = sys.exc_info() + raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None + + +if _USE_C: + _decode_all = _cbson._decode_all + + +@overload +def decode_all(data: _ReadableBuffer, codec_options: None = None) -> list[dict[str, Any]]: + ... + + +@overload +def decode_all( + data: _ReadableBuffer, codec_options: CodecOptions[_DocumentType] +) -> list[_DocumentType]: + ... + + +def decode_all( + data: _ReadableBuffer, codec_options: Optional[CodecOptions[_DocumentType]] = None +) -> Union[list[dict[str, Any]], list[_DocumentType]]: + """Decode BSON data to multiple documents. + + `data` must be a bytes-like object implementing the buffer protocol that + provides concatenated, valid, BSON-encoded documents. + + :param data: BSON data + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionchanged:: 3.9 + Supports bytes-like objects that implement the buffer protocol. + + .. versionchanged:: 3.0 + Removed `compile_re` option: PyMongo now always represents BSON regular + expressions as :class:`~bson.regex.Regex` objects. Use + :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a + BSON regular expression to a Python regular expression object. + + Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with + `codec_options`. + """ + if codec_options is None: + return _decode_all(data, DEFAULT_CODEC_OPTIONS) + + if not isinstance(codec_options, CodecOptions): + raise _CODEC_OPTIONS_TYPE_ERROR + + return _decode_all(data, codec_options) + + +def _decode_selective( + rawdoc: Any, fields: Any, codec_options: CodecOptions[_DocumentType] +) -> _DocumentType: + if _raw_document_class(codec_options.document_class): + # If document_class is RawBSONDocument, use vanilla dictionary for + # decoding command response. + doc: _DocumentType = {} # type:ignore[assignment] + else: + # Else, use the specified document_class. + doc = codec_options.document_class() + for key, value in rawdoc.items(): + if key in fields: + if fields[key] == 1: + doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key] # type:ignore[index] + else: + doc[key] = _decode_selective( # type:ignore[index] + value, fields[key], codec_options + ) + else: + doc[key] = value # type:ignore[index] + return doc + + +def _array_of_documents_to_buffer(view: memoryview) -> bytes: + # Extract the raw bytes of each document. + position = 0 + _, end = _get_object_size(view, position, len(view)) + position += 4 + buffers: list[memoryview] = [] + append = buffers.append + while position < end - 1: + # Just skip the keys. + while view[position] != 0: + position += 1 + position += 1 + obj_size, _ = _get_object_size(view, position, end) + append(view[position : position + obj_size]) + position += obj_size + if position != end: + raise InvalidBSON("bad object or element length") + return b"".join(buffers) + + +if _USE_C: + _array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer + + +def _convert_raw_document_lists_to_streams(document: Any) -> None: + """Convert raw array of documents to a stream of BSON documents.""" + cursor = document.get("cursor") + if not cursor: + return + for key in ("firstBatch", "nextBatch"): + batch = cursor.get(key) + if not batch: + continue + data = _array_of_documents_to_buffer(batch) + if data: + cursor[key] = [data] + else: + cursor[key] = [] + + +def _decode_all_selective( + data: Any, codec_options: CodecOptions[_DocumentType], fields: Any +) -> list[_DocumentType]: + """Decode BSON data to a single document while using user-provided + custom decoding logic. + + `data` must be a string representing a valid, BSON-encoded document. + + :param data: BSON data + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions` with user-specified type + decoders. If no decoders are found, this method is the same as + ``decode_all``. + :param fields: Map of document namespaces where data that needs + to be custom decoded lives or None. For example, to custom decode a + list of objects in 'field1.subfield1', the specified value should be + ``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or + None, this method is the same as ``decode_all``. + + :return: Single-member list containing the decoded document. + + .. versionadded:: 3.8 + """ + if not codec_options.type_registry._decoder_map: + return decode_all(data, codec_options) + + if not fields: + return decode_all(data, codec_options.with_options(type_registry=None)) + + # Decode documents for internal use. + from bson.raw_bson import RawBSONDocument + + internal_codec_options: CodecOptions[RawBSONDocument] = codec_options.with_options( + document_class=RawBSONDocument, type_registry=None + ) + _doc = _bson_to_dict(data, internal_codec_options) + return [ + _decode_selective( + _doc, + fields, + codec_options, + ) + ] + + +@overload +def decode_iter(data: bytes, codec_options: None = None) -> Iterator[dict[str, Any]]: + ... + + +@overload +def decode_iter(data: bytes, codec_options: CodecOptions[_DocumentType]) -> Iterator[_DocumentType]: + ... + + +def decode_iter( + data: bytes, codec_options: Optional[CodecOptions[_DocumentType]] = None +) -> Union[Iterator[dict[str, Any]], Iterator[_DocumentType]]: + """Decode BSON data to multiple documents as a generator. + + Works similarly to the decode_all function, but yields one document at a + time. + + `data` must be a string of concatenated, valid, BSON-encoded + documents. + + :param data: BSON data + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionchanged:: 3.0 + Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with + `codec_options`. + + .. versionadded:: 2.8 + """ + opts = codec_options or DEFAULT_CODEC_OPTIONS + if not isinstance(opts, CodecOptions): + raise _CODEC_OPTIONS_TYPE_ERROR + + position = 0 + end = len(data) - 1 + while position < end: + obj_size = _UNPACK_INT_FROM(data, position)[0] + elements = data[position : position + obj_size] + position += obj_size + + yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var] + + +@overload +def decode_file_iter( + file_obj: Union[BinaryIO, IO[bytes]], codec_options: None = None +) -> Iterator[dict[str, Any]]: + ... + + +@overload +def decode_file_iter( + file_obj: Union[BinaryIO, IO[bytes]], codec_options: CodecOptions[_DocumentType] +) -> Iterator[_DocumentType]: + ... + + +def decode_file_iter( + file_obj: Union[BinaryIO, IO[bytes]], + codec_options: Optional[CodecOptions[_DocumentType]] = None, +) -> Union[Iterator[dict[str, Any]], Iterator[_DocumentType]]: + """Decode bson data from a file to multiple documents as a generator. + + Works similarly to the decode_all function, but reads from the file object + in chunks and parses bson in chunks, yielding one document at a time. + + :param file_obj: A file object containing BSON data. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionchanged:: 3.0 + Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with + `codec_options`. + + .. versionadded:: 2.8 + """ + opts = codec_options or DEFAULT_CODEC_OPTIONS + while True: + # Read size of next object. + size_data: Any = file_obj.read(4) + if not size_data: + break # Finished with file normally. + elif len(size_data) != 4: + raise InvalidBSON("cut off in middle of objsize") + obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4 + elements = size_data + file_obj.read(max(0, obj_size)) + yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc] + + +def is_valid(bson: bytes) -> bool: + """Check that the given string represents valid :class:`BSON` data. + + Raises :class:`TypeError` if `bson` is not an instance of + :class:`bytes`. Returns ``True`` + if `bson` is valid :class:`BSON`, ``False`` otherwise. + + :param bson: the data to be validated + """ + if not isinstance(bson, bytes): + raise TypeError("BSON data must be an instance of a subclass of bytes") + + try: + _bson_to_dict(bson, DEFAULT_CODEC_OPTIONS) + return True + except Exception: + return False + + +class BSON(bytes): + """BSON (Binary JSON) data. + + .. warning:: Using this class to encode and decode BSON adds a performance + cost. For better performance use the module level functions + :func:`encode` and :func:`decode` instead. + """ + + @classmethod + def encode( + cls: Type[BSON], + document: Mapping[str, Any], + check_keys: bool = False, + codec_options: CodecOptions[Any] = DEFAULT_CODEC_OPTIONS, + ) -> BSON: + """Encode a document to a new :class:`BSON` instance. + + A document can be any mapping type (like :class:`dict`). + + Raises :class:`TypeError` if `document` is not a mapping type, + or contains keys that are not instances of + :class:`str'. Raises :class:`~bson.errors.InvalidDocument` + if `document` cannot be converted to :class:`BSON`. + + :param document: mapping type representing a document + :param check_keys: check if keys start with '$' or + contain '.', raising :class:`~bson.errors.InvalidDocument` in + either case + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionchanged:: 3.0 + Replaced `uuid_subtype` option with `codec_options`. + """ + return cls(encode(document, check_keys, codec_options)) + + def decode( # type:ignore[override] + self, codec_options: CodecOptions[Any] = DEFAULT_CODEC_OPTIONS + ) -> dict[str, Any]: + """Decode this BSON data. + + By default, returns a BSON document represented as a Python + :class:`dict`. To use a different :class:`MutableMapping` class, + configure a :class:`~bson.codec_options.CodecOptions`:: + + >>> import collections # From Python standard library. + >>> import bson + >>> from bson.codec_options import CodecOptions + >>> data = bson.BSON.encode({'a': 1}) + >>> decoded_doc = bson.BSON(data).decode() + + >>> options = CodecOptions(document_class=collections.OrderedDict) + >>> decoded_doc = bson.BSON(data).decode(codec_options=options) + >>> type(decoded_doc) + + + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. + + .. versionchanged:: 3.0 + Removed `compile_re` option: PyMongo now always represents BSON + regular expressions as :class:`~bson.regex.Regex` objects. Use + :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a + BSON regular expression to a Python regular expression object. + + Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with + `codec_options`. + """ + return decode(self, codec_options) + + +def has_c() -> bool: + """Is the C extension installed?""" + return _USE_C + + +def _after_fork() -> None: + """Releases the ObjectID lock child.""" + if ObjectId._inc_lock.locked(): + ObjectId._inc_lock.release() + + +if hasattr(os, "register_at_fork"): + # This will run in the same thread as the fork was called. + # If we fork in a critical region on the same thread, it should break. + # This is fine since we would never call fork directly from a critical region. + os.register_at_fork(after_in_child=_after_fork) diff --git a/venv/Lib/site-packages/bson/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..fc6036b6 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/_helpers.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/_helpers.cpython-312.pyc new file mode 100644 index 00000000..fdd80c66 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/_helpers.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/binary.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/binary.cpython-312.pyc new file mode 100644 index 00000000..98820ee5 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/binary.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/code.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/code.cpython-312.pyc new file mode 100644 index 00000000..d2b1ddc4 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/code.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/codec_options.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/codec_options.cpython-312.pyc new file mode 100644 index 00000000..f0a4340a Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/codec_options.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/datetime_ms.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/datetime_ms.cpython-312.pyc new file mode 100644 index 00000000..3b429255 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/datetime_ms.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/dbref.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/dbref.cpython-312.pyc new file mode 100644 index 00000000..63173d58 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/dbref.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/decimal128.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/decimal128.cpython-312.pyc new file mode 100644 index 00000000..05110db6 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/decimal128.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/errors.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/errors.cpython-312.pyc new file mode 100644 index 00000000..e0807673 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/errors.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/int64.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/int64.cpython-312.pyc new file mode 100644 index 00000000..bdd70821 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/int64.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/json_util.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/json_util.cpython-312.pyc new file mode 100644 index 00000000..4129fe82 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/json_util.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/max_key.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/max_key.cpython-312.pyc new file mode 100644 index 00000000..9dbe2c7c Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/max_key.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/min_key.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/min_key.cpython-312.pyc new file mode 100644 index 00000000..9fddc91c Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/min_key.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/objectid.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/objectid.cpython-312.pyc new file mode 100644 index 00000000..817704d8 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/objectid.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/raw_bson.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/raw_bson.cpython-312.pyc new file mode 100644 index 00000000..fd4f7e5c Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/raw_bson.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/regex.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/regex.cpython-312.pyc new file mode 100644 index 00000000..311953f3 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/regex.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/son.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/son.cpython-312.pyc new file mode 100644 index 00000000..33dfe8de Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/son.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/timestamp.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/timestamp.cpython-312.pyc new file mode 100644 index 00000000..59276cf0 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/timestamp.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/typings.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/typings.cpython-312.pyc new file mode 100644 index 00000000..18e39d7b Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/typings.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/__pycache__/tz_util.cpython-312.pyc b/venv/Lib/site-packages/bson/__pycache__/tz_util.cpython-312.pyc new file mode 100644 index 00000000..6fb9eed6 Binary files /dev/null and b/venv/Lib/site-packages/bson/__pycache__/tz_util.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/bson/_cbson.cp312-win_amd64.pyd b/venv/Lib/site-packages/bson/_cbson.cp312-win_amd64.pyd new file mode 100644 index 00000000..48462a32 Binary files /dev/null and b/venv/Lib/site-packages/bson/_cbson.cp312-win_amd64.pyd differ diff --git a/venv/Lib/site-packages/bson/_cbsonmodule.c b/venv/Lib/site-packages/bson/_cbsonmodule.c new file mode 100644 index 00000000..3b3aecc4 --- /dev/null +++ b/venv/Lib/site-packages/bson/_cbsonmodule.c @@ -0,0 +1,3164 @@ +/* + * Copyright 2009-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This file contains C implementations of some of the functions + * needed by the bson module. If possible, these implementations + * should be used to speed up BSON encoding and decoding. + */ + +#define PY_SSIZE_T_CLEAN +#include "Python.h" +#include "datetime.h" + +#include "buffer.h" +#include "time64.h" + +#define _CBSON_MODULE +#include "_cbsonmodule.h" + +/* New module state and initialization code. + * See the module-initialization-and-state + * section in the following doc: + * http://docs.python.org/release/3.1.3/howto/cporting.html + * which references the following pep: + * http://www.python.org/dev/peps/pep-3121/ + * */ +struct module_state { + PyObject* Binary; + PyObject* Code; + PyObject* ObjectId; + PyObject* DBRef; + PyObject* Regex; + PyObject* UUID; + PyObject* Timestamp; + PyObject* MinKey; + PyObject* MaxKey; + PyObject* UTC; + PyTypeObject* REType; + PyObject* BSONInt64; + PyObject* Decimal128; + PyObject* Mapping; + PyObject* DatetimeMS; + PyObject* _min_datetime_ms; + PyObject* _max_datetime_ms; + PyObject* _type_marker_str; + PyObject* _flags_str; + PyObject* _pattern_str; + PyObject* _encoder_map_str; + PyObject* _decoder_map_str; + PyObject* _fallback_encoder_str; + PyObject* _raw_str; + PyObject* _subtype_str; + PyObject* _binary_str; + PyObject* _scope_str; + PyObject* _inc_str; + PyObject* _time_str; + PyObject* _bid_str; + PyObject* _replace_str; + PyObject* _astimezone_str; + PyObject* _id_str; + PyObject* _dollar_ref_str; + PyObject* _dollar_id_str; + PyObject* _dollar_db_str; + PyObject* _tzinfo_str; + PyObject* _as_doc_str; + PyObject* _utcoffset_str; + PyObject* _from_uuid_str; + PyObject* _as_uuid_str; + PyObject* _from_bid_str; +}; + +#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) + +/* Maximum number of regex flags */ +#define FLAGS_SIZE 7 + +/* Default UUID representation type code. */ +#define PYTHON_LEGACY 3 + +/* Other UUID representations. */ +#define STANDARD 4 +#define JAVA_LEGACY 5 +#define CSHARP_LEGACY 6 +#define UNSPECIFIED 0 + +#define BSON_MAX_SIZE 2147483647 +/* The smallest possible BSON document, i.e. "{}" */ +#define BSON_MIN_SIZE 5 + +/* Datetime codec options */ +#define DATETIME 1 +#define DATETIME_CLAMP 2 +#define DATETIME_MS 3 +#define DATETIME_AUTO 4 + +/* Converts integer to its string representation in decimal notation. */ +extern int cbson_long_long_to_str(long long num, char* str, size_t size) { + // Buffer should fit 64-bit signed integer + if (size < 21) { + PyErr_Format( + PyExc_RuntimeError, + "Buffer too small to hold long long: %d < 21", size); + return -1; + } + int index = 0; + int sign = 1; + // Convert to unsigned to handle -LLONG_MIN overflow + unsigned long long absNum; + // Handle the case of 0 + if (num == 0) { + str[index++] = '0'; + str[index] = '\0'; + return 0; + } + // Handle negative numbers + if (num < 0) { + sign = -1; + absNum = 0ULL - (unsigned long long)num; + } else { + absNum = (unsigned long long)num; + } + // Convert the number to string + unsigned long long digit; + while (absNum > 0) { + digit = absNum % 10ULL; + str[index++] = (char)digit + '0'; // Convert digit to character + absNum /= 10; + } + // Add minus sign if negative + if (sign == -1) { + str[index++] = '-'; + } + str[index] = '\0'; // Null terminator + // Reverse the string + int start = 0; + int end = index - 1; + while (start < end) { + char temp = str[start]; + str[start++] = str[end]; + str[end--] = temp; + } + return 0; +} + +static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) { + // Test extreme values + Py_ssize_t maxNum = PY_SSIZE_T_MAX; + Py_ssize_t minNum = PY_SSIZE_T_MIN; + Py_ssize_t num; + char str_1[BUF_SIZE]; + char str_2[BUF_SIZE]; + int res = LL2STR(str_1, (long long)minNum); + if (res == -1) { + return NULL; + } + INT2STRING(str_2, (long long)minNum); + if (strcmp(str_1, str_2) != 0) { + PyErr_Format( + PyExc_RuntimeError, + "LL2STR != INT2STRING: %s != %s", str_1, str_2); + return NULL; + } + LL2STR(str_1, (long long)maxNum); + INT2STRING(str_2, (long long)maxNum); + if (strcmp(str_1, str_2) != 0) { + PyErr_Format( + PyExc_RuntimeError, + "LL2STR != INT2STRING: %s != %s", str_1, str_2); + return NULL; + } + + // Test common values + for (num = 0; num < 10000; num++) { + char str_1[BUF_SIZE]; + char str_2[BUF_SIZE]; + LL2STR(str_1, (long long)num); + INT2STRING(str_2, (long long)num); + if (strcmp(str_1, str_2) != 0) { + PyErr_Format( + PyExc_RuntimeError, + "LL2STR != INT2STRING: %s != %s", str_1, str_2); + return NULL; + } + } + + return args; +} + +/* Get an error class from the bson.errors module. + * + * Returns a new ref */ +static PyObject* _error(char* name) { + PyObject* error; + PyObject* errors = PyImport_ImportModule("bson.errors"); + if (!errors) { + return NULL; + } + error = PyObject_GetAttrString(errors, name); + Py_DECREF(errors); + return error; +} + +/* Safely downcast from Py_ssize_t to int, setting an + * exception and returning -1 on error. */ +static int +_downcast_and_check(Py_ssize_t size, uint8_t extra) { + if (size > BSON_MAX_SIZE || ((BSON_MAX_SIZE - extra) < size)) { + PyObject* InvalidStringData = _error("InvalidStringData"); + if (InvalidStringData) { + PyErr_SetString(InvalidStringData, + "String length must be <= 2147483647"); + Py_DECREF(InvalidStringData); + } + return -1; + } + return (int)size + extra; +} + +static PyObject* elements_to_dict(PyObject* self, const char* string, + unsigned max, + const codec_options_t* options); + +static int _write_element_to_buffer(PyObject* self, buffer_t buffer, + int type_byte, PyObject* value, + unsigned char check_keys, + const codec_options_t* options, + unsigned char in_custom_call, + unsigned char in_fallback_call); + +/* Write a RawBSONDocument to the buffer. + * Returns the number of bytes written or 0 on failure. + */ +static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw); + +/* Date stuff */ +static PyObject* datetime_from_millis(long long millis) { + /* To encode a datetime instance like datetime(9999, 12, 31, 23, 59, 59, 999999) + * we follow these steps: + * 1. Calculate a timestamp in seconds: 253402300799 + * 2. Multiply that by 1000: 253402300799000 + * 3. Add in microseconds divided by 1000 253402300799999 + * + * (Note: BSON doesn't support microsecond accuracy, hence the rounding.) + * + * To decode we could do: + * 1. Get seconds: timestamp / 1000: 253402300799 + * 2. Get micros: (timestamp % 1000) * 1000: 999000 + * Resulting in datetime(9999, 12, 31, 23, 59, 59, 999000) -- the expected result + * + * Now what if the we encode (1, 1, 1, 1, 1, 1, 111111)? + * 1. and 2. gives: -62135593139000 + * 3. Gives us: -62135593138889 + * + * Now decode: + * 1. Gives us: -62135593138 + * 2. Gives us: -889000 + * Resulting in datetime(1, 1, 1, 1, 1, 2, 15888216) -- an invalid result + * + * If instead to decode we do: + * diff = ((millis % 1000) + 1000) % 1000: 111 + * seconds = (millis - diff) / 1000: -62135593139 + * micros = diff * 1000 111000 + * Resulting in datetime(1, 1, 1, 1, 1, 1, 111000) -- the expected result + */ + PyObject* datetime; + int diff = (int)(((millis % 1000) + 1000) % 1000); + int microseconds = diff * 1000; + Time64_T seconds = (millis - diff) / 1000; + struct TM timeinfo; + cbson_gmtime64_r(&seconds, &timeinfo); + + datetime = PyDateTime_FromDateAndTime(timeinfo.tm_year + 1900, + timeinfo.tm_mon + 1, + timeinfo.tm_mday, + timeinfo.tm_hour, + timeinfo.tm_min, + timeinfo.tm_sec, + microseconds); + if(!datetime) { + PyObject *etype, *evalue, *etrace; + + /* + * Calling _error clears the error state, so fetch it first. + */ + PyErr_Fetch(&etype, &evalue, &etrace); + + /* Only add addition error message on ValueError exceptions. */ + if (PyErr_GivenExceptionMatches(etype, PyExc_ValueError)) { + if (evalue) { + PyObject* err_msg = PyObject_Str(evalue); + if (err_msg) { + PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes"); + if (appendage) { + PyObject* msg = PyUnicode_Concat(err_msg, appendage); + if (msg) { + Py_DECREF(evalue); + evalue = msg; + } + } + Py_XDECREF(appendage); + } + Py_XDECREF(err_msg); + } + PyErr_NormalizeException(&etype, &evalue, &etrace); + } + /* Steals references to args. */ + PyErr_Restore(etype, evalue, etrace); + } + return datetime; +} + +static long long millis_from_datetime(PyObject* datetime) { + struct TM timeinfo; + long long millis; + + timeinfo.tm_year = PyDateTime_GET_YEAR(datetime) - 1900; + timeinfo.tm_mon = PyDateTime_GET_MONTH(datetime) - 1; + timeinfo.tm_mday = PyDateTime_GET_DAY(datetime); + timeinfo.tm_hour = PyDateTime_DATE_GET_HOUR(datetime); + timeinfo.tm_min = PyDateTime_DATE_GET_MINUTE(datetime); + timeinfo.tm_sec = PyDateTime_DATE_GET_SECOND(datetime); + + millis = cbson_timegm64(&timeinfo) * 1000; + millis += PyDateTime_DATE_GET_MICROSECOND(datetime) / 1000; + return millis; +} + +/* Extended-range datetime, returns a DatetimeMS object with millis */ +static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){ + // Allocate a new DatetimeMS object. + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + PyObject* dt; + PyObject* ll_millis; + + if (!(ll_millis = PyLong_FromLongLong(millis))){ + return NULL; + } + dt = PyObject_CallFunctionObjArgs(state->DatetimeMS, ll_millis, NULL); + Py_DECREF(ll_millis); + return dt; +} + +/* Extended-range datetime, takes a DatetimeMS object and extracts the long long value. */ +static int millis_from_datetime_ms(PyObject* dt, long long* out){ + PyObject* ll_millis; + long long millis; + + if (!(ll_millis = PyNumber_Long(dt))){ + return 0; + } + millis = PyLong_AsLongLong(ll_millis); + Py_DECREF(ll_millis); + if (millis == -1 && PyErr_Occurred()) { /* Overflow */ + PyErr_SetString(PyExc_OverflowError, + "MongoDB datetimes can only handle up to 8-byte ints"); + return 0; + } + *out = millis; + return 1; +} + +/* Just make this compatible w/ the old API. */ +int buffer_write_bytes(buffer_t buffer, const char* data, int size) { + if (pymongo_buffer_write(buffer, data, size)) { + return 0; + } + return 1; +} + +int buffer_write_double(buffer_t buffer, double data) { + double data_le = BSON_DOUBLE_TO_LE(data); + return buffer_write_bytes(buffer, (const char*)&data_le, 8); +} + +int buffer_write_int32(buffer_t buffer, int32_t data) { + uint32_t data_le = BSON_UINT32_TO_LE(data); + return buffer_write_bytes(buffer, (const char*)&data_le, 4); +} + +int buffer_write_int64(buffer_t buffer, int64_t data) { + uint64_t data_le = BSON_UINT64_TO_LE(data); + return buffer_write_bytes(buffer, (const char*)&data_le, 8); +} + +void buffer_write_int32_at_position(buffer_t buffer, + int position, + int32_t data) { + uint32_t data_le = BSON_UINT32_TO_LE(data); + memcpy(pymongo_buffer_get_buffer(buffer) + position, &data_le, 4); +} + +static int write_unicode(buffer_t buffer, PyObject* py_string) { + int size; + const char* data; + PyObject* encoded = PyUnicode_AsUTF8String(py_string); + if (!encoded) { + return 0; + } + data = PyBytes_AS_STRING(encoded); + if (!data) + goto unicodefail; + + if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) + goto unicodefail; + + if (!buffer_write_int32(buffer, (int32_t)size)) + goto unicodefail; + + if (!buffer_write_bytes(buffer, data, size)) + goto unicodefail; + + Py_DECREF(encoded); + return 1; + +unicodefail: + Py_DECREF(encoded); + return 0; +} + +/* returns 0 on failure */ +static int write_string(buffer_t buffer, PyObject* py_string) { + int size; + const char* data; + if (PyUnicode_Check(py_string)){ + return write_unicode(buffer, py_string); + } + data = PyBytes_AsString(py_string); + if (!data) { + return 0; + } + + if ((size = _downcast_and_check(PyBytes_Size(py_string), 1)) == -1) + return 0; + + if (!buffer_write_int32(buffer, (int32_t)size)) { + return 0; + } + if (!buffer_write_bytes(buffer, data, size)) { + return 0; + } + return 1; +} + +/* Load a Python object to cache. + * + * Returns non-zero on failure. */ +static int _load_object(PyObject** object, char* module_name, char* object_name) { + PyObject* module; + + module = PyImport_ImportModule(module_name); + if (!module) { + return 1; + } + + *object = PyObject_GetAttrString(module, object_name); + Py_DECREF(module); + + return (*object) ? 0 : 2; +} + +/* Load all Python objects to cache. + * + * Returns non-zero on failure. */ +static int _load_python_objects(PyObject* module) { + PyObject* empty_string = NULL; + PyObject* re_compile = NULL; + PyObject* compiled = NULL; + struct module_state *state = GETSTATE(module); + if (!state) { + return 1; + } + + /* Cache commonly used attribute names to improve performance. */ + if (!((state->_type_marker_str = PyUnicode_FromString("_type_marker")) && + (state->_flags_str = PyUnicode_FromString("flags")) && + (state->_pattern_str = PyUnicode_FromString("pattern")) && + (state->_encoder_map_str = PyUnicode_FromString("_encoder_map")) && + (state->_decoder_map_str = PyUnicode_FromString("_decoder_map")) && + (state->_fallback_encoder_str = PyUnicode_FromString("_fallback_encoder")) && + (state->_raw_str = PyUnicode_FromString("raw")) && + (state->_subtype_str = PyUnicode_FromString("subtype")) && + (state->_binary_str = PyUnicode_FromString("binary")) && + (state->_scope_str = PyUnicode_FromString("scope")) && + (state->_inc_str = PyUnicode_FromString("inc")) && + (state->_time_str = PyUnicode_FromString("time")) && + (state->_bid_str = PyUnicode_FromString("bid")) && + (state->_replace_str = PyUnicode_FromString("replace")) && + (state->_astimezone_str = PyUnicode_FromString("astimezone")) && + (state->_id_str = PyUnicode_FromString("_id")) && + (state->_dollar_ref_str = PyUnicode_FromString("$ref")) && + (state->_dollar_id_str = PyUnicode_FromString("$id")) && + (state->_dollar_db_str = PyUnicode_FromString("$db")) && + (state->_tzinfo_str = PyUnicode_FromString("tzinfo")) && + (state->_as_doc_str = PyUnicode_FromString("as_doc")) && + (state->_utcoffset_str = PyUnicode_FromString("utcoffset")) && + (state->_from_uuid_str = PyUnicode_FromString("from_uuid")) && + (state->_as_uuid_str = PyUnicode_FromString("as_uuid")) && + (state->_from_bid_str = PyUnicode_FromString("from_bid")))) { + return 1; + } + + if (_load_object(&state->Binary, "bson.binary", "Binary") || + _load_object(&state->Code, "bson.code", "Code") || + _load_object(&state->ObjectId, "bson.objectid", "ObjectId") || + _load_object(&state->DBRef, "bson.dbref", "DBRef") || + _load_object(&state->Timestamp, "bson.timestamp", "Timestamp") || + _load_object(&state->MinKey, "bson.min_key", "MinKey") || + _load_object(&state->MaxKey, "bson.max_key", "MaxKey") || + _load_object(&state->UTC, "bson.tz_util", "utc") || + _load_object(&state->Regex, "bson.regex", "Regex") || + _load_object(&state->BSONInt64, "bson.int64", "Int64") || + _load_object(&state->Decimal128, "bson.decimal128", "Decimal128") || + _load_object(&state->UUID, "uuid", "UUID") || + _load_object(&state->Mapping, "collections.abc", "Mapping") || + _load_object(&state->DatetimeMS, "bson.datetime_ms", "DatetimeMS") || + _load_object(&state->_min_datetime_ms, "bson.datetime_ms", "_min_datetime_ms") || + _load_object(&state->_max_datetime_ms, "bson.datetime_ms", "_max_datetime_ms")) { + return 1; + } + /* Reload our REType hack too. */ + empty_string = PyBytes_FromString(""); + if (empty_string == NULL) { + state->REType = NULL; + return 1; + } + + if (_load_object(&re_compile, "re", "compile")) { + state->REType = NULL; + Py_DECREF(empty_string); + return 1; + } + + compiled = PyObject_CallFunction(re_compile, "O", empty_string); + Py_DECREF(re_compile); + if (compiled == NULL) { + state->REType = NULL; + Py_DECREF(empty_string); + return 1; + } + Py_INCREF(Py_TYPE(compiled)); + state->REType = Py_TYPE(compiled); + Py_DECREF(empty_string); + Py_DECREF(compiled); + return 0; +} + +/* + * Get the _type_marker from an Object. + * + * Return the type marker, 0 if there is no marker, or -1 on failure. + */ +static long _type_marker(PyObject* object, PyObject* _type_marker_str) { + PyObject* type_marker = NULL; + long type = 0; + + if (PyObject_HasAttr(object, _type_marker_str)) { + type_marker = PyObject_GetAttr(object, _type_marker_str); + if (type_marker == NULL) { + return -1; + } + } + + /* + * Python objects with broken __getattr__ implementations could return + * arbitrary types for a call to PyObject_GetAttrString. For example + * pymongo.database.Database returns a new Collection instance for + * __getattr__ calls with names that don't match an existing attribute + * or method. In some cases "value" could be a subtype of something + * we know how to serialize. Make a best effort to encode these types. + */ + if (type_marker && PyLong_CheckExact(type_marker)) { + type = PyLong_AsLong(type_marker); + Py_DECREF(type_marker); + } else { + Py_XDECREF(type_marker); + } + + return type; +} + +/* Fill out a type_registry_t* from a TypeRegistry object. + * + * Return 1 on success. options->document_class is a new reference. + * Return 0 on failure. + */ +int cbson_convert_type_registry(PyObject* registry_obj, type_registry_t* registry, PyObject* _encoder_map_str, PyObject* _decoder_map_str, PyObject* _fallback_encoder_str) { + registry->encoder_map = NULL; + registry->decoder_map = NULL; + registry->fallback_encoder = NULL; + registry->registry_obj = NULL; + + registry->encoder_map = PyObject_GetAttr(registry_obj, _encoder_map_str); + if (registry->encoder_map == NULL) { + goto fail; + } + registry->is_encoder_empty = (PyDict_Size(registry->encoder_map) == 0); + + registry->decoder_map = PyObject_GetAttr(registry_obj, _decoder_map_str); + if (registry->decoder_map == NULL) { + goto fail; + } + registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0); + + registry->fallback_encoder = PyObject_GetAttr(registry_obj, _fallback_encoder_str); + if (registry->fallback_encoder == NULL) { + goto fail; + } + registry->has_fallback_encoder = (registry->fallback_encoder != Py_None); + + registry->registry_obj = registry_obj; + Py_INCREF(registry->registry_obj); + return 1; + +fail: + Py_XDECREF(registry->encoder_map); + Py_XDECREF(registry->decoder_map); + Py_XDECREF(registry->fallback_encoder); + return 0; +} + +/* Fill out a codec_options_t* from a CodecOptions object. + * + * Return 1 on success. options->document_class is a new reference. + * Return 0 on failure. + */ +int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t* options) { + PyObject* type_registry_obj = NULL; + struct module_state *state = GETSTATE(self); + long type_marker; + if (!state) { + return 0; + } + + options->unicode_decode_error_handler = NULL; + + if (!PyArg_ParseTuple(options_obj, "ObbzOOb", + &options->document_class, + &options->tz_aware, + &options->uuid_rep, + &options->unicode_decode_error_handler, + &options->tzinfo, + &type_registry_obj, + &options->datetime_conversion)) { + return 0; + } + + type_marker = _type_marker(options->document_class, + state->_type_marker_str); + if (type_marker < 0) { + return 0; + } + + if (!cbson_convert_type_registry(type_registry_obj, + &options->type_registry, state->_encoder_map_str, state->_decoder_map_str, state->_fallback_encoder_str)) { + return 0; + } + + options->is_raw_bson = (101 == type_marker); + options->options_obj = options_obj; + + Py_INCREF(options->options_obj); + Py_INCREF(options->document_class); + Py_INCREF(options->tzinfo); + + return 1; +} + +void destroy_codec_options(codec_options_t* options) { + Py_CLEAR(options->document_class); + Py_CLEAR(options->tzinfo); + Py_CLEAR(options->options_obj); + Py_CLEAR(options->type_registry.registry_obj); + Py_CLEAR(options->type_registry.encoder_map); + Py_CLEAR(options->type_registry.decoder_map); + Py_CLEAR(options->type_registry.fallback_encoder); +} + +static int write_element_to_buffer(PyObject* self, buffer_t buffer, + int type_byte, PyObject* value, + unsigned char check_keys, + const codec_options_t* options, + unsigned char in_custom_call, + unsigned char in_fallback_call) { + int result = 0; + if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) { + return 0; + } + result = _write_element_to_buffer(self, buffer, type_byte, + value, check_keys, options, + in_custom_call, in_fallback_call); + Py_LeaveRecursiveCall(); + return result; +} + +static void +_set_cannot_encode(PyObject* value) { + if (PyLong_Check(value)) { + if ((PyLong_AsLongLong(value) == -1) && PyErr_Occurred()) { + return PyErr_SetString(PyExc_OverflowError, + "MongoDB can only handle up to 8-byte ints"); + } + } + + PyObject* type = NULL; + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument == NULL) { + goto error; + } + + type = PyObject_Type(value); + if (type == NULL) { + goto error; + } + PyErr_Format(InvalidDocument, "cannot encode object: %R, of type: %R", + value, type); +error: + Py_XDECREF(type); + Py_XDECREF(InvalidDocument); +} + +/* + * Encode a builtin Python regular expression or our custom Regex class. + * + * Sets exception and returns 0 on failure. + */ +static int _write_regex_to_buffer( + buffer_t buffer, int type_byte, PyObject* value, PyObject* _flags_str, PyObject* _pattern_str) { + + PyObject* py_flags; + PyObject* py_pattern; + PyObject* encoded_pattern; + PyObject* decoded_pattern; + long int_flags; + char flags[FLAGS_SIZE]; + char check_utf8 = 0; + const char* pattern_data; + int pattern_length, flags_length; + + /* + * Both the builtin re type and our Regex class have attributes + * "flags" and "pattern". + */ + py_flags = PyObject_GetAttr(value, _flags_str); + if (!py_flags) { + return 0; + } + int_flags = PyLong_AsLong(py_flags); + Py_DECREF(py_flags); + if (int_flags == -1 && PyErr_Occurred()) { + return 0; + } + py_pattern = PyObject_GetAttr(value, _pattern_str); + if (!py_pattern) { + return 0; + } + + if (PyUnicode_Check(py_pattern)) { + encoded_pattern = PyUnicode_AsUTF8String(py_pattern); + Py_DECREF(py_pattern); + if (!encoded_pattern) { + return 0; + } + } else { + encoded_pattern = py_pattern; + check_utf8 = 1; + } + + if (!(pattern_data = PyBytes_AsString(encoded_pattern))) { + Py_DECREF(encoded_pattern); + return 0; + } + if ((pattern_length = _downcast_and_check(PyBytes_Size(encoded_pattern), 0)) == -1) { + Py_DECREF(encoded_pattern); + return 0; + } + + if (strlen(pattern_data) != (size_t) pattern_length){ + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + PyErr_SetString(InvalidDocument, + "regex patterns must not contain the NULL byte"); + Py_DECREF(InvalidDocument); + } + Py_DECREF(encoded_pattern); + return 0; + } + + if (check_utf8) { + decoded_pattern = PyUnicode_DecodeUTF8(pattern_data, (Py_ssize_t) pattern_length, NULL); + if (decoded_pattern == NULL) { + PyErr_Clear(); + PyObject* InvalidStringData = _error("InvalidStringData"); + if (InvalidStringData) { + PyErr_SetString(InvalidStringData, + "regex patterns must be valid UTF-8"); + Py_DECREF(InvalidStringData); + } + Py_DECREF(encoded_pattern); + return 0; + } + Py_DECREF(decoded_pattern); + } + + if (!buffer_write_bytes(buffer, pattern_data, pattern_length + 1)) { + Py_DECREF(encoded_pattern); + return 0; + } + Py_DECREF(encoded_pattern); + + flags[0] = 0; + + if (int_flags & 2) { + STRCAT(flags, FLAGS_SIZE, "i"); + } + if (int_flags & 4) { + STRCAT(flags, FLAGS_SIZE, "l"); + } + if (int_flags & 8) { + STRCAT(flags, FLAGS_SIZE, "m"); + } + if (int_flags & 16) { + STRCAT(flags, FLAGS_SIZE, "s"); + } + if (int_flags & 32) { + STRCAT(flags, FLAGS_SIZE, "u"); + } + if (int_flags & 64) { + STRCAT(flags, FLAGS_SIZE, "x"); + } + flags_length = (int)strlen(flags) + 1; + if (!buffer_write_bytes(buffer, flags, flags_length)) { + return 0; + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x0B; + return 1; +} + +/* Write a single value to the buffer (also write its type_byte, for which + * space has already been reserved. + * + * returns 0 on failure */ +static int _write_element_to_buffer(PyObject* self, buffer_t buffer, + int type_byte, PyObject* value, + unsigned char check_keys, + const codec_options_t* options, + unsigned char in_custom_call, + unsigned char in_fallback_call) { + PyObject* new_value = NULL; + int retval; + int is_list; + long type; + struct module_state *state = GETSTATE(self); + if (!state) { + return 0; + } + /* + * Use _type_marker attribute instead of PyObject_IsInstance for better perf. + */ + type = _type_marker(value, state->_type_marker_str); + if (type < 0) { + return 0; + } + + switch (type) { + case 5: + { + /* Binary */ + PyObject* subtype_object; + char subtype; + const char* data; + int size; + + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x05; + subtype_object = PyObject_GetAttr(value, state->_subtype_str); + if (!subtype_object) { + return 0; + } + subtype = (char)PyLong_AsLong(subtype_object); + if (subtype == -1) { + Py_DECREF(subtype_object); + return 0; + } + size = _downcast_and_check(PyBytes_Size(value), 0); + if (size == -1) { + Py_DECREF(subtype_object); + return 0; + } + + Py_DECREF(subtype_object); + if (subtype == 2) { + int other_size = _downcast_and_check(PyBytes_Size(value), 4); + if (other_size == -1) + return 0; + if (!buffer_write_int32(buffer, other_size)) { + return 0; + } + if (!buffer_write_bytes(buffer, &subtype, 1)) { + return 0; + } + } + if (!buffer_write_int32(buffer, size)) { + return 0; + } + if (subtype != 2) { + if (!buffer_write_bytes(buffer, &subtype, 1)) { + return 0; + } + } + data = PyBytes_AsString(value); + if (!data) { + return 0; + } + if (!buffer_write_bytes(buffer, data, size)) { + return 0; + } + return 1; + } + case 7: + { + /* ObjectId */ + const char* data; + PyObject* pystring = PyObject_GetAttr(value, state->_binary_str); + if (!pystring) { + return 0; + } + data = PyBytes_AsString(pystring); + if (!data) { + Py_DECREF(pystring); + return 0; + } + if (!buffer_write_bytes(buffer, data, 12)) { + Py_DECREF(pystring); + return 0; + } + Py_DECREF(pystring); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x07; + return 1; + } + case 9: + { + /* DatetimeMS */ + long long millis; + if (!millis_from_datetime_ms(value, &millis)) { + return 0; + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x09; + return buffer_write_int64(buffer, (int64_t)millis); + } + case 11: + { + /* Regex */ + return _write_regex_to_buffer(buffer, type_byte, value, state->_flags_str, state->_pattern_str); + } + case 13: + { + /* Code */ + int start_position, + length_location, + length; + + PyObject* scope = PyObject_GetAttr(value, state->_scope_str); + if (!scope) { + return 0; + } + + if (scope == Py_None) { + Py_DECREF(scope); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x0D; + return write_string(buffer, value); + } + + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x0F; + + start_position = pymongo_buffer_get_position(buffer); + /* save space for length */ + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + Py_DECREF(scope); + return 0; + } + + if (!write_string(buffer, value)) { + Py_DECREF(scope); + return 0; + } + + if (!write_dict(self, buffer, scope, 0, options, 0)) { + Py_DECREF(scope); + return 0; + } + Py_DECREF(scope); + + length = pymongo_buffer_get_position(buffer) - start_position; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)length); + return 1; + } + case 17: + { + /* Timestamp */ + PyObject* obj; + unsigned long i; + + obj = PyObject_GetAttr(value, state->_inc_str); + if (!obj) { + return 0; + } + i = PyLong_AsUnsignedLong(obj); + Py_DECREF(obj); + if (i == (unsigned long)-1 && PyErr_Occurred()) { + return 0; + } + if (!buffer_write_int32(buffer, (int32_t)i)) { + return 0; + } + + obj = PyObject_GetAttr(value, state->_time_str); + if (!obj) { + return 0; + } + i = PyLong_AsUnsignedLong(obj); + Py_DECREF(obj); + if (i == (unsigned long)-1 && PyErr_Occurred()) { + return 0; + } + if (!buffer_write_int32(buffer, (int32_t)i)) { + return 0; + } + + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x11; + return 1; + } + case 18: + { + /* Int64 */ + const long long ll = PyLong_AsLongLong(value); + if (PyErr_Occurred()) { /* Overflow */ + PyErr_SetString(PyExc_OverflowError, + "MongoDB can only handle up to 8-byte ints"); + return 0; + } + if (!buffer_write_int64(buffer, (int64_t)ll)) { + return 0; + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x12; + return 1; + } + case 19: + { + /* Decimal128 */ + const char* data; + PyObject* pystring = PyObject_GetAttr(value, state->_bid_str); + if (!pystring) { + return 0; + } + data = PyBytes_AsString(pystring); + if (!data) { + Py_DECREF(pystring); + return 0; + } + if (!buffer_write_bytes(buffer, data, 16)) { + Py_DECREF(pystring); + return 0; + } + Py_DECREF(pystring); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x13; + return 1; + } + case 100: + { + /* DBRef */ + PyObject* as_doc = PyObject_CallMethodObjArgs(value, state->_as_doc_str, NULL); + if (!as_doc) { + return 0; + } + if (!write_dict(self, buffer, as_doc, 0, options, 0)) { + Py_DECREF(as_doc); + return 0; + } + Py_DECREF(as_doc); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x03; + return 1; + } + case 101: + { + /* RawBSONDocument */ + if (!write_raw_doc(buffer, value, state->_raw_str)) { + return 0; + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x03; + return 1; + } + case 255: + { + /* MinKey */ + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0xFF; + return 1; + } + case 127: + { + /* MaxKey */ + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x7F; + return 1; + } + } + + /* No _type_marker attribute or not one of our types. */ + + if (PyBool_Check(value)) { + const char c = (value == Py_True) ? 0x01 : 0x00; + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x08; + return buffer_write_bytes(buffer, &c, 1); + } + else if (PyLong_Check(value)) { + const long long long_long_value = PyLong_AsLongLong(value); + if (long_long_value == -1 && PyErr_Occurred()) { + /* Ignore error and give the fallback_encoder a chance. */ + PyErr_Clear(); + } else if (-2147483648LL <= long_long_value && long_long_value <= 2147483647LL) { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x10; + return buffer_write_int32(buffer, (int32_t)long_long_value); + } else { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x12; + return buffer_write_int64(buffer, (int64_t)long_long_value); + } + } else if (PyFloat_Check(value)) { + const double d = PyFloat_AsDouble(value); + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x01; + return buffer_write_double(buffer, d); + } else if (value == Py_None) { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x0A; + return 1; + } else if (PyDict_Check(value)) { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x03; + return write_dict(self, buffer, value, check_keys, options, 0); + } else if ((is_list = PyList_Check(value)) || PyTuple_Check(value)) { + Py_ssize_t items, i; + int start_position, + length_location, + length; + char zero = 0; + + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x04; + start_position = pymongo_buffer_get_position(buffer); + + /* save space for length */ + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + return 0; + } + if (is_list) { + items = PyList_Size(value); + } else { + items = PyTuple_Size(value); + } + if (items > BSON_MAX_SIZE) { + PyObject* BSONError = _error("BSONError"); + if (BSONError) { + PyErr_SetString(BSONError, + "Too many items to serialize."); + Py_DECREF(BSONError); + } + return 0; + } + for(i = 0; i < items; i++) { + int list_type_byte = pymongo_buffer_save_space(buffer, 1); + char name[BUF_SIZE]; + PyObject* item_value; + + if (list_type_byte == -1) { + return 0; + } + int res = LL2STR(name, (long long)i); + if (res == -1) { + return 0; + } + if (!buffer_write_bytes(buffer, name, (int)strlen(name) + 1)) { + return 0; + } + if (is_list) { + item_value = PyList_GET_ITEM(value, i); + } else { + item_value = PyTuple_GET_ITEM(value, i); + } + if (!item_value) { + return 0; + } + if (!write_element_to_buffer(self, buffer, list_type_byte, + item_value, check_keys, options, + 0, 0)) { + return 0; + } + } + + /* write null byte and fill in length */ + if (!buffer_write_bytes(buffer, &zero, 1)) { + return 0; + } + length = pymongo_buffer_get_position(buffer) - start_position; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)length); + return 1; + /* Python3 special case. Store bytes as BSON binary subtype 0. */ + } else if (PyBytes_Check(value)) { + char subtype = 0; + int size; + const char* data = PyBytes_AS_STRING(value); + if (!data) + return 0; + if ((size = _downcast_and_check(PyBytes_GET_SIZE(value), 0)) == -1) + return 0; + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x05; + if (!buffer_write_int32(buffer, (int32_t)size)) { + return 0; + } + if (!buffer_write_bytes(buffer, &subtype, 1)) { + return 0; + } + if (!buffer_write_bytes(buffer, data, size)) { + return 0; + } + return 1; + } else if (PyUnicode_Check(value)) { + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x02; + return write_unicode(buffer, value); + } else if (PyDateTime_Check(value)) { + long long millis; + PyObject* utcoffset = PyObject_CallMethodObjArgs(value, state->_utcoffset_str , NULL); + if (utcoffset == NULL) + return 0; + if (utcoffset != Py_None) { + PyObject* result = PyNumber_Subtract(value, utcoffset); + Py_DECREF(utcoffset); + if (!result) { + return 0; + } + millis = millis_from_datetime(result); + Py_DECREF(result); + } else { + millis = millis_from_datetime(value); + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x09; + return buffer_write_int64(buffer, (int64_t)millis); + } else if (PyObject_TypeCheck(value, state->REType)) { + return _write_regex_to_buffer(buffer, type_byte, value, state->_flags_str, state->_pattern_str); + } else if (PyObject_IsInstance(value, state->Mapping)) { + /* PyObject_IsInstance returns -1 on error */ + if (PyErr_Occurred()) { + return 0; + } + *(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x03; + return write_dict(self, buffer, value, check_keys, options, 0); + } else if (PyObject_IsInstance(value, state->UUID)) { + PyObject* binary_value = NULL; + PyObject *uuid_rep_obj = NULL; + int result; + + /* PyObject_IsInstance returns -1 on error */ + if (PyErr_Occurred()) { + return 0; + } + + if (!(uuid_rep_obj = PyLong_FromLong(options->uuid_rep))) { + return 0; + } + binary_value = PyObject_CallMethodObjArgs(state->Binary, state->_from_uuid_str, value, uuid_rep_obj, NULL); + Py_DECREF(uuid_rep_obj); + + if (binary_value == NULL) { + return 0; + } + + result = _write_element_to_buffer(self, buffer, + type_byte, binary_value, + check_keys, options, + in_custom_call, + in_fallback_call); + Py_DECREF(binary_value); + return result; + } + + /* Try a custom encoder if one is provided and we have not already + * attempted to use a type encoder. */ + if (!in_custom_call && !options->type_registry.is_encoder_empty) { + PyObject* value_type = NULL; + PyObject* converter = NULL; + value_type = PyObject_Type(value); + if (value_type == NULL) { + return 0; + } + converter = PyDict_GetItem(options->type_registry.encoder_map, value_type); + Py_XDECREF(value_type); + if (converter != NULL) { + /* Transform types that have a registered converter. + * A new reference is created upon transformation. */ + new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + if (new_value == NULL) { + return 0; + } + retval = write_element_to_buffer(self, buffer, type_byte, new_value, + check_keys, options, 1, 0); + Py_XDECREF(new_value); + return retval; + } + } + + /* Try the fallback encoder if one is provided and we have not already + * attempted to use the fallback encoder. */ + if (!in_fallback_call && options->type_registry.has_fallback_encoder) { + new_value = PyObject_CallFunctionObjArgs( + options->type_registry.fallback_encoder, value, NULL); + if (new_value == NULL) { + // propagate any exception raised by the callback + return 0; + } + retval = write_element_to_buffer(self, buffer, type_byte, new_value, + check_keys, options, 0, 1); + Py_XDECREF(new_value); + return retval; + } + + /* We can't determine value's type. Fail. */ + _set_cannot_encode(value); + return 0; +} + +static int check_key_name(const char* name, int name_length) { + + if (name_length > 0 && name[0] == '$') { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + PyObject* errmsg = PyUnicode_FromFormat( + "key '%s' must not start with '$'", name); + if (errmsg) { + PyErr_SetObject(InvalidDocument, errmsg); + Py_DECREF(errmsg); + } + Py_DECREF(InvalidDocument); + } + return 0; + } + if (strchr(name, '.')) { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + PyObject* errmsg = PyUnicode_FromFormat( + "key '%s' must not contain '.'", name); + if (errmsg) { + PyErr_SetObject(InvalidDocument, errmsg); + Py_DECREF(errmsg); + } + Py_DECREF(InvalidDocument); + } + return 0; + } + return 1; +} + +/* Write a (key, value) pair to the buffer. + * + * Returns 0 on failure */ +int write_pair(PyObject* self, buffer_t buffer, const char* name, int name_length, + PyObject* value, unsigned char check_keys, + const codec_options_t* options, unsigned char allow_id) { + int type_byte; + + /* Don't write any _id elements unless we're explicitly told to - + * _id has to be written first so we do so, but don't bother + * deleting it from the dictionary being written. */ + if (!allow_id && strcmp(name, "_id") == 0) { + return 1; + } + + type_byte = pymongo_buffer_save_space(buffer, 1); + if (type_byte == -1) { + return 0; + } + if (check_keys && !check_key_name(name, name_length)) { + return 0; + } + if (!buffer_write_bytes(buffer, name, name_length + 1)) { + return 0; + } + if (!write_element_to_buffer(self, buffer, type_byte, + value, check_keys, options, 0, 0)) { + return 0; + } + return 1; +} + +int decode_and_write_pair(PyObject* self, buffer_t buffer, + PyObject* key, PyObject* value, + unsigned char check_keys, + const codec_options_t* options, + unsigned char top_level) { + PyObject* encoded; + const char* data; + int size; + if (PyUnicode_Check(key)) { + encoded = PyUnicode_AsUTF8String(key); + if (!encoded) { + return 0; + } + if (!(data = PyBytes_AS_STRING(encoded))) { + Py_DECREF(encoded); + return 0; + } + if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) { + Py_DECREF(encoded); + return 0; + } + if (strlen(data) != (size_t)(size - 1)) { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + PyErr_SetString(InvalidDocument, + "Key names must not contain the NULL byte"); + Py_DECREF(InvalidDocument); + } + Py_DECREF(encoded); + return 0; + } + } else { + PyObject* InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument) { + PyObject* repr = PyObject_Repr(key); + if (repr) { + PyObject* errmsg = PyUnicode_FromString( + "documents must have only string keys, key was "); + if (errmsg) { + PyObject* error = PyUnicode_Concat(errmsg, repr); + if (error) { + PyErr_SetObject(InvalidDocument, error); + Py_DECREF(error); + } + Py_DECREF(errmsg); + Py_DECREF(repr); + } else { + Py_DECREF(repr); + } + } + Py_DECREF(InvalidDocument); + } + return 0; + } + + /* If top_level is True, don't allow writing _id here - it was already written. */ + if (!write_pair(self, buffer, data, + size - 1, value, check_keys, options, !top_level)) { + Py_DECREF(encoded); + return 0; + } + + Py_DECREF(encoded); + return 1; +} + + +/* Write a RawBSONDocument to the buffer. + * Returns the number of bytes written or 0 on failure. + */ +static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw_str) { + char* bytes; + Py_ssize_t len; + int len_int; + int bytes_written = 0; + PyObject* bytes_obj = NULL; + + bytes_obj = PyObject_GetAttr(raw, _raw_str); + if (!bytes_obj) { + goto fail; + } + + if (-1 == PyBytes_AsStringAndSize(bytes_obj, &bytes, &len)) { + goto fail; + } + len_int = _downcast_and_check(len, 0); + if (-1 == len_int) { + goto fail; + } + if (!buffer_write_bytes(buffer, bytes, len_int)) { + goto fail; + } + bytes_written = len_int; +fail: + Py_XDECREF(bytes_obj); + return bytes_written; +} + +/* returns the number of bytes written or 0 on failure */ +int write_dict(PyObject* self, buffer_t buffer, + PyObject* dict, unsigned char check_keys, + const codec_options_t* options, unsigned char top_level) { + PyObject* key; + PyObject* iter; + char zero = 0; + int length; + int length_location; + struct module_state *state = GETSTATE(self); + long type_marker; + int is_dict = PyDict_Check(dict); + if (!state) { + return 0; + } + + if (!is_dict) { + /* check for RawBSONDocument */ + type_marker = _type_marker(dict, state->_type_marker_str); + if (type_marker < 0) { + return 0; + } + + if (101 == type_marker) { + return write_raw_doc(buffer, dict, state->_raw_str); + } + + if (!PyObject_IsInstance(dict, state->Mapping)) { + PyObject* repr; + if ((repr = PyObject_Repr(dict))) { + PyObject* errmsg = PyUnicode_FromString( + "encoder expected a mapping type but got: "); + if (errmsg) { + PyObject* error = PyUnicode_Concat(errmsg, repr); + if (error) { + PyErr_SetObject(PyExc_TypeError, error); + Py_DECREF(error); + } + Py_DECREF(errmsg); + Py_DECREF(repr); + } + else { + Py_DECREF(repr); + } + } else { + PyErr_SetString(PyExc_TypeError, + "encoder expected a mapping type"); + } + + return 0; + } + /* PyObject_IsInstance returns -1 on error */ + if (PyErr_Occurred()) { + return 0; + } + } + + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + return 0; + } + + /* Write _id first if this is a top level doc. */ + if (top_level) { + /* + * If "dict" is a defaultdict we don't want to call + * PyObject_GetItem on it. That would **create** + * an _id where one didn't previously exist (PYTHON-871). + */ + if (is_dict) { + /* PyDict_GetItem returns a borrowed reference. */ + PyObject* _id = PyDict_GetItem(dict, state->_id_str); + if (_id) { + if (!write_pair(self, buffer, "_id", 3, + _id, check_keys, options, 1)) { + return 0; + } + } + } else if (PyMapping_HasKey(dict, state->_id_str)) { + PyObject* _id = PyObject_GetItem(dict, state->_id_str); + if (!_id) { + return 0; + } + if (!write_pair(self, buffer, "_id", 3, + _id, check_keys, options, 1)) { + Py_DECREF(_id); + return 0; + } + /* PyObject_GetItem returns a new reference. */ + Py_DECREF(_id); + } + } + + if (is_dict) { + PyObject* value; + Py_ssize_t pos = 0; + while (PyDict_Next(dict, &pos, &key, &value)) { + if (!decode_and_write_pair(self, buffer, key, value, + check_keys, options, top_level)) { + return 0; + } + } + } else { + iter = PyObject_GetIter(dict); + if (iter == NULL) { + return 0; + } + while ((key = PyIter_Next(iter)) != NULL) { + PyObject* value = PyObject_GetItem(dict, key); + if (!value) { + PyErr_SetObject(PyExc_KeyError, key); + Py_DECREF(key); + Py_DECREF(iter); + return 0; + } + if (!decode_and_write_pair(self, buffer, key, value, + check_keys, options, top_level)) { + Py_DECREF(key); + Py_DECREF(value); + Py_DECREF(iter); + return 0; + } + Py_DECREF(key); + Py_DECREF(value); + } + Py_DECREF(iter); + if (PyErr_Occurred()) { + return 0; + } + } + + /* write null byte and fill in length */ + if (!buffer_write_bytes(buffer, &zero, 1)) { + return 0; + } + length = pymongo_buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)length); + return length; +} + +static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) { + PyObject* dict; + PyObject* result; + unsigned char check_keys; + unsigned char top_level = 1; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer; + PyObject* raw_bson_document_bytes_obj; + long type_marker; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + if (!(PyArg_ParseTuple(args, "ObO|b", &dict, &check_keys, + &options_obj, &top_level) && + convert_codec_options(self, options_obj, &options))) { + return NULL; + } + + /* check for RawBSONDocument */ + type_marker = _type_marker(dict, state->_type_marker_str); + if (type_marker < 0) { + destroy_codec_options(&options); + return NULL; + } else if (101 == type_marker) { + destroy_codec_options(&options); + raw_bson_document_bytes_obj = PyObject_GetAttr(dict, state->_raw_str); + if (NULL == raw_bson_document_bytes_obj) { + return NULL; + } + return raw_bson_document_bytes_obj; + } + + buffer = pymongo_buffer_new(); + if (!buffer) { + destroy_codec_options(&options); + return NULL; + } + + if (!write_dict(self, buffer, dict, check_keys, &options, top_level)) { + destroy_codec_options(&options); + pymongo_buffer_free(buffer); + return NULL; + } + + /* objectify buffer */ + result = Py_BuildValue("y#", pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer)); + destroy_codec_options(&options); + pymongo_buffer_free(buffer); + return result; +} + +/* + * Hook for optional decoding BSON documents to DBRef. + */ +static PyObject *_dbref_hook(PyObject* self, PyObject* value) { + struct module_state *state = GETSTATE(self); + PyObject* ref = NULL; + PyObject* id = NULL; + PyObject* database = NULL; + PyObject* ret = NULL; + int db_present = 0; + if (!state) { + return NULL; + } + + /* Decoding for DBRefs */ + if (PyMapping_HasKey(value, state->_dollar_ref_str) && PyMapping_HasKey(value, state->_dollar_id_str)) { /* DBRef */ + ref = PyObject_GetItem(value, state->_dollar_ref_str); + /* PyObject_GetItem returns NULL to indicate error. */ + if (!ref) { + goto invalid; + } + id = PyObject_GetItem(value, state->_dollar_id_str); + /* PyObject_GetItem returns NULL to indicate error. */ + if (!id) { + goto invalid; + } + + if (PyMapping_HasKey(value, state->_dollar_db_str)) { + database = PyObject_GetItem(value, state->_dollar_db_str); + if (!database) { + goto invalid; + } + db_present = 1; + } else { + database = Py_None; + Py_INCREF(database); + } + + // check types + if (!(PyUnicode_Check(ref) && (database == Py_None || PyUnicode_Check(database)))) { + ret = value; + goto invalid; + } + + PyMapping_DelItem(value, state->_dollar_ref_str); + PyMapping_DelItem(value, state->_dollar_id_str); + if (db_present) { + PyMapping_DelItem(value, state->_dollar_db_str); + } + + ret = PyObject_CallFunctionObjArgs(state->DBRef, ref, id, database, value, NULL); + Py_DECREF(value); + } else { + ret = value; + } +invalid: + Py_XDECREF(ref); + Py_XDECREF(id); + Py_XDECREF(database); + return ret; +} + +static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, + unsigned* position, unsigned char type, + unsigned max, const codec_options_t* options, int raw_array) { + struct module_state *state = GETSTATE(self); + PyObject* value = NULL; + if (!state) { + return NULL; + } + switch (type) { + case 1: + { + double d; + if (max < 8) { + goto invalid; + } + memcpy(&d, buffer + *position, 8); + value = PyFloat_FromDouble(BSON_DOUBLE_FROM_LE(d)); + *position += 8; + break; + } + case 2: + case 14: + { + uint32_t value_length; + if (max < 4) { + goto invalid; + } + memcpy(&value_length, buffer + *position, 4); + value_length = BSON_UINT32_FROM_LE(value_length); + /* Encoded string length + string */ + if (!value_length || max < value_length || max < 4 + value_length) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + value_length - 1]) { + goto invalid; + } + value = PyUnicode_DecodeUTF8( + buffer + *position, value_length - 1, + options->unicode_decode_error_handler); + if (!value) { + goto invalid; + } + *position += value_length; + break; + } + case 3: + { + uint32_t size; + + if (max < 4) { + goto invalid; + } + memcpy(&size, buffer + *position, 4); + size = BSON_UINT32_FROM_LE(size); + if (size < BSON_MIN_SIZE || max < size) { + goto invalid; + } + /* Check for bad eoo */ + if (buffer[*position + size - 1]) { + goto invalid; + } + + value = elements_to_dict(self, buffer + *position, + size, options); + if (!value) { + goto invalid; + } + + if (options->is_raw_bson) { + *position += size; + break; + } + + /* Hook for DBRefs */ + value = _dbref_hook(self, value); + if (!value) { + goto invalid; + } + + *position += size; + break; + } + case 4: + { + uint32_t size, end; + + if (max < 4) { + goto invalid; + } + memcpy(&size, buffer + *position, 4); + size = BSON_UINT32_FROM_LE(size); + if (size < BSON_MIN_SIZE || max < size) { + goto invalid; + } + + end = *position + size - 1; + /* Check for bad eoo */ + if (buffer[end]) { + goto invalid; + } + + if (raw_array != 0) { + // Treat it as a binary buffer. + value = PyBytes_FromStringAndSize(buffer + *position, size); + *position += size; + break; + } + + *position += 4; + + value = PyList_New(0); + if (!value) { + goto invalid; + } + while (*position < end) { + PyObject* to_append; + + unsigned char bson_type = (unsigned char)buffer[(*position)++]; + + size_t key_size = strlen(buffer + *position); + if (max < key_size) { + Py_DECREF(value); + goto invalid; + } + /* just skip the key, they're in order. */ + *position += (unsigned)key_size + 1; + if (Py_EnterRecursiveCall(" while decoding a list value")) { + Py_DECREF(value); + goto invalid; + } + to_append = get_value(self, name, buffer, position, bson_type, + max - (unsigned)key_size, options, raw_array); + Py_LeaveRecursiveCall(); + if (!to_append) { + Py_DECREF(value); + goto invalid; + } + if (PyList_Append(value, to_append) < 0) { + Py_DECREF(value); + Py_DECREF(to_append); + goto invalid; + } + Py_DECREF(to_append); + } + if (*position != end) { + goto invalid; + } + (*position)++; + break; + } + case 5: + { + PyObject* data; + PyObject* st; + uint32_t length, length2; + unsigned char subtype; + + if (max < 5) { + goto invalid; + } + memcpy(&length, buffer + *position, 4); + length = BSON_UINT32_FROM_LE(length); + if (max < length) { + goto invalid; + } + + subtype = (unsigned char)buffer[*position + 4]; + *position += 5; + if (subtype == 2) { + if (length < 4) { + goto invalid; + } + memcpy(&length2, buffer + *position, 4); + length2 = BSON_UINT32_FROM_LE(length2); + if (length2 != length - 4) { + goto invalid; + } + } + /* Python3 special case. Decode BSON binary subtype 0 to bytes. */ + if (subtype == 0) { + value = PyBytes_FromStringAndSize(buffer + *position, length); + *position += length; + break; + } + if (subtype == 2) { + data = PyBytes_FromStringAndSize(buffer + *position + 4, length - 4); + } else { + data = PyBytes_FromStringAndSize(buffer + *position, length); + } + if (!data) { + goto invalid; + } + /* Encode as UUID or Binary based on options->uuid_rep */ + if (subtype == 3 || subtype == 4) { + PyObject* binary_value = NULL; + char uuid_rep = options->uuid_rep; + + /* UUID should always be 16 bytes */ + if (length != 16) { + goto uuiderror; + } + + binary_value = PyObject_CallFunction(state->Binary, "(Oi)", data, subtype); + if (binary_value == NULL) { + goto uuiderror; + } + + if ((uuid_rep == UNSPECIFIED) || + (subtype == 4 && uuid_rep != STANDARD) || + (subtype == 3 && uuid_rep == STANDARD)) { + value = binary_value; + Py_INCREF(value); + } else { + PyObject *uuid_rep_obj = PyLong_FromLong(uuid_rep); + if (!uuid_rep_obj) { + goto uuiderror; + } + value = PyObject_CallMethodObjArgs(binary_value, state->_as_uuid_str, uuid_rep_obj, NULL); + Py_DECREF(uuid_rep_obj); + } + + uuiderror: + Py_XDECREF(binary_value); + Py_DECREF(data); + if (!value) { + goto invalid; + } + *position += length; + break; + } + + st = PyLong_FromLong(subtype); + if (!st) { + Py_DECREF(data); + goto invalid; + } + value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL); + Py_DECREF(st); + Py_DECREF(data); + if (!value) { + goto invalid; + } + *position += length; + break; + } + case 6: + case 10: + { + value = Py_None; + Py_INCREF(value); + break; + } + case 7: + { + if (max < 12) { + goto invalid; + } + value = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12); + *position += 12; + break; + } + case 8: + { + char boolean_raw = buffer[(*position)++]; + if (0 == boolean_raw) { + value = Py_False; + } else if (1 == boolean_raw) { + value = Py_True; + } else { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_Format(InvalidBSON, "invalid boolean value: %x", boolean_raw); + Py_DECREF(InvalidBSON); + } + return NULL; + } + Py_INCREF(value); + break; + } + case 9: + { + PyObject* naive; + PyObject* replace; + PyObject* args; + PyObject* kwargs; + PyObject* astimezone; + int64_t millis; + if (max < 8) { + goto invalid; + } + memcpy(&millis, buffer + *position, 8); + millis = (int64_t)BSON_UINT64_FROM_LE(millis); + *position += 8; + + if (options->datetime_conversion == DATETIME_MS){ + value = datetime_ms_from_millis(self, millis); + break; + } + + int dt_clamp = options->datetime_conversion == DATETIME_CLAMP; + int dt_auto = options->datetime_conversion == DATETIME_AUTO; + + + if (dt_clamp || dt_auto){ + PyObject *min_millis_fn_res; + PyObject *max_millis_fn_res; + int64_t min_millis; + int64_t max_millis; + + if (options->tz_aware){ + PyObject* tzinfo = options->tzinfo; + if (tzinfo == Py_None) { + // Default to UTC. + tzinfo = state->UTC; + } + min_millis_fn_res = PyObject_CallFunctionObjArgs(state->_min_datetime_ms, tzinfo, NULL); + max_millis_fn_res = PyObject_CallFunctionObjArgs(state->_max_datetime_ms, tzinfo, NULL); + } else { + min_millis_fn_res = PyObject_CallObject(state->_min_datetime_ms, NULL); + max_millis_fn_res = PyObject_CallObject(state->_max_datetime_ms, NULL); + } + + if (!min_millis_fn_res || !max_millis_fn_res){ + Py_XDECREF(min_millis_fn_res); + Py_XDECREF(max_millis_fn_res); + goto invalid; + } + + min_millis = PyLong_AsLongLong(min_millis_fn_res); + max_millis = PyLong_AsLongLong(max_millis_fn_res); + + if ((min_millis == -1 || max_millis == -1) && PyErr_Occurred()) + { + // min/max_millis check + goto invalid; + } + + if (dt_clamp) { + if (millis < min_millis) { + millis = min_millis; + } else if (millis > max_millis) { + millis = max_millis; + } + // Continues from here to return a datetime. + } else { // dt_auto + if (millis < min_millis || millis > max_millis){ + value = datetime_ms_from_millis(self, millis); + break; // Out-of-range so done. + } + } + } + + naive = datetime_from_millis(millis); + if (!options->tz_aware) { /* In the naive case, we're done here. */ + value = naive; + break; + } + + if (!naive) { + goto invalid; + } + replace = PyObject_GetAttr(naive, state->_replace_str); + Py_DECREF(naive); + if (!replace) { + goto invalid; + } + args = PyTuple_New(0); + if (!args) { + Py_DECREF(replace); + goto invalid; + } + kwargs = PyDict_New(); + if (!kwargs) { + Py_DECREF(replace); + Py_DECREF(args); + goto invalid; + } + if (PyDict_SetItem(kwargs, state->_tzinfo_str, state->UTC) == -1) { + Py_DECREF(replace); + Py_DECREF(args); + Py_DECREF(kwargs); + goto invalid; + } + value = PyObject_Call(replace, args, kwargs); + if (!value) { + Py_DECREF(replace); + Py_DECREF(args); + Py_DECREF(kwargs); + goto invalid; + } + + /* convert to local time */ + if (options->tzinfo != Py_None) { + astimezone = PyObject_GetAttr(value, state->_astimezone_str); + Py_DECREF(value); + if (!astimezone) { + Py_DECREF(replace); + Py_DECREF(args); + Py_DECREF(kwargs); + goto invalid; + } + value = PyObject_CallFunctionObjArgs(astimezone, options->tzinfo, NULL); + Py_DECREF(astimezone); + } + + Py_DECREF(replace); + Py_DECREF(args); + Py_DECREF(kwargs); + break; + } + case 11: + { + PyObject* pattern; + int flags; + size_t flags_length, i; + size_t pattern_length = strlen(buffer + *position); + if (pattern_length > BSON_MAX_SIZE || max < pattern_length) { + goto invalid; + } + pattern = PyUnicode_DecodeUTF8( + buffer + *position, pattern_length, + options->unicode_decode_error_handler); + if (!pattern) { + goto invalid; + } + *position += (unsigned)pattern_length + 1; + flags_length = strlen(buffer + *position); + if (flags_length > BSON_MAX_SIZE || + (BSON_MAX_SIZE - pattern_length) < flags_length) { + Py_DECREF(pattern); + goto invalid; + } + if (max < pattern_length + flags_length) { + Py_DECREF(pattern); + goto invalid; + } + flags = 0; + for (i = 0; i < flags_length; i++) { + if (buffer[*position + i] == 'i') { + flags |= 2; + } else if (buffer[*position + i] == 'l') { + flags |= 4; + } else if (buffer[*position + i] == 'm') { + flags |= 8; + } else if (buffer[*position + i] == 's') { + flags |= 16; + } else if (buffer[*position + i] == 'u') { + flags |= 32; + } else if (buffer[*position + i] == 'x') { + flags |= 64; + } + } + *position += (unsigned)flags_length + 1; + + value = PyObject_CallFunction(state->Regex, "Oi", pattern, flags); + Py_DECREF(pattern); + break; + } + case 12: + { + uint32_t coll_length; + PyObject* collection; + PyObject* id = NULL; + + if (max < 4) { + goto invalid; + } + memcpy(&coll_length, buffer + *position, 4); + coll_length = BSON_UINT32_FROM_LE(coll_length); + /* Encoded string length + string + 12 byte ObjectId */ + if (!coll_length || max < coll_length || max < 4 + coll_length + 12) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + coll_length - 1]) { + goto invalid; + } + + collection = PyUnicode_DecodeUTF8( + buffer + *position, coll_length - 1, + options->unicode_decode_error_handler); + if (!collection) { + goto invalid; + } + *position += coll_length; + + id = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12); + if (!id) { + Py_DECREF(collection); + goto invalid; + } + *position += 12; + value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL); + Py_DECREF(collection); + Py_DECREF(id); + break; + } + case 13: + { + PyObject* code; + uint32_t value_length; + if (max < 4) { + goto invalid; + } + memcpy(&value_length, buffer + *position, 4); + value_length = BSON_UINT32_FROM_LE(value_length); + /* Encoded string length + string */ + if (!value_length || max < value_length || max < 4 + value_length) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + value_length - 1]) { + goto invalid; + } + code = PyUnicode_DecodeUTF8( + buffer + *position, value_length - 1, + options->unicode_decode_error_handler); + if (!code) { + goto invalid; + } + *position += value_length; + value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL); + Py_DECREF(code); + break; + } + case 15: + { + uint32_t c_w_s_size; + uint32_t code_size; + uint32_t scope_size; + uint32_t len; + PyObject* code; + PyObject* scope; + + if (max < 8) { + goto invalid; + } + + memcpy(&c_w_s_size, buffer + *position, 4); + c_w_s_size = BSON_UINT32_FROM_LE(c_w_s_size); + *position += 4; + + if (max < c_w_s_size) { + goto invalid; + } + + memcpy(&code_size, buffer + *position, 4); + code_size = BSON_UINT32_FROM_LE(code_size); + /* code_w_scope length + code length + code + scope length */ + len = 4 + 4 + code_size + 4; + if (!code_size || max < code_size || max < len || len < code_size) { + goto invalid; + } + *position += 4; + /* Strings must end in \0 */ + if (buffer[*position + code_size - 1]) { + goto invalid; + } + code = PyUnicode_DecodeUTF8( + buffer + *position, code_size - 1, + options->unicode_decode_error_handler); + if (!code) { + goto invalid; + } + *position += code_size; + + memcpy(&scope_size, buffer + *position, 4); + scope_size = BSON_UINT32_FROM_LE(scope_size); + /* code length + code + scope length + scope */ + len = 4 + 4 + code_size + scope_size; + if (scope_size < BSON_MIN_SIZE || len != c_w_s_size || len < scope_size) { + Py_DECREF(code); + goto invalid; + } + + /* Check for bad eoo */ + if (buffer[*position + scope_size - 1]) { + goto invalid; + } + scope = elements_to_dict(self, buffer + *position, + scope_size, options); + if (!scope) { + Py_DECREF(code); + goto invalid; + } + *position += scope_size; + + value = PyObject_CallFunctionObjArgs(state->Code, code, scope, NULL); + Py_DECREF(code); + Py_DECREF(scope); + break; + } + case 16: + { + int32_t i; + if (max < 4) { + goto invalid; + } + memcpy(&i, buffer + *position, 4); + i = (int32_t)BSON_UINT32_FROM_LE(i); + value = PyLong_FromLong(i); + if (!value) { + goto invalid; + } + *position += 4; + break; + } + case 17: + { + uint32_t time, inc; + if (max < 8) { + goto invalid; + } + memcpy(&inc, buffer + *position, 4); + memcpy(&time, buffer + *position + 4, 4); + inc = BSON_UINT32_FROM_LE(inc); + time = BSON_UINT32_FROM_LE(time); + value = PyObject_CallFunction(state->Timestamp, "II", time, inc); + *position += 8; + break; + } + case 18: + { + int64_t ll; + if (max < 8) { + goto invalid; + } + memcpy(&ll, buffer + *position, 8); + ll = (int64_t)BSON_UINT64_FROM_LE(ll); + value = PyObject_CallFunction(state->BSONInt64, "L", ll); + *position += 8; + break; + } + case 19: + { + if (max < 16) { + goto invalid; + } + PyObject *_bytes_obj = PyBytes_FromStringAndSize(buffer + *position, (Py_ssize_t)16); + if (!_bytes_obj) { + goto invalid; + } + value = PyObject_CallMethodObjArgs(state->Decimal128, state->_from_bid_str, _bytes_obj, NULL); + Py_DECREF(_bytes_obj); + *position += 16; + break; + } + case 255: + { + value = PyObject_CallFunctionObjArgs(state->MinKey, NULL); + break; + } + case 127: + { + value = PyObject_CallFunctionObjArgs(state->MaxKey, NULL); + break; + } + default: + { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyObject* bobj = PyBytes_FromFormat("%c", type); + if (bobj) { + PyObject* repr = PyObject_Repr(bobj); + Py_DECREF(bobj); + /* + * See http://bugs.python.org/issue22023 for why we can't + * just use PyUnicode_FromFormat with %S or %R to do this + * work. + */ + if (repr) { + PyObject* left = PyUnicode_FromString( + "Detected unknown BSON type "); + if (left) { + PyObject* lmsg = PyUnicode_Concat(left, repr); + Py_DECREF(left); + if (lmsg) { + PyObject* errmsg = PyUnicode_FromFormat( + "%U for fieldname '%U'. Are you using the " + "latest driver version?", lmsg, name); + if (errmsg) { + PyErr_SetObject(InvalidBSON, errmsg); + Py_DECREF(errmsg); + } + Py_DECREF(lmsg); + } + } + Py_DECREF(repr); + } + } + Py_DECREF(InvalidBSON); + } + goto invalid; + } + } + + if (value) { + if (!options->type_registry.is_decoder_empty) { + PyObject* value_type = NULL; + PyObject* converter = NULL; + value_type = PyObject_Type(value); + if (value_type == NULL) { + goto invalid; + } + converter = PyDict_GetItem(options->type_registry.decoder_map, value_type); + if (converter != NULL) { + PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + Py_DECREF(value_type); + Py_DECREF(value); + return new_value; + } else { + Py_DECREF(value_type); + return value; + } + } + return value; + } + + invalid: + + /* + * Wrap any non-InvalidBSON errors in InvalidBSON. + */ + if (PyErr_Occurred()) { + PyObject *etype, *evalue, *etrace; + PyObject *InvalidBSON; + + /* + * Calling _error clears the error state, so fetch it first. + */ + PyErr_Fetch(&etype, &evalue, &etrace); + + /* Dont reraise anything but PyExc_Exceptions as InvalidBSON. */ + if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { + InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) { + /* + * Raise InvalidBSON(str(e)). + */ + Py_DECREF(etype); + etype = InvalidBSON; + + if (evalue) { + PyObject *msg = PyObject_Str(evalue); + Py_DECREF(evalue); + evalue = msg; + } + PyErr_NormalizeException(&etype, &evalue, &etrace); + } else { + /* + * The current exception matches InvalidBSON, so we don't + * need this reference after all. + */ + Py_DECREF(InvalidBSON); + } + } + } + /* Steals references to args. */ + PyErr_Restore(etype, evalue, etrace); + } else { + PyObject *InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid length or type code"); + Py_DECREF(InvalidBSON); + } + } + return NULL; +} + +/* + * Get the next 'name' and 'value' from a document in a string, whose position + * is provided. + * + * Returns the position of the next element in the document, or -1 on error. + */ +static int _element_to_dict(PyObject* self, const char* string, + unsigned position, unsigned max, + const codec_options_t* options, + int raw_array, + PyObject** name, PyObject** value) { + unsigned char type = (unsigned char)string[position++]; + size_t name_length = strlen(string + position); + if (name_length > BSON_MAX_SIZE || position + name_length >= max) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "field name too large"); + Py_DECREF(InvalidBSON); + } + return -1; + } + *name = PyUnicode_DecodeUTF8( + string + position, name_length, + options->unicode_decode_error_handler); + if (!*name) { + /* If NULL is returned then wrap the UnicodeDecodeError + in an InvalidBSON error */ + PyObject *etype, *evalue, *etrace; + PyObject *InvalidBSON; + + PyErr_Fetch(&etype, &evalue, &etrace); + if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { + InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + Py_DECREF(etype); + etype = InvalidBSON; + + if (evalue) { + PyObject *msg = PyObject_Str(evalue); + Py_DECREF(evalue); + evalue = msg; + } + PyErr_NormalizeException(&etype, &evalue, &etrace); + } + } + PyErr_Restore(etype, evalue, etrace); + return -1; + } + position += (unsigned)name_length + 1; + *value = get_value(self, *name, string, &position, type, + max - position, options, raw_array); + if (!*value) { + Py_DECREF(*name); + return -1; + } + return position; +} + +static PyObject* _cbson_element_to_dict(PyObject* self, PyObject* args) { + /* TODO: Support buffer protocol */ + char* string; + PyObject* bson; + PyObject* options_obj; + codec_options_t options; + unsigned position; + unsigned max; + int new_position; + int raw_array = 0; + PyObject* name; + PyObject* value; + PyObject* result_tuple; + + if (!(PyArg_ParseTuple(args, "OIIOp", &bson, &position, &max, + &options_obj, &raw_array) && + convert_codec_options(self, options_obj, &options))) { + return NULL; + } + + if (!PyBytes_Check(bson)) { + PyErr_SetString(PyExc_TypeError, "argument to _element_to_dict must be a bytes object"); + return NULL; + } + string = PyBytes_AS_STRING(bson); + + new_position = _element_to_dict(self, string, position, max, &options, raw_array, &name, &value); + if (new_position < 0) { + return NULL; + } + + result_tuple = Py_BuildValue("NNi", name, value, new_position); + if (!result_tuple) { + Py_DECREF(name); + Py_DECREF(value); + return NULL; + } + + destroy_codec_options(&options); + return result_tuple; +} + +static PyObject* _elements_to_dict(PyObject* self, const char* string, + unsigned max, + const codec_options_t* options) { + unsigned position = 0; + PyObject* dict = PyObject_CallObject(options->document_class, NULL); + if (!dict) { + return NULL; + } + int raw_array = 0; + while (position < max) { + PyObject* name = NULL; + PyObject* value = NULL; + int new_position; + + new_position = _element_to_dict( + self, string, position, max, options, raw_array, &name, &value); + if (new_position < 0) { + Py_DECREF(dict); + return NULL; + } else { + position = (unsigned)new_position; + } + + PyObject_SetItem(dict, name, value); + Py_DECREF(name); + Py_DECREF(value); + } + return dict; +} + +static PyObject* elements_to_dict(PyObject* self, const char* string, + unsigned max, + const codec_options_t* options) { + PyObject* result; + if (options->is_raw_bson) { + return PyObject_CallFunction( + options->document_class, "y#O", + string, max, options->options_obj); + } + if (Py_EnterRecursiveCall(" while decoding a BSON document")) + return NULL; + result = _elements_to_dict(self, string + 4, max - 5, options); + Py_LeaveRecursiveCall(); + return result; +} + +static int _get_buffer(PyObject *exporter, Py_buffer *view) { + if (PyObject_GetBuffer(exporter, view, PyBUF_SIMPLE) == -1) { + return 0; + } + if (!PyBuffer_IsContiguous(view, 'C')) { + PyErr_SetString(PyExc_ValueError, + "must be a contiguous buffer"); + goto fail; + } + if (!view->buf || view->len < 0) { + PyErr_SetString(PyExc_ValueError, "invalid buffer"); + goto fail; + } + if (view->itemsize != 1) { + PyErr_SetString(PyExc_ValueError, + "buffer data must be ascii or utf8"); + goto fail; + } + return 1; +fail: + PyBuffer_Release(view); + return 0; +} + +static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { + int32_t size; + Py_ssize_t total_size; + const char* string; + PyObject* bson; + codec_options_t options; + PyObject* result = NULL; + PyObject* options_obj; + Py_buffer view = {0}; + + if (! (PyArg_ParseTuple(args, "OO", &bson, &options_obj) && + convert_codec_options(self, options_obj, &options))) { + return result; + } + + if (!_get_buffer(bson, &view)) { + destroy_codec_options(&options); + return result; + } + + total_size = view.len; + + if (total_size < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, + "not enough data for a BSON document"); + Py_DECREF(InvalidBSON); + } + goto done;; + } + + string = (char*)view.buf; + memcpy(&size, string, 4); + size = (int32_t)BSON_UINT32_FROM_LE(size); + if (size < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid message size"); + Py_DECREF(InvalidBSON); + } + goto done; + } + + if (total_size < size || total_size > BSON_MAX_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "objsize too large"); + Py_DECREF(InvalidBSON); + } + goto done; + } + + if (size != total_size || string[size - 1]) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "bad eoo"); + Py_DECREF(InvalidBSON); + } + goto done; + } + + result = elements_to_dict(self, string, (unsigned)size, &options); +done: + PyBuffer_Release(&view); + destroy_codec_options(&options); + return result; +} + +static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { + int32_t size; + Py_ssize_t total_size; + const char* string; + PyObject* bson; + PyObject* dict; + PyObject* result = NULL; + codec_options_t options; + PyObject* options_obj = NULL; + Py_buffer view = {0}; + + if (!(PyArg_ParseTuple(args, "OO", &bson, &options_obj) && + convert_codec_options(self, options_obj, &options))) { + return NULL; + } + + if (!_get_buffer(bson, &view)) { + destroy_codec_options(&options); + return NULL; + } + total_size = view.len; + string = (char*)view.buf; + + if (!(result = PyList_New(0))) { + goto fail; + } + + while (total_size > 0) { + if (total_size < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, + "not enough data for a BSON document"); + Py_DECREF(InvalidBSON); + } + Py_DECREF(result); + goto fail; + } + + memcpy(&size, string, 4); + size = (int32_t)BSON_UINT32_FROM_LE(size); + if (size < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid message size"); + Py_DECREF(InvalidBSON); + } + Py_DECREF(result); + goto fail; + } + + if (total_size < size) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "objsize too large"); + Py_DECREF(InvalidBSON); + } + Py_DECREF(result); + goto fail; + } + + if (string[size - 1]) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "bad eoo"); + Py_DECREF(InvalidBSON); + } + Py_DECREF(result); + goto fail; + } + + dict = elements_to_dict(self, string, (unsigned)size, &options); + if (!dict) { + Py_DECREF(result); + goto fail; + } + if (PyList_Append(result, dict) < 0) { + Py_DECREF(dict); + Py_DECREF(result); + goto fail; + } + Py_DECREF(dict); + string += size; + total_size -= size; + } + goto done; +fail: + result = NULL; +done: + PyBuffer_Release(&view); + destroy_codec_options(&options); + return result; +} + + +static PyObject* _cbson_array_of_documents_to_buffer(PyObject* self, PyObject* args) { + uint32_t size; + uint32_t value_length; + uint32_t position = 0; + buffer_t buffer; + const char* string; + PyObject* arr; + PyObject* result = NULL; + Py_buffer view = {0}; + + if (!PyArg_ParseTuple(args, "O", &arr)) { + return NULL; + } + + if (!_get_buffer(arr, &view)) { + return NULL; + } + + buffer = pymongo_buffer_new(); + if (!buffer) { + PyBuffer_Release(&view); + return NULL; + } + + string = (char*)view.buf; + + if (view.len < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, + "not enough data for a BSON document"); + Py_DECREF(InvalidBSON); + } + goto done; + } + + memcpy(&size, string, 4); + size = BSON_UINT32_FROM_LE(size); + /* save space for length */ + if (pymongo_buffer_save_space(buffer, size) == -1) { + goto fail; + } + pymongo_buffer_update_position(buffer, 0); + + position += 4; + while (position < size - 1) { + // Verify the value is an object. + unsigned char type = (unsigned char)string[position]; + if (type != 3) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "array element was not an object"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + // Just skip the keys. + position = position + strlen(string + position) + 1; + + if (position >= size || (size - position) < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid array content"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + memcpy(&value_length, string + position, 4); + value_length = BSON_UINT32_FROM_LE(value_length); + if (value_length < BSON_MIN_SIZE) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "invalid message size"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + if (view.len < size) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "objsize too large"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + if (string[size - 1]) { + PyObject* InvalidBSON = _error("InvalidBSON"); + if (InvalidBSON) { + PyErr_SetString(InvalidBSON, "bad eoo"); + Py_DECREF(InvalidBSON); + } + goto fail; + } + + if (pymongo_buffer_write(buffer, string + position, value_length) == 1) { + goto fail; + } + position += value_length; + } + + /* objectify buffer */ + result = Py_BuildValue("y#", pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer)); + goto done; +fail: + result = NULL; +done: + PyBuffer_Release(&view); + pymongo_buffer_free(buffer); + return result; +} + + +static PyMethodDef _CBSONMethods[] = { + {"_dict_to_bson", _cbson_dict_to_bson, METH_VARARGS, + "convert a dictionary to a string containing its BSON representation."}, + {"_bson_to_dict", _cbson_bson_to_dict, METH_VARARGS, + "convert a BSON string to a SON object."}, + {"_decode_all", _cbson_decode_all, METH_VARARGS, + "convert binary data to a sequence of documents."}, + {"_element_to_dict", _cbson_element_to_dict, METH_VARARGS, + "Decode a single key, value pair."}, + {"_array_of_documents_to_buffer", _cbson_array_of_documents_to_buffer, METH_VARARGS, "Convert raw array of documents to a stream of BSON documents"}, + {"_test_long_long_to_str", _test_long_long_to_str, METH_VARARGS, "Test conversion of extreme and common Py_ssize_t values to str."}, + {NULL, NULL, 0, NULL} +}; + +#define INITERROR return -1; +static int _cbson_traverse(PyObject *m, visitproc visit, void *arg) { + struct module_state *state = GETSTATE(m); + if (!state) { + return 0; + } + Py_VISIT(state->Binary); + Py_VISIT(state->Code); + Py_VISIT(state->ObjectId); + Py_VISIT(state->DBRef); + Py_VISIT(state->Regex); + Py_VISIT(state->UUID); + Py_VISIT(state->Timestamp); + Py_VISIT(state->MinKey); + Py_VISIT(state->MaxKey); + Py_VISIT(state->UTC); + Py_VISIT(state->REType); + Py_VISIT(state->_type_marker_str); + Py_VISIT(state->_flags_str); + Py_VISIT(state->_pattern_str); + Py_VISIT(state->_encoder_map_str); + Py_VISIT(state->_decoder_map_str); + Py_VISIT(state->_fallback_encoder_str); + Py_VISIT(state->_raw_str); + Py_VISIT(state->_subtype_str); + Py_VISIT(state->_binary_str); + Py_VISIT(state->_scope_str); + Py_VISIT(state->_inc_str); + Py_VISIT(state->_time_str); + Py_VISIT(state->_bid_str); + Py_VISIT(state->_replace_str); + Py_VISIT(state->_astimezone_str); + Py_VISIT(state->_id_str); + Py_VISIT(state->_dollar_ref_str); + Py_VISIT(state->_dollar_id_str); + Py_VISIT(state->_dollar_db_str); + Py_VISIT(state->_tzinfo_str); + Py_VISIT(state->_as_doc_str); + Py_VISIT(state->_utcoffset_str); + Py_VISIT(state->_from_uuid_str); + Py_VISIT(state->_as_uuid_str); + Py_VISIT(state->_from_bid_str); + return 0; +} + +static int _cbson_clear(PyObject *m) { + struct module_state *state = GETSTATE(m); + if (!state) { + return 0; + } + Py_CLEAR(state->Binary); + Py_CLEAR(state->Code); + Py_CLEAR(state->ObjectId); + Py_CLEAR(state->DBRef); + Py_CLEAR(state->Regex); + Py_CLEAR(state->UUID); + Py_CLEAR(state->Timestamp); + Py_CLEAR(state->MinKey); + Py_CLEAR(state->MaxKey); + Py_CLEAR(state->UTC); + Py_CLEAR(state->REType); + Py_CLEAR(state->_type_marker_str); + Py_CLEAR(state->_flags_str); + Py_CLEAR(state->_pattern_str); + Py_CLEAR(state->_encoder_map_str); + Py_CLEAR(state->_decoder_map_str); + Py_CLEAR(state->_fallback_encoder_str); + Py_CLEAR(state->_raw_str); + Py_CLEAR(state->_subtype_str); + Py_CLEAR(state->_binary_str); + Py_CLEAR(state->_scope_str); + Py_CLEAR(state->_inc_str); + Py_CLEAR(state->_time_str); + Py_CLEAR(state->_bid_str); + Py_CLEAR(state->_replace_str); + Py_CLEAR(state->_astimezone_str); + Py_CLEAR(state->_id_str); + Py_CLEAR(state->_dollar_ref_str); + Py_CLEAR(state->_dollar_id_str); + Py_CLEAR(state->_dollar_db_str); + Py_CLEAR(state->_tzinfo_str); + Py_CLEAR(state->_as_doc_str); + Py_CLEAR(state->_utcoffset_str); + Py_CLEAR(state->_from_uuid_str); + Py_CLEAR(state->_as_uuid_str); + Py_CLEAR(state->_from_bid_str); + return 0; +} + +/* Multi-phase extension module initialization code. + * See https://peps.python.org/pep-0489/. +*/ +static int +_cbson_exec(PyObject *m) +{ + PyObject *c_api_object; + static void *_cbson_API[_cbson_API_POINTER_COUNT]; + + PyDateTime_IMPORT; + if (PyDateTimeAPI == NULL) { + INITERROR; + } + + /* Export C API */ + _cbson_API[_cbson_buffer_write_bytes_INDEX] = (void *) buffer_write_bytes; + _cbson_API[_cbson_write_dict_INDEX] = (void *) write_dict; + _cbson_API[_cbson_write_pair_INDEX] = (void *) write_pair; + _cbson_API[_cbson_decode_and_write_pair_INDEX] = (void *) decode_and_write_pair; + _cbson_API[_cbson_convert_codec_options_INDEX] = (void *) convert_codec_options; + _cbson_API[_cbson_destroy_codec_options_INDEX] = (void *) destroy_codec_options; + _cbson_API[_cbson_buffer_write_double_INDEX] = (void *) buffer_write_double; + _cbson_API[_cbson_buffer_write_int32_INDEX] = (void *) buffer_write_int32; + _cbson_API[_cbson_buffer_write_int64_INDEX] = (void *) buffer_write_int64; + _cbson_API[_cbson_buffer_write_int32_at_position_INDEX] = + (void *) buffer_write_int32_at_position; + _cbson_API[_cbson_downcast_and_check_INDEX] = (void *) _downcast_and_check; + + c_api_object = PyCapsule_New((void *) _cbson_API, "_cbson._C_API", NULL); + if (c_api_object == NULL) + INITERROR; + + /* Import several python objects */ + if (_load_python_objects(m)) { + Py_DECREF(c_api_object); + Py_DECREF(m); + INITERROR; + } + + if (PyModule_AddObject(m, "_C_API", c_api_object) < 0) { + Py_DECREF(c_api_object); + Py_DECREF(m); + INITERROR; + } + + return 0; +} + +static PyModuleDef_Slot _cbson_slots[] = { + {Py_mod_exec, _cbson_exec}, +#if defined(Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED) + {Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED}, +#endif + {0, NULL}, +}; + + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_cbson", + NULL, + sizeof(struct module_state), + _CBSONMethods, + _cbson_slots, + _cbson_traverse, + _cbson_clear, + NULL +}; + +PyMODINIT_FUNC +PyInit__cbson(void) +{ + return PyModuleDef_Init(&moduledef); +} diff --git a/venv/Lib/site-packages/bson/_cbsonmodule.h b/venv/Lib/site-packages/bson/_cbsonmodule.h new file mode 100644 index 00000000..3be2b744 --- /dev/null +++ b/venv/Lib/site-packages/bson/_cbsonmodule.h @@ -0,0 +1,181 @@ +/* + * Copyright 2009-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bson-endian.h" + +#ifndef _CBSONMODULE_H +#define _CBSONMODULE_H + +#if defined(WIN32) || defined(_MSC_VER) +/* + * This macro is basically an implementation of asprintf for win32 + * We print to the provided buffer to get the string value as an int. + * USE LL2STR. This is kept only to test LL2STR. + */ +#if defined(_MSC_VER) && (_MSC_VER >= 1400) +#define INT2STRING(buffer, i) \ + _snprintf_s((buffer), \ + _scprintf("%lld", (i)) + 1, \ + _scprintf("%lld", (i)) + 1, \ + "%lld", \ + (i)) +#define STRCAT(dest, n, src) strcat_s((dest), (n), (src)) +#else +#define INT2STRING(buffer, i) \ + _snprintf((buffer), \ + _scprintf("%lld", (i)) + 1, \ + "%lld", \ + (i)) +#define STRCAT(dest, n, src) strcat((dest), (src)) +#endif +#else +#define INT2STRING(buffer, i) snprintf((buffer), sizeof((buffer)), "%lld", (i)) +#define STRCAT(dest, n, src) strcat((dest), (src)) +#endif + +/* Just enough space in char array to hold LLONG_MIN and null terminator */ +#define BUF_SIZE 21 +/* Converts integer to its string representation in decimal notation. */ +extern int cbson_long_long_to_str(long long int num, char* str, size_t size); +#define LL2STR(buffer, i) cbson_long_long_to_str((i), (buffer), sizeof(buffer)) + +typedef struct type_registry_t { + PyObject* encoder_map; + PyObject* decoder_map; + PyObject* fallback_encoder; + PyObject* registry_obj; + unsigned char is_encoder_empty; + unsigned char is_decoder_empty; + unsigned char has_fallback_encoder; +} type_registry_t; + +typedef struct codec_options_t { + PyObject* document_class; + unsigned char tz_aware; + unsigned char uuid_rep; + char* unicode_decode_error_handler; + PyObject* tzinfo; + type_registry_t type_registry; + unsigned char datetime_conversion; + PyObject* options_obj; + unsigned char is_raw_bson; +} codec_options_t; + +/* C API functions */ +#define _cbson_buffer_write_bytes_INDEX 0 +#define _cbson_buffer_write_bytes_RETURN int +#define _cbson_buffer_write_bytes_PROTO (buffer_t buffer, const char* data, int size) + +#define _cbson_write_dict_INDEX 1 +#define _cbson_write_dict_RETURN int +#define _cbson_write_dict_PROTO (PyObject* self, buffer_t buffer, PyObject* dict, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) + +#define _cbson_write_pair_INDEX 2 +#define _cbson_write_pair_RETURN int +#define _cbson_write_pair_PROTO (PyObject* self, buffer_t buffer, const char* name, int name_length, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char allow_id) + +#define _cbson_decode_and_write_pair_INDEX 3 +#define _cbson_decode_and_write_pair_RETURN int +#define _cbson_decode_and_write_pair_PROTO (PyObject* self, buffer_t buffer, PyObject* key, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) + +#define _cbson_convert_codec_options_INDEX 4 +#define _cbson_convert_codec_options_RETURN int +#define _cbson_convert_codec_options_PROTO (PyObject* self, PyObject* options_obj, codec_options_t* options) + +#define _cbson_destroy_codec_options_INDEX 5 +#define _cbson_destroy_codec_options_RETURN void +#define _cbson_destroy_codec_options_PROTO (codec_options_t* options) + +#define _cbson_buffer_write_double_INDEX 6 +#define _cbson_buffer_write_double_RETURN int +#define _cbson_buffer_write_double_PROTO (buffer_t buffer, double data) + +#define _cbson_buffer_write_int32_INDEX 7 +#define _cbson_buffer_write_int32_RETURN int +#define _cbson_buffer_write_int32_PROTO (buffer_t buffer, int32_t data) + +#define _cbson_buffer_write_int64_INDEX 8 +#define _cbson_buffer_write_int64_RETURN int +#define _cbson_buffer_write_int64_PROTO (buffer_t buffer, int64_t data) + +#define _cbson_buffer_write_int32_at_position_INDEX 9 +#define _cbson_buffer_write_int32_at_position_RETURN void +#define _cbson_buffer_write_int32_at_position_PROTO (buffer_t buffer, int position, int32_t data) + +#define _cbson_downcast_and_check_INDEX 10 +#define _cbson_downcast_and_check_RETURN int +#define _cbson_downcast_and_check_PROTO (Py_ssize_t size, uint8_t extra) + +/* Total number of C API pointers */ +#define _cbson_API_POINTER_COUNT 11 + +#ifdef _CBSON_MODULE +/* This section is used when compiling _cbsonmodule */ + +static _cbson_buffer_write_bytes_RETURN buffer_write_bytes _cbson_buffer_write_bytes_PROTO; + +static _cbson_write_dict_RETURN write_dict _cbson_write_dict_PROTO; + +static _cbson_write_pair_RETURN write_pair _cbson_write_pair_PROTO; + +static _cbson_decode_and_write_pair_RETURN decode_and_write_pair _cbson_decode_and_write_pair_PROTO; + +static _cbson_convert_codec_options_RETURN convert_codec_options _cbson_convert_codec_options_PROTO; + +static _cbson_destroy_codec_options_RETURN destroy_codec_options _cbson_destroy_codec_options_PROTO; + +static _cbson_buffer_write_double_RETURN buffer_write_double _cbson_buffer_write_double_PROTO; + +static _cbson_buffer_write_int32_RETURN buffer_write_int32 _cbson_buffer_write_int32_PROTO; + +static _cbson_buffer_write_int64_RETURN buffer_write_int64 _cbson_buffer_write_int64_PROTO; + +static _cbson_buffer_write_int32_at_position_RETURN buffer_write_int32_at_position _cbson_buffer_write_int32_at_position_PROTO; + +static _cbson_downcast_and_check_RETURN _downcast_and_check _cbson_downcast_and_check_PROTO; + +#else +/* This section is used in modules that use _cbsonmodule's API */ + +static void **_cbson_API; + +#define buffer_write_bytes (*(_cbson_buffer_write_bytes_RETURN (*)_cbson_buffer_write_bytes_PROTO) _cbson_API[_cbson_buffer_write_bytes_INDEX]) + +#define write_dict (*(_cbson_write_dict_RETURN (*)_cbson_write_dict_PROTO) _cbson_API[_cbson_write_dict_INDEX]) + +#define write_pair (*(_cbson_write_pair_RETURN (*)_cbson_write_pair_PROTO) _cbson_API[_cbson_write_pair_INDEX]) + +#define decode_and_write_pair (*(_cbson_decode_and_write_pair_RETURN (*)_cbson_decode_and_write_pair_PROTO) _cbson_API[_cbson_decode_and_write_pair_INDEX]) + +#define convert_codec_options (*(_cbson_convert_codec_options_RETURN (*)_cbson_convert_codec_options_PROTO) _cbson_API[_cbson_convert_codec_options_INDEX]) + +#define destroy_codec_options (*(_cbson_destroy_codec_options_RETURN (*)_cbson_destroy_codec_options_PROTO) _cbson_API[_cbson_destroy_codec_options_INDEX]) + +#define buffer_write_double (*(_cbson_buffer_write_double_RETURN (*)_cbson_buffer_write_double_PROTO) _cbson_API[_cbson_buffer_write_double_INDEX]) + +#define buffer_write_int32 (*(_cbson_buffer_write_int32_RETURN (*)_cbson_buffer_write_int32_PROTO) _cbson_API[_cbson_buffer_write_int32_INDEX]) + +#define buffer_write_int64 (*(_cbson_buffer_write_int64_RETURN (*)_cbson_buffer_write_int64_PROTO) _cbson_API[_cbson_buffer_write_int64_INDEX]) + +#define buffer_write_int32_at_position (*(_cbson_buffer_write_int32_at_position_RETURN (*)_cbson_buffer_write_int32_at_position_PROTO) _cbson_API[_cbson_buffer_write_int32_at_position_INDEX]) + +#define _downcast_and_check (*(_cbson_downcast_and_check_RETURN (*)_cbson_downcast_and_check_PROTO) _cbson_API[_cbson_downcast_and_check_INDEX]) + +#define _cbson_IMPORT _cbson_API = (void **)PyCapsule_Import("_cbson._C_API", 0) + +#endif + +#endif // _CBSONMODULE_H diff --git a/venv/Lib/site-packages/bson/_helpers.py b/venv/Lib/site-packages/bson/_helpers.py new file mode 100644 index 00000000..5a479867 --- /dev/null +++ b/venv/Lib/site-packages/bson/_helpers.py @@ -0,0 +1,43 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setstate and getstate functions for objects with __slots__, allowing +compatibility with default pickling protocol +""" +from __future__ import annotations + +from typing import Any, Mapping + + +def _setstate_slots(self: Any, state: Any) -> None: + for slot, value in state.items(): + setattr(self, slot, value) + + +def _mangle_name(name: str, prefix: str) -> str: + if name.startswith("__"): + prefix = "_" + prefix + else: + prefix = "" + return prefix + name + + +def _getstate_slots(self: Any) -> Mapping[Any, Any]: + prefix = self.__class__.__name__ + ret = {} + for name in self.__slots__: + mangled_name = _mangle_name(name, prefix) + if hasattr(self, mangled_name): + ret[mangled_name] = getattr(self, mangled_name) + return ret diff --git a/venv/Lib/site-packages/bson/binary.py b/venv/Lib/site-packages/bson/binary.py new file mode 100644 index 00000000..be334644 --- /dev/null +++ b/venv/Lib/site-packages/bson/binary.py @@ -0,0 +1,367 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Tuple, Type, Union +from uuid import UUID + +"""Tools for representing BSON binary data. +""" + +BINARY_SUBTYPE = 0 +"""BSON binary subtype for binary data. + +This is the default subtype for binary data. +""" + +FUNCTION_SUBTYPE = 1 +"""BSON binary subtype for functions. +""" + +OLD_BINARY_SUBTYPE = 2 +"""Old BSON binary subtype for binary data. + +This is the old default subtype, the current +default is :data:`BINARY_SUBTYPE`. +""" + +OLD_UUID_SUBTYPE = 3 +"""Old BSON binary subtype for a UUID. + +:class:`uuid.UUID` instances will automatically be encoded +by :mod:`bson` using this subtype when using +:data:`UuidRepresentation.PYTHON_LEGACY`, +:data:`UuidRepresentation.JAVA_LEGACY`, or +:data:`UuidRepresentation.CSHARP_LEGACY`. + +.. versionadded:: 2.1 +""" + +UUID_SUBTYPE = 4 +"""BSON binary subtype for a UUID. + +This is the standard BSON binary subtype for UUIDs. +:class:`uuid.UUID` instances will automatically be encoded +by :mod:`bson` using this subtype when using +:data:`UuidRepresentation.STANDARD`. +""" + + +if TYPE_CHECKING: + from array import array as _array + from mmap import mmap as _mmap + + +class UuidRepresentation: + UNSPECIFIED = 0 + """An unspecified UUID representation. + + When configured, :class:`uuid.UUID` instances will **not** be + automatically encoded to or decoded from :class:`~bson.binary.Binary`. + When encoding a :class:`uuid.UUID` instance, an error will be raised. + To encode a :class:`uuid.UUID` instance with this configuration, it must + be wrapped in the :class:`~bson.binary.Binary` class by the application + code. When decoding a BSON binary field with a UUID subtype, a + :class:`~bson.binary.Binary` instance will be returned instead of a + :class:`uuid.UUID` instance. + + See :ref:`unspecified-representation-details` for details. + + .. versionadded:: 3.11 + """ + + STANDARD = UUID_SUBTYPE + """The standard UUID representation. + + :class:`uuid.UUID` instances will automatically be encoded to + and decoded from BSON binary, using RFC-4122 byte order with + binary subtype :data:`UUID_SUBTYPE`. + + See :ref:`standard-representation-details` for details. + + .. versionadded:: 3.11 + """ + + PYTHON_LEGACY = OLD_UUID_SUBTYPE + """The Python legacy UUID representation. + + :class:`uuid.UUID` instances will automatically be encoded to + and decoded from BSON binary, using RFC-4122 byte order with + binary subtype :data:`OLD_UUID_SUBTYPE`. + + See :ref:`python-legacy-representation-details` for details. + + .. versionadded:: 3.11 + """ + + JAVA_LEGACY = 5 + """The Java legacy UUID representation. + + :class:`uuid.UUID` instances will automatically be encoded to + and decoded from BSON binary subtype :data:`OLD_UUID_SUBTYPE`, + using the Java driver's legacy byte order. + + See :ref:`java-legacy-representation-details` for details. + + .. versionadded:: 3.11 + """ + + CSHARP_LEGACY = 6 + """The C#/.net legacy UUID representation. + + :class:`uuid.UUID` instances will automatically be encoded to + and decoded from BSON binary subtype :data:`OLD_UUID_SUBTYPE`, + using the C# driver's legacy byte order. + + See :ref:`csharp-legacy-representation-details` for details. + + .. versionadded:: 3.11 + """ + + +STANDARD = UuidRepresentation.STANDARD +"""An alias for :data:`UuidRepresentation.STANDARD`. + +.. versionadded:: 3.0 +""" + +PYTHON_LEGACY = UuidRepresentation.PYTHON_LEGACY +"""An alias for :data:`UuidRepresentation.PYTHON_LEGACY`. + +.. versionadded:: 3.0 +""" + +JAVA_LEGACY = UuidRepresentation.JAVA_LEGACY +"""An alias for :data:`UuidRepresentation.JAVA_LEGACY`. + +.. versionchanged:: 3.6 + BSON binary subtype 4 is decoded using RFC-4122 byte order. +.. versionadded:: 2.3 +""" + +CSHARP_LEGACY = UuidRepresentation.CSHARP_LEGACY +"""An alias for :data:`UuidRepresentation.CSHARP_LEGACY`. + +.. versionchanged:: 3.6 + BSON binary subtype 4 is decoded using RFC-4122 byte order. +.. versionadded:: 2.3 +""" + +ALL_UUID_SUBTYPES = (OLD_UUID_SUBTYPE, UUID_SUBTYPE) +ALL_UUID_REPRESENTATIONS = ( + UuidRepresentation.UNSPECIFIED, + UuidRepresentation.STANDARD, + UuidRepresentation.PYTHON_LEGACY, + UuidRepresentation.JAVA_LEGACY, + UuidRepresentation.CSHARP_LEGACY, +) +UUID_REPRESENTATION_NAMES = { + UuidRepresentation.UNSPECIFIED: "UuidRepresentation.UNSPECIFIED", + UuidRepresentation.STANDARD: "UuidRepresentation.STANDARD", + UuidRepresentation.PYTHON_LEGACY: "UuidRepresentation.PYTHON_LEGACY", + UuidRepresentation.JAVA_LEGACY: "UuidRepresentation.JAVA_LEGACY", + UuidRepresentation.CSHARP_LEGACY: "UuidRepresentation.CSHARP_LEGACY", +} + +MD5_SUBTYPE = 5 +"""BSON binary subtype for an MD5 hash. +""" + +COLUMN_SUBTYPE = 7 +"""BSON binary subtype for columns. + +.. versionadded:: 4.0 +""" + +SENSITIVE_SUBTYPE = 8 +"""BSON binary subtype for sensitive data. + +.. versionadded:: 4.5 +""" + + +USER_DEFINED_SUBTYPE = 128 +"""BSON binary subtype for any user defined structure. +""" + + +class Binary(bytes): + """Representation of BSON binary data. + + This is necessary because we want to represent Python strings as + the BSON string type. We need to wrap binary data so we can tell + the difference between what should be considered binary data and + what should be considered a string when we encode to BSON. + + Raises TypeError if `data` is not an instance of :class:`bytes` + or `subtype` is not an instance of :class:`int`. + Raises ValueError if `subtype` is not in [0, 256). + + .. note:: + Instances of Binary with subtype 0 will be decoded directly to :class:`bytes`. + + :param data: the binary data to represent. Can be any bytes-like type + that implements the buffer protocol. + :param subtype: the `binary subtype + `_ + to use + + .. versionchanged:: 3.9 + Support any bytes-like type that implements the buffer protocol. + """ + + _type_marker = 5 + __subtype: int + + def __new__( + cls: Type[Binary], + data: Union[memoryview, bytes, _mmap, _array[Any]], + subtype: int = BINARY_SUBTYPE, + ) -> Binary: + if not isinstance(subtype, int): + raise TypeError("subtype must be an instance of int") + if subtype >= 256 or subtype < 0: + raise ValueError("subtype must be contained in [0, 256)") + # Support any type that implements the buffer protocol. + self = bytes.__new__(cls, memoryview(data).tobytes()) + self.__subtype = subtype + return self + + @classmethod + def from_uuid( + cls: Type[Binary], uuid: UUID, uuid_representation: int = UuidRepresentation.STANDARD + ) -> Binary: + """Create a BSON Binary object from a Python UUID. + + Creates a :class:`~bson.binary.Binary` object from a + :class:`uuid.UUID` instance. Assumes that the native + :class:`uuid.UUID` instance uses the byte-order implied by the + provided ``uuid_representation``. + + Raises :exc:`TypeError` if `uuid` is not an instance of + :class:`~uuid.UUID`. + + :param uuid: A :class:`uuid.UUID` instance. + :param uuid_representation: A member of + :class:`~bson.binary.UuidRepresentation`. Default: + :const:`~bson.binary.UuidRepresentation.STANDARD`. + See :ref:`handling-uuid-data-example` for details. + + .. versionadded:: 3.11 + """ + if not isinstance(uuid, UUID): + raise TypeError("uuid must be an instance of uuid.UUID") + + if uuid_representation not in ALL_UUID_REPRESENTATIONS: + raise ValueError( + "uuid_representation must be a value from bson.binary.UuidRepresentation" + ) + + if uuid_representation == UuidRepresentation.UNSPECIFIED: + raise ValueError( + "cannot encode native uuid.UUID with " + "UuidRepresentation.UNSPECIFIED. UUIDs can be manually " + "converted to bson.Binary instances using " + "bson.Binary.from_uuid() or a different UuidRepresentation " + "can be configured. See the documentation for " + "UuidRepresentation for more information." + ) + + subtype = OLD_UUID_SUBTYPE + if uuid_representation == UuidRepresentation.PYTHON_LEGACY: + payload = uuid.bytes + elif uuid_representation == UuidRepresentation.JAVA_LEGACY: + from_uuid = uuid.bytes + payload = from_uuid[0:8][::-1] + from_uuid[8:16][::-1] + elif uuid_representation == UuidRepresentation.CSHARP_LEGACY: + payload = uuid.bytes_le + else: + # uuid_representation == UuidRepresentation.STANDARD + subtype = UUID_SUBTYPE + payload = uuid.bytes + + return cls(payload, subtype) + + def as_uuid(self, uuid_representation: int = UuidRepresentation.STANDARD) -> UUID: + """Create a Python UUID from this BSON Binary object. + + Decodes this binary object as a native :class:`uuid.UUID` instance + with the provided ``uuid_representation``. + + Raises :exc:`ValueError` if this :class:`~bson.binary.Binary` instance + does not contain a UUID. + + :param uuid_representation: A member of + :class:`~bson.binary.UuidRepresentation`. Default: + :const:`~bson.binary.UuidRepresentation.STANDARD`. + See :ref:`handling-uuid-data-example` for details. + + .. versionadded:: 3.11 + """ + if self.subtype not in ALL_UUID_SUBTYPES: + raise ValueError(f"cannot decode subtype {self.subtype} as a uuid") + + if uuid_representation not in ALL_UUID_REPRESENTATIONS: + raise ValueError( + "uuid_representation must be a value from bson.binary.UuidRepresentation" + ) + + if uuid_representation == UuidRepresentation.UNSPECIFIED: + raise ValueError("uuid_representation cannot be UNSPECIFIED") + elif uuid_representation == UuidRepresentation.PYTHON_LEGACY: + if self.subtype == OLD_UUID_SUBTYPE: + return UUID(bytes=self) + elif uuid_representation == UuidRepresentation.JAVA_LEGACY: + if self.subtype == OLD_UUID_SUBTYPE: + return UUID(bytes=self[0:8][::-1] + self[8:16][::-1]) + elif uuid_representation == UuidRepresentation.CSHARP_LEGACY: + if self.subtype == OLD_UUID_SUBTYPE: + return UUID(bytes_le=self) + else: + # uuid_representation == UuidRepresentation.STANDARD + if self.subtype == UUID_SUBTYPE: + return UUID(bytes=self) + + raise ValueError( + f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}" + ) + + @property + def subtype(self) -> int: + """Subtype of this binary data.""" + return self.__subtype + + def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override] + # Work around http://bugs.python.org/issue7382 + data = super().__getnewargs__()[0] + if not isinstance(data, bytes): + data = data.encode("latin-1") + return data, self.__subtype + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Binary): + return (self.__subtype, bytes(self)) == (other.subtype, bytes(other)) + # We don't return NotImplemented here because if we did then + # Binary("foo") == "foo" would return True, since Binary is a + # subclass of str... + return False + + def __hash__(self) -> int: + return super().__hash__() ^ hash(self.__subtype) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + return f"Binary({bytes.__repr__(self)}, {self.__subtype})" diff --git a/venv/Lib/site-packages/bson/bson-endian.h b/venv/Lib/site-packages/bson/bson-endian.h new file mode 100644 index 00000000..e906b077 --- /dev/null +++ b/venv/Lib/site-packages/bson/bson-endian.h @@ -0,0 +1,233 @@ +/* + * Copyright 2013-2016 MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef BSON_ENDIAN_H +#define BSON_ENDIAN_H + + +#if defined(__sun) +# include +#endif + + +#ifdef _MSC_VER +# define BSON_INLINE __inline +#else +# include +# define BSON_INLINE __inline__ +#endif + + +#define BSON_BIG_ENDIAN 4321 +#define BSON_LITTLE_ENDIAN 1234 + + +/* WORDS_BIGENDIAN from pyconfig.h / Python.h */ +#ifdef WORDS_BIGENDIAN +# define BSON_BYTE_ORDER BSON_BIG_ENDIAN +#else +# define BSON_BYTE_ORDER BSON_LITTLE_ENDIAN +#endif + + +#if defined(__sun) +# define BSON_UINT16_SWAP_LE_BE(v) BSWAP_16((uint16_t)v) +# define BSON_UINT32_SWAP_LE_BE(v) BSWAP_32((uint32_t)v) +# define BSON_UINT64_SWAP_LE_BE(v) BSWAP_64((uint64_t)v) +#elif defined(__clang__) && defined(__clang_major__) && defined(__clang_minor__) && \ + (__clang_major__ >= 3) && (__clang_minor__ >= 1) +# if __has_builtin(__builtin_bswap16) +# define BSON_UINT16_SWAP_LE_BE(v) __builtin_bswap16(v) +# endif +# if __has_builtin(__builtin_bswap32) +# define BSON_UINT32_SWAP_LE_BE(v) __builtin_bswap32(v) +# endif +# if __has_builtin(__builtin_bswap64) +# define BSON_UINT64_SWAP_LE_BE(v) __builtin_bswap64(v) +# endif +#elif defined(__GNUC__) && (__GNUC__ >= 4) +# if __GNUC__ >= 4 && defined (__GNUC_MINOR__) && __GNUC_MINOR__ >= 3 +# define BSON_UINT32_SWAP_LE_BE(v) __builtin_bswap32 ((uint32_t)v) +# define BSON_UINT64_SWAP_LE_BE(v) __builtin_bswap64 ((uint64_t)v) +# endif +# if __GNUC__ >= 4 && defined (__GNUC_MINOR__) && __GNUC_MINOR__ >= 8 +# define BSON_UINT16_SWAP_LE_BE(v) __builtin_bswap16 ((uint32_t)v) +# endif +#endif + + +#ifndef BSON_UINT16_SWAP_LE_BE +# define BSON_UINT16_SWAP_LE_BE(v) __bson_uint16_swap_slow ((uint16_t)v) +#endif + + +#ifndef BSON_UINT32_SWAP_LE_BE +# define BSON_UINT32_SWAP_LE_BE(v) __bson_uint32_swap_slow ((uint32_t)v) +#endif + + +#ifndef BSON_UINT64_SWAP_LE_BE +# define BSON_UINT64_SWAP_LE_BE(v) __bson_uint64_swap_slow ((uint64_t)v) +#endif + + +#if BSON_BYTE_ORDER == BSON_LITTLE_ENDIAN +# define BSON_UINT16_FROM_LE(v) ((uint16_t)v) +# define BSON_UINT16_TO_LE(v) ((uint16_t)v) +# define BSON_UINT16_FROM_BE(v) BSON_UINT16_SWAP_LE_BE (v) +# define BSON_UINT16_TO_BE(v) BSON_UINT16_SWAP_LE_BE (v) +# define BSON_UINT32_FROM_LE(v) ((uint32_t)v) +# define BSON_UINT32_TO_LE(v) ((uint32_t)v) +# define BSON_UINT32_FROM_BE(v) BSON_UINT32_SWAP_LE_BE (v) +# define BSON_UINT32_TO_BE(v) BSON_UINT32_SWAP_LE_BE (v) +# define BSON_UINT64_FROM_LE(v) ((uint64_t)v) +# define BSON_UINT64_TO_LE(v) ((uint64_t)v) +# define BSON_UINT64_FROM_BE(v) BSON_UINT64_SWAP_LE_BE (v) +# define BSON_UINT64_TO_BE(v) BSON_UINT64_SWAP_LE_BE (v) +# define BSON_DOUBLE_FROM_LE(v) ((double)v) +# define BSON_DOUBLE_TO_LE(v) ((double)v) +#elif BSON_BYTE_ORDER == BSON_BIG_ENDIAN +# define BSON_UINT16_FROM_LE(v) BSON_UINT16_SWAP_LE_BE (v) +# define BSON_UINT16_TO_LE(v) BSON_UINT16_SWAP_LE_BE (v) +# define BSON_UINT16_FROM_BE(v) ((uint16_t)v) +# define BSON_UINT16_TO_BE(v) ((uint16_t)v) +# define BSON_UINT32_FROM_LE(v) BSON_UINT32_SWAP_LE_BE (v) +# define BSON_UINT32_TO_LE(v) BSON_UINT32_SWAP_LE_BE (v) +# define BSON_UINT32_FROM_BE(v) ((uint32_t)v) +# define BSON_UINT32_TO_BE(v) ((uint32_t)v) +# define BSON_UINT64_FROM_LE(v) BSON_UINT64_SWAP_LE_BE (v) +# define BSON_UINT64_TO_LE(v) BSON_UINT64_SWAP_LE_BE (v) +# define BSON_UINT64_FROM_BE(v) ((uint64_t)v) +# define BSON_UINT64_TO_BE(v) ((uint64_t)v) +# define BSON_DOUBLE_FROM_LE(v) (__bson_double_swap_slow (v)) +# define BSON_DOUBLE_TO_LE(v) (__bson_double_swap_slow (v)) +#else +# error "The endianness of target architecture is unknown." +#endif + + +/* + *-------------------------------------------------------------------------- + * + * __bson_uint16_swap_slow -- + * + * Fallback endianness conversion for 16-bit integers. + * + * Returns: + * The endian swapped version. + * + * Side effects: + * None. + * + *-------------------------------------------------------------------------- + */ + +static BSON_INLINE uint16_t +__bson_uint16_swap_slow (uint16_t v) /* IN */ +{ + return ((v & 0x00FF) << 8) | + ((v & 0xFF00) >> 8); +} + + +/* + *-------------------------------------------------------------------------- + * + * __bson_uint32_swap_slow -- + * + * Fallback endianness conversion for 32-bit integers. + * + * Returns: + * The endian swapped version. + * + * Side effects: + * None. + * + *-------------------------------------------------------------------------- + */ + +static BSON_INLINE uint32_t +__bson_uint32_swap_slow (uint32_t v) /* IN */ +{ + return ((v & 0x000000FFU) << 24) | + ((v & 0x0000FF00U) << 8) | + ((v & 0x00FF0000U) >> 8) | + ((v & 0xFF000000U) >> 24); +} + + +/* + *-------------------------------------------------------------------------- + * + * __bson_uint64_swap_slow -- + * + * Fallback endianness conversion for 64-bit integers. + * + * Returns: + * The endian swapped version. + * + * Side effects: + * None. + * + *-------------------------------------------------------------------------- + */ + +static BSON_INLINE uint64_t +__bson_uint64_swap_slow (uint64_t v) /* IN */ +{ + return ((v & 0x00000000000000FFULL) << 56) | + ((v & 0x000000000000FF00ULL) << 40) | + ((v & 0x0000000000FF0000ULL) << 24) | + ((v & 0x00000000FF000000ULL) << 8) | + ((v & 0x000000FF00000000ULL) >> 8) | + ((v & 0x0000FF0000000000ULL) >> 24) | + ((v & 0x00FF000000000000ULL) >> 40) | + ((v & 0xFF00000000000000ULL) >> 56); +} + + +/* + *-------------------------------------------------------------------------- + * + * __bson_double_swap_slow -- + * + * Fallback endianness conversion for double floating point. + * + * Returns: + * The endian swapped version. + * + * Side effects: + * None. + * + *-------------------------------------------------------------------------- + */ + + +static BSON_INLINE double +__bson_double_swap_slow (double v) /* IN */ +{ + uint64_t uv; + + memcpy(&uv, &v, sizeof(v)); + uv = BSON_UINT64_SWAP_LE_BE(uv); + memcpy(&v, &uv, sizeof(v)); + + return v; +} + + +#endif /* BSON_ENDIAN_H */ diff --git a/venv/Lib/site-packages/bson/buffer.c b/venv/Lib/site-packages/bson/buffer.c new file mode 100644 index 00000000..cc752027 --- /dev/null +++ b/venv/Lib/site-packages/bson/buffer.c @@ -0,0 +1,157 @@ +/* + * Copyright 2009-2015 MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Include Python.h so we can set Python's error indicator. */ +#define PY_SSIZE_T_CLEAN +#include "Python.h" + +#include +#include + +#include "buffer.h" + +#define INITIAL_BUFFER_SIZE 256 + +struct buffer { + char* buffer; + int size; + int position; +}; + +/* Set Python's error indicator to MemoryError. + * Called after allocation failures. */ +static void set_memory_error(void) { + PyErr_NoMemory(); +} + +/* Allocate and return a new buffer. + * Return NULL and sets MemoryError on allocation failure. */ +buffer_t pymongo_buffer_new(void) { + buffer_t buffer; + buffer = (buffer_t)malloc(sizeof(struct buffer)); + if (buffer == NULL) { + set_memory_error(); + return NULL; + } + + buffer->size = INITIAL_BUFFER_SIZE; + buffer->position = 0; + buffer->buffer = (char*)malloc(sizeof(char) * INITIAL_BUFFER_SIZE); + if (buffer->buffer == NULL) { + free(buffer); + set_memory_error(); + return NULL; + } + + return buffer; +} + +/* Free the memory allocated for `buffer`. + * Return non-zero on failure. */ +int pymongo_buffer_free(buffer_t buffer) { + if (buffer == NULL) { + return 1; + } + /* Buffer will be NULL when buffer_grow fails. */ + if (buffer->buffer != NULL) { + free(buffer->buffer); + } + free(buffer); + return 0; +} + +/* Grow `buffer` to at least `min_length`. + * Return non-zero and sets MemoryError on allocation failure. */ +static int buffer_grow(buffer_t buffer, int min_length) { + int old_size = 0; + int size = buffer->size; + char* old_buffer = buffer->buffer; + if (size >= min_length) { + return 0; + } + while (size < min_length) { + old_size = size; + size *= 2; + if (size <= old_size) { + /* Size did not increase. Could be an overflow + * or size < 1. Just go with min_length. */ + size = min_length; + } + } + buffer->buffer = (char*)realloc(buffer->buffer, sizeof(char) * size); + if (buffer->buffer == NULL) { + free(old_buffer); + set_memory_error(); + return 1; + } + buffer->size = size; + return 0; +} + +/* Assure that `buffer` has at least `size` free bytes (and grow if needed). + * Return non-zero and sets MemoryError on allocation failure. + * Return non-zero and sets ValueError if `size` would exceed 2GiB. */ +static int buffer_assure_space(buffer_t buffer, int size) { + int new_size = buffer->position + size; + /* Check for overflow. */ + if (new_size < buffer->position) { + PyErr_SetString(PyExc_ValueError, + "Document would overflow BSON size limit"); + return 1; + } + + if (new_size <= buffer->size) { + return 0; + } + return buffer_grow(buffer, new_size); +} + +/* Save `size` bytes from the current position in `buffer` (and grow if needed). + * Return offset for writing, or -1 on failure. + * Sets MemoryError or ValueError on failure. */ +buffer_position pymongo_buffer_save_space(buffer_t buffer, int size) { + int position = buffer->position; + if (buffer_assure_space(buffer, size) != 0) { + return -1; + } + buffer->position += size; + return position; +} + +/* Write `size` bytes from `data` to `buffer` (and grow if needed). + * Return non-zero on failure. + * Sets MemoryError or ValueError on failure. */ +int pymongo_buffer_write(buffer_t buffer, const char* data, int size) { + if (buffer_assure_space(buffer, size) != 0) { + return 1; + } + + memcpy(buffer->buffer + buffer->position, data, size); + buffer->position += size; + return 0; +} + +int pymongo_buffer_get_position(buffer_t buffer) { + return buffer->position; +} + +char* pymongo_buffer_get_buffer(buffer_t buffer) { + return buffer->buffer; +} + +void pymongo_buffer_update_position(buffer_t buffer, buffer_position new_position) { + buffer->position = new_position; +} diff --git a/venv/Lib/site-packages/bson/buffer.h b/venv/Lib/site-packages/bson/buffer.h new file mode 100644 index 00000000..a78e34e4 --- /dev/null +++ b/venv/Lib/site-packages/bson/buffer.h @@ -0,0 +1,51 @@ +/* + * Copyright 2009-2015 MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef BUFFER_H +#define BUFFER_H + +/* Note: if any of these functions return a failure condition then the buffer + * has already been freed. */ + +/* A buffer */ +typedef struct buffer* buffer_t; +/* A position in the buffer */ +typedef int buffer_position; + +/* Allocate and return a new buffer. + * Return NULL on allocation failure. */ +buffer_t pymongo_buffer_new(void); + +/* Free the memory allocated for `buffer`. + * Return non-zero on failure. */ +int pymongo_buffer_free(buffer_t buffer); + +/* Save `size` bytes from the current position in `buffer` (and grow if needed). + * Return offset for writing, or -1 on allocation failure. */ +buffer_position pymongo_buffer_save_space(buffer_t buffer, int size); + +/* Write `size` bytes from `data` to `buffer` (and grow if needed). + * Return non-zero on allocation failure. */ +int pymongo_buffer_write(buffer_t buffer, const char* data, int size); + +/* Getters for the internals of a buffer_t. + * Should try to avoid using these as much as possible + * since they break the abstraction. */ +buffer_position pymongo_buffer_get_position(buffer_t buffer); +char* pymongo_buffer_get_buffer(buffer_t buffer); +void pymongo_buffer_update_position(buffer_t buffer, buffer_position new_position); + +#endif diff --git a/venv/Lib/site-packages/bson/code.py b/venv/Lib/site-packages/bson/code.py new file mode 100644 index 00000000..6b4541d0 --- /dev/null +++ b/venv/Lib/site-packages/bson/code.py @@ -0,0 +1,100 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing JavaScript code in BSON.""" +from __future__ import annotations + +from collections.abc import Mapping as _Mapping +from typing import Any, Mapping, Optional, Type, Union + + +class Code(str): + """BSON's JavaScript code type. + + Raises :class:`TypeError` if `code` is not an instance of + :class:`str` or `scope` is not ``None`` or an instance + of :class:`dict`. + + Scope variables can be set by passing a dictionary as the `scope` + argument or by using keyword arguments. If a variable is set as a + keyword argument it will override any setting for that variable in + the `scope` dictionary. + + :param code: A string containing JavaScript code to be evaluated or another + instance of Code. In the latter case, the scope of `code` becomes this + Code's :attr:`scope`. + :param scope: dictionary representing the scope in which + `code` should be evaluated - a mapping from identifiers (as + strings) to values. Defaults to ``None``. This is applied after any + scope associated with a given `code` above. + :param kwargs: scope variables can also be passed as + keyword arguments. These are applied after `scope` and `code`. + + .. versionchanged:: 3.4 + The default value for :attr:`scope` is ``None`` instead of ``{}``. + + """ + + _type_marker = 13 + __scope: Union[Mapping[str, Any], None] + + def __new__( + cls: Type[Code], + code: Union[str, Code], + scope: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> Code: + if not isinstance(code, str): + raise TypeError("code must be an instance of str") + + self = str.__new__(cls, code) + + try: + self.__scope = code.scope # type: ignore + except AttributeError: + self.__scope = None + + if scope is not None: + if not isinstance(scope, _Mapping): + raise TypeError("scope must be an instance of dict") + if self.__scope is not None: + self.__scope.update(scope) # type: ignore + else: + self.__scope = scope + + if kwargs: + if self.__scope is not None: + self.__scope.update(kwargs) # type: ignore + else: + self.__scope = kwargs + + return self + + @property + def scope(self) -> Optional[Mapping[str, Any]]: + """Scope dictionary for this instance or ``None``.""" + return self.__scope + + def __repr__(self) -> str: + return f"Code({str.__repr__(self)}, {self.__scope!r})" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Code): + return (self.__scope, str(self)) == (other.__scope, str(other)) + return False + + __hash__: Any = None + + def __ne__(self, other: Any) -> bool: + return not self == other diff --git a/venv/Lib/site-packages/bson/codec_options.py b/venv/Lib/site-packages/bson/codec_options.py new file mode 100644 index 00000000..3a0b83b7 --- /dev/null +++ b/venv/Lib/site-packages/bson/codec_options.py @@ -0,0 +1,505 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for specifying BSON codec options.""" +from __future__ import annotations + +import abc +import datetime +import enum +from collections.abc import MutableMapping as _MutableMapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Iterable, + Mapping, + NamedTuple, + Optional, + Tuple, + Type, + Union, + cast, +) + +from bson.binary import ( + ALL_UUID_REPRESENTATIONS, + UUID_REPRESENTATION_NAMES, + UuidRepresentation, +) +from bson.typings import _DocumentType + +_RAW_BSON_DOCUMENT_MARKER = 101 + + +def _raw_document_class(document_class: Any) -> bool: + """Determine if a document_class is a RawBSONDocument class.""" + marker = getattr(document_class, "_type_marker", None) + return marker == _RAW_BSON_DOCUMENT_MARKER + + +class TypeEncoder(abc.ABC): + """Base class for defining type codec classes which describe how a + custom type can be transformed to one of the types BSON understands. + + Codec classes must implement the ``python_type`` attribute, and the + ``transform_python`` method to support encoding. + + See :ref:`custom-type-type-codec` documentation for an example. + """ + + @abc.abstractproperty + def python_type(self) -> Any: + """The Python type to be converted into something serializable.""" + + @abc.abstractmethod + def transform_python(self, value: Any) -> Any: + """Convert the given Python object into something serializable.""" + + +class TypeDecoder(abc.ABC): + """Base class for defining type codec classes which describe how a + BSON type can be transformed to a custom type. + + Codec classes must implement the ``bson_type`` attribute, and the + ``transform_bson`` method to support decoding. + + See :ref:`custom-type-type-codec` documentation for an example. + """ + + @abc.abstractproperty + def bson_type(self) -> Any: + """The BSON type to be converted into our own type.""" + + @abc.abstractmethod + def transform_bson(self, value: Any) -> Any: + """Convert the given BSON value into our own type.""" + + +class TypeCodec(TypeEncoder, TypeDecoder): + """Base class for defining type codec classes which describe how a + custom type can be transformed to/from one of the types :mod:`bson` + can already encode/decode. + + Codec classes must implement the ``python_type`` attribute, and the + ``transform_python`` method to support encoding, as well as the + ``bson_type`` attribute, and the ``transform_bson`` method to support + decoding. + + See :ref:`custom-type-type-codec` documentation for an example. + """ + + +_Codec = Union[TypeEncoder, TypeDecoder, TypeCodec] +_Fallback = Callable[[Any], Any] + + +class TypeRegistry: + """Encapsulates type codecs used in encoding and / or decoding BSON, as + well as the fallback encoder. Type registries cannot be modified after + instantiation. + + ``TypeRegistry`` can be initialized with an iterable of type codecs, and + a callable for the fallback encoder:: + + >>> from bson.codec_options import TypeRegistry + >>> type_registry = TypeRegistry([Codec1, Codec2, Codec3, ...], + ... fallback_encoder) + + See :ref:`custom-type-type-registry` documentation for an example. + + :param type_codecs: iterable of type codec instances. If + ``type_codecs`` contains multiple codecs that transform a single + python or BSON type, the transformation specified by the type codec + occurring last prevails. A TypeError will be raised if one or more + type codecs modify the encoding behavior of a built-in :mod:`bson` + type. + :param fallback_encoder: callable that accepts a single, + unencodable python value and transforms it into a type that + :mod:`bson` can encode. See :ref:`fallback-encoder-callable` + documentation for an example. + """ + + def __init__( + self, + type_codecs: Optional[Iterable[_Codec]] = None, + fallback_encoder: Optional[_Fallback] = None, + ) -> None: + self.__type_codecs = list(type_codecs or []) + self._fallback_encoder = fallback_encoder + self._encoder_map: dict[Any, Any] = {} + self._decoder_map: dict[Any, Any] = {} + + if self._fallback_encoder is not None: + if not callable(fallback_encoder): + raise TypeError("fallback_encoder %r is not a callable" % (fallback_encoder)) + + for codec in self.__type_codecs: + is_valid_codec = False + if isinstance(codec, TypeEncoder): + self._validate_type_encoder(codec) + is_valid_codec = True + self._encoder_map[codec.python_type] = codec.transform_python + if isinstance(codec, TypeDecoder): + is_valid_codec = True + self._decoder_map[codec.bson_type] = codec.transform_bson + if not is_valid_codec: + raise TypeError( + f"Expected an instance of {TypeEncoder.__name__}, {TypeDecoder.__name__}, or {TypeCodec.__name__}, got {codec!r} instead" + ) + + def _validate_type_encoder(self, codec: _Codec) -> None: + from bson import _BUILT_IN_TYPES + + for pytype in _BUILT_IN_TYPES: + if issubclass(cast(TypeCodec, codec).python_type, pytype): + err_msg = ( + "TypeEncoders cannot change how built-in types are " + f"encoded (encoder {codec} transforms type {pytype})" + ) + raise TypeError(err_msg) + + def __repr__(self) -> str: + return "{}(type_codecs={!r}, fallback_encoder={!r})".format( + self.__class__.__name__, + self.__type_codecs, + self._fallback_encoder, + ) + + def __eq__(self, other: Any) -> Any: + if not isinstance(other, type(self)): + return NotImplemented + return ( + (self._decoder_map == other._decoder_map) + and (self._encoder_map == other._encoder_map) + and (self._fallback_encoder == other._fallback_encoder) + ) + + +class DatetimeConversion(int, enum.Enum): + """Options for decoding BSON datetimes.""" + + DATETIME = 1 + """Decode a BSON UTC datetime as a :class:`datetime.datetime`. + + BSON UTC datetimes that cannot be represented as a + :class:`~datetime.datetime` will raise an :class:`OverflowError` + or a :class:`ValueError`. + + .. versionadded 4.3 + """ + + DATETIME_CLAMP = 2 + """Decode a BSON UTC datetime as a :class:`datetime.datetime`, clamping + to :attr:`~datetime.datetime.min` and :attr:`~datetime.datetime.max`. + + .. versionadded 4.3 + """ + + DATETIME_MS = 3 + """Decode a BSON UTC datetime as a :class:`~bson.datetime_ms.DatetimeMS` + object. + + .. versionadded 4.3 + """ + + DATETIME_AUTO = 4 + """Decode a BSON UTC datetime as a :class:`datetime.datetime` if possible, + and a :class:`~bson.datetime_ms.DatetimeMS` if not. + + .. versionadded 4.3 + """ + + +class _BaseCodecOptions(NamedTuple): + document_class: Type[Mapping[str, Any]] + tz_aware: bool + uuid_representation: int + unicode_decode_error_handler: str + tzinfo: Optional[datetime.tzinfo] + type_registry: TypeRegistry + datetime_conversion: Optional[DatetimeConversion] + + +if TYPE_CHECKING: + + class CodecOptions(Tuple[_DocumentType], Generic[_DocumentType]): + document_class: Type[_DocumentType] + tz_aware: bool + uuid_representation: int + unicode_decode_error_handler: Optional[str] + tzinfo: Optional[datetime.tzinfo] + type_registry: TypeRegistry + datetime_conversion: Optional[int] + + def __new__( + cls: Type[CodecOptions[_DocumentType]], + document_class: Optional[Type[_DocumentType]] = ..., + tz_aware: bool = ..., + uuid_representation: Optional[int] = ..., + unicode_decode_error_handler: Optional[str] = ..., + tzinfo: Optional[datetime.tzinfo] = ..., + type_registry: Optional[TypeRegistry] = ..., + datetime_conversion: Optional[int] = ..., + ) -> CodecOptions[_DocumentType]: + ... + + # CodecOptions API + def with_options(self, **kwargs: Any) -> CodecOptions[Any]: + ... + + def _arguments_repr(self) -> str: + ... + + def _options_dict(self) -> dict[Any, Any]: + ... + + # NamedTuple API + @classmethod + def _make(cls, obj: Iterable[Any]) -> CodecOptions[_DocumentType]: + ... + + def _asdict(self) -> dict[str, Any]: + ... + + def _replace(self, **kwargs: Any) -> CodecOptions[_DocumentType]: + ... + + _source: str + _fields: Tuple[str] + +else: + + class CodecOptions(_BaseCodecOptions): + """Encapsulates options used encoding and / or decoding BSON.""" + + def __init__(self, *args, **kwargs): + """Encapsulates options used encoding and / or decoding BSON. + + The `document_class` option is used to define a custom type for use + decoding BSON documents. Access to the underlying raw BSON bytes for + a document is available using the :class:`~bson.raw_bson.RawBSONDocument` + type:: + + >>> from bson.raw_bson import RawBSONDocument + >>> from bson.codec_options import CodecOptions + >>> codec_options = CodecOptions(document_class=RawBSONDocument) + >>> coll = db.get_collection('test', codec_options=codec_options) + >>> doc = coll.find_one() + >>> doc.raw + '\\x16\\x00\\x00\\x00\\x07_id\\x00[0\\x165\\x91\\x10\\xea\\x14\\xe8\\xc5\\x8b\\x93\\x00' + + The document class can be any type that inherits from + :class:`~collections.abc.MutableMapping`:: + + >>> class AttributeDict(dict): + ... # A dict that supports attribute access. + ... def __getattr__(self, key): + ... return self[key] + ... def __setattr__(self, key, value): + ... self[key] = value + ... + >>> codec_options = CodecOptions(document_class=AttributeDict) + >>> coll = db.get_collection('test', codec_options=codec_options) + >>> doc = coll.find_one() + >>> doc._id + ObjectId('5b3016359110ea14e8c58b93') + + See :doc:`/examples/datetimes` for examples using the `tz_aware` and + `tzinfo` options. + + See :doc:`/examples/uuid` for examples using the `uuid_representation` + option. + + :param document_class: BSON documents returned in queries will be decoded + to an instance of this class. Must be a subclass of + :class:`~collections.abc.MutableMapping`. Defaults to :class:`dict`. + :param tz_aware: If ``True``, BSON datetimes will be decoded to timezone + aware instances of :class:`~datetime.datetime`. Otherwise they will be + naive. Defaults to ``False``. + :param uuid_representation: The BSON representation to use when encoding + and decoding instances of :class:`~uuid.UUID`. Defaults to + :data:`~bson.binary.UuidRepresentation.UNSPECIFIED`. New + applications should consider setting this to + :data:`~bson.binary.UuidRepresentation.STANDARD` for cross language + compatibility. See :ref:`handling-uuid-data-example` for details. + :param unicode_decode_error_handler: The error handler to apply when + a Unicode-related error occurs during BSON decoding that would + otherwise raise :exc:`UnicodeDecodeError`. Valid options include + 'strict', 'replace', 'backslashreplace', 'surrogateescape', and + 'ignore'. Defaults to 'strict'. + :param tzinfo: A :class:`~datetime.tzinfo` subclass that specifies the + timezone to/from which :class:`~datetime.datetime` objects should be + encoded/decoded. + :param type_registry: Instance of :class:`TypeRegistry` used to customize + encoding and decoding behavior. + :param datetime_conversion: Specifies how UTC datetimes should be decoded + within BSON. Valid options include 'datetime_ms' to return as a + DatetimeMS, 'datetime' to return as a datetime.datetime and + raising a ValueError for out-of-range values, 'datetime_auto' to + return DatetimeMS objects when the underlying datetime is + out-of-range and 'datetime_clamp' to clamp to the minimum and + maximum possible datetimes. Defaults to 'datetime'. + + .. versionchanged:: 4.0 + The default for `uuid_representation` was changed from + :const:`~bson.binary.UuidRepresentation.PYTHON_LEGACY` to + :const:`~bson.binary.UuidRepresentation.UNSPECIFIED`. + + .. versionadded:: 3.8 + `type_registry` attribute. + + .. warning:: Care must be taken when changing + `unicode_decode_error_handler` from its default value ('strict'). + The 'replace' and 'ignore' modes should not be used when documents + retrieved from the server will be modified in the client application + and stored back to the server. + """ + super().__init__() + + def __new__( + cls: Type[CodecOptions], + document_class: Optional[Type[Mapping[str, Any]]] = None, + tz_aware: bool = False, + uuid_representation: Optional[int] = UuidRepresentation.UNSPECIFIED, + unicode_decode_error_handler: str = "strict", + tzinfo: Optional[datetime.tzinfo] = None, + type_registry: Optional[TypeRegistry] = None, + datetime_conversion: Optional[DatetimeConversion] = DatetimeConversion.DATETIME, + ) -> CodecOptions: + doc_class = document_class or dict + # issubclass can raise TypeError for generic aliases like SON[str, Any]. + # In that case we can use the base class for the comparison. + is_mapping = False + try: + is_mapping = issubclass(doc_class, _MutableMapping) + except TypeError: + if hasattr(doc_class, "__origin__"): + is_mapping = issubclass(doc_class.__origin__, _MutableMapping) + if not (is_mapping or _raw_document_class(doc_class)): + raise TypeError( + "document_class must be dict, bson.son.SON, " + "bson.raw_bson.RawBSONDocument, or a " + "subclass of collections.abc.MutableMapping" + ) + if not isinstance(tz_aware, bool): + raise TypeError(f"tz_aware must be True or False, was: tz_aware={tz_aware}") + if uuid_representation not in ALL_UUID_REPRESENTATIONS: + raise ValueError( + "uuid_representation must be a value from bson.binary.UuidRepresentation" + ) + if not isinstance(unicode_decode_error_handler, str): + raise ValueError("unicode_decode_error_handler must be a string") + if tzinfo is not None: + if not isinstance(tzinfo, datetime.tzinfo): + raise TypeError("tzinfo must be an instance of datetime.tzinfo") + if not tz_aware: + raise ValueError("cannot specify tzinfo without also setting tz_aware=True") + + type_registry = type_registry or TypeRegistry() + + if not isinstance(type_registry, TypeRegistry): + raise TypeError("type_registry must be an instance of TypeRegistry") + + return tuple.__new__( + cls, + ( + doc_class, + tz_aware, + uuid_representation, + unicode_decode_error_handler, + tzinfo, + type_registry, + datetime_conversion, + ), + ) + + def _arguments_repr(self) -> str: + """Representation of the arguments used to create this object.""" + document_class_repr = ( + "dict" if self.document_class is dict else repr(self.document_class) + ) + + uuid_rep_repr = UUID_REPRESENTATION_NAMES.get( + self.uuid_representation, self.uuid_representation + ) + + return ( + "document_class={}, tz_aware={!r}, uuid_representation={}, " + "unicode_decode_error_handler={!r}, tzinfo={!r}, " + "type_registry={!r}, datetime_conversion={!s}".format( + document_class_repr, + self.tz_aware, + uuid_rep_repr, + self.unicode_decode_error_handler, + self.tzinfo, + self.type_registry, + self.datetime_conversion, + ) + ) + + def _options_dict(self) -> dict[str, Any]: + """Dictionary of the arguments used to create this object.""" + # TODO: PYTHON-2442 use _asdict() instead + return { + "document_class": self.document_class, + "tz_aware": self.tz_aware, + "uuid_representation": self.uuid_representation, + "unicode_decode_error_handler": self.unicode_decode_error_handler, + "tzinfo": self.tzinfo, + "type_registry": self.type_registry, + "datetime_conversion": self.datetime_conversion, + } + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._arguments_repr()})" + + def with_options(self, **kwargs: Any) -> CodecOptions: + """Make a copy of this CodecOptions, overriding some options:: + + >>> from bson.codec_options import DEFAULT_CODEC_OPTIONS + >>> DEFAULT_CODEC_OPTIONS.tz_aware + False + >>> options = DEFAULT_CODEC_OPTIONS.with_options(tz_aware=True) + >>> options.tz_aware + True + + .. versionadded:: 3.5 + """ + opts = self._options_dict() + opts.update(kwargs) + return CodecOptions(**opts) + + +DEFAULT_CODEC_OPTIONS: CodecOptions[dict[str, Any]] = CodecOptions() + + +def _parse_codec_options(options: Any) -> CodecOptions[Any]: + """Parse BSON codec options.""" + kwargs = {} + for k in set(options) & { + "document_class", + "tz_aware", + "uuidrepresentation", + "unicode_decode_error_handler", + "tzinfo", + "type_registry", + "datetime_conversion", + }: + if k == "uuidrepresentation": + kwargs["uuid_representation"] = options[k] + else: + kwargs[k] = options[k] + return CodecOptions(**kwargs) diff --git a/venv/Lib/site-packages/bson/datetime_ms.py b/venv/Lib/site-packages/bson/datetime_ms.py new file mode 100644 index 00000000..112871a1 --- /dev/null +++ b/venv/Lib/site-packages/bson/datetime_ms.py @@ -0,0 +1,171 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for representing the BSON datetime type. + +.. versionadded:: 4.3 +""" +from __future__ import annotations + +import calendar +import datetime +import functools +from typing import Any, Union, cast + +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, DatetimeConversion +from bson.errors import InvalidBSON +from bson.tz_util import utc + +EPOCH_AWARE = datetime.datetime.fromtimestamp(0, utc) +EPOCH_NAIVE = EPOCH_AWARE.replace(tzinfo=None) +_DATETIME_ERROR_SUGGESTION = ( + "(Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO)" + " or MongoClient(datetime_conversion='DATETIME_AUTO'))." + " See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes" +) + + +class DatetimeMS: + """Represents a BSON UTC datetime.""" + + __slots__ = ("_value",) + + def __init__(self, value: Union[int, datetime.datetime]): + """Represents a BSON UTC datetime. + + BSON UTC datetimes are defined as an int64 of milliseconds since the + Unix epoch. The principal use of DatetimeMS is to represent + datetimes outside the range of the Python builtin + :class:`~datetime.datetime` class when + encoding/decoding BSON. + + To decode UTC datetimes as a ``DatetimeMS``, `datetime_conversion` in + :class:`~bson.codec_options.CodecOptions` must be set to 'datetime_ms' or + 'datetime_auto'. See :ref:`handling-out-of-range-datetimes` for + details. + + :param value: An instance of :class:`datetime.datetime` to be + represented as milliseconds since the Unix epoch, or int of + milliseconds since the Unix epoch. + """ + if isinstance(value, int): + if not (-(2**63) <= value <= 2**63 - 1): + raise OverflowError("Must be a 64-bit integer of milliseconds") + self._value = value + elif isinstance(value, datetime.datetime): + self._value = _datetime_to_millis(value) + else: + raise TypeError(f"{type(value)} is not a valid type for DatetimeMS") + + def __hash__(self) -> int: + return hash(self._value) + + def __repr__(self) -> str: + return type(self).__name__ + "(" + str(self._value) + ")" + + def __lt__(self, other: Union[DatetimeMS, int]) -> bool: + return self._value < other + + def __le__(self, other: Union[DatetimeMS, int]) -> bool: + return self._value <= other + + def __eq__(self, other: Any) -> bool: + if isinstance(other, DatetimeMS): + return self._value == other._value + return False + + def __ne__(self, other: Any) -> bool: + if isinstance(other, DatetimeMS): + return self._value != other._value + return True + + def __gt__(self, other: Union[DatetimeMS, int]) -> bool: + return self._value > other + + def __ge__(self, other: Union[DatetimeMS, int]) -> bool: + return self._value >= other + + _type_marker = 9 + + def as_datetime( + self, codec_options: CodecOptions[Any] = DEFAULT_CODEC_OPTIONS + ) -> datetime.datetime: + """Create a Python :class:`~datetime.datetime` from this DatetimeMS object. + + :param codec_options: A CodecOptions instance for specifying how the + resulting DatetimeMS object will be formatted using ``tz_aware`` + and ``tz_info``. Defaults to + :const:`~bson.codec_options.DEFAULT_CODEC_OPTIONS`. + """ + return cast(datetime.datetime, _millis_to_datetime(self._value, codec_options)) + + def __int__(self) -> int: + return self._value + + +# Inclusive and exclusive min and max for timezones. +# Timezones are hashed by their offset, which is a timedelta +# and therefore there are more than 24 possible timezones. +@functools.lru_cache(maxsize=None) +def _min_datetime_ms(tz: datetime.timezone = datetime.timezone.utc) -> int: + return _datetime_to_millis(datetime.datetime.min.replace(tzinfo=tz)) + + +@functools.lru_cache(maxsize=None) +def _max_datetime_ms(tz: datetime.timezone = datetime.timezone.utc) -> int: + return _datetime_to_millis(datetime.datetime.max.replace(tzinfo=tz)) + + +def _millis_to_datetime( + millis: int, opts: CodecOptions[Any] +) -> Union[datetime.datetime, DatetimeMS]: + """Convert milliseconds since epoch UTC to datetime.""" + if ( + opts.datetime_conversion == DatetimeConversion.DATETIME + or opts.datetime_conversion == DatetimeConversion.DATETIME_CLAMP + or opts.datetime_conversion == DatetimeConversion.DATETIME_AUTO + ): + tz = opts.tzinfo or datetime.timezone.utc + if opts.datetime_conversion == DatetimeConversion.DATETIME_CLAMP: + millis = max(_min_datetime_ms(tz), min(millis, _max_datetime_ms(tz))) + elif opts.datetime_conversion == DatetimeConversion.DATETIME_AUTO: + if not (_min_datetime_ms(tz) <= millis <= _max_datetime_ms(tz)): + return DatetimeMS(millis) + + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) // 1000 + micros = diff * 1000 + + try: + if opts.tz_aware: + dt = EPOCH_AWARE + datetime.timedelta(seconds=seconds, microseconds=micros) + if opts.tzinfo: + dt = dt.astimezone(tz) + return dt + else: + return EPOCH_NAIVE + datetime.timedelta(seconds=seconds, microseconds=micros) + except ArithmeticError as err: + raise InvalidBSON(f"{err} {_DATETIME_ERROR_SUGGESTION}") from err + + elif opts.datetime_conversion == DatetimeConversion.DATETIME_MS: + return DatetimeMS(millis) + else: + raise ValueError("datetime_conversion must be an element of DatetimeConversion") + + +def _datetime_to_millis(dtm: datetime.datetime) -> int: + """Convert datetime to milliseconds since epoch UTC.""" + if dtm.utcoffset() is not None: + dtm = dtm - dtm.utcoffset() # type: ignore + return int(calendar.timegm(dtm.timetuple()) * 1000 + dtm.microsecond // 1000) diff --git a/venv/Lib/site-packages/bson/dbref.py b/venv/Lib/site-packages/bson/dbref.py new file mode 100644 index 00000000..6c21b816 --- /dev/null +++ b/venv/Lib/site-packages/bson/dbref.py @@ -0,0 +1,133 @@ +# Copyright 2009-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for manipulating DBRefs (references to MongoDB documents).""" +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Mapping, Optional + +from bson._helpers import _getstate_slots, _setstate_slots +from bson.son import SON + + +class DBRef: + """A reference to a document stored in MongoDB.""" + + __slots__ = "__collection", "__id", "__database", "__kwargs" + __getstate__ = _getstate_slots + __setstate__ = _setstate_slots + # DBRef isn't actually a BSON "type" so this number was arbitrarily chosen. + _type_marker = 100 + + def __init__( + self, + collection: str, + id: Any, + database: Optional[str] = None, + _extra: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Initialize a new :class:`DBRef`. + + Raises :class:`TypeError` if `collection` or `database` is not + an instance of :class:`str`. `database` is optional and allows + references to documents to work across databases. Any additional + keyword arguments will create additional fields in the resultant + embedded document. + + :param collection: name of the collection the document is stored in + :param id: the value of the document's ``"_id"`` field + :param database: name of the database to reference + :param kwargs: additional keyword arguments will + create additional, custom fields + + .. seealso:: The MongoDB documentation on `dbrefs `_. + """ + if not isinstance(collection, str): + raise TypeError("collection must be an instance of str") + if database is not None and not isinstance(database, str): + raise TypeError("database must be an instance of str") + + self.__collection = collection + self.__id = id + self.__database = database + kwargs.update(_extra or {}) + self.__kwargs = kwargs + + @property + def collection(self) -> str: + """Get the name of this DBRef's collection.""" + return self.__collection + + @property + def id(self) -> Any: + """Get this DBRef's _id.""" + return self.__id + + @property + def database(self) -> Optional[str]: + """Get the name of this DBRef's database. + + Returns None if this DBRef doesn't specify a database. + """ + return self.__database + + def __getattr__(self, key: Any) -> Any: + try: + return self.__kwargs[key] + except KeyError: + raise AttributeError(key) from None + + def as_doc(self) -> SON[str, Any]: + """Get the SON document representation of this DBRef. + + Generally not needed by application developers + """ + doc = SON([("$ref", self.collection), ("$id", self.id)]) + if self.database is not None: + doc["$db"] = self.database + doc.update(self.__kwargs) + return doc + + def __repr__(self) -> str: + extra = "".join([f", {k}={v!r}" for k, v in self.__kwargs.items()]) + if self.database is None: + return f"DBRef({self.collection!r}, {self.id!r}{extra})" + return f"DBRef({self.collection!r}, {self.id!r}, {self.database!r}{extra})" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, DBRef): + us = (self.__database, self.__collection, self.__id, self.__kwargs) + them = (other.__database, other.__collection, other.__id, other.__kwargs) + return us == them + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + """Get a hash value for this :class:`DBRef`.""" + return hash( + (self.__collection, self.__id, self.__database, tuple(sorted(self.__kwargs.items()))) + ) + + def __deepcopy__(self, memo: Any) -> DBRef: + """Support function for `copy.deepcopy()`.""" + return DBRef( + deepcopy(self.__collection, memo), + deepcopy(self.__id, memo), + deepcopy(self.__database, memo), + deepcopy(self.__kwargs, memo), + ) diff --git a/venv/Lib/site-packages/bson/decimal128.py b/venv/Lib/site-packages/bson/decimal128.py new file mode 100644 index 00000000..8581d5a3 --- /dev/null +++ b/venv/Lib/site-packages/bson/decimal128.py @@ -0,0 +1,312 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with the BSON decimal128 type. + +.. versionadded:: 3.4 +""" +from __future__ import annotations + +import decimal +import struct +from typing import Any, Sequence, Tuple, Type, Union + +_PACK_64 = struct.Struct(" decimal.Context: + """Returns an instance of :class:`decimal.Context` appropriate + for working with IEEE-754 128-bit decimal floating point values. + """ + opts = _CTX_OPTIONS.copy() + opts["traps"] = [] + return decimal.Context(**opts) # type: ignore + + +def _decimal_to_128(value: _VALUE_OPTIONS) -> Tuple[int, int]: + """Converts a decimal.Decimal to BID (high bits, low bits). + + :param value: An instance of decimal.Decimal + """ + with decimal.localcontext(_DEC128_CTX) as ctx: + value = ctx.create_decimal(value) + + if value.is_infinite(): + return _NINF if value.is_signed() else _PINF + + sign, digits, exponent = value.as_tuple() + + if value.is_nan(): + if digits: + raise ValueError("NaN with debug payload is not supported") + if value.is_snan(): + return _NSNAN if value.is_signed() else _PSNAN + return _NNAN if value.is_signed() else _PNAN + + significand = int("".join([str(digit) for digit in digits])) + bit_length = significand.bit_length() + + high = 0 + low = 0 + for i in range(min(64, bit_length)): + if significand & (1 << i): + low |= 1 << i + + for i in range(64, bit_length): + if significand & (1 << i): + high |= 1 << (i - 64) + + biased_exponent = exponent + _EXPONENT_BIAS # type: ignore[operator] + + if high >> 49 == 1: + high = high & 0x7FFFFFFFFFFF + high |= _EXPONENT_MASK + high |= (biased_exponent & 0x3FFF) << 47 + else: + high |= biased_exponent << 49 + + if sign: + high |= _SIGN + + return high, low + + +class Decimal128: + """BSON Decimal128 type:: + + >>> Decimal128(Decimal("0.0005")) + Decimal128('0.0005') + >>> Decimal128("0.0005") + Decimal128('0.0005') + >>> Decimal128((3474527112516337664, 5)) + Decimal128('0.0005') + + :param value: An instance of :class:`decimal.Decimal`, string, or tuple of + (high bits, low bits) from Binary Integer Decimal (BID) format. + + .. note:: :class:`~Decimal128` uses an instance of :class:`decimal.Context` + configured for IEEE-754 Decimal128 when validating parameters. + Signals like :class:`decimal.InvalidOperation`, :class:`decimal.Inexact`, + and :class:`decimal.Overflow` are trapped and raised as exceptions:: + + >>> Decimal128(".13.1") + Traceback (most recent call last): + File "", line 1, in + ... + decimal.InvalidOperation: [] + >>> + >>> Decimal128("1E-6177") + Traceback (most recent call last): + File "", line 1, in + ... + decimal.Inexact: [] + >>> + >>> Decimal128("1E6145") + Traceback (most recent call last): + File "", line 1, in + ... + decimal.Overflow: [, ] + + To ensure the result of a calculation can always be stored as BSON + Decimal128 use the context returned by + :func:`create_decimal128_context`:: + + >>> import decimal + >>> decimal128_ctx = create_decimal128_context() + >>> with decimal.localcontext(decimal128_ctx) as ctx: + ... Decimal128(ctx.create_decimal(".13.3")) + ... + Decimal128('NaN') + >>> + >>> with decimal.localcontext(decimal128_ctx) as ctx: + ... Decimal128(ctx.create_decimal("1E-6177")) + ... + Decimal128('0E-6176') + >>> + >>> with decimal.localcontext(DECIMAL128_CTX) as ctx: + ... Decimal128(ctx.create_decimal("1E6145")) + ... + Decimal128('Infinity') + + To match the behavior of MongoDB's Decimal128 implementation + str(Decimal(value)) may not match str(Decimal128(value)) for NaN values:: + + >>> Decimal128(Decimal('NaN')) + Decimal128('NaN') + >>> Decimal128(Decimal('-NaN')) + Decimal128('NaN') + >>> Decimal128(Decimal('sNaN')) + Decimal128('NaN') + >>> Decimal128(Decimal('-sNaN')) + Decimal128('NaN') + + However, :meth:`~Decimal128.to_decimal` will return the exact value:: + + >>> Decimal128(Decimal('NaN')).to_decimal() + Decimal('NaN') + >>> Decimal128(Decimal('-NaN')).to_decimal() + Decimal('-NaN') + >>> Decimal128(Decimal('sNaN')).to_decimal() + Decimal('sNaN') + >>> Decimal128(Decimal('-sNaN')).to_decimal() + Decimal('-sNaN') + + Two instances of :class:`Decimal128` compare equal if their Binary + Integer Decimal encodings are equal:: + + >>> Decimal128('NaN') == Decimal128('NaN') + True + >>> Decimal128('NaN').bid == Decimal128('NaN').bid + True + + This differs from :class:`decimal.Decimal` comparisons for NaN:: + + >>> Decimal('NaN') == Decimal('NaN') + False + """ + + __slots__ = ("__high", "__low") + + _type_marker = 19 + + def __init__(self, value: _VALUE_OPTIONS) -> None: + if isinstance(value, (str, decimal.Decimal)): + self.__high, self.__low = _decimal_to_128(value) + elif isinstance(value, (list, tuple)): + if len(value) != 2: + raise ValueError( + "Invalid size for creation of Decimal128 " + "from list or tuple. Must have exactly 2 " + "elements." + ) + self.__high, self.__low = value # type: ignore + else: + raise TypeError(f"Cannot convert {value!r} to Decimal128") + + def to_decimal(self) -> decimal.Decimal: + """Returns an instance of :class:`decimal.Decimal` for this + :class:`Decimal128`. + """ + high = self.__high + low = self.__low + sign = 1 if (high & _SIGN) else 0 + + if (high & _SNAN) == _SNAN: + return decimal.Decimal((sign, (), "N")) # type: ignore + elif (high & _NAN) == _NAN: + return decimal.Decimal((sign, (), "n")) # type: ignore + elif (high & _INF) == _INF: + return decimal.Decimal((sign, (), "F")) # type: ignore + + if (high & _EXPONENT_MASK) == _EXPONENT_MASK: + exponent = ((high & 0x1FFFE00000000000) >> 47) - _EXPONENT_BIAS + return decimal.Decimal((sign, (0,), exponent)) + else: + exponent = ((high & 0x7FFF800000000000) >> 49) - _EXPONENT_BIAS + + arr = bytearray(15) + mask = 0x00000000000000FF + for i in range(14, 6, -1): + arr[i] = (low & mask) >> ((14 - i) << 3) + mask = mask << 8 + + mask = 0x00000000000000FF + for i in range(6, 0, -1): + arr[i] = (high & mask) >> ((6 - i) << 3) + mask = mask << 8 + + mask = 0x0001000000000000 + arr[0] = (high & mask) >> 48 + + # cdecimal only accepts a tuple for digits. + digits = tuple(int(digit) for digit in str(int.from_bytes(arr, "big"))) + + with decimal.localcontext(_DEC128_CTX) as ctx: + return ctx.create_decimal((sign, digits, exponent)) + + @classmethod + def from_bid(cls: Type[Decimal128], value: bytes) -> Decimal128: + """Create an instance of :class:`Decimal128` from Binary Integer + Decimal string. + + :param value: 16 byte string (128-bit IEEE 754-2008 decimal floating + point in Binary Integer Decimal (BID) format). + """ + if not isinstance(value, bytes): + raise TypeError("value must be an instance of bytes") + if len(value) != 16: + raise ValueError("value must be exactly 16 bytes") + return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) # type: ignore + + @property + def bid(self) -> bytes: + """The Binary Integer Decimal (BID) encoding of this instance.""" + return _PACK_64(self.__low) + _PACK_64(self.__high) + + def __str__(self) -> str: + dec = self.to_decimal() + if dec.is_nan(): + # Required by the drivers spec to match MongoDB behavior. + return "NaN" + return str(dec) + + def __repr__(self) -> str: + return f"Decimal128('{self!s}')" + + def __setstate__(self, value: Tuple[int, int]) -> None: + self.__high, self.__low = value + + def __getstate__(self) -> Tuple[int, int]: + return self.__high, self.__low + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Decimal128): + return self.bid == other.bid + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other diff --git a/venv/Lib/site-packages/bson/errors.py b/venv/Lib/site-packages/bson/errors.py new file mode 100644 index 00000000..a3699e70 --- /dev/null +++ b/venv/Lib/site-packages/bson/errors.py @@ -0,0 +1,36 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by the BSON package.""" +from __future__ import annotations + + +class BSONError(Exception): + """Base class for all BSON exceptions.""" + + +class InvalidBSON(BSONError): + """Raised when trying to create a BSON object from invalid data.""" + + +class InvalidStringData(BSONError): + """Raised when trying to encode a string containing non-UTF8 data.""" + + +class InvalidDocument(BSONError): + """Raised when trying to create a BSON object from an invalid document.""" + + +class InvalidId(BSONError): + """Raised when trying to create an ObjectId from invalid data.""" diff --git a/venv/Lib/site-packages/bson/int64.py b/venv/Lib/site-packages/bson/int64.py new file mode 100644 index 00000000..5846504a --- /dev/null +++ b/venv/Lib/site-packages/bson/int64.py @@ -0,0 +1,39 @@ +# Copyright 2014-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A BSON wrapper for long (int in python3)""" +from __future__ import annotations + +from typing import Any + + +class Int64(int): + """Representation of the BSON int64 type. + + This is necessary because every integral number is an :class:`int` in + Python 3. Small integral numbers are encoded to BSON int32 by default, + but Int64 numbers will always be encoded to BSON int64. + + :param value: the numeric value to represent + """ + + __slots__ = () + + _type_marker = 18 + + def __getstate__(self) -> Any: + return {} + + def __setstate__(self, state: Any) -> None: + pass diff --git a/venv/Lib/site-packages/bson/json_util.py b/venv/Lib/site-packages/bson/json_util.py new file mode 100644 index 00000000..6c5197c7 --- /dev/null +++ b/venv/Lib/site-packages/bson/json_util.py @@ -0,0 +1,1161 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for using Python's :mod:`json` module with BSON documents. + +This module provides two helper methods `dumps` and `loads` that wrap the +native :mod:`json` methods and provide explicit BSON conversion to and from +JSON. :class:`~bson.json_util.JSONOptions` provides a way to control how JSON +is emitted and parsed, with the default being the Relaxed Extended JSON format. +:mod:`~bson.json_util` can also generate Canonical or legacy `Extended JSON`_ +when :const:`CANONICAL_JSON_OPTIONS` or :const:`LEGACY_JSON_OPTIONS` is +provided, respectively. + +.. _Extended JSON: https://github.com/mongodb/specifications/blob/master/source/extended-json.rst + +Example usage (deserialization): + +.. doctest:: + + >>> from bson.json_util import loads + >>> loads( + ... '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$scope": {}, "$code": "function x() { return 1; }"}}, {"bin": {"$type": "80", "$binary": "AQIDBA=="}}]' + ... ) + [{'foo': [1, 2]}, {'bar': {'hello': 'world'}}, {'code': Code('function x() { return 1; }', {})}, {'bin': Binary(b'...', 128)}] + +Example usage with :const:`RELAXED_JSON_OPTIONS` (the default): + +.. doctest:: + + >>> from bson import Binary, Code + >>> from bson.json_util import dumps + >>> dumps( + ... [ + ... {"foo": [1, 2]}, + ... {"bar": {"hello": "world"}}, + ... {"code": Code("function x() { return 1; }")}, + ... {"bin": Binary(b"\x01\x02\x03\x04")}, + ... ] + ... ) + '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }"}}, {"bin": {"$binary": {"base64": "AQIDBA==", "subType": "00"}}}]' + +Example usage (with :const:`CANONICAL_JSON_OPTIONS`): + +.. doctest:: + + >>> from bson import Binary, Code + >>> from bson.json_util import dumps, CANONICAL_JSON_OPTIONS + >>> dumps( + ... [ + ... {"foo": [1, 2]}, + ... {"bar": {"hello": "world"}}, + ... {"code": Code("function x() { return 1; }")}, + ... {"bin": Binary(b"\x01\x02\x03\x04")}, + ... ], + ... json_options=CANONICAL_JSON_OPTIONS, + ... ) + '[{"foo": [{"$numberInt": "1"}, {"$numberInt": "2"}]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }"}}, {"bin": {"$binary": {"base64": "AQIDBA==", "subType": "00"}}}]' + +Example usage (with :const:`LEGACY_JSON_OPTIONS`): + +.. doctest:: + + >>> from bson import Binary, Code + >>> from bson.json_util import dumps, LEGACY_JSON_OPTIONS + >>> dumps( + ... [ + ... {"foo": [1, 2]}, + ... {"bar": {"hello": "world"}}, + ... {"code": Code("function x() { return 1; }", {})}, + ... {"bin": Binary(b"\x01\x02\x03\x04")}, + ... ], + ... json_options=LEGACY_JSON_OPTIONS, + ... ) + '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }", "$scope": {}}}, {"bin": {"$binary": "AQIDBA==", "$type": "00"}}]' + +Alternatively, you can manually pass the `default` to :func:`json.dumps`. +It won't handle :class:`~bson.binary.Binary` and :class:`~bson.code.Code` +instances (as they are extended strings you can't provide custom defaults), +but it will be faster as there is less recursion. + +.. note:: + If your application does not need the flexibility offered by + :class:`JSONOptions` and spends a large amount of time in the `json_util` + module, look to + `python-bsonjs `_ for a nice + performance improvement. `python-bsonjs` is a fast BSON to MongoDB + Extended JSON converter for Python built on top of + `libbson `_. `python-bsonjs` works best + with PyMongo when using :class:`~bson.raw_bson.RawBSONDocument`. +""" +from __future__ import annotations + +import base64 +import datetime +import json +import math +import re +import uuid +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +from bson.binary import ALL_UUID_SUBTYPES, UUID_SUBTYPE, Binary, UuidRepresentation +from bson.code import Code +from bson.codec_options import CodecOptions, DatetimeConversion +from bson.datetime_ms import ( + EPOCH_AWARE, + DatetimeMS, + _datetime_to_millis, + _max_datetime_ms, + _millis_to_datetime, +) +from bson.dbref import DBRef +from bson.decimal128 import Decimal128 +from bson.int64 import Int64 +from bson.max_key import MaxKey +from bson.min_key import MinKey +from bson.objectid import ObjectId +from bson.regex import Regex +from bson.son import RE_TYPE +from bson.timestamp import Timestamp +from bson.tz_util import utc + +_RE_OPT_TABLE = { + "i": re.I, + "l": re.L, + "m": re.M, + "s": re.S, + "u": re.U, + "x": re.X, +} + + +class DatetimeRepresentation: + LEGACY = 0 + """Legacy MongoDB Extended JSON datetime representation. + + :class:`datetime.datetime` instances will be encoded to JSON in the + format `{"$date": }`, where `dateAsMilliseconds` is + a 64-bit signed integer giving the number of milliseconds since the Unix + epoch UTC. This was the default encoding before PyMongo version 3.4. + + .. versionadded:: 3.4 + """ + + NUMBERLONG = 1 + """NumberLong datetime representation. + + :class:`datetime.datetime` instances will be encoded to JSON in the + format `{"$date": {"$numberLong": ""}}`, + where `dateAsMilliseconds` is the string representation of a 64-bit signed + integer giving the number of milliseconds since the Unix epoch UTC. + + .. versionadded:: 3.4 + """ + + ISO8601 = 2 + """ISO-8601 datetime representation. + + :class:`datetime.datetime` instances greater than or equal to the Unix + epoch UTC will be encoded to JSON in the format `{"$date": ""}`. + :class:`datetime.datetime` instances before the Unix epoch UTC will be + encoded as if the datetime representation is + :const:`~DatetimeRepresentation.NUMBERLONG`. + + .. versionadded:: 3.4 + """ + + +class JSONMode: + LEGACY = 0 + """Legacy Extended JSON representation. + + In this mode, :func:`~bson.json_util.dumps` produces PyMongo's legacy + non-standard JSON output. Consider using + :const:`~bson.json_util.JSONMode.RELAXED` or + :const:`~bson.json_util.JSONMode.CANONICAL` instead. + + .. versionadded:: 3.5 + """ + + RELAXED = 1 + """Relaxed Extended JSON representation. + + In this mode, :func:`~bson.json_util.dumps` produces Relaxed Extended JSON, + a mostly JSON-like format. Consider using this for things like a web API, + where one is sending a document (or a projection of a document) that only + uses ordinary JSON type primitives. In particular, the ``int``, + :class:`~bson.int64.Int64`, and ``float`` numeric types are represented in + the native JSON number format. This output is also the most human readable + and is useful for debugging and documentation. + + .. seealso:: The specification for Relaxed `Extended JSON`_. + + .. versionadded:: 3.5 + """ + + CANONICAL = 2 + """Canonical Extended JSON representation. + + In this mode, :func:`~bson.json_util.dumps` produces Canonical Extended + JSON, a type preserving format. Consider using this for things like + testing, where one has to precisely specify expected types in JSON. In + particular, the ``int``, :class:`~bson.int64.Int64`, and ``float`` numeric + types are encoded with type wrappers. + + .. seealso:: The specification for Canonical `Extended JSON`_. + + .. versionadded:: 3.5 + """ + + +if TYPE_CHECKING: + _BASE_CLASS = CodecOptions[MutableMapping[str, Any]] +else: + _BASE_CLASS = CodecOptions + +_INT32_MAX = 2**31 + + +class JSONOptions(_BASE_CLASS): + json_mode: int + strict_number_long: bool + datetime_representation: int + strict_uuid: bool + document_class: Type[MutableMapping[str, Any]] + + def __init__(self, *args: Any, **kwargs: Any): + """Encapsulates JSON options for :func:`dumps` and :func:`loads`. + + :param strict_number_long: If ``True``, :class:`~bson.int64.Int64` objects + are encoded to MongoDB Extended JSON's *Strict mode* type + `NumberLong`, ie ``'{"$numberLong": "" }'``. Otherwise they + will be encoded as an `int`. Defaults to ``False``. + :param datetime_representation: The representation to use when encoding + instances of :class:`datetime.datetime`. Defaults to + :const:`~DatetimeRepresentation.LEGACY`. + :param strict_uuid: If ``True``, :class:`uuid.UUID` object are encoded to + MongoDB Extended JSON's *Strict mode* type `Binary`. Otherwise it + will be encoded as ``'{"$uuid": "" }'``. Defaults to ``False``. + :param json_mode: The :class:`JSONMode` to use when encoding BSON types to + Extended JSON. Defaults to :const:`~JSONMode.LEGACY`. + :param document_class: BSON documents returned by :func:`loads` will be + decoded to an instance of this class. Must be a subclass of + :class:`collections.MutableMapping`. Defaults to :class:`dict`. + :param uuid_representation: The :class:`~bson.binary.UuidRepresentation` + to use when encoding and decoding instances of :class:`uuid.UUID`. + Defaults to :const:`~bson.binary.UuidRepresentation.UNSPECIFIED`. + :param tz_aware: If ``True``, MongoDB Extended JSON's *Strict mode* type + `Date` will be decoded to timezone aware instances of + :class:`datetime.datetime`. Otherwise they will be naive. Defaults + to ``False``. + :param tzinfo: A :class:`datetime.tzinfo` subclass that specifies the + timezone from which :class:`~datetime.datetime` objects should be + decoded. Defaults to :const:`~bson.tz_util.utc`. + :param datetime_conversion: Specifies how UTC datetimes should be decoded + within BSON. Valid options include 'datetime_ms' to return as a + DatetimeMS, 'datetime' to return as a datetime.datetime and + raising a ValueError for out-of-range values, 'datetime_auto' to + return DatetimeMS objects when the underlying datetime is + out-of-range and 'datetime_clamp' to clamp to the minimum and + maximum possible datetimes. Defaults to 'datetime'. See + :ref:`handling-out-of-range-datetimes` for details. + :param args: arguments to :class:`~bson.codec_options.CodecOptions` + :param kwargs: arguments to :class:`~bson.codec_options.CodecOptions` + + .. seealso:: The specification for Relaxed and Canonical `Extended JSON`_. + + .. versionchanged:: 4.0 + The default for `json_mode` was changed from :const:`JSONMode.LEGACY` + to :const:`JSONMode.RELAXED`. + The default for `uuid_representation` was changed from + :const:`~bson.binary.UuidRepresentation.PYTHON_LEGACY` to + :const:`~bson.binary.UuidRepresentation.UNSPECIFIED`. + + .. versionchanged:: 3.5 + Accepts the optional parameter `json_mode`. + + .. versionchanged:: 4.0 + Changed default value of `tz_aware` to False. + """ + super().__init__() + + def __new__( + cls: Type[JSONOptions], + strict_number_long: Optional[bool] = None, + datetime_representation: Optional[int] = None, + strict_uuid: Optional[bool] = None, + json_mode: int = JSONMode.RELAXED, + *args: Any, + **kwargs: Any, + ) -> JSONOptions: + kwargs["tz_aware"] = kwargs.get("tz_aware", False) + if kwargs["tz_aware"]: + kwargs["tzinfo"] = kwargs.get("tzinfo", utc) + if datetime_representation not in ( + DatetimeRepresentation.LEGACY, + DatetimeRepresentation.NUMBERLONG, + DatetimeRepresentation.ISO8601, + None, + ): + raise ValueError( + "JSONOptions.datetime_representation must be one of LEGACY, " + "NUMBERLONG, or ISO8601 from DatetimeRepresentation." + ) + self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type] + if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): + raise ValueError( + "JSONOptions.json_mode must be one of LEGACY, RELAXED, " + "or CANONICAL from JSONMode." + ) + self.json_mode = json_mode + if self.json_mode == JSONMode.RELAXED: + if strict_number_long: + raise ValueError("Cannot specify strict_number_long=True with JSONMode.RELAXED") + if datetime_representation not in (None, DatetimeRepresentation.ISO8601): + raise ValueError( + "datetime_representation must be DatetimeRepresentation." + "ISO8601 or omitted with JSONMode.RELAXED" + ) + if strict_uuid not in (None, True): + raise ValueError("Cannot specify strict_uuid=False with JSONMode.RELAXED") + self.strict_number_long = False + self.datetime_representation = DatetimeRepresentation.ISO8601 + self.strict_uuid = True + elif self.json_mode == JSONMode.CANONICAL: + if strict_number_long not in (None, True): + raise ValueError("Cannot specify strict_number_long=False with JSONMode.RELAXED") + if datetime_representation not in (None, DatetimeRepresentation.NUMBERLONG): + raise ValueError( + "datetime_representation must be DatetimeRepresentation." + "NUMBERLONG or omitted with JSONMode.RELAXED" + ) + if strict_uuid not in (None, True): + raise ValueError("Cannot specify strict_uuid=False with JSONMode.RELAXED") + self.strict_number_long = True + self.datetime_representation = DatetimeRepresentation.NUMBERLONG + self.strict_uuid = True + else: # JSONMode.LEGACY + self.strict_number_long = False + self.datetime_representation = DatetimeRepresentation.LEGACY + self.strict_uuid = False + if strict_number_long is not None: + self.strict_number_long = strict_number_long + if datetime_representation is not None: + self.datetime_representation = datetime_representation + if strict_uuid is not None: + self.strict_uuid = strict_uuid + return self + + def _arguments_repr(self) -> str: + return ( + "strict_number_long={!r}, " + "datetime_representation={!r}, " + "strict_uuid={!r}, json_mode={!r}, {}".format( + self.strict_number_long, + self.datetime_representation, + self.strict_uuid, + self.json_mode, + super()._arguments_repr(), + ) + ) + + def _options_dict(self) -> dict[Any, Any]: + # TODO: PYTHON-2442 use _asdict() instead + options_dict = super()._options_dict() + options_dict.update( + { + "strict_number_long": self.strict_number_long, + "datetime_representation": self.datetime_representation, + "strict_uuid": self.strict_uuid, + "json_mode": self.json_mode, + } + ) + return options_dict + + def with_options(self, **kwargs: Any) -> JSONOptions: + """ + Make a copy of this JSONOptions, overriding some options:: + + >>> from bson.json_util import CANONICAL_JSON_OPTIONS + >>> CANONICAL_JSON_OPTIONS.tz_aware + True + >>> json_options = CANONICAL_JSON_OPTIONS.with_options(tz_aware=False, tzinfo=None) + >>> json_options.tz_aware + False + + .. versionadded:: 3.12 + """ + opts = self._options_dict() + for opt in ("strict_number_long", "datetime_representation", "strict_uuid", "json_mode"): + opts[opt] = kwargs.get(opt, getattr(self, opt)) + opts.update(kwargs) + return JSONOptions(**opts) + + +LEGACY_JSON_OPTIONS: JSONOptions = JSONOptions(json_mode=JSONMode.LEGACY) +""":class:`JSONOptions` for encoding to PyMongo's legacy JSON format. + +.. seealso:: The documentation for :const:`bson.json_util.JSONMode.LEGACY`. + +.. versionadded:: 3.5 +""" + +CANONICAL_JSON_OPTIONS: JSONOptions = JSONOptions(json_mode=JSONMode.CANONICAL) +""":class:`JSONOptions` for Canonical Extended JSON. + +.. seealso:: The documentation for :const:`bson.json_util.JSONMode.CANONICAL`. + +.. versionadded:: 3.5 +""" + +RELAXED_JSON_OPTIONS: JSONOptions = JSONOptions(json_mode=JSONMode.RELAXED) +""":class:`JSONOptions` for Relaxed Extended JSON. + +.. seealso:: The documentation for :const:`bson.json_util.JSONMode.RELAXED`. + +.. versionadded:: 3.5 +""" + +DEFAULT_JSON_OPTIONS: JSONOptions = RELAXED_JSON_OPTIONS +"""The default :class:`JSONOptions` for JSON encoding/decoding. + +The same as :const:`RELAXED_JSON_OPTIONS`. + +.. versionchanged:: 4.0 + Changed from :const:`LEGACY_JSON_OPTIONS` to + :const:`RELAXED_JSON_OPTIONS`. + +.. versionadded:: 3.4 +""" + + +def dumps(obj: Any, *args: Any, **kwargs: Any) -> str: + """Helper function that wraps :func:`json.dumps`. + + Recursive function that handles all BSON types including + :class:`~bson.binary.Binary` and :class:`~bson.code.Code`. + + :param json_options: A :class:`JSONOptions` instance used to modify the + encoding of MongoDB Extended JSON types. Defaults to + :const:`DEFAULT_JSON_OPTIONS`. + + .. versionchanged:: 4.0 + Now outputs MongoDB Relaxed Extended JSON by default (using + :const:`DEFAULT_JSON_OPTIONS`). + + .. versionchanged:: 3.4 + Accepts optional parameter `json_options`. See :class:`JSONOptions`. + """ + json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS) + return json.dumps(_json_convert(obj, json_options), *args, **kwargs) + + +def loads(s: Union[str, bytes, bytearray], *args: Any, **kwargs: Any) -> Any: + """Helper function that wraps :func:`json.loads`. + + Automatically passes the object_hook for BSON type conversion. + + Raises ``TypeError``, ``ValueError``, ``KeyError``, or + :exc:`~bson.errors.InvalidId` on invalid MongoDB Extended JSON. + + :param json_options: A :class:`JSONOptions` instance used to modify the + decoding of MongoDB Extended JSON types. Defaults to + :const:`DEFAULT_JSON_OPTIONS`. + + .. versionchanged:: 4.0 + Now loads :class:`datetime.datetime` instances as naive by default. To + load timezone aware instances utilize the `json_options` parameter. + See :ref:`tz_aware_default_change` for an example. + + .. versionchanged:: 3.5 + Parses Relaxed and Canonical Extended JSON as well as PyMongo's legacy + format. Now raises ``TypeError`` or ``ValueError`` when parsing JSON + type wrappers with values of the wrong type or any extra keys. + + .. versionchanged:: 3.4 + Accepts optional parameter `json_options`. See :class:`JSONOptions`. + """ + json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS) + # Execution time optimization if json_options.document_class is dict + if json_options.document_class is dict: + kwargs["object_hook"] = lambda obj: object_hook(obj, json_options) + else: + kwargs["object_pairs_hook"] = lambda pairs: object_pairs_hook(pairs, json_options) + return json.loads(s, *args, **kwargs) + + +def _json_convert(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any: + """Recursive helper method that converts BSON types so they can be + converted into json. + """ + if hasattr(obj, "items"): + return {k: _json_convert(v, json_options) for k, v in obj.items()} + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): + return [_json_convert(v, json_options) for v in obj] + try: + return default(obj, json_options) + except TypeError: + return obj + + +def object_pairs_hook( + pairs: Sequence[Tuple[str, Any]], json_options: JSONOptions = DEFAULT_JSON_OPTIONS +) -> Any: + return object_hook(json_options.document_class(pairs), json_options) # type:ignore[call-arg] + + +def object_hook(dct: Mapping[str, Any], json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any: + match = None + for k in dct: + if k in _PARSERS_SET: + match = k + break + if match: + return _PARSERS[match](dct, json_options) + return dct + + +def _parse_legacy_regex(doc: Any, dummy0: Any) -> Any: + pattern = doc["$regex"] + # Check if this is the $regex query operator. + if not isinstance(pattern, (str, bytes)): + return doc + flags = 0 + # PyMongo always adds $options but some other tools may not. + for opt in doc.get("$options", ""): + flags |= _RE_OPT_TABLE.get(opt, 0) + return Regex(pattern, flags) + + +def _parse_legacy_uuid(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: + """Decode a JSON legacy $uuid to Python UUID.""" + if len(doc) != 1: + raise TypeError(f"Bad $uuid, extra field(s): {doc}") + if not isinstance(doc["$uuid"], str): + raise TypeError(f"$uuid must be a string: {doc}") + if json_options.uuid_representation == UuidRepresentation.UNSPECIFIED: + return Binary.from_uuid(uuid.UUID(doc["$uuid"])) + else: + return uuid.UUID(doc["$uuid"]) + + +def _binary_or_uuid(data: Any, subtype: int, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: + # special handling for UUID + if subtype in ALL_UUID_SUBTYPES: + uuid_representation = json_options.uuid_representation + binary_value = Binary(data, subtype) + if uuid_representation == UuidRepresentation.UNSPECIFIED: + return binary_value + if subtype == UUID_SUBTYPE: + # Legacy behavior: use STANDARD with binary subtype 4. + uuid_representation = UuidRepresentation.STANDARD + elif uuid_representation == UuidRepresentation.STANDARD: + # subtype == OLD_UUID_SUBTYPE + # Legacy behavior: STANDARD is the same as PYTHON_LEGACY. + uuid_representation = UuidRepresentation.PYTHON_LEGACY + return binary_value.as_uuid(uuid_representation) + + if subtype == 0: + return cast(uuid.UUID, data) + return Binary(data, subtype) + + +def _parse_legacy_binary(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: + if isinstance(doc["$type"], int): + doc["$type"] = "%02x" % doc["$type"] + subtype = int(doc["$type"], 16) + if subtype >= 0xFFFFFF80: # Handle mongoexport values + subtype = int(doc["$type"][6:], 16) + data = base64.b64decode(doc["$binary"].encode()) + return _binary_or_uuid(data, subtype, json_options) + + +def _parse_canonical_binary(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: + binary = doc["$binary"] + b64 = binary["base64"] + subtype = binary["subType"] + if not isinstance(b64, str): + raise TypeError(f"$binary base64 must be a string: {doc}") + if not isinstance(subtype, str) or len(subtype) > 2: + raise TypeError(f"$binary subType must be a string at most 2 characters: {doc}") + if len(binary) != 2: + raise TypeError(f'$binary must include only "base64" and "subType" components: {doc}') + + data = base64.b64decode(b64.encode()) + return _binary_or_uuid(data, int(subtype, 16), json_options) + + +def _parse_canonical_datetime( + doc: Any, json_options: JSONOptions +) -> Union[datetime.datetime, DatetimeMS]: + """Decode a JSON datetime to python datetime.datetime.""" + dtm = doc["$date"] + if len(doc) != 1: + raise TypeError(f"Bad $date, extra field(s): {doc}") + # mongoexport 2.6 and newer + if isinstance(dtm, str): + # Parse offset + if dtm[-1] == "Z": + dt = dtm[:-1] + offset = "Z" + elif dtm[-6] in ("+", "-") and dtm[-3] == ":": + # (+|-)HH:MM + dt = dtm[:-6] + offset = dtm[-6:] + elif dtm[-5] in ("+", "-"): + # (+|-)HHMM + dt = dtm[:-5] + offset = dtm[-5:] + elif dtm[-3] in ("+", "-"): + # (+|-)HH + dt = dtm[:-3] + offset = dtm[-3:] + else: + dt = dtm + offset = "" + + # Parse the optional factional seconds portion. + dot_index = dt.rfind(".") + microsecond = 0 + if dot_index != -1: + microsecond = int(float(dt[dot_index:]) * 1000000) + dt = dt[:dot_index] + + aware = datetime.datetime.strptime(dt, "%Y-%m-%dT%H:%M:%S").replace( + microsecond=microsecond, tzinfo=utc + ) + + if offset and offset != "Z": + if len(offset) == 6: + hours, minutes = offset[1:].split(":") + secs = int(hours) * 3600 + int(minutes) * 60 + elif len(offset) == 5: + secs = int(offset[1:3]) * 3600 + int(offset[3:]) * 60 + elif len(offset) == 3: + secs = int(offset[1:3]) * 3600 + if offset[0] == "-": + secs *= -1 + aware = aware - datetime.timedelta(seconds=secs) + + if json_options.tz_aware: + if json_options.tzinfo: + aware = aware.astimezone(json_options.tzinfo) + if json_options.datetime_conversion == DatetimeConversion.DATETIME_MS: + return DatetimeMS(aware) + return aware + else: + aware_tzinfo_none = aware.replace(tzinfo=None) + if json_options.datetime_conversion == DatetimeConversion.DATETIME_MS: + return DatetimeMS(aware_tzinfo_none) + return aware_tzinfo_none + return _millis_to_datetime(int(dtm), cast("CodecOptions[Any]", json_options)) + + +def _parse_canonical_oid(doc: Any, dummy0: Any) -> ObjectId: + """Decode a JSON ObjectId to bson.objectid.ObjectId.""" + if len(doc) != 1: + raise TypeError(f"Bad $oid, extra field(s): {doc}") + return ObjectId(doc["$oid"]) + + +def _parse_canonical_symbol(doc: Any, dummy0: Any) -> str: + """Decode a JSON symbol to Python string.""" + symbol = doc["$symbol"] + if len(doc) != 1: + raise TypeError(f"Bad $symbol, extra field(s): {doc}") + return str(symbol) + + +def _parse_canonical_code(doc: Any, dummy0: Any) -> Code: + """Decode a JSON code to bson.code.Code.""" + for key in doc: + if key not in ("$code", "$scope"): + raise TypeError(f"Bad $code, extra field(s): {doc}") + return Code(doc["$code"], scope=doc.get("$scope")) + + +def _parse_canonical_regex(doc: Any, dummy0: Any) -> Regex[str]: + """Decode a JSON regex to bson.regex.Regex.""" + regex = doc["$regularExpression"] + if len(doc) != 1: + raise TypeError(f"Bad $regularExpression, extra field(s): {doc}") + if len(regex) != 2: + raise TypeError( + f'Bad $regularExpression must include only "pattern and "options" components: {doc}' + ) + opts = regex["options"] + if not isinstance(opts, str): + raise TypeError( + "Bad $regularExpression options, options must be string, was type %s" % (type(opts)) + ) + return Regex(regex["pattern"], opts) + + +def _parse_canonical_dbref(doc: Any, dummy0: Any) -> Any: + """Decode a JSON DBRef to bson.dbref.DBRef.""" + if ( + isinstance(doc.get("$ref"), str) + and "$id" in doc + and isinstance(doc.get("$db"), (str, type(None))) + ): + return DBRef(doc.pop("$ref"), doc.pop("$id"), database=doc.pop("$db", None), **doc) + return doc + + +def _parse_canonical_dbpointer(doc: Any, dummy0: Any) -> Any: + """Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef.""" + dbref = doc["$dbPointer"] + if len(doc) != 1: + raise TypeError(f"Bad $dbPointer, extra field(s): {doc}") + if isinstance(dbref, DBRef): + dbref_doc = dbref.as_doc() + # DBPointer must not contain $db in its value. + if dbref.database is not None: + raise TypeError(f"Bad $dbPointer, extra field $db: {dbref_doc}") + if not isinstance(dbref.id, ObjectId): + raise TypeError(f"Bad $dbPointer, $id must be an ObjectId: {dbref_doc}") + if len(dbref_doc) != 2: + raise TypeError(f"Bad $dbPointer, extra field(s) in DBRef: {dbref_doc}") + return dbref + else: + raise TypeError(f"Bad $dbPointer, expected a DBRef: {doc}") + + +def _parse_canonical_int32(doc: Any, dummy0: Any) -> int: + """Decode a JSON int32 to python int.""" + i_str = doc["$numberInt"] + if len(doc) != 1: + raise TypeError(f"Bad $numberInt, extra field(s): {doc}") + if not isinstance(i_str, str): + raise TypeError(f"$numberInt must be string: {doc}") + return int(i_str) + + +def _parse_canonical_int64(doc: Any, dummy0: Any) -> Int64: + """Decode a JSON int64 to bson.int64.Int64.""" + l_str = doc["$numberLong"] + if len(doc) != 1: + raise TypeError(f"Bad $numberLong, extra field(s): {doc}") + return Int64(l_str) + + +def _parse_canonical_double(doc: Any, dummy0: Any) -> float: + """Decode a JSON double to python float.""" + d_str = doc["$numberDouble"] + if len(doc) != 1: + raise TypeError(f"Bad $numberDouble, extra field(s): {doc}") + if not isinstance(d_str, str): + raise TypeError(f"$numberDouble must be string: {doc}") + return float(d_str) + + +def _parse_canonical_decimal128(doc: Any, dummy0: Any) -> Decimal128: + """Decode a JSON decimal128 to bson.decimal128.Decimal128.""" + d_str = doc["$numberDecimal"] + if len(doc) != 1: + raise TypeError(f"Bad $numberDecimal, extra field(s): {doc}") + if not isinstance(d_str, str): + raise TypeError(f"$numberDecimal must be string: {doc}") + return Decimal128(d_str) + + +def _parse_canonical_minkey(doc: Any, dummy0: Any) -> MinKey: + """Decode a JSON MinKey to bson.min_key.MinKey.""" + if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: # noqa: E721 + raise TypeError(f"$minKey value must be 1: {doc}") + if len(doc) != 1: + raise TypeError(f"Bad $minKey, extra field(s): {doc}") + return MinKey() + + +def _parse_canonical_maxkey(doc: Any, dummy0: Any) -> MaxKey: + """Decode a JSON MaxKey to bson.max_key.MaxKey.""" + if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: # noqa: E721 + raise TypeError("$maxKey value must be 1: %s", (doc,)) + if len(doc) != 1: + raise TypeError(f"Bad $minKey, extra field(s): {doc}") + return MaxKey() + + +def _parse_binary(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: + if "$type" in doc: + return _parse_legacy_binary(doc, json_options) + else: + return _parse_canonical_binary(doc, json_options) + + +def _parse_timestamp(doc: Any, dummy0: Any) -> Timestamp: + tsp = doc["$timestamp"] + return Timestamp(tsp["t"], tsp["i"]) + + +_PARSERS: dict[str, Callable[[Any, JSONOptions], Any]] = { + "$oid": _parse_canonical_oid, + "$ref": _parse_canonical_dbref, + "$date": _parse_canonical_datetime, + "$regex": _parse_legacy_regex, + "$minKey": _parse_canonical_minkey, + "$maxKey": _parse_canonical_maxkey, + "$binary": _parse_binary, + "$code": _parse_canonical_code, + "$uuid": _parse_legacy_uuid, + "$undefined": lambda _, _1: None, + "$numberLong": _parse_canonical_int64, + "$timestamp": _parse_timestamp, + "$numberDecimal": _parse_canonical_decimal128, + "$dbPointer": _parse_canonical_dbpointer, + "$regularExpression": _parse_canonical_regex, + "$symbol": _parse_canonical_symbol, + "$numberInt": _parse_canonical_int32, + "$numberDouble": _parse_canonical_double, +} +_PARSERS_SET = set(_PARSERS) + + +def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any: + if json_options.json_mode == JSONMode.LEGACY: + return {"$binary": base64.b64encode(data).decode(), "$type": "%02x" % subtype} + return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}} + + +def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: + if ( + json_options.datetime_representation == DatetimeRepresentation.ISO8601 + and 0 <= int(obj) <= _max_datetime_ms() + ): + return _encode_datetime(obj.as_datetime(), json_options) + elif json_options.datetime_representation == DatetimeRepresentation.LEGACY: + return {"$date": str(int(obj))} + return {"$date": {"$numberLong": str(int(obj))}} + + +def _encode_code(obj: Code, json_options: JSONOptions) -> dict: + if obj.scope is None: + return {"$code": str(obj)} + else: + return {"$code": str(obj), "$scope": _json_convert(obj.scope, json_options)} + + +def _encode_int64(obj: Int64, json_options: JSONOptions) -> Any: + if json_options.strict_number_long: + return {"$numberLong": str(obj)} + else: + return int(obj) + + +def _encode_noop(obj: Any, dummy0: Any) -> Any: + return obj + + +def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: + flags = "" + if obj.flags & re.IGNORECASE: + flags += "i" + if obj.flags & re.LOCALE: + flags += "l" + if obj.flags & re.MULTILINE: + flags += "m" + if obj.flags & re.DOTALL: + flags += "s" + if obj.flags & re.UNICODE: + flags += "u" + if obj.flags & re.VERBOSE: + flags += "x" + if isinstance(obj.pattern, str): + pattern = obj.pattern + else: + pattern = obj.pattern.decode("utf-8") + if json_options.json_mode == JSONMode.LEGACY: + return {"$regex": pattern, "$options": flags} + return {"$regularExpression": {"pattern": pattern, "options": flags}} + + +def _encode_int(obj: int, json_options: JSONOptions) -> Any: + if json_options.json_mode == JSONMode.CANONICAL: + if -_INT32_MAX <= obj < _INT32_MAX: + return {"$numberInt": str(obj)} + return {"$numberLong": str(obj)} + return obj + + +def _encode_float(obj: float, json_options: JSONOptions) -> Any: + if json_options.json_mode != JSONMode.LEGACY: + if math.isnan(obj): + return {"$numberDouble": "NaN"} + elif math.isinf(obj): + representation = "Infinity" if obj > 0 else "-Infinity" + return {"$numberDouble": representation} + elif json_options.json_mode == JSONMode.CANONICAL: + # repr() will return the shortest string guaranteed to produce the + # original value, when float() is called on it. + return {"$numberDouble": str(repr(obj))} + return obj + + +def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: + if json_options.datetime_representation == DatetimeRepresentation.ISO8601: + if not obj.tzinfo: + obj = obj.replace(tzinfo=utc) + assert obj.tzinfo is not None + if obj >= EPOCH_AWARE: + off = obj.tzinfo.utcoffset(obj) + if (off.days, off.seconds, off.microseconds) == (0, 0, 0): # type: ignore + tz_string = "Z" + else: + tz_string = obj.strftime("%z") + millis = int(obj.microsecond / 1000) + fracsecs = ".%03d" % (millis,) if millis else "" + return { + "$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) + } + + millis = _datetime_to_millis(obj) + if json_options.datetime_representation == DatetimeRepresentation.LEGACY: + return {"$date": millis} + return {"$date": {"$numberLong": str(millis)}} + + +def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: + return _encode_binary(obj, 0, json_options) + + +def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: + return _encode_binary(obj, obj.subtype, json_options) + + +def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: + if json_options.strict_uuid: + binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation) + return _encode_binary(binval, binval.subtype, json_options) + else: + return {"$uuid": obj.hex} + + +def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: + return {"$oid": str(obj)} + + +def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: + return {"$timestamp": {"t": obj.time, "i": obj.inc}} + + +def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: + return {"$numberDecimal": str(obj)} + + +def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: + return _json_convert(obj.as_doc(), json_options=json_options) + + +def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: + return {"$minKey": 1} + + +def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: + return {"$maxKey": 1} + + +# Encoders for BSON types +# Each encoder function's signature is: +# - obj: a Python data type, e.g. a Python int for _encode_int +# - json_options: a JSONOptions +_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { + bool: _encode_noop, + bytes: _encode_bytes, + datetime.datetime: _encode_datetime, + DatetimeMS: _encode_datetimems, + float: _encode_float, + int: _encode_int, + str: _encode_noop, + type(None): _encode_noop, + uuid.UUID: _encode_uuid, + Binary: _encode_binary_obj, + Int64: _encode_int64, + Code: _encode_code, + DBRef: _encode_dbref, + MaxKey: _encode_maxkey, + MinKey: _encode_minkey, + ObjectId: _encode_objectid, + Regex: _encode_regex, + RE_TYPE: _encode_regex, + Timestamp: _encode_timestamp, + Decimal128: _encode_decimal128, +} + +# Map each _type_marker to its encoder for faster lookup. +_MARKERS: dict[int, Callable[[Any, JSONOptions], Any]] = {} +for _typ in _ENCODERS: + if hasattr(_typ, "_type_marker"): + _MARKERS[_typ._type_marker] = _ENCODERS[_typ] + +_BUILT_IN_TYPES = tuple(t for t in _ENCODERS) + + +def default(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any: + # First see if the type is already cached. KeyError will only ever + # happen once per subtype. + try: + return _ENCODERS[type(obj)](obj, json_options) + except KeyError: + pass + + # Second, fall back to trying _type_marker. This has to be done + # before the loop below since users could subclass one of our + # custom types that subclasses a python built-in (e.g. Binary) + if hasattr(obj, "_type_marker"): + marker = obj._type_marker + if marker in _MARKERS: + func = _MARKERS[marker] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(obj)] = func + return func(obj, json_options) + + # Third, test each base type. This will only happen once for + # a subtype of a supported base type. + for base in _BUILT_IN_TYPES: + if isinstance(obj, base): + func = _ENCODERS[base] + # Cache this type for faster subsequent lookup. + _ENCODERS[type(obj)] = func + return func(obj, json_options) + + raise TypeError("%r is not JSON serializable" % obj) + + +def _get_str_size(obj: Any) -> int: + return len(obj) + + +def _get_datetime_size(obj: datetime.datetime) -> int: + return 5 + len(str(obj.time())) + + +def _get_regex_size(obj: Regex) -> int: + return 18 + len(obj.pattern) + + +def _get_dbref_size(obj: DBRef) -> int: + return 34 + len(obj.collection) + + +_CONSTANT_SIZE_TABLE: dict[Any, int] = { + ObjectId: 28, + int: 11, + Int64: 11, + Decimal128: 11, + Timestamp: 14, + MinKey: 8, + MaxKey: 8, +} + +_VARIABLE_SIZE_TABLE: dict[Any, Callable[[Any], int]] = { + str: _get_str_size, + bytes: _get_str_size, + datetime.datetime: _get_datetime_size, + Regex: _get_regex_size, + DBRef: _get_dbref_size, +} + + +def get_size(obj: Any, max_size: int, current_size: int = 0) -> int: + """Recursively finds size of objects""" + if current_size >= max_size: + return current_size + + obj_type = type(obj) + + # Check to see if the obj has a constant size estimate + try: + return _CONSTANT_SIZE_TABLE[obj_type] + except KeyError: + pass + + # Check to see if the obj has a variable but simple size estimate + try: + return _VARIABLE_SIZE_TABLE[obj_type](obj) + except KeyError: + pass + + # Special cases that require recursion + if obj_type == Code: + if obj.scope: + current_size += ( + 5 + get_size(obj.scope, max_size, current_size) + len(obj) - len(obj.scope) + ) + else: + current_size += 5 + len(obj) + elif obj_type == dict: + for k, v in obj.items(): + current_size += get_size(k, max_size, current_size) + current_size += get_size(v, max_size, current_size) + if current_size >= max_size: + return current_size + elif hasattr(obj, "__iter__"): + for i in obj: + current_size += get_size(i, max_size, current_size) + if current_size >= max_size: + return current_size + return current_size + + +def _truncate_documents(obj: Any, max_length: int) -> Tuple[Any, int]: + """Recursively truncate documents as needed to fit inside max_length characters.""" + if max_length <= 0: + return None, 0 + remaining = max_length + if hasattr(obj, "items"): + truncated: Any = {} + for k, v in obj.items(): + truncated_v, remaining = _truncate_documents(v, remaining) + if truncated_v: + truncated[k] = truncated_v + if remaining <= 0: + break + return truncated, remaining + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): + truncated: Any = [] # type:ignore[no-redef] + for v in obj: + truncated_v, remaining = _truncate_documents(v, remaining) + if truncated_v: + truncated.append(truncated_v) + if remaining <= 0: + break + return truncated, remaining + else: + return _truncate(obj, remaining) + + +def _truncate(obj: Any, remaining: int) -> Tuple[Any, int]: + size = get_size(obj, remaining) + + if size <= remaining: + return obj, remaining - size + else: + try: + truncated = obj[:remaining] + except TypeError: + truncated = obj + return truncated, remaining - size diff --git a/venv/Lib/site-packages/bson/max_key.py b/venv/Lib/site-packages/bson/max_key.py new file mode 100644 index 00000000..445e12f5 --- /dev/null +++ b/venv/Lib/site-packages/bson/max_key.py @@ -0,0 +1,56 @@ +# Copyright 2010-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representation for the MongoDB internal MaxKey type.""" +from __future__ import annotations + +from typing import Any + + +class MaxKey: + """MongoDB internal MaxKey type.""" + + __slots__ = () + + _type_marker = 127 + + def __getstate__(self) -> Any: + return {} + + def __setstate__(self, state: Any) -> None: + pass + + def __eq__(self, other: Any) -> bool: + return isinstance(other, MaxKey) + + def __hash__(self) -> int: + return hash(self._type_marker) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __le__(self, other: Any) -> bool: + return isinstance(other, MaxKey) + + def __lt__(self, dummy: Any) -> bool: + return False + + def __ge__(self, dummy: Any) -> bool: + return True + + def __gt__(self, other: Any) -> bool: + return not isinstance(other, MaxKey) + + def __repr__(self) -> str: + return "MaxKey()" diff --git a/venv/Lib/site-packages/bson/min_key.py b/venv/Lib/site-packages/bson/min_key.py new file mode 100644 index 00000000..37828dcf --- /dev/null +++ b/venv/Lib/site-packages/bson/min_key.py @@ -0,0 +1,56 @@ +# Copyright 2010-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representation for the MongoDB internal MinKey type.""" +from __future__ import annotations + +from typing import Any + + +class MinKey: + """MongoDB internal MinKey type.""" + + __slots__ = () + + _type_marker = 255 + + def __getstate__(self) -> Any: + return {} + + def __setstate__(self, state: Any) -> None: + pass + + def __eq__(self, other: Any) -> bool: + return isinstance(other, MinKey) + + def __hash__(self) -> int: + return hash(self._type_marker) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __le__(self, dummy: Any) -> bool: + return True + + def __lt__(self, other: Any) -> bool: + return not isinstance(other, MinKey) + + def __ge__(self, other: Any) -> bool: + return isinstance(other, MinKey) + + def __gt__(self, dummy: Any) -> bool: + return False + + def __repr__(self) -> str: + return "MinKey()" diff --git a/venv/Lib/site-packages/bson/objectid.py b/venv/Lib/site-packages/bson/objectid.py new file mode 100644 index 00000000..57efdc79 --- /dev/null +++ b/venv/Lib/site-packages/bson/objectid.py @@ -0,0 +1,278 @@ +# Copyright 2009-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with MongoDB ObjectIds.""" +from __future__ import annotations + +import binascii +import calendar +import datetime +import os +import struct +import threading +import time +from random import SystemRandom +from typing import Any, NoReturn, Optional, Type, Union + +from bson.errors import InvalidId +from bson.tz_util import utc + +_MAX_COUNTER_VALUE = 0xFFFFFF + + +def _raise_invalid_id(oid: str) -> NoReturn: + raise InvalidId( + "%r is not a valid ObjectId, it must be a 12-byte input" + " or a 24-character hex string" % oid + ) + + +def _random_bytes() -> bytes: + """Get the 5-byte random field of an ObjectId.""" + return os.urandom(5) + + +class ObjectId: + """A MongoDB ObjectId.""" + + _pid = os.getpid() + + _inc = SystemRandom().randint(0, _MAX_COUNTER_VALUE) + _inc_lock = threading.Lock() + + __random = _random_bytes() + + __slots__ = ("__id",) + + _type_marker = 7 + + def __init__(self, oid: Optional[Union[str, ObjectId, bytes]] = None) -> None: + """Initialize a new ObjectId. + + An ObjectId is a 12-byte unique identifier consisting of: + + - a 4-byte value representing the seconds since the Unix epoch, + - a 5-byte random value, + - a 3-byte counter, starting with a random value. + + By default, ``ObjectId()`` creates a new unique identifier. The + optional parameter `oid` can be an :class:`ObjectId`, or any 12 + :class:`bytes`. + + For example, the 12 bytes b'foo-bar-quux' do not follow the ObjectId + specification but they are acceptable input:: + + >>> ObjectId(b'foo-bar-quux') + ObjectId('666f6f2d6261722d71757578') + + `oid` can also be a :class:`str` of 24 hex digits:: + + >>> ObjectId('0123456789ab0123456789ab') + ObjectId('0123456789ab0123456789ab') + + Raises :class:`~bson.errors.InvalidId` if `oid` is not 12 bytes nor + 24 hex digits, or :class:`TypeError` if `oid` is not an accepted type. + + :param oid: a valid ObjectId. + + .. seealso:: The MongoDB documentation on `ObjectIds `_. + + .. versionchanged:: 3.8 + :class:`~bson.objectid.ObjectId` now implements the `ObjectID + specification version 0.2 + `_. + """ + if oid is None: + self.__generate() + elif isinstance(oid, bytes) and len(oid) == 12: + self.__id = oid + else: + self.__validate(oid) + + @classmethod + def from_datetime(cls: Type[ObjectId], generation_time: datetime.datetime) -> ObjectId: + """Create a dummy ObjectId instance with a specific generation time. + + This method is useful for doing range queries on a field + containing :class:`ObjectId` instances. + + .. warning:: + It is not safe to insert a document containing an ObjectId + generated using this method. This method deliberately + eliminates the uniqueness guarantee that ObjectIds + generally provide. ObjectIds generated with this method + should be used exclusively in queries. + + `generation_time` will be converted to UTC. Naive datetime + instances will be treated as though they already contain UTC. + + An example using this helper to get documents where ``"_id"`` + was generated before January 1, 2010 would be: + + >>> gen_time = datetime.datetime(2010, 1, 1) + >>> dummy_id = ObjectId.from_datetime(gen_time) + >>> result = collection.find({"_id": {"$lt": dummy_id}}) + + :param generation_time: :class:`~datetime.datetime` to be used + as the generation time for the resulting ObjectId. + """ + offset = generation_time.utcoffset() + if offset is not None: + generation_time = generation_time - offset + timestamp = calendar.timegm(generation_time.timetuple()) + oid = struct.pack(">I", int(timestamp)) + b"\x00\x00\x00\x00\x00\x00\x00\x00" + return cls(oid) + + @classmethod + def is_valid(cls: Type[ObjectId], oid: Any) -> bool: + """Checks if a `oid` string is valid or not. + + :param oid: the object id to validate + + .. versionadded:: 2.3 + """ + if not oid: + return False + + try: + ObjectId(oid) + return True + except (InvalidId, TypeError): + return False + + @classmethod + def _random(cls) -> bytes: + """Generate a 5-byte random number once per process.""" + pid = os.getpid() + if pid != cls._pid: + cls._pid = pid + cls.__random = _random_bytes() + return cls.__random + + def __generate(self) -> None: + """Generate a new value for this ObjectId.""" + # 4 bytes current time + oid = struct.pack(">I", int(time.time())) + + # 5 bytes random + oid += ObjectId._random() + + # 3 bytes inc + with ObjectId._inc_lock: + oid += struct.pack(">I", ObjectId._inc)[1:4] + ObjectId._inc = (ObjectId._inc + 1) % (_MAX_COUNTER_VALUE + 1) + + self.__id = oid + + def __validate(self, oid: Any) -> None: + """Validate and use the given id for this ObjectId. + + Raises TypeError if id is not an instance of :class:`str`, + :class:`bytes`, or ObjectId. Raises InvalidId if it is not a + valid ObjectId. + + :param oid: a valid ObjectId + """ + if isinstance(oid, ObjectId): + self.__id = oid.binary + elif isinstance(oid, str): + if len(oid) == 24: + try: + self.__id = bytes.fromhex(oid) + except (TypeError, ValueError): + _raise_invalid_id(oid) + else: + _raise_invalid_id(oid) + else: + raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}") + + @property + def binary(self) -> bytes: + """12-byte binary representation of this ObjectId.""" + return self.__id + + @property + def generation_time(self) -> datetime.datetime: + """A :class:`datetime.datetime` instance representing the time of + generation for this :class:`ObjectId`. + + The :class:`datetime.datetime` is timezone aware, and + represents the generation time in UTC. It is precise to the + second. + """ + timestamp = struct.unpack(">I", self.__id[0:4])[0] + return datetime.datetime.fromtimestamp(timestamp, utc) + + def __getstate__(self) -> bytes: + """Return value of object for pickling. + needed explicitly because __slots__() defined. + """ + return self.__id + + def __setstate__(self, value: Any) -> None: + """Explicit state set from pickling""" + # Provide backwards compatibility with OIDs + # pickled with pymongo-1.9 or older. + if isinstance(value, dict): + oid = value["_ObjectId__id"] + else: + oid = value + # ObjectIds pickled in python 2.x used `str` for __id. + # In python 3.x this has to be converted to `bytes` + # by encoding latin-1. + if isinstance(oid, str): + self.__id = oid.encode("latin-1") + else: + self.__id = oid + + def __str__(self) -> str: + return binascii.hexlify(self.__id).decode() + + def __repr__(self) -> str: + return f"ObjectId('{self!s}')" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id == other.binary + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id != other.binary + return NotImplemented + + def __lt__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id < other.binary + return NotImplemented + + def __le__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id <= other.binary + return NotImplemented + + def __gt__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id > other.binary + return NotImplemented + + def __ge__(self, other: Any) -> bool: + if isinstance(other, ObjectId): + return self.__id >= other.binary + return NotImplemented + + def __hash__(self) -> int: + """Get a hash value for this :class:`ObjectId`.""" + return hash(self.__id) diff --git a/venv/Lib/site-packages/bson/py.typed b/venv/Lib/site-packages/bson/py.typed new file mode 100644 index 00000000..0f405706 --- /dev/null +++ b/venv/Lib/site-packages/bson/py.typed @@ -0,0 +1,2 @@ +# PEP-561 Support File. +# "Package maintainers who wish to support type checking of their code MUST add a marker file named py.typed to their package supporting typing". diff --git a/venv/Lib/site-packages/bson/raw_bson.py b/venv/Lib/site-packages/bson/raw_bson.py new file mode 100644 index 00000000..2ce53143 --- /dev/null +++ b/venv/Lib/site-packages/bson/raw_bson.py @@ -0,0 +1,196 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing raw BSON documents. + +Inserting and Retrieving RawBSONDocuments +========================================= + +Example: Moving a document between different databases/collections + +.. doctest:: + + >>> import bson + >>> from pymongo import MongoClient + >>> from bson.raw_bson import RawBSONDocument + >>> client = MongoClient(document_class=RawBSONDocument) + >>> client.drop_database("db") + >>> client.drop_database("replica_db") + >>> db = client.db + >>> result = db.test.insert_many( + ... [{"_id": 1, "a": 1}, {"_id": 2, "b": 1}, {"_id": 3, "c": 1}, {"_id": 4, "d": 1}] + ... ) + >>> replica_db = client.replica_db + >>> for doc in db.test.find(): + ... print(f"raw document: {doc.raw}") + ... print(f"decoded document: {bson.decode(doc.raw)}") + ... result = replica_db.test.insert_one(doc) + ... + raw document: b'...' + decoded document: {'_id': 1, 'a': 1} + raw document: b'...' + decoded document: {'_id': 2, 'b': 1} + raw document: b'...' + decoded document: {'_id': 3, 'c': 1} + raw document: b'...' + decoded document: {'_id': 4, 'd': 1} + +For use cases like moving documents across different databases or writing binary +blobs to disk, using raw BSON documents provides better speed and avoids the +overhead of decoding or encoding BSON. +""" +from __future__ import annotations + +from typing import Any, ItemsView, Iterator, Mapping, Optional + +from bson import _get_object_size, _raw_to_dict +from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER, CodecOptions +from bson.codec_options import DEFAULT_CODEC_OPTIONS as DEFAULT + + +def _inflate_bson( + bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument], raw_array: bool = False +) -> dict[str, Any]: + """Inflates the top level fields of a BSON document. + + :param bson_bytes: the BSON bytes that compose this document + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions` whose ``document_class`` + must be :class:`RawBSONDocument`. + """ + return _raw_to_dict(bson_bytes, 4, len(bson_bytes) - 1, codec_options, {}, raw_array=raw_array) + + +class RawBSONDocument(Mapping[str, Any]): + """Representation for a MongoDB document that provides access to the raw + BSON bytes that compose it. + + Only when a field is accessed or modified within the document does + RawBSONDocument decode its bytes. + """ + + __slots__ = ("__raw", "__inflated_doc", "__codec_options") + _type_marker = _RAW_BSON_DOCUMENT_MARKER + __codec_options: CodecOptions[RawBSONDocument] + + def __init__( + self, bson_bytes: bytes, codec_options: Optional[CodecOptions[RawBSONDocument]] = None + ) -> None: + """Create a new :class:`RawBSONDocument` + + :class:`RawBSONDocument` is a representation of a BSON document that + provides access to the underlying raw BSON bytes. Only when a field is + accessed or modified within the document does RawBSONDocument decode + its bytes. + + :class:`RawBSONDocument` implements the ``Mapping`` abstract base + class from the standard library so it can be used like a read-only + ``dict``:: + + >>> from bson import encode + >>> raw_doc = RawBSONDocument(encode({'_id': 'my_doc'})) + >>> raw_doc.raw + b'...' + >>> raw_doc['_id'] + 'my_doc' + + :param bson_bytes: the BSON bytes that compose this document + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions` whose ``document_class`` + must be :class:`RawBSONDocument`. The default is + :attr:`DEFAULT_RAW_BSON_OPTIONS`. + + .. versionchanged:: 3.8 + :class:`RawBSONDocument` now validates that the ``bson_bytes`` + passed in represent a single bson document. + + .. versionchanged:: 3.5 + If a :class:`~bson.codec_options.CodecOptions` is passed in, its + `document_class` must be :class:`RawBSONDocument`. + """ + self.__raw = bson_bytes + self.__inflated_doc: Optional[Mapping[str, Any]] = None + # Can't default codec_options to DEFAULT_RAW_BSON_OPTIONS in signature, + # it refers to this class RawBSONDocument. + if codec_options is None: + codec_options = DEFAULT_RAW_BSON_OPTIONS + elif not issubclass(codec_options.document_class, RawBSONDocument): + raise TypeError( + "RawBSONDocument cannot use CodecOptions with document " + f"class {codec_options.document_class}" + ) + self.__codec_options = codec_options + # Validate the bson object size. + _get_object_size(bson_bytes, 0, len(bson_bytes)) + + @property + def raw(self) -> bytes: + """The raw BSON bytes composing this document.""" + return self.__raw + + def items(self) -> ItemsView[str, Any]: + """Lazily decode and iterate elements in this document.""" + return self.__inflated.items() + + @property + def __inflated(self) -> Mapping[str, Any]: + if self.__inflated_doc is None: + # We already validated the object's size when this document was + # created, so no need to do that again. + self.__inflated_doc = self._inflate_bson(self.__raw, self.__codec_options) + return self.__inflated_doc + + @staticmethod + def _inflate_bson( + bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument] + ) -> Mapping[str, Any]: + return _inflate_bson(bson_bytes, codec_options) + + def __getitem__(self, item: str) -> Any: + return self.__inflated[item] + + def __iter__(self) -> Iterator[str]: + return iter(self.__inflated) + + def __len__(self) -> int: + return len(self.__inflated) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RawBSONDocument): + return self.__raw == other.raw + return NotImplemented + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.raw!r}, codec_options={self.__codec_options!r})" + + +class _RawArrayBSONDocument(RawBSONDocument): + """A RawBSONDocument that only expands sub-documents and arrays when accessed.""" + + @staticmethod + def _inflate_bson( + bson_bytes: bytes, codec_options: CodecOptions[RawBSONDocument] + ) -> Mapping[str, Any]: + return _inflate_bson(bson_bytes, codec_options, raw_array=True) + + +DEFAULT_RAW_BSON_OPTIONS: CodecOptions[RawBSONDocument] = DEFAULT.with_options( + document_class=RawBSONDocument +) +_RAW_ARRAY_BSON_OPTIONS: CodecOptions[_RawArrayBSONDocument] = DEFAULT.with_options( + document_class=_RawArrayBSONDocument +) +"""The default :class:`~bson.codec_options.CodecOptions` for +:class:`RawBSONDocument`. +""" diff --git a/venv/Lib/site-packages/bson/regex.py b/venv/Lib/site-packages/bson/regex.py new file mode 100644 index 00000000..60cff4fd --- /dev/null +++ b/venv/Lib/site-packages/bson/regex.py @@ -0,0 +1,133 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing MongoDB regular expressions.""" +from __future__ import annotations + +import re +from typing import Any, Generic, Pattern, Type, TypeVar, Union + +from bson._helpers import _getstate_slots, _setstate_slots +from bson.son import RE_TYPE + + +def str_flags_to_int(str_flags: str) -> int: + flags = 0 + if "i" in str_flags: + flags |= re.IGNORECASE + if "l" in str_flags: + flags |= re.LOCALE + if "m" in str_flags: + flags |= re.MULTILINE + if "s" in str_flags: + flags |= re.DOTALL + if "u" in str_flags: + flags |= re.UNICODE + if "x" in str_flags: + flags |= re.VERBOSE + + return flags + + +_T = TypeVar("_T", str, bytes) + + +class Regex(Generic[_T]): + """BSON regular expression data.""" + + __slots__ = ("pattern", "flags") + + __getstate__ = _getstate_slots + __setstate__ = _setstate_slots + + _type_marker = 11 + + @classmethod + def from_native(cls: Type[Regex[Any]], regex: Pattern[_T]) -> Regex[_T]: + """Convert a Python regular expression into a ``Regex`` instance. + + Note that in Python 3, a regular expression compiled from a + :class:`str` has the ``re.UNICODE`` flag set. If it is undesirable + to store this flag in a BSON regular expression, unset it first:: + + >>> pattern = re.compile('.*') + >>> regex = Regex.from_native(pattern) + >>> regex.flags ^= re.UNICODE + >>> db.collection.insert_one({'pattern': regex}) + + :param regex: A regular expression object from ``re.compile()``. + + .. warning:: + Python regular expressions use a different syntax and different + set of flags than MongoDB, which uses `PCRE`_. A regular + expression retrieved from the server may not compile in + Python, or may match a different set of strings in Python than + when used in a MongoDB query. + + .. _PCRE: http://www.pcre.org/ + """ + if not isinstance(regex, RE_TYPE): + raise TypeError("regex must be a compiled regular expression, not %s" % type(regex)) + + return Regex(regex.pattern, regex.flags) + + def __init__(self, pattern: _T, flags: Union[str, int] = 0) -> None: + """BSON regular expression data. + + This class is useful to store and retrieve regular expressions that are + incompatible with Python's regular expression dialect. + + :param pattern: string + :param flags: an integer bitmask, or a string of flag + characters like "im" for IGNORECASE and MULTILINE + """ + if not isinstance(pattern, (str, bytes)): + raise TypeError("pattern must be a string, not %s" % type(pattern)) + self.pattern: _T = pattern + + if isinstance(flags, str): + self.flags = str_flags_to_int(flags) + elif isinstance(flags, int): + self.flags = flags + else: + raise TypeError("flags must be a string or int, not %s" % type(flags)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Regex): + return self.pattern == other.pattern and self.flags == other.flags + else: + return NotImplemented + + __hash__ = None # type: ignore + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + return f"Regex({self.pattern!r}, {self.flags!r})" + + def try_compile(self) -> Pattern[_T]: + """Compile this :class:`Regex` as a Python regular expression. + + .. warning:: + Python regular expressions use a different syntax and different + set of flags than MongoDB, which uses `PCRE`_. A regular + expression retrieved from the server may not compile in + Python, or may match a different set of strings in Python than + when used in a MongoDB query. :meth:`try_compile()` may raise + :exc:`re.error`. + + .. _PCRE: http://www.pcre.org/ + """ + return re.compile(self.pattern, self.flags) diff --git a/venv/Lib/site-packages/bson/son.py b/venv/Lib/site-packages/bson/son.py new file mode 100644 index 00000000..cf627172 --- /dev/null +++ b/venv/Lib/site-packages/bson/son.py @@ -0,0 +1,211 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for creating and manipulating SON, the Serialized Ocument Notation. + +Regular dictionaries can be used instead of SON objects, but not when the order +of keys is important. A SON object can be used just like a normal Python +dictionary. +""" +from __future__ import annotations + +import copy +import re +from collections.abc import Mapping as _Mapping +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Pattern, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +# This sort of sucks, but seems to be as good as it gets... +# This is essentially the same as re._pattern_type +RE_TYPE: Type[Pattern[Any]] = type(re.compile("")) + +_Key = TypeVar("_Key") +_Value = TypeVar("_Value") +_T = TypeVar("_T") + + +class SON(Dict[_Key, _Value]): + """SON data. + + A subclass of dict that maintains ordering of keys and provides a + few extra niceties for dealing with SON. SON provides an API + similar to collections.OrderedDict. + """ + + __keys: list[Any] + + def __init__( + self, + data: Optional[Union[Mapping[_Key, _Value], Iterable[Tuple[_Key, _Value]]]] = None, + **kwargs: Any, + ) -> None: + self.__keys = [] + dict.__init__(self) + self.update(data) + self.update(kwargs) + + def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]: + instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var] + instance.__keys = [] + return instance + + def __repr__(self) -> str: + result = [] + for key in self.__keys: + result.append(f"({key!r}, {self[key]!r})") + return "SON([%s])" % ", ".join(result) + + def __setitem__(self, key: _Key, value: _Value) -> None: + if key not in self.__keys: + self.__keys.append(key) + dict.__setitem__(self, key, value) + + def __delitem__(self, key: _Key) -> None: + self.__keys.remove(key) + dict.__delitem__(self, key) + + def copy(self) -> SON[_Key, _Value]: + other: SON[_Key, _Value] = SON() + other.update(self) + return other + + # TODO this is all from UserDict.DictMixin. it could probably be made more + # efficient. + # second level definitions support higher levels + def __iter__(self) -> Iterator[_Key]: + yield from self.__keys + + def has_key(self, key: _Key) -> bool: + return key in self.__keys + + def iterkeys(self) -> Iterator[_Key]: + return self.__iter__() + + # fourth level uses definitions from lower levels + def itervalues(self) -> Iterator[_Value]: + for _, v in self.items(): + yield v + + def values(self) -> list[_Value]: # type: ignore[override] + return [v for _, v in self.items()] + + def clear(self) -> None: + self.__keys = [] + super().clear() + + def setdefault(self, key: _Key, default: _Value) -> _Value: + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key: _Key, *args: Union[_Value, _T]) -> Union[_Value, _T]: + if len(args) > 1: + raise TypeError("pop expected at most 2 arguments, got " + repr(1 + len(args))) + try: + value = self[key] + except KeyError: + if args: + return args[0] + raise + del self[key] + return value + + def popitem(self) -> Tuple[_Key, _Value]: + try: + k, v = next(iter(self.items())) + except StopIteration: + raise KeyError("container is empty") from None + del self[k] + return (k, v) + + def update(self, other: Optional[Any] = None, **kwargs: _Value) -> None: # type: ignore[override] + # Make progressively weaker assumptions about "other" + if other is None: + pass + elif hasattr(other, "items"): + for k, v in other.items(): + self[k] = v + elif hasattr(other, "keys"): + for k in other: + self[k] = other[k] + else: + for k, v in other: + self[k] = v + if kwargs: + self.update(kwargs) + + def get( # type: ignore[override] + self, key: _Key, default: Optional[Union[_Value, _T]] = None + ) -> Union[_Value, _T, None]: + try: + return self[key] + except KeyError: + return default + + def __eq__(self, other: Any) -> bool: + """Comparison to another SON is order-sensitive while comparison to a + regular dictionary is order-insensitive. + """ + if isinstance(other, SON): + return len(self) == len(other) and list(self.items()) == list(other.items()) + return cast(bool, self.to_dict() == other) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __len__(self) -> int: + return len(self.__keys) + + def to_dict(self) -> dict[_Key, _Value]: + """Convert a SON document to a normal Python dictionary instance. + + This is trickier than just *dict(...)* because it needs to be + recursive. + """ + + def transform_value(value: Any) -> Any: + if isinstance(value, list): + return [transform_value(v) for v in value] + elif isinstance(value, _Mapping): + return {k: transform_value(v) for k, v in value.items()} + else: + return value + + return cast("dict[_Key, _Value]", transform_value(dict(self))) + + def __deepcopy__(self, memo: dict[int, SON[_Key, _Value]]) -> SON[_Key, _Value]: + out: SON[_Key, _Value] = SON() + val_id = id(self) + if val_id in memo: + return memo[val_id] + memo[val_id] = out + for k, v in self.items(): + if not isinstance(v, RE_TYPE): + v = copy.deepcopy(v, memo) # noqa: PLW2901 + out[k] = v + return out diff --git a/venv/Lib/site-packages/bson/time64.c b/venv/Lib/site-packages/bson/time64.c new file mode 100644 index 00000000..a21fbb90 --- /dev/null +++ b/venv/Lib/site-packages/bson/time64.c @@ -0,0 +1,781 @@ +/* + +Copyright (c) 2007-2010 Michael G Schwern + +This software originally derived from Paul Sheer's pivotal_gmtime_r.c. + +The MIT License: + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +*/ + +/* + +Programmers who have available to them 64-bit time values as a 'long +long' type can use cbson_localtime64_r() and cbson_gmtime64_r() which correctly +converts the time even on 32-bit systems. Whether you have 64-bit time +values will depend on the operating system. + +cbson_localtime64_r() is a 64-bit equivalent of localtime_r(). + +cbson_gmtime64_r() is a 64-bit equivalent of gmtime_r(). + +*/ + +#ifdef _MSC_VER + #define _CRT_SECURE_NO_WARNINGS +#endif + +/* Including Python.h fixes issues with interpreters built with -std=c99. */ +#define PY_SSIZE_T_CLEAN +#include "Python.h" + +#include +#include "time64.h" +#include "time64_limits.h" + + +/* Spec says except for stftime() and the _r() functions, these + all return static memory. Stabbings! */ +static struct TM Static_Return_Date; + +static const int days_in_month[2][12] = { + {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, + {31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, +}; + +static const int julian_days_by_month[2][12] = { + {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334}, + {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335}, +}; + +static const int length_of_year[2] = { 365, 366 }; + +/* Some numbers relating to the gregorian cycle */ +static const Year years_in_gregorian_cycle = 400; +#define days_in_gregorian_cycle ((365 * 400) + 100 - 4 + 1) +static const Time64_T seconds_in_gregorian_cycle = days_in_gregorian_cycle * 60LL * 60LL * 24LL; + +/* Year range we can trust the time functions with */ +#define MAX_SAFE_YEAR 2037 +#define MIN_SAFE_YEAR 1971 + +/* 28 year Julian calendar cycle */ +#define SOLAR_CYCLE_LENGTH 28 + +/* Year cycle from MAX_SAFE_YEAR down. */ +static const int safe_years_high[SOLAR_CYCLE_LENGTH] = { + 2016, 2017, 2018, 2019, + 2020, 2021, 2022, 2023, + 2024, 2025, 2026, 2027, + 2028, 2029, 2030, 2031, + 2032, 2033, 2034, 2035, + 2036, 2037, 2010, 2011, + 2012, 2013, 2014, 2015 +}; + +/* Year cycle from MIN_SAFE_YEAR up */ +static const int safe_years_low[SOLAR_CYCLE_LENGTH] = { + 1996, 1997, 1998, 1971, + 1972, 1973, 1974, 1975, + 1976, 1977, 1978, 1979, + 1980, 1981, 1982, 1983, + 1984, 1985, 1986, 1987, + 1988, 1989, 1990, 1991, + 1992, 1993, 1994, 1995, +}; + +/* Let's assume people are going to be looking for dates in the future. + Let's provide some cheats so you can skip ahead. + This has a 4x speed boost when near 2008. +*/ +/* Number of days since epoch on Jan 1st, 2008 GMT */ +#define CHEAT_DAYS (1199145600 / 24 / 60 / 60) +#define CHEAT_YEARS 108 + +#define IS_LEAP(n) ((!(((n) + 1900) % 400) || (!(((n) + 1900) % 4) && (((n) + 1900) % 100))) != 0) +#define _TIME64_WRAP(a,b,m) ((a) = ((a) < 0 ) ? ((b)--, (a) + (m)) : (a)) + +#ifdef USE_SYSTEM_LOCALTIME +# define SHOULD_USE_SYSTEM_LOCALTIME(a) ( \ + (a) <= SYSTEM_LOCALTIME_MAX && \ + (a) >= SYSTEM_LOCALTIME_MIN \ +) +#else +# define SHOULD_USE_SYSTEM_LOCALTIME(a) (0) +#endif + +#ifdef USE_SYSTEM_GMTIME +# define SHOULD_USE_SYSTEM_GMTIME(a) ( \ + (a) <= SYSTEM_GMTIME_MAX && \ + (a) >= SYSTEM_GMTIME_MIN \ +) +#else +# define SHOULD_USE_SYSTEM_GMTIME(a) (0) +#endif + +/* Multi varadic macros are a C99 thing, alas */ +#ifdef TIME_64_DEBUG +# define TIME64_TRACE(format) (fprintf(stderr, format)) +# define TIME64_TRACE1(format, var1) (fprintf(stderr, format, var1)) +# define TIME64_TRACE2(format, var1, var2) (fprintf(stderr, format, var1, var2)) +# define TIME64_TRACE3(format, var1, var2, var3) (fprintf(stderr, format, var1, var2, var3)) +#else +# define TIME64_TRACE(format) ((void)0) +# define TIME64_TRACE1(format, var1) ((void)0) +# define TIME64_TRACE2(format, var1, var2) ((void)0) +# define TIME64_TRACE3(format, var1, var2, var3) ((void)0) +#endif + + +static int is_exception_century(Year year) +{ + int is_exception = ((year % 100 == 0) && !(year % 400 == 0)); + TIME64_TRACE1("# is_exception_century: %s\n", is_exception ? "yes" : "no"); + + return(is_exception); +} + + +/* Compare two dates. + The result is like cmp. + Ignores things like gmtoffset and dst +*/ +int cbson_cmp_date( const struct TM* left, const struct tm* right ) { + if( left->tm_year > right->tm_year ) + return 1; + else if( left->tm_year < right->tm_year ) + return -1; + + if( left->tm_mon > right->tm_mon ) + return 1; + else if( left->tm_mon < right->tm_mon ) + return -1; + + if( left->tm_mday > right->tm_mday ) + return 1; + else if( left->tm_mday < right->tm_mday ) + return -1; + + if( left->tm_hour > right->tm_hour ) + return 1; + else if( left->tm_hour < right->tm_hour ) + return -1; + + if( left->tm_min > right->tm_min ) + return 1; + else if( left->tm_min < right->tm_min ) + return -1; + + if( left->tm_sec > right->tm_sec ) + return 1; + else if( left->tm_sec < right->tm_sec ) + return -1; + + return 0; +} + + +/* Check if a date is safely inside a range. + The intention is to check if its a few days inside. +*/ +int cbson_date_in_safe_range( const struct TM* date, const struct tm* min, const struct tm* max ) { + if( cbson_cmp_date(date, min) == -1 ) + return 0; + + if( cbson_cmp_date(date, max) == 1 ) + return 0; + + return 1; +} + + +/* timegm() is not in the C or POSIX spec, but it is such a useful + extension I would be remiss in leaving it out. Also I need it + for cbson_localtime64() +*/ +Time64_T cbson_timegm64(const struct TM *date) { + Time64_T days = 0; + Time64_T seconds = 0; + Year year; + Year orig_year = (Year)date->tm_year; + int cycles = 0; + + if( orig_year > 100 ) { + cycles = (int)((orig_year - 100) / 400); + orig_year -= cycles * 400; + days += (Time64_T)cycles * days_in_gregorian_cycle; + } + else if( orig_year < -300 ) { + cycles = (int)((orig_year - 100) / 400); + orig_year -= cycles * 400; + days += (Time64_T)cycles * days_in_gregorian_cycle; + } + TIME64_TRACE3("# timegm/ cycles: %d, days: %lld, orig_year: %lld\n", cycles, days, orig_year); + + if( orig_year > 70 ) { + year = 70; + while( year < orig_year ) { + days += length_of_year[IS_LEAP(year)]; + year++; + } + } + else if ( orig_year < 70 ) { + year = 69; + do { + days -= length_of_year[IS_LEAP(year)]; + year--; + } while( year >= orig_year ); + } + + days += julian_days_by_month[IS_LEAP(orig_year)][date->tm_mon]; + days += date->tm_mday - 1; + + seconds = days * 60 * 60 * 24; + + seconds += date->tm_hour * 60 * 60; + seconds += date->tm_min * 60; + seconds += date->tm_sec; + + return(seconds); +} + + +#ifndef NDEBUG +static int check_tm(struct TM *tm) +{ + /* Don't forget leap seconds */ + assert(tm->tm_sec >= 0); + assert(tm->tm_sec <= 61); + + assert(tm->tm_min >= 0); + assert(tm->tm_min <= 59); + + assert(tm->tm_hour >= 0); + assert(tm->tm_hour <= 23); + + assert(tm->tm_mday >= 1); + assert(tm->tm_mday <= days_in_month[IS_LEAP(tm->tm_year)][tm->tm_mon]); + + assert(tm->tm_mon >= 0); + assert(tm->tm_mon <= 11); + + assert(tm->tm_wday >= 0); + assert(tm->tm_wday <= 6); + + assert(tm->tm_yday >= 0); + assert(tm->tm_yday <= length_of_year[IS_LEAP(tm->tm_year)]); + +#ifdef HAS_TM_TM_GMTOFF + assert(tm->tm_gmtoff >= -24 * 60 * 60); + assert(tm->tm_gmtoff <= 24 * 60 * 60); +#endif + + return 1; +} +#endif + + +/* The exceptional centuries without leap years cause the cycle to + shift by 16 +*/ +static Year cycle_offset(Year year) +{ + const Year start_year = 2000; + Year year_diff = year - start_year; + Year exceptions; + + if( year > start_year ) + year_diff--; + + exceptions = year_diff / 100; + exceptions -= year_diff / 400; + + TIME64_TRACE3("# year: %lld, exceptions: %lld, year_diff: %lld\n", + year, exceptions, year_diff); + + return exceptions * 16; +} + +/* For a given year after 2038, pick the latest possible matching + year in the 28 year calendar cycle. + + A matching year... + 1) Starts on the same day of the week. + 2) Has the same leap year status. + + This is so the calendars match up. + + Also the previous year must match. When doing Jan 1st you might + wind up on Dec 31st the previous year when doing a -UTC time zone. + + Finally, the next year must have the same start day of week. This + is for Dec 31st with a +UTC time zone. + It doesn't need the same leap year status since we only care about + January 1st. +*/ +static int safe_year(const Year year) +{ + int safe_year = 0; + Year year_cycle; + + if( year >= MIN_SAFE_YEAR && year <= MAX_SAFE_YEAR ) { + return (int)year; + } + + year_cycle = year + cycle_offset(year); + + /* safe_years_low is off from safe_years_high by 8 years */ + if( year < MIN_SAFE_YEAR ) + year_cycle -= 8; + + /* Change non-leap xx00 years to an equivalent */ + if( is_exception_century(year) ) + year_cycle += 11; + + /* Also xx01 years, since the previous year will be wrong */ + if( is_exception_century(year - 1) ) + year_cycle += 17; + + year_cycle %= SOLAR_CYCLE_LENGTH; + if( year_cycle < 0 ) + year_cycle = SOLAR_CYCLE_LENGTH + year_cycle; + + assert( year_cycle >= 0 ); + assert( year_cycle < SOLAR_CYCLE_LENGTH ); + if( year < MIN_SAFE_YEAR ) + safe_year = safe_years_low[year_cycle]; + else if( year > MAX_SAFE_YEAR ) + safe_year = safe_years_high[year_cycle]; + else + assert(0); + + TIME64_TRACE3("# year: %lld, year_cycle: %lld, safe_year: %d\n", + year, year_cycle, safe_year); + + assert(safe_year <= MAX_SAFE_YEAR && safe_year >= MIN_SAFE_YEAR); + + return safe_year; +} + + +void pymongo_copy_tm_to_TM64(const struct tm *src, struct TM *dest) { + if( src == NULL ) { + memset(dest, 0, sizeof(*dest)); + } + else { +# ifdef USE_TM64 + dest->tm_sec = src->tm_sec; + dest->tm_min = src->tm_min; + dest->tm_hour = src->tm_hour; + dest->tm_mday = src->tm_mday; + dest->tm_mon = src->tm_mon; + dest->tm_year = (Year)src->tm_year; + dest->tm_wday = src->tm_wday; + dest->tm_yday = src->tm_yday; + dest->tm_isdst = src->tm_isdst; + +# ifdef HAS_TM_TM_GMTOFF + dest->tm_gmtoff = src->tm_gmtoff; +# endif + +# ifdef HAS_TM_TM_ZONE + dest->tm_zone = src->tm_zone; +# endif + +# else + /* They're the same type */ + memcpy(dest, src, sizeof(*dest)); +# endif + } +} + + +void cbson_copy_TM64_to_tm(const struct TM *src, struct tm *dest) { + if( src == NULL ) { + memset(dest, 0, sizeof(*dest)); + } + else { +# ifdef USE_TM64 + dest->tm_sec = src->tm_sec; + dest->tm_min = src->tm_min; + dest->tm_hour = src->tm_hour; + dest->tm_mday = src->tm_mday; + dest->tm_mon = src->tm_mon; + dest->tm_year = (int)src->tm_year; + dest->tm_wday = src->tm_wday; + dest->tm_yday = src->tm_yday; + dest->tm_isdst = src->tm_isdst; + +# ifdef HAS_TM_TM_GMTOFF + dest->tm_gmtoff = src->tm_gmtoff; +# endif + +# ifdef HAS_TM_TM_ZONE + dest->tm_zone = src->tm_zone; +# endif + +# else + /* They're the same type */ + memcpy(dest, src, sizeof(*dest)); +# endif + } +} + + +/* Simulate localtime_r() to the best of our ability */ +struct tm * cbson_fake_localtime_r(const time_t *time, struct tm *result) { + const struct tm *static_result = localtime(time); + + assert(result != NULL); + + if( static_result == NULL ) { + memset(result, 0, sizeof(*result)); + return NULL; + } + else { + memcpy(result, static_result, sizeof(*result)); + return result; + } +} + + +/* Simulate gmtime_r() to the best of our ability */ +struct tm * cbson_fake_gmtime_r(const time_t *time, struct tm *result) { + const struct tm *static_result = gmtime(time); + + assert(result != NULL); + + if( static_result == NULL ) { + memset(result, 0, sizeof(*result)); + return NULL; + } + else { + memcpy(result, static_result, sizeof(*result)); + return result; + } +} + + +static Time64_T seconds_between_years(Year left_year, Year right_year) { + int increment = (left_year > right_year) ? 1 : -1; + Time64_T seconds = 0; + int cycles; + + if( left_year > 2400 ) { + cycles = (int)((left_year - 2400) / 400); + left_year -= cycles * 400; + seconds += cycles * seconds_in_gregorian_cycle; + } + else if( left_year < 1600 ) { + cycles = (int)((left_year - 1600) / 400); + left_year += cycles * 400; + seconds += cycles * seconds_in_gregorian_cycle; + } + + while( left_year != right_year ) { + seconds += length_of_year[IS_LEAP(right_year - 1900)] * 60 * 60 * 24; + right_year += increment; + } + + return seconds * increment; +} + + +Time64_T cbson_mktime64(const struct TM *input_date) { + struct tm safe_date; + struct TM date; + Time64_T time; + Year year = input_date->tm_year + 1900; + + if( cbson_date_in_safe_range(input_date, &SYSTEM_MKTIME_MIN, &SYSTEM_MKTIME_MAX) ) + { + cbson_copy_TM64_to_tm(input_date, &safe_date); + return (Time64_T)mktime(&safe_date); + } + + /* Have to make the year safe in date else it won't fit in safe_date */ + date = *input_date; + date.tm_year = safe_year(year) - 1900; + cbson_copy_TM64_to_tm(&date, &safe_date); + + time = (Time64_T)mktime(&safe_date); + + time += seconds_between_years(year, (Year)(safe_date.tm_year + 1900)); + + return time; +} + + +/* Because I think mktime() is a crappy name */ +Time64_T timelocal64(const struct TM *date) { + return cbson_mktime64(date); +} + + +struct TM *cbson_gmtime64_r (const Time64_T *in_time, struct TM *p) +{ + int v_tm_sec, v_tm_min, v_tm_hour, v_tm_mon, v_tm_wday; + Time64_T v_tm_tday; + int leap; + Time64_T m; + Time64_T time = *in_time; + Year year = 70; + int cycles = 0; + + assert(p != NULL); + +#ifdef USE_SYSTEM_GMTIME + /* Use the system gmtime() if time_t is small enough */ + if( SHOULD_USE_SYSTEM_GMTIME(*in_time) ) { + time_t safe_time = (time_t)*in_time; + struct tm safe_date; + GMTIME_R(&safe_time, &safe_date); + + pymongo_copy_tm_to_TM64(&safe_date, p); + assert(check_tm(p)); + + return p; + } +#endif + +#ifdef HAS_TM_TM_GMTOFF + p->tm_gmtoff = 0; +#endif + p->tm_isdst = 0; + +#ifdef HAS_TM_TM_ZONE + p->tm_zone = "UTC"; +#endif + + v_tm_sec = (int)(time % 60); + time /= 60; + v_tm_min = (int)(time % 60); + time /= 60; + v_tm_hour = (int)(time % 24); + time /= 24; + v_tm_tday = time; + + _TIME64_WRAP (v_tm_sec, v_tm_min, 60); + _TIME64_WRAP (v_tm_min, v_tm_hour, 60); + _TIME64_WRAP (v_tm_hour, v_tm_tday, 24); + + v_tm_wday = (int)((v_tm_tday + 4) % 7); + if (v_tm_wday < 0) + v_tm_wday += 7; + m = v_tm_tday; + + if (m >= CHEAT_DAYS) { + year = CHEAT_YEARS; + m -= CHEAT_DAYS; + } + + if (m >= 0) { + /* Gregorian cycles, this is huge optimization for distant times */ + cycles = (int)(m / (Time64_T) days_in_gregorian_cycle); + if( cycles ) { + m -= (cycles * (Time64_T) days_in_gregorian_cycle); + year += (cycles * years_in_gregorian_cycle); + } + + /* Years */ + leap = IS_LEAP (year); + while (m >= (Time64_T) length_of_year[leap]) { + m -= (Time64_T) length_of_year[leap]; + year++; + leap = IS_LEAP (year); + } + + /* Months */ + v_tm_mon = 0; + while (m >= (Time64_T) days_in_month[leap][v_tm_mon]) { + m -= (Time64_T) days_in_month[leap][v_tm_mon]; + v_tm_mon++; + } + } else { + year--; + + /* Gregorian cycles */ + cycles = (int)((m / (Time64_T) days_in_gregorian_cycle) + 1); + if( cycles ) { + m -= (cycles * (Time64_T) days_in_gregorian_cycle); + year += (cycles * years_in_gregorian_cycle); + } + + /* Years */ + leap = IS_LEAP (year); + while (m < (Time64_T) -length_of_year[leap]) { + m += (Time64_T) length_of_year[leap]; + year--; + leap = IS_LEAP (year); + } + + /* Months */ + v_tm_mon = 11; + while (m < (Time64_T) -days_in_month[leap][v_tm_mon]) { + m += (Time64_T) days_in_month[leap][v_tm_mon]; + v_tm_mon--; + } + m += (Time64_T) days_in_month[leap][v_tm_mon]; + } + + p->tm_year = (int)year; + if( p->tm_year != year ) { +#ifdef EOVERFLOW + errno = EOVERFLOW; +#endif + return NULL; + } + + /* At this point m is less than a year so casting to an int is safe */ + p->tm_mday = (int) m + 1; + p->tm_yday = julian_days_by_month[leap][v_tm_mon] + (int)m; + p->tm_sec = v_tm_sec; + p->tm_min = v_tm_min; + p->tm_hour = v_tm_hour; + p->tm_mon = v_tm_mon; + p->tm_wday = v_tm_wday; + + assert(check_tm(p)); + + return p; +} + + +struct TM *cbson_localtime64_r (const Time64_T *time, struct TM *local_tm) +{ + time_t safe_time; + struct tm safe_date; + struct TM gm_tm; + Year orig_year; + int month_diff; + + assert(local_tm != NULL); + +#ifdef USE_SYSTEM_LOCALTIME + /* Use the system localtime() if time_t is small enough */ + if( SHOULD_USE_SYSTEM_LOCALTIME(*time) ) { + safe_time = (time_t)*time; + + TIME64_TRACE1("Using system localtime for %lld\n", *time); + + LOCALTIME_R(&safe_time, &safe_date); + + pymongo_copy_tm_to_TM64(&safe_date, local_tm); + assert(check_tm(local_tm)); + + return local_tm; + } +#endif + + if( cbson_gmtime64_r(time, &gm_tm) == NULL ) { + TIME64_TRACE1("cbson_gmtime64_r returned null for %lld\n", *time); + return NULL; + } + + orig_year = gm_tm.tm_year; + + if (gm_tm.tm_year > (2037 - 1900) || + gm_tm.tm_year < (1970 - 1900) + ) + { + TIME64_TRACE1("Mapping tm_year %lld to safe_year\n", (Year)gm_tm.tm_year); + gm_tm.tm_year = safe_year((Year)(gm_tm.tm_year + 1900)) - 1900; + } + + safe_time = (time_t)cbson_timegm64(&gm_tm); + if( LOCALTIME_R(&safe_time, &safe_date) == NULL ) { + TIME64_TRACE1("localtime_r(%d) returned NULL\n", (int)safe_time); + return NULL; + } + + pymongo_copy_tm_to_TM64(&safe_date, local_tm); + + local_tm->tm_year = (int)orig_year; + if( local_tm->tm_year != orig_year ) { + TIME64_TRACE2("tm_year overflow: tm_year %lld, orig_year %lld\n", + (Year)local_tm->tm_year, (Year)orig_year); + +#ifdef EOVERFLOW + errno = EOVERFLOW; +#endif + return NULL; + } + + + month_diff = local_tm->tm_mon - gm_tm.tm_mon; + + /* When localtime is Dec 31st previous year and + gmtime is Jan 1st next year. + */ + if( month_diff == 11 ) { + local_tm->tm_year--; + } + + /* When localtime is Jan 1st, next year and + gmtime is Dec 31st, previous year. + */ + if( month_diff == -11 ) { + local_tm->tm_year++; + } + + /* GMT is Jan 1st, xx01 year, but localtime is still Dec 31st + in a non-leap xx00. There is one point in the cycle + we can't account for which the safe xx00 year is a leap + year. So we need to correct for Dec 31st coming out as + the 366th day of the year. + */ + if( !IS_LEAP(local_tm->tm_year) && local_tm->tm_yday == 365 ) + local_tm->tm_yday--; + + assert(check_tm(local_tm)); + + return local_tm; +} + + +int cbson_valid_tm_wday( const struct TM* date ) { + if( 0 <= date->tm_wday && date->tm_wday <= 6 ) + return 1; + else + return 0; +} + +int cbson_valid_tm_mon( const struct TM* date ) { + if( 0 <= date->tm_mon && date->tm_mon <= 11 ) + return 1; + else + return 0; +} + + +/* Non-thread safe versions of the above */ +struct TM *cbson_localtime64(const Time64_T *time) { +#ifdef _MSC_VER + _tzset(); +#else + tzset(); +#endif + return cbson_localtime64_r(time, &Static_Return_Date); +} + +struct TM *cbson_gmtime64(const Time64_T *time) { + return cbson_gmtime64_r(time, &Static_Return_Date); +} diff --git a/venv/Lib/site-packages/bson/time64.h b/venv/Lib/site-packages/bson/time64.h new file mode 100644 index 00000000..6321eb30 --- /dev/null +++ b/venv/Lib/site-packages/bson/time64.h @@ -0,0 +1,67 @@ +#ifndef TIME64_H +# define TIME64_H + +#include +#include "time64_config.h" + +/* Set our custom types */ +typedef INT_64_T Int64; +typedef Int64 Time64_T; +typedef Int64 Year; + + +/* A copy of the tm struct but with a 64 bit year */ +struct TM64 { + int tm_sec; + int tm_min; + int tm_hour; + int tm_mday; + int tm_mon; + Year tm_year; + int tm_wday; + int tm_yday; + int tm_isdst; + +#ifdef HAS_TM_TM_GMTOFF + long tm_gmtoff; +#endif + +#ifdef HAS_TM_TM_ZONE + char *tm_zone; +#endif +}; + + +/* Decide which tm struct to use */ +#ifdef USE_TM64 +#define TM TM64 +#else +#define TM tm +#endif + + +/* Declare public functions */ +struct TM *cbson_gmtime64_r (const Time64_T *, struct TM *); +struct TM *cbson_localtime64_r (const Time64_T *, struct TM *); +struct TM *cbson_gmtime64 (const Time64_T *); +struct TM *cbson_localtime64 (const Time64_T *); + +Time64_T cbson_timegm64 (const struct TM *); +Time64_T cbson_mktime64 (const struct TM *); +Time64_T timelocal64 (const struct TM *); + + +/* Not everyone has gm/localtime_r(), provide a replacement */ +#ifdef HAS_LOCALTIME_R +# define LOCALTIME_R(clock, result) localtime_r(clock, result) +#else +# define LOCALTIME_R(clock, result) cbson_fake_localtime_r(clock, result) +#endif +#ifdef HAS_GMTIME_R +# define GMTIME_R(clock, result) gmtime_r(clock, result) +#else +# define GMTIME_R(clock, result) cbson_fake_gmtime_r(clock, result) +#endif + + +#endif diff --git a/venv/Lib/site-packages/bson/time64_config.h b/venv/Lib/site-packages/bson/time64_config.h new file mode 100644 index 00000000..9d4c111c --- /dev/null +++ b/venv/Lib/site-packages/bson/time64_config.h @@ -0,0 +1,78 @@ +/* Configuration + ------------- + Define as appropriate for your system. + Sensible defaults provided. +*/ + + +#ifndef TIME64_CONFIG_H +# define TIME64_CONFIG_H + +/* Debugging + TIME_64_DEBUG + Define if you want debugging messages +*/ +/* #define TIME_64_DEBUG */ + + +/* INT_64_T + A 64 bit integer type to use to store time and others. + Must be defined. +*/ +#define INT_64_T long long + + +/* USE_TM64 + Should we use a 64 bit safe replacement for tm? This will + let you go past year 2 billion but the struct will be incompatible + with tm. Conversion functions will be provided. +*/ +/* #define USE_TM64 */ + + +/* Availability of system functions. + + HAS_GMTIME_R + Define if your system has gmtime_r() + + HAS_LOCALTIME_R + Define if your system has localtime_r() + + HAS_TIMEGM + Define if your system has timegm(), a GNU extension. +*/ +#if !defined(WIN32) && !defined(_MSC_VER) +#define HAS_GMTIME_R +#define HAS_LOCALTIME_R +#endif +/* #define HAS_TIMEGM */ + + +/* Details of non-standard tm struct elements. + + HAS_TM_TM_GMTOFF + True if your tm struct has a "tm_gmtoff" element. + A BSD extension. + + HAS_TM_TM_ZONE + True if your tm struct has a "tm_zone" element. + A BSD extension. +*/ +/* #define HAS_TM_TM_GMTOFF */ +/* #define HAS_TM_TM_ZONE */ + + +/* USE_SYSTEM_LOCALTIME + USE_SYSTEM_GMTIME + USE_SYSTEM_MKTIME + USE_SYSTEM_TIMEGM + Should we use the system functions if the time is inside their range? + Your system localtime() is probably more accurate, but our gmtime() is + fast and safe. +*/ +#define USE_SYSTEM_LOCALTIME +/* #define USE_SYSTEM_GMTIME */ +#define USE_SYSTEM_MKTIME +/* #define USE_SYSTEM_TIMEGM */ + +#endif /* TIME64_CONFIG_H */ diff --git a/venv/Lib/site-packages/bson/time64_limits.h b/venv/Lib/site-packages/bson/time64_limits.h new file mode 100644 index 00000000..1d30607b --- /dev/null +++ b/venv/Lib/site-packages/bson/time64_limits.h @@ -0,0 +1,95 @@ +/* + Maximum and minimum inputs your system's respective time functions + can correctly handle. time64.h will use your system functions if + the input falls inside these ranges and corresponding USE_SYSTEM_* + constant is defined. +*/ + +#ifndef TIME64_LIMITS_H +#define TIME64_LIMITS_H + +/* Max/min for localtime() */ +#define SYSTEM_LOCALTIME_MAX 2147483647 +#define SYSTEM_LOCALTIME_MIN -2147483647-1 + +/* Max/min for gmtime() */ +#define SYSTEM_GMTIME_MAX 2147483647 +#define SYSTEM_GMTIME_MIN -2147483647-1 + +/* Max/min for mktime() */ +static const struct tm SYSTEM_MKTIME_MAX = { + 7, + 14, + 19, + 18, + 0, + 138, + 1, + 17, + 0 +#ifdef HAS_TM_TM_GMTOFF + ,-28800 +#endif +#ifdef HAS_TM_TM_ZONE + ,"PST" +#endif +}; + +static const struct tm SYSTEM_MKTIME_MIN = { + 52, + 45, + 12, + 13, + 11, + 1, + 5, + 346, + 0 +#ifdef HAS_TM_TM_GMTOFF + ,-28800 +#endif +#ifdef HAS_TM_TM_ZONE + ,"PST" +#endif +}; + +/* Max/min for timegm() */ +#ifdef HAS_TIMEGM +static const struct tm SYSTEM_TIMEGM_MAX = { + 7, + 14, + 3, + 19, + 0, + 138, + 2, + 18, + 0 + #ifdef HAS_TM_TM_GMTOFF + ,0 + #endif + #ifdef HAS_TM_TM_ZONE + ,"UTC" + #endif +}; + +static const struct tm SYSTEM_TIMEGM_MIN = { + 52, + 45, + 20, + 13, + 11, + 1, + 5, + 346, + 0 + #ifdef HAS_TM_TM_GMTOFF + ,0 + #endif + #ifdef HAS_TM_TM_ZONE + ,"UTC" + #endif +}; +#endif /* HAS_TIMEGM */ + +#endif /* TIME64_LIMITS_H */ diff --git a/venv/Lib/site-packages/bson/timestamp.py b/venv/Lib/site-packages/bson/timestamp.py new file mode 100644 index 00000000..3e76e7ba --- /dev/null +++ b/venv/Lib/site-packages/bson/timestamp.py @@ -0,0 +1,123 @@ +# Copyright 2010-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing MongoDB internal Timestamps.""" +from __future__ import annotations + +import calendar +import datetime +from typing import Any, Union + +from bson._helpers import _getstate_slots, _setstate_slots +from bson.tz_util import utc + +UPPERBOUND = 4294967296 + + +class Timestamp: + """MongoDB internal timestamps used in the opLog.""" + + __slots__ = ("__time", "__inc") + + __getstate__ = _getstate_slots + __setstate__ = _setstate_slots + + _type_marker = 17 + + def __init__(self, time: Union[datetime.datetime, int], inc: int) -> None: + """Create a new :class:`Timestamp`. + + This class is only for use with the MongoDB opLog. If you need + to store a regular timestamp, please use a + :class:`~datetime.datetime`. + + Raises :class:`TypeError` if `time` is not an instance of + :class: `int` or :class:`~datetime.datetime`, or `inc` is not + an instance of :class:`int`. Raises :class:`ValueError` if + `time` or `inc` is not in [0, 2**32). + + :param time: time in seconds since epoch UTC, or a naive UTC + :class:`~datetime.datetime`, or an aware + :class:`~datetime.datetime` + :param inc: the incrementing counter + """ + if isinstance(time, datetime.datetime): + offset = time.utcoffset() + if offset is not None: + time = time - offset + time = int(calendar.timegm(time.timetuple())) + if not isinstance(time, int): + raise TypeError("time must be an instance of int") + if not isinstance(inc, int): + raise TypeError("inc must be an instance of int") + if not 0 <= time < UPPERBOUND: + raise ValueError("time must be contained in [0, 2**32)") + if not 0 <= inc < UPPERBOUND: + raise ValueError("inc must be contained in [0, 2**32)") + + self.__time = time + self.__inc = inc + + @property + def time(self) -> int: + """Get the time portion of this :class:`Timestamp`.""" + return self.__time + + @property + def inc(self) -> int: + """Get the inc portion of this :class:`Timestamp`.""" + return self.__inc + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Timestamp): + return self.__time == other.time and self.__inc == other.inc + else: + return NotImplemented + + def __hash__(self) -> int: + return hash(self.time) ^ hash(self.inc) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __lt__(self, other: Any) -> bool: + if isinstance(other, Timestamp): + return (self.time, self.inc) < (other.time, other.inc) + return NotImplemented + + def __le__(self, other: Any) -> bool: + if isinstance(other, Timestamp): + return (self.time, self.inc) <= (other.time, other.inc) + return NotImplemented + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Timestamp): + return (self.time, self.inc) > (other.time, other.inc) + return NotImplemented + + def __ge__(self, other: Any) -> bool: + if isinstance(other, Timestamp): + return (self.time, self.inc) >= (other.time, other.inc) + return NotImplemented + + def __repr__(self) -> str: + return f"Timestamp({self.__time}, {self.__inc})" + + def as_datetime(self) -> datetime.datetime: + """Return a :class:`~datetime.datetime` instance corresponding + to the time portion of this :class:`Timestamp`. + + The returned datetime's timezone is UTC. + """ + return datetime.datetime.fromtimestamp(self.__time, utc) diff --git a/venv/Lib/site-packages/bson/typings.py b/venv/Lib/site-packages/bson/typings.py new file mode 100644 index 00000000..b80c6614 --- /dev/null +++ b/venv/Lib/site-packages/bson/typings.py @@ -0,0 +1,31 @@ +# Copyright 2023-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type aliases used by bson""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, TypeVar, Union + +if TYPE_CHECKING: + from array import array + from mmap import mmap + + from bson.raw_bson import RawBSONDocument + + +# Common Shared Types. +_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"] +_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) +_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any]) +_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] diff --git a/venv/Lib/site-packages/bson/tz_util.py b/venv/Lib/site-packages/bson/tz_util.py new file mode 100644 index 00000000..a21d3c17 --- /dev/null +++ b/venv/Lib/site-packages/bson/tz_util.py @@ -0,0 +1,53 @@ +# Copyright 2010-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Timezone related utilities for BSON.""" +from __future__ import annotations + +from datetime import datetime, timedelta, tzinfo +from typing import Optional, Tuple, Union + +ZERO: timedelta = timedelta(0) + + +class FixedOffset(tzinfo): + """Fixed offset timezone, in minutes east from UTC. + + Implementation based from the Python `standard library documentation + `_. + Defining __getinitargs__ enables pickling / copying. + """ + + def __init__(self, offset: Union[float, timedelta], name: str) -> None: + if isinstance(offset, timedelta): + self.__offset = offset + else: + self.__offset = timedelta(minutes=offset) + self.__name = name + + def __getinitargs__(self) -> Tuple[timedelta, str]: + return self.__offset, self.__name + + def utcoffset(self, dt: Optional[datetime]) -> timedelta: + return self.__offset + + def tzname(self, dt: Optional[datetime]) -> str: + return self.__name + + def dst(self, dt: Optional[datetime]) -> timedelta: + return ZERO + + +utc: FixedOffset = FixedOffset(0, "UTC") +"""Fixed offset timezone representing UTC.""" diff --git a/venv/Lib/site-packages/dns/__init__.py b/venv/Lib/site-packages/dns/__init__.py new file mode 100644 index 00000000..a4249b9e --- /dev/null +++ b/venv/Lib/site-packages/dns/__init__.py @@ -0,0 +1,70 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""dnspython DNS toolkit""" + +__all__ = [ + "asyncbackend", + "asyncquery", + "asyncresolver", + "dnssec", + "dnssecalgs", + "dnssectypes", + "e164", + "edns", + "entropy", + "exception", + "flags", + "immutable", + "inet", + "ipv4", + "ipv6", + "message", + "name", + "namedict", + "node", + "opcode", + "query", + "quic", + "rcode", + "rdata", + "rdataclass", + "rdataset", + "rdatatype", + "renderer", + "resolver", + "reversename", + "rrset", + "serial", + "set", + "tokenizer", + "transaction", + "tsig", + "tsigkeyring", + "ttl", + "rdtypes", + "update", + "version", + "versioned", + "wire", + "xfr", + "zone", + "zonetypes", + "zonefile", +] + +from dns.version import version as __version__ # noqa diff --git a/venv/Lib/site-packages/dns/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..b3a495d4 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_asyncbackend.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_asyncbackend.cpython-312.pyc new file mode 100644 index 00000000..fd3537f6 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_asyncbackend.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_asyncio_backend.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_asyncio_backend.cpython-312.pyc new file mode 100644 index 00000000..5079bc0e Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_asyncio_backend.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_ddr.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_ddr.cpython-312.pyc new file mode 100644 index 00000000..27f2430b Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_ddr.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_features.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_features.cpython-312.pyc new file mode 100644 index 00000000..646cb432 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_features.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_immutable_ctx.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_immutable_ctx.cpython-312.pyc new file mode 100644 index 00000000..0385611e Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_immutable_ctx.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/_trio_backend.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/_trio_backend.cpython-312.pyc new file mode 100644 index 00000000..cae51140 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/_trio_backend.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/asyncbackend.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/asyncbackend.cpython-312.pyc new file mode 100644 index 00000000..e3f7387c Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/asyncbackend.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/asyncquery.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/asyncquery.cpython-312.pyc new file mode 100644 index 00000000..4479d580 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/asyncquery.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/asyncresolver.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/asyncresolver.cpython-312.pyc new file mode 100644 index 00000000..7bea0651 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/asyncresolver.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/dnssec.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/dnssec.cpython-312.pyc new file mode 100644 index 00000000..b7481c8f Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/dnssec.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/dnssectypes.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/dnssectypes.cpython-312.pyc new file mode 100644 index 00000000..3214b9a7 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/dnssectypes.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/e164.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/e164.cpython-312.pyc new file mode 100644 index 00000000..701bc830 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/e164.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/edns.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/edns.cpython-312.pyc new file mode 100644 index 00000000..56e37c67 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/edns.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/entropy.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/entropy.cpython-312.pyc new file mode 100644 index 00000000..4476c54e Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/entropy.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/enum.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/enum.cpython-312.pyc new file mode 100644 index 00000000..108bda1b Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/enum.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/exception.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/exception.cpython-312.pyc new file mode 100644 index 00000000..9d689b08 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/exception.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/flags.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/flags.cpython-312.pyc new file mode 100644 index 00000000..e6b7ace1 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/flags.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/grange.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/grange.cpython-312.pyc new file mode 100644 index 00000000..2342ee45 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/grange.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/immutable.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/immutable.cpython-312.pyc new file mode 100644 index 00000000..d95df6a5 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/immutable.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/inet.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/inet.cpython-312.pyc new file mode 100644 index 00000000..b22120a5 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/inet.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/ipv4.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/ipv4.cpython-312.pyc new file mode 100644 index 00000000..41808eec Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/ipv4.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/ipv6.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/ipv6.cpython-312.pyc new file mode 100644 index 00000000..abac4e2f Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/ipv6.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/message.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/message.cpython-312.pyc new file mode 100644 index 00000000..f400c838 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/message.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/name.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/name.cpython-312.pyc new file mode 100644 index 00000000..ea658eee Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/name.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/namedict.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/namedict.cpython-312.pyc new file mode 100644 index 00000000..9e5dfe9e Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/namedict.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/nameserver.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/nameserver.cpython-312.pyc new file mode 100644 index 00000000..2f69332c Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/nameserver.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/node.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/node.cpython-312.pyc new file mode 100644 index 00000000..326da874 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/node.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/opcode.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/opcode.cpython-312.pyc new file mode 100644 index 00000000..ddc79ed2 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/opcode.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/query.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/query.cpython-312.pyc new file mode 100644 index 00000000..82430f65 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/query.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rcode.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rcode.cpython-312.pyc new file mode 100644 index 00000000..57f047e1 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rcode.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rdata.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rdata.cpython-312.pyc new file mode 100644 index 00000000..3997a445 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rdata.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rdataclass.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rdataclass.cpython-312.pyc new file mode 100644 index 00000000..dc2e4ccb Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rdataclass.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rdataset.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rdataset.cpython-312.pyc new file mode 100644 index 00000000..8a81dd5e Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rdataset.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rdatatype.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rdatatype.cpython-312.pyc new file mode 100644 index 00000000..e1e01717 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rdatatype.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/renderer.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/renderer.cpython-312.pyc new file mode 100644 index 00000000..b8293abf Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/renderer.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/resolver.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/resolver.cpython-312.pyc new file mode 100644 index 00000000..299aef9d Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/resolver.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/reversename.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/reversename.cpython-312.pyc new file mode 100644 index 00000000..80323da5 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/reversename.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/rrset.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/rrset.cpython-312.pyc new file mode 100644 index 00000000..0a9996de Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/rrset.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/serial.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/serial.cpython-312.pyc new file mode 100644 index 00000000..633e9540 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/serial.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/set.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/set.cpython-312.pyc new file mode 100644 index 00000000..e8d7ba96 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/set.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/tokenizer.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 00000000..955a278f Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/transaction.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/transaction.cpython-312.pyc new file mode 100644 index 00000000..ecc53a42 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/transaction.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/tsig.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/tsig.cpython-312.pyc new file mode 100644 index 00000000..da0de6fa Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/tsig.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/tsigkeyring.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/tsigkeyring.cpython-312.pyc new file mode 100644 index 00000000..227fd53b Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/tsigkeyring.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/ttl.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/ttl.cpython-312.pyc new file mode 100644 index 00000000..1c869018 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/ttl.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/update.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/update.cpython-312.pyc new file mode 100644 index 00000000..e652f319 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/update.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/version.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/version.cpython-312.pyc new file mode 100644 index 00000000..7bba43fd Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/version.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/versioned.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/versioned.cpython-312.pyc new file mode 100644 index 00000000..a7c3a4ea Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/versioned.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/win32util.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/win32util.cpython-312.pyc new file mode 100644 index 00000000..2a169ef0 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/win32util.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/wire.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/wire.cpython-312.pyc new file mode 100644 index 00000000..518bfaae Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/wire.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/xfr.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/xfr.cpython-312.pyc new file mode 100644 index 00000000..44c6cea8 Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/xfr.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/zone.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/zone.cpython-312.pyc new file mode 100644 index 00000000..4f0af0df Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/zone.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/zonefile.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/zonefile.cpython-312.pyc new file mode 100644 index 00000000..db13074f Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/zonefile.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/__pycache__/zonetypes.cpython-312.pyc b/venv/Lib/site-packages/dns/__pycache__/zonetypes.cpython-312.pyc new file mode 100644 index 00000000..eedc5b4c Binary files /dev/null and b/venv/Lib/site-packages/dns/__pycache__/zonetypes.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/_asyncbackend.py b/venv/Lib/site-packages/dns/_asyncbackend.py new file mode 100644 index 00000000..49f14fed --- /dev/null +++ b/venv/Lib/site-packages/dns/_asyncbackend.py @@ -0,0 +1,99 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This is a nullcontext for both sync and async. 3.7 has a nullcontext, +# but it is only for sync use. + + +class NullContext: + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, exc_type, exc_value, traceback): + pass + + async def __aenter__(self): + return self.enter_result + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +# These are declared here so backends can import them without creating +# circular dependencies with dns.asyncbackend. + + +class Socket: # pragma: no cover + async def close(self): + pass + + async def getpeername(self): + raise NotImplementedError + + async def getsockname(self): + raise NotImplementedError + + async def getpeercert(self, timeout): + raise NotImplementedError + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + +class DatagramSocket(Socket): # pragma: no cover + def __init__(self, family: int): + self.family = family + + async def sendto(self, what, destination, timeout): + raise NotImplementedError + + async def recvfrom(self, size, timeout): + raise NotImplementedError + + +class StreamSocket(Socket): # pragma: no cover + async def sendall(self, what, timeout): + raise NotImplementedError + + async def recv(self, size, timeout): + raise NotImplementedError + + +class NullTransport: + async def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + +class Backend: # pragma: no cover + def name(self): + return "unknown" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): + raise NotImplementedError + + def datagram_connection_required(self): + return False + + async def sleep(self, interval): + raise NotImplementedError + + def get_transport_class(self): + raise NotImplementedError + + async def wait_for(self, awaitable, timeout): + raise NotImplementedError diff --git a/venv/Lib/site-packages/dns/_asyncio_backend.py b/venv/Lib/site-packages/dns/_asyncio_backend.py new file mode 100644 index 00000000..9d9ed369 --- /dev/null +++ b/venv/Lib/site-packages/dns/_asyncio_backend.py @@ -0,0 +1,275 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""asyncio library query support""" + +import asyncio +import socket +import sys + +import dns._asyncbackend +import dns._features +import dns.exception +import dns.inet + +_is_win32 = sys.platform == "win32" + + +def _get_running_loop(): + try: + return asyncio.get_running_loop() + except AttributeError: # pragma: no cover + return asyncio.get_event_loop() + + +class _DatagramProtocol: + def __init__(self): + self.transport = None + self.recvfrom = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + if self.recvfrom and not self.recvfrom.done(): + self.recvfrom.set_result((data, addr)) + + def error_received(self, exc): # pragma: no cover + if self.recvfrom and not self.recvfrom.done(): + self.recvfrom.set_exception(exc) + + def connection_lost(self, exc): + if self.recvfrom and not self.recvfrom.done(): + if exc is None: + # EOF we triggered. Is there a better way to do this? + try: + raise EOFError + except EOFError as e: + self.recvfrom.set_exception(e) + else: + self.recvfrom.set_exception(exc) + + def close(self): + self.transport.close() + + +async def _maybe_wait_for(awaitable, timeout): + if timeout is not None: + try: + return await asyncio.wait_for(awaitable, timeout) + except asyncio.TimeoutError: + raise dns.exception.Timeout(timeout=timeout) + else: + return await awaitable + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, family, transport, protocol): + super().__init__(family) + self.transport = transport + self.protocol = protocol + + async def sendto(self, what, destination, timeout): # pragma: no cover + # no timeout for asyncio sendto + self.transport.sendto(what, destination) + return len(what) + + async def recvfrom(self, size, timeout): + # ignore size as there's no way I know to tell protocol about it + done = _get_running_loop().create_future() + try: + assert self.protocol.recvfrom is None + self.protocol.recvfrom = done + await _maybe_wait_for(done, timeout) + return done.result() + finally: + self.protocol.recvfrom = None + + async def close(self): + self.protocol.close() + + async def getpeername(self): + return self.transport.get_extra_info("peername") + + async def getsockname(self): + return self.transport.get_extra_info("sockname") + + async def getpeercert(self, timeout): + raise NotImplementedError + + +class StreamSocket(dns._asyncbackend.StreamSocket): + def __init__(self, af, reader, writer): + self.family = af + self.reader = reader + self.writer = writer + + async def sendall(self, what, timeout): + self.writer.write(what) + return await _maybe_wait_for(self.writer.drain(), timeout) + + async def recv(self, size, timeout): + return await _maybe_wait_for(self.reader.read(size), timeout) + + async def close(self): + self.writer.close() + + async def getpeername(self): + return self.writer.get_extra_info("peername") + + async def getsockname(self): + return self.writer.get_extra_info("sockname") + + async def getpeercert(self, timeout): + return self.writer.get_extra_info("peercert") + + +if dns._features.have("doh"): + import anyio + import httpcore + import httpcore._backends.anyio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + if local_port != 0: + raise NotImplementedError( + "the asyncio transport for HTTPX cannot set the local port" + ) + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + with anyio.fail_after(timeout): + stream = await anyio.connect_tcp( + remote_host=address, + remote_port=port, + local_host=local_address, + ) + return _CoreAnyIOStream(stream) + except Exception: + pass + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await anyio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +else: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return "asyncio" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): + loop = _get_running_loop() + if socktype == socket.SOCK_DGRAM: + if _is_win32 and source is None: + # Win32 wants explicit binding before recvfrom(). This is the + # proper fix for [#637]. + source = (dns.inet.any_for_af(af), 0) + transport, protocol = await loop.create_datagram_endpoint( + _DatagramProtocol, + source, + family=af, + proto=proto, + remote_addr=destination, + ) + return DatagramSocket(af, transport, protocol) + elif socktype == socket.SOCK_STREAM: + if destination is None: + # This shouldn't happen, but we check to make code analysis software + # happier. + raise ValueError("destination required for stream sockets") + (r, w) = await _maybe_wait_for( + asyncio.open_connection( + destination[0], + destination[1], + ssl=ssl_context, + family=af, + proto=proto, + local_addr=source, + server_hostname=server_hostname, + ), + timeout, + ) + return StreamSocket(af, r, w) + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover + + async def sleep(self, interval): + await asyncio.sleep(interval) + + def datagram_connection_required(self): + return False + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + return await _maybe_wait_for(awaitable, timeout) diff --git a/venv/Lib/site-packages/dns/_ddr.py b/venv/Lib/site-packages/dns/_ddr.py new file mode 100644 index 00000000..bf5c11eb --- /dev/null +++ b/venv/Lib/site-packages/dns/_ddr.py @@ -0,0 +1,154 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +# +# Support for Discovery of Designated Resolvers + +import socket +import time +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.inet +import dns.name +import dns.nameserver +import dns.query +import dns.rdtypes.svcbbase + +# The special name of the local resolver when using DDR +_local_resolver_name = dns.name.from_text("_dns.resolver.arpa") + + +# +# Processing is split up into I/O independent and I/O dependent parts to +# make supporting sync and async versions easy. +# + + +class _SVCBInfo: + def __init__(self, bootstrap_address, port, hostname, nameservers): + self.bootstrap_address = bootstrap_address + self.port = port + self.hostname = hostname + self.nameservers = nameservers + + def ddr_check_certificate(self, cert): + """Verify that the _SVCBInfo's address is in the cert's subjectAltName (SAN)""" + for name, value in cert["subjectAltName"]: + if name == "IP Address" and value == self.bootstrap_address: + return True + return False + + def make_tls_context(self): + ssl = dns.query.ssl + ctx = ssl.create_default_context() + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + return ctx + + def ddr_tls_check_sync(self, lifetime): + ctx = self.make_tls_context() + expiration = time.time() + lifetime + with socket.create_connection( + (self.bootstrap_address, self.port), lifetime + ) as s: + with ctx.wrap_socket(s, server_hostname=self.hostname) as ts: + ts.settimeout(dns.query._remaining(expiration)) + ts.do_handshake() + cert = ts.getpeercert() + return self.ddr_check_certificate(cert) + + async def ddr_tls_check_async(self, lifetime, backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + ctx = self.make_tls_context() + expiration = time.time() + lifetime + async with await backend.make_socket( + dns.inet.af_for_address(self.bootstrap_address), + socket.SOCK_STREAM, + 0, + None, + (self.bootstrap_address, self.port), + lifetime, + ctx, + self.hostname, + ) as ts: + cert = await ts.getpeercert(dns.query._remaining(expiration)) + return self.ddr_check_certificate(cert) + + +def _extract_nameservers_from_svcb(answer): + bootstrap_address = answer.nameserver + if not dns.inet.is_address(bootstrap_address): + return [] + infos = [] + for rr in answer.rrset.processing_order(): + nameservers = [] + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.ALPN) + if param is None: + continue + alpns = set(param.ids) + host = rr.target.to_text(omit_final_dot=True) + port = None + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.PORT) + if param is not None: + port = param.port + # For now we ignore address hints and address resolution and always use the + # bootstrap address + if b"h2" in alpns: + param = rr.params.get(dns.rdtypes.svcbbase.ParamKey.DOHPATH) + if param is None or not param.value.endswith(b"{?dns}"): + continue + path = param.value[:-6].decode() + if not path.startswith("/"): + path = "/" + path + if port is None: + port = 443 + url = f"https://{host}:{port}{path}" + # check the URL + try: + urlparse(url) + nameservers.append(dns.nameserver.DoHNameserver(url, bootstrap_address)) + except Exception: + # continue processing other ALPN types + pass + if b"dot" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoTNameserver(bootstrap_address, port, host) + ) + if b"doq" in alpns: + if port is None: + port = 853 + nameservers.append( + dns.nameserver.DoQNameserver(bootstrap_address, port, True, host) + ) + if len(nameservers) > 0: + infos.append(_SVCBInfo(bootstrap_address, port, host, nameservers)) + return infos + + +def _get_nameservers_sync(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if info.ddr_tls_check_sync(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers + + +async def _get_nameservers_async(answer, lifetime): + """Return a list of TLS-validated resolver nameservers extracted from an SVCB + answer.""" + nameservers = [] + infos = _extract_nameservers_from_svcb(answer) + for info in infos: + try: + if await info.ddr_tls_check_async(lifetime): + nameservers.extend(info.nameservers) + except Exception: + pass + return nameservers diff --git a/venv/Lib/site-packages/dns/_features.py b/venv/Lib/site-packages/dns/_features.py new file mode 100644 index 00000000..03ccaa77 --- /dev/null +++ b/venv/Lib/site-packages/dns/_features.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import importlib.metadata +import itertools +import string +from typing import Dict, List, Tuple + + +def _tuple_from_text(version: str) -> Tuple: + text_parts = version.split(".") + int_parts = [] + for text_part in text_parts: + digit_prefix = "".join( + itertools.takewhile(lambda x: x in string.digits, text_part) + ) + try: + int_parts.append(int(digit_prefix)) + except Exception: + break + return tuple(int_parts) + + +def _version_check( + requirement: str, +) -> bool: + """Is the requirement fulfilled? + + The requirement must be of the form + + package>=version + """ + package, minimum = requirement.split(">=") + try: + version = importlib.metadata.version(package) + except Exception: + return False + t_version = _tuple_from_text(version) + t_minimum = _tuple_from_text(minimum) + if t_version < t_minimum: + return False + return True + + +_cache: Dict[str, bool] = {} + + +def have(feature: str) -> bool: + """Is *feature* available? + + This tests if all optional packages needed for the + feature are available and recent enough. + + Returns ``True`` if the feature is available, + and ``False`` if it is not or if metadata is + missing. + """ + value = _cache.get(feature) + if value is not None: + return value + requirements = _requirements.get(feature) + if requirements is None: + # we make a cache entry here for consistency not performance + _cache[feature] = False + return False + ok = True + for requirement in requirements: + if not _version_check(requirement): + ok = False + break + _cache[feature] = ok + return ok + + +def force(feature: str, enabled: bool) -> None: + """Force the status of *feature* to be *enabled*. + + This method is provided as a workaround for any cases + where importlib.metadata is ineffective, or for testing. + """ + _cache[feature] = enabled + + +_requirements: Dict[str, List[str]] = { + ### BEGIN generated requirements + "dnssec": ["cryptography>=41"], + "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"], + "doq": ["aioquic>=0.9.25"], + "idna": ["idna>=3.6"], + "trio": ["trio>=0.23"], + "wmi": ["wmi>=1.5.1"], + ### END generated requirements +} diff --git a/venv/Lib/site-packages/dns/_immutable_ctx.py b/venv/Lib/site-packages/dns/_immutable_ctx.py new file mode 100644 index 00000000..ae7a33bf --- /dev/null +++ b/venv/Lib/site-packages/dns/_immutable_ctx.py @@ -0,0 +1,76 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This implementation of the immutable decorator requires python >= +# 3.7, and is significantly more storage efficient when making classes +# with slots immutable. It's also faster. + +import contextvars +import inspect + +_in__init__ = contextvars.ContextVar("_immutable_in__init__", default=False) + + +class _Immutable: + """Immutable mixin class""" + + # We set slots to the empty list to say "we don't have any attributes". + # We do this so that if we're mixed in with a class with __slots__, we + # don't cause a __dict__ to be added which would waste space. + + __slots__ = () + + def __setattr__(self, name, value): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__delattr__(name) + + +def _immutable_init(f): + def nf(*args, **kwargs): + previous = _in__init__.set(args[0]) + try: + # call the actual __init__ + f(*args, **kwargs) + finally: + _in__init__.reset(previous) + + nf.__signature__ = inspect.signature(f) + return nf + + +def immutable(cls): + if _Immutable in cls.__mro__: + # Some ancestor already has the mixin, so just make sure we keep + # following the __init__ protocol. + cls.__init__ = _immutable_init(cls.__init__) + if hasattr(cls, "__setstate__"): + cls.__setstate__ = _immutable_init(cls.__setstate__) + ncls = cls + else: + # Mixin the Immutable class and follow the __init__ protocol. + class ncls(_Immutable, cls): + # We have to do the __slots__ declaration here too! + __slots__ = () + + @_immutable_init + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if hasattr(cls, "__setstate__"): + + @_immutable_init + def __setstate__(self, *args, **kwargs): + super().__setstate__(*args, **kwargs) + + # make ncls have the same name and module as cls + ncls.__name__ = cls.__name__ + ncls.__qualname__ = cls.__qualname__ + ncls.__module__ = cls.__module__ + return ncls diff --git a/venv/Lib/site-packages/dns/_trio_backend.py b/venv/Lib/site-packages/dns/_trio_backend.py new file mode 100644 index 00000000..398e3276 --- /dev/null +++ b/venv/Lib/site-packages/dns/_trio_backend.py @@ -0,0 +1,250 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""trio async I/O library query support""" + +import socket + +import trio +import trio.socket # type: ignore + +import dns._asyncbackend +import dns._features +import dns.exception +import dns.inet + +if not dns._features.have("trio"): + raise ImportError("trio not found or too old") + + +def _maybe_timeout(timeout): + if timeout is not None: + return trio.move_on_after(timeout) + else: + return dns._asyncbackend.NullContext() + + +# for brevity +_lltuple = dns.inet.low_level_address_tuple + +# pylint: disable=redefined-outer-name + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + super().__init__(socket.family) + self.socket = socket + + async def sendto(self, what, destination, timeout): + with _maybe_timeout(timeout): + return await self.socket.sendto(what, destination) + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] + + async def recvfrom(self, size, timeout): + with _maybe_timeout(timeout): + return await self.socket.recvfrom(size) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] + + async def close(self): + self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + async def getsockname(self): + return self.socket.getsockname() + + async def getpeercert(self, timeout): + raise NotImplementedError + + +class StreamSocket(dns._asyncbackend.StreamSocket): + def __init__(self, family, stream, tls=False): + self.family = family + self.stream = stream + self.tls = tls + + async def sendall(self, what, timeout): + with _maybe_timeout(timeout): + return await self.stream.send_all(what) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] + + async def recv(self, size, timeout): + with _maybe_timeout(timeout): + return await self.stream.receive_some(size) + raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] + + async def close(self): + await self.stream.aclose() + + async def getpeername(self): + if self.tls: + return self.stream.transport_stream.socket.getpeername() + else: + return self.stream.socket.getpeername() + + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + + async def getpeercert(self, timeout): + if self.tls: + with _maybe_timeout(timeout): + await self.stream.do_handshake() + return self.stream.getpeercert() + else: + raise NotImplementedError + + +if dns._features.have("doh"): + import httpcore + import httpcore._backends.trio + import httpx + + _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend + _CoreTrioStream = httpcore._backends.trio.TrioStream + + from dns.query import _compute_times, _expiration_for_this_attempt, _remaining + + class _NetworkBackend(_CoreAsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + async def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = (local_address, self._local_port) + else: + source = None + destination = (address, port) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + sock = await Backend().make_socket( + af, socket.SOCK_STREAM, 0, source, destination, timeout + ) + return _CoreTrioStream(sock.stream) + except Exception: + continue + raise httpcore.ConnectError + + async def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + async def sleep(self, seconds): # pylint: disable=signature-differs + await trio.sleep(seconds) + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +else: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return "trio" + + async def make_socket( + self, + af, + socktype, + proto=0, + source=None, + destination=None, + timeout=None, + ssl_context=None, + server_hostname=None, + ): + s = trio.socket.socket(af, socktype, proto) + stream = None + try: + if source: + await s.bind(_lltuple(source, af)) + if socktype == socket.SOCK_STREAM: + connected = False + with _maybe_timeout(timeout): + await s.connect(_lltuple(destination, af)) + connected = True + if not connected: + raise dns.exception.Timeout( + timeout=timeout + ) # lgtm[py/unreachable-statement] + except Exception: # pragma: no cover + s.close() + raise + if socktype == socket.SOCK_DGRAM: + return DatagramSocket(s) + elif socktype == socket.SOCK_STREAM: + stream = trio.SocketStream(s) + tls = False + if ssl_context: + tls = True + try: + stream = trio.SSLStream( + stream, ssl_context, server_hostname=server_hostname + ) + except Exception: # pragma: no cover + await stream.aclose() + raise + return StreamSocket(af, stream, tls) + raise NotImplementedError( + "unsupported socket " + f"type {socktype}" + ) # pragma: no cover + + async def sleep(self, interval): + await trio.sleep(interval) + + def get_transport_class(self): + return _HTTPTransport + + async def wait_for(self, awaitable, timeout): + with _maybe_timeout(timeout): + return await awaitable + raise dns.exception.Timeout( + timeout=timeout + ) # pragma: no cover lgtm[py/unreachable-statement] diff --git a/venv/Lib/site-packages/dns/asyncbackend.py b/venv/Lib/site-packages/dns/asyncbackend.py new file mode 100644 index 00000000..0ec58b06 --- /dev/null +++ b/venv/Lib/site-packages/dns/asyncbackend.py @@ -0,0 +1,101 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +from typing import Dict + +import dns.exception + +# pylint: disable=unused-import +from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import] + Backend, + DatagramSocket, + Socket, + StreamSocket, +) + +# pylint: enable=unused-import + +_default_backend = None + +_backends: Dict[str, Backend] = {} + +# Allow sniffio import to be disabled for testing purposes +_no_sniffio = False + + +class AsyncLibraryNotFoundError(dns.exception.DNSException): + pass + + +def get_backend(name: str) -> Backend: + """Get the specified asynchronous backend. + + *name*, a ``str``, the name of the backend. Currently the "trio" + and "asyncio" backends are available. + + Raises NotImplementedError if an unknown backend name is specified. + """ + # pylint: disable=import-outside-toplevel,redefined-outer-name + backend = _backends.get(name) + if backend: + return backend + if name == "trio": + import dns._trio_backend + + backend = dns._trio_backend.Backend() + elif name == "asyncio": + import dns._asyncio_backend + + backend = dns._asyncio_backend.Backend() + else: + raise NotImplementedError(f"unimplemented async backend {name}") + _backends[name] = backend + return backend + + +def sniff() -> str: + """Attempt to determine the in-use asynchronous I/O library by using + the ``sniffio`` module if it is available. + + Returns the name of the library, or raises AsyncLibraryNotFoundError + if the library cannot be determined. + """ + # pylint: disable=import-outside-toplevel + try: + if _no_sniffio: + raise ImportError + import sniffio + + try: + return sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + raise AsyncLibraryNotFoundError("sniffio cannot determine async library") + except ImportError: + import asyncio + + try: + asyncio.get_running_loop() + return "asyncio" + except RuntimeError: + raise AsyncLibraryNotFoundError("no async library detected") + + +def get_default_backend() -> Backend: + """Get the default backend, initializing it if necessary.""" + if _default_backend: + return _default_backend + + return set_default_backend(sniff()) + + +def set_default_backend(name: str) -> Backend: + """Set the default backend. + + It's not normally necessary to call this method, as + ``get_default_backend()`` will initialize the backend + appropriately in many cases. If ``sniffio`` is not installed, or + in testing situations, this function allows the backend to be set + explicitly. + """ + global _default_backend + _default_backend = get_backend(name) + return _default_backend diff --git a/venv/Lib/site-packages/dns/asyncquery.py b/venv/Lib/site-packages/dns/asyncquery.py new file mode 100644 index 00000000..4d9ab9ae --- /dev/null +++ b/venv/Lib/site-packages/dns/asyncquery.py @@ -0,0 +1,780 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Talk to a DNS server.""" + +import base64 +import contextlib +import socket +import struct +import time +from typing import Any, Dict, Optional, Tuple, Union + +import dns.asyncbackend +import dns.exception +import dns.inet +import dns.message +import dns.name +import dns.quic +import dns.rcode +import dns.rdataclass +import dns.rdatatype +import dns.transaction +from dns._asyncbackend import NullContext +from dns.query import ( + BadResponse, + NoDOH, + NoDOQ, + UDPMode, + _compute_times, + _make_dot_ssl_context, + _matches_destination, + _remaining, + have_doh, + ssl, +) + +if have_doh: + import httpx + +# for brevity +_lltuple = dns.inet.low_level_address_tuple + + +def _source_tuple(af, address, port): + # Make a high level source tuple, or return None if address and port + # are both None + if address or port: + if address is None: + if af == socket.AF_INET: + address = "0.0.0.0" + elif af == socket.AF_INET6: + address = "::" + else: + raise NotImplementedError(f"unknown address family {af}") + return (address, port) + else: + return None + + +def _timeout(expiration, now=None): + if expiration is not None: + if not now: + now = time.time() + return max(expiration - now, 0) + else: + return None + + +async def send_udp( + sock: dns.asyncbackend.DatagramSocket, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: + """Send a DNS message to the specified UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. The expiration value is meaningless for the asyncio backend, as + asyncio's transport sendto() never blocks. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + sent_time = time.time() + n = await sock.sendto(what, destination, _timeout(expiration, sent_time)) + return (n, sent_time) + + +async def receive_udp( + sock: dns.asyncbackend.DatagramSocket, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + ignore_errors: bool = False, + query: Optional[dns.message.Message] = None, +) -> Any: + """Read a DNS message from a UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + See :py:func:`dns.query.receive_udp()` for the documentation of the other + parameters, and exceptions. + + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the + received time, and the address where the message arrived from. + """ + + wire = b"" + while True: + (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) + if not _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): + continue + received_time = time.time() + try: + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) + except dns.message.Truncated as e: + # See the comment in query.py for details. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise + except Exception: + if ignore_errors: + continue + else: + raise + if ignore_errors and query is not None and not query.is_response(r): + continue + return (r, received_time, from_address) + + +async def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[dns.asyncbackend.DatagramSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ignore_errors: bool = False, +) -> dns.message.Message: + """Return the response obtained after sending a query via UDP. + + *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the query. If ``None``, the default, a + socket is created. Note that if a socket is provided, the + *source*, *source_port*, and *backend* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.udp()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + af = dns.inet.af_for_address(where) + destination = _lltuple((where, port), af) + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + if not backend: + backend = dns.asyncbackend.get_default_backend() + stuple = _source_tuple(af, source, source_port) + if backend.datagram_connection_required(): + dtuple = (where, port) + else: + dtuple = None + cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) + async with cm as s: + await send_udp(s, wire, destination, expiration) + (r, received_time, _) = await receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ignore_errors, + q, + ) + r.time = received_time - begin_time + # We don't need to check q.is_response() if we are in ignore_errors mode + # as receive_udp() will have checked it. + if not (ignore_errors or q.is_response(r)): + raise BadResponse + return r + + +async def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[dns.asyncbackend.DatagramSocket] = None, + tcp_sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ignore_errors: bool = False, +) -> Tuple[dns.message.Message, bool]: + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the UDP query. If ``None``, the default, a + socket is created. Note that if a socket is provided the *source*, + *source_port*, and *backend* are ignored for the UDP query. + + *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the + socket to use for the TCP query. If ``None``, the default, a + socket is created. Note that if a socket is provided *where*, + *source*, *source_port*, and *backend* are ignored for the TCP query. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.udp_with_fallback()` for the documentation + of the other parameters, exceptions, and return type of this + method. + """ + try: + response = await udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + backend, + ignore_errors, + ) + return (response, False) + except dns.message.Truncated: + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + backend, + ) + return (response, True) + + +async def send_tcp( + sock: dns.asyncbackend.StreamSocket, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: + """Send a DNS message to the specified TCP socket. + + *sock*, a ``dns.asyncbackend.StreamSocket``. + + See :py:func:`dns.query.send_tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if isinstance(what, dns.message.Message): + tcpmsg = what.to_wire(prepend_length=True) + else: + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, "big") + what + sent_time = time.time() + await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) + return (len(tcpmsg), sent_time) + + +async def _read_exactly(sock, count, expiration): + """Read the specified number of bytes from stream. Keep trying until we + either get the desired amount, or we hit EOF. + """ + s = b"" + while count > 0: + n = await sock.recv(count, _timeout(expiration)) + if n == b"": + raise EOFError + count = count - len(n) + s = s + n + return s + + +async def receive_tcp( + sock: dns.asyncbackend.StreamSocket, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: + """Read a DNS message from a TCP socket. + + *sock*, a ``dns.asyncbackend.StreamSocket``. + + See :py:func:`dns.query.receive_tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + ldata = await _read_exactly(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = await _read_exactly(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return (r, received_time) + + +async def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: + """Return the response obtained after sending a query via TCP. + + *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the + socket to use for the query. If ``None``, the default, a socket + is created. Note that if a socket is provided + *where*, *port*, *source*, *source_port*, and *backend* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.tcp()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + if sock: + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + await sock.getpeername() + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + # These are simple (address, port) pairs, not family-dependent tuples + # you pass to low-level socket code. + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + cm = await backend.make_socket( + af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout + ) + async with cm as s: + await send_tcp(s, wire, expiration) + (r, received_time) = await receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + + +async def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[dns.asyncbackend.StreamSocket] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, +) -> dns.message.Message: + """Return the response obtained after sending a query via TLS. + + *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket + to use for the query. If ``None``, the default, a socket is + created. Note that if a socket is provided, it must be a + connected SSL stream socket, and *where*, *port*, + *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname* + are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.tls()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + (begin_time, expiration) = _compute_times(timeout) + if sock: + cm: contextlib.AbstractAsyncContextManager = NullContext(sock) + else: + if ssl_context is None: + ssl_context = _make_dot_ssl_context(server_hostname, verify) + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + if not backend: + backend = dns.asyncbackend.get_default_backend() + cm = await backend.make_socket( + af, + socket.SOCK_STREAM, + 0, + stuple, + dtuple, + timeout, + ssl_context, + server_hostname, + ) + async with cm as s: + timeout = _timeout(expiration) + response = await tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + s, + backend, + ) + end_time = time.time() + response.time = end_time - begin_time + return response + + +async def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, # pylint: disable=W0613 + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + client: Optional["httpx.AsyncClient"] = None, + path: str = "/dns-query", + post: bool = True, + verify: Union[bool, str] = True, + bootstrap_address: Optional[str] = None, + resolver: Optional["dns.asyncresolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, +) -> dns.message.Message: + """Return the response obtained after sending a query via DNS-over-HTTPS. + + *client*, a ``httpx.AsyncClient``. If provided, the client to use for + the query. + + Unlike the other dnspython async functions, a backend cannot be provided + in this function because httpx always auto-detects the async backend. + + See :py:func:`dns.query.https()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not have_doh: + raise NoDOH # pragma: no cover + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") + + wire = q.to_wire() + try: + af = dns.inet.af_for_address(where) + except ValueError: + af = None + transport = None + headers = {"accept": "application/dns-message"} + if af is not None and dns.inet.is_address(where): + if af == socket.AF_INET: + url = "https://{}:{}{}".format(where, port, path) + elif af == socket.AF_INET6: + url = "https://[{}]:{}{}".format(where, port, path) + else: + url = where + + backend = dns.asyncbackend.get_default_backend() + + if source is None: + local_address = None + local_port = 0 + else: + local_address = source + local_port = source_port + transport = backend.get_transport_class()( + local_address=local_address, + http1=True, + http2=True, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) + + if client: + cm: contextlib.AbstractAsyncContextManager = NullContext(client) + else: + cm = httpx.AsyncClient( + http1=True, http2=True, verify=verify, transport=transport + ) + + async with cm as the_client: + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) + response = await backend.wait_for( + the_client.post(url, headers=headers, content=wire), timeout + ) + else: + wire = base64.urlsafe_b64encode(wire).rstrip(b"=") + twire = wire.decode() # httpx does a repr() if we give it bytes + response = await backend.wait_for( + the_client.get(url, headers=headers, params={"dns": twire}), timeout + ) + + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes + if response.status_code < 200 or response.status_code > 299: + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {!r}".format( + where, response.status_code, response.content + ) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = response.elapsed.total_seconds() + if not q.is_response(r): + raise BadResponse + return r + + +async def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> None: + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.inbound_xfr()` for the documentation of + the other parameters, exceptions, and return type of this method. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + if not backend: + backend = dns.asyncbackend.get_default_backend() + s = await backend.make_socket( + af, sock_type, 0, stuple, dtuple, _timeout(expiration) + ) + async with s: + if is_udp: + await s.sendto(wire, dtuple, _timeout(expiration)) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + await s.sendall(tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + destination = _lltuple((where, port), af) + while True: + timeout = _timeout(mexpiration) + (rwire, from_address) = await s.recvfrom(65535, timeout) + if _matches_destination( + af, from_address, destination, True + ): + break + else: + ldata = await _read_exactly(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = await _read_exactly(s, l, mexpiration) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + try: + done = inbound.process_message(r) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + + +async def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.AsyncQuicConnection] = None, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, + server_hostname: Optional[str] = None, +) -> dns.message.Message: + """Return the response obtained after sending an asynchronous query via + DNS-over-QUIC. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.quic()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + the_connection = connection + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory( + context, verify_mode=verify, server_name=server_hostname + ) as the_manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) + async with stream: + await stream.send(wire, True) + wire = await stream.receive(_remaining(expiration)) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r diff --git a/venv/Lib/site-packages/dns/asyncresolver.py b/venv/Lib/site-packages/dns/asyncresolver.py new file mode 100644 index 00000000..8f5e062a --- /dev/null +++ b/venv/Lib/site-packages/dns/asyncresolver.py @@ -0,0 +1,475 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Asynchronous DNS stub resolver.""" + +import socket +import time +from typing import Any, Dict, List, Optional, Union + +import dns._ddr +import dns.asyncbackend +import dns.asyncquery +import dns.exception +import dns.name +import dns.query +import dns.rdataclass +import dns.rdatatype +import dns.resolver # lgtm[py/import-and-import-from] + +# import some resolver symbols for brevity +from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute + +# for indentation purposes below +_udp = dns.asyncquery.udp +_tcp = dns.asyncquery.tcp + + +class Resolver(dns.resolver.BaseResolver): + """Asynchronous DNS stub resolver.""" + + async def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + ) -> dns.resolver.Answer: + """Query nameservers asynchronously to find the answer to the question. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.resolver.Resolver.resolve()` for the + documentation of the other parameters, exceptions, and return + type of this method. + """ + + resolution = dns.resolver._Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) + if not backend: + backend = dns.asyncbackend.get_default_backend() + start = time.time() + while True: + (request, answer) = resolution.next_request() + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + # cache hit! + return answer + assert request is not None # needed for type checking + done = False + while not done: + (nameserver, tcp, backoff) = resolution.next_nameserver() + if backoff: + await backend.sleep(backoff) + timeout = self._compute_timeout(start, lifetime, resolution.errors) + try: + response = await nameserver.async_query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + backend=backend, + ) + except Exception as ex: + (_, done) = resolution.query_result(None, ex) + continue + (answer, done) = resolution.query_result(response, None) + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + return answer + + async def resolve_address( + self, ipaddr: str, *args: Any, **kwargs: Any + ) -> dns.resolver.Answer: + """Use an asynchronous resolver to run a reverse query for PTR + records. + + This utilizes the resolve() method to perform a PTR lookup on the + specified IP address. + + *ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get + the PTR record for. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return await self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) + + async def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> dns.resolver.HostAnswers: + """Use an asynchronous resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return dns.resolver.HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return dns.resolver.HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = await self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = await self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = dns.resolver.HostAnswers.make( + v6=v6, v4=v4, add_empty=not raise_on_no_answer + ) + if not answers: + raise NoAnswer(response=v6.response) + return answers + + # pylint: disable=redefined-outer-name + + async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: + """Determine the canonical name of *name*. + + The canonical name is the name the resolver uses for queries + after all CNAME and DNAME renamings have been applied. + + *name*, a ``dns.name.Name`` or ``str``, the query name. + + This method can raise any exception that ``resolve()`` can + raise, other than ``dns.resolver.NoAnswer`` and + ``dns.resolver.NXDOMAIN``. + + Returns a ``dns.name.Name``. + """ + try: + answer = await self.resolve(name, raise_on_no_answer=False) + canonical_name = answer.canonical_name + except dns.resolver.NXDOMAIN as e: + canonical_name = e.canonical_name + return canonical_name + + async def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = await self.resolve( + dns._ddr._local_resolver_name, "svcb", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = await dns._ddr._get_nameservers_async(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + + +default_resolver = None + + +def get_default_resolver() -> Resolver: + """Get the default asynchronous resolver, initializing it if necessary.""" + if default_resolver is None: + reset_default_resolver() + assert default_resolver is not None + return default_resolver + + +def reset_default_resolver() -> None: + """Re-initialize default asynchronous resolver. + + Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX + systems) will be re-read immediately. + """ + + global default_resolver + default_resolver = Resolver() + + +async def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.resolver.Answer: + """Query nameservers asynchronously to find the answer to the question. + + This is a convenience function that uses the default resolver + object to make the query. + + See :py:func:`dns.asyncresolver.Resolver.resolve` for more + information on the parameters. + """ + + return await get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) + + +async def resolve_address( + ipaddr: str, *args: Any, **kwargs: Any +) -> dns.resolver.Answer: + """Use a resolver to run a reverse query for PTR records. + + See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) + + +async def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> dns.resolver.HostAnswers: + """Use a resolver to asynchronously query for address records. + + See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_name(name, family, **kwargs) + + +async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: + """Determine the canonical name of *name*. + + See :py:func:`dns.resolver.Resolver.canonical_name` for more + information on the parameters and possible exceptions. + """ + + return await get_default_resolver().canonical_name(name) + + +async def try_ddr(timeout: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return await get_default_resolver().try_ddr(timeout) + + +async def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.name.Name: + """Find the name of the zone which contains the specified name. + + See :py:func:`dns.resolver.Resolver.zone_for_name` for more + information on the parameters and possible exceptions. + """ + + if isinstance(name, str): + name = dns.name.from_text(name, dns.name.root) + if resolver is None: + resolver = get_default_resolver() + if not name.is_absolute(): + raise NotAbsolute(name) + while True: + try: + answer = await resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, backend=backend + ) + assert answer.rrset is not None + if answer.rrset.name == name: + return name + # otherwise we were CNAMEd or DNAMEd and need to look higher + except (NXDOMAIN, NoAnswer): + pass + try: + name = name.parent() + except dns.name.NoParent: # pragma: no cover + raise NoRootSOA + + +async def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + answers = await resolver.resolve_name(where, family) + for address in answers.addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.asyncresolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +async def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + backend: Optional[dns.asyncbackend.Backend] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> dns.resolver.Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.asyncresolver.make_resolver_at()`` + to make a resolver, and then uses it to resolve the query. + + See ``dns.asyncresolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.asyncresolver.make_resolver_at`` for information about the + resolver parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.asyncresolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + res = await make_resolver_at(where, port, family, resolver) + return await res.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + backend, + ) diff --git a/venv/Lib/site-packages/dns/dnssec.py b/venv/Lib/site-packages/dns/dnssec.py new file mode 100644 index 00000000..e49c3b79 --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssec.py @@ -0,0 +1,1223 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Common DNSSEC-related functions and constants.""" + + +import base64 +import contextlib +import functools +import hashlib +import struct +import time +from datetime import datetime +from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast + +import dns._features +import dns.exception +import dns.name +import dns.node +import dns.rdata +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rrset +import dns.transaction +import dns.zone +from dns.dnssectypes import Algorithm, DSDigest, NSEC3Hash +from dns.exception import ( # pylint: disable=W0611 + AlgorithmKeyMismatch, + DeniedByPolicy, + UnsupportedAlgorithm, + ValidationFailure, +) +from dns.rdtypes.ANY.CDNSKEY import CDNSKEY +from dns.rdtypes.ANY.CDS import CDS +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.ANY.DS import DS +from dns.rdtypes.ANY.NSEC import NSEC, Bitmap +from dns.rdtypes.ANY.NSEC3PARAM import NSEC3PARAM +from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime +from dns.rdtypes.dnskeybase import Flag + +PublicKey = Union[ + "GenericPublicKey", + "rsa.RSAPublicKey", + "ec.EllipticCurvePublicKey", + "ed25519.Ed25519PublicKey", + "ed448.Ed448PublicKey", +] + +PrivateKey = Union[ + "GenericPrivateKey", + "rsa.RSAPrivateKey", + "ec.EllipticCurvePrivateKey", + "ed25519.Ed25519PrivateKey", + "ed448.Ed448PrivateKey", +] + +RRsetSigner = Callable[[dns.transaction.Transaction, dns.rrset.RRset], None] + + +def algorithm_from_text(text: str) -> Algorithm: + """Convert text into a DNSSEC algorithm value. + + *text*, a ``str``, the text to convert to into an algorithm value. + + Returns an ``int``. + """ + + return Algorithm.from_text(text) + + +def algorithm_to_text(value: Union[Algorithm, int]) -> str: + """Convert a DNSSEC algorithm value to text + + *value*, a ``dns.dnssec.Algorithm``. + + Returns a ``str``, the name of a DNSSEC algorithm. + """ + + return Algorithm.to_text(value) + + +def to_timestamp(value: Union[datetime, str, float, int]) -> int: + """Convert various format to a timestamp""" + if isinstance(value, datetime): + return int(value.timestamp()) + elif isinstance(value, str): + return sigtime_to_posixtime(value) + elif isinstance(value, float): + return int(value) + elif isinstance(value, int): + return value + else: + raise TypeError("Unsupported timestamp type") + + +def key_id(key: Union[DNSKEY, CDNSKEY]) -> int: + """Return the key id (a 16-bit number) for the specified key. + + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` + + Returns an ``int`` between 0 and 65535 + """ + + rdata = key.to_wire() + if key.algorithm == Algorithm.RSAMD5: + return (rdata[-3] << 8) + rdata[-2] + else: + total = 0 + for i in range(len(rdata) // 2): + total += (rdata[2 * i] << 8) + rdata[2 * i + 1] + if len(rdata) % 2 != 0: + total += rdata[len(rdata) - 1] << 8 + total += (total >> 16) & 0xFFFF + return total & 0xFFFF + + +class Policy: + def __init__(self): + pass + + def ok_to_sign(self, _: DNSKEY) -> bool: # pragma: no cover + return False + + def ok_to_validate(self, _: DNSKEY) -> bool: # pragma: no cover + return False + + def ok_to_create_ds(self, _: DSDigest) -> bool: # pragma: no cover + return False + + def ok_to_validate_ds(self, _: DSDigest) -> bool: # pragma: no cover + return False + + +class SimpleDeny(Policy): + def __init__(self, deny_sign, deny_validate, deny_create_ds, deny_validate_ds): + super().__init__() + self._deny_sign = deny_sign + self._deny_validate = deny_validate + self._deny_create_ds = deny_create_ds + self._deny_validate_ds = deny_validate_ds + + def ok_to_sign(self, key: DNSKEY) -> bool: + return key.algorithm not in self._deny_sign + + def ok_to_validate(self, key: DNSKEY) -> bool: + return key.algorithm not in self._deny_validate + + def ok_to_create_ds(self, algorithm: DSDigest) -> bool: + return algorithm not in self._deny_create_ds + + def ok_to_validate_ds(self, algorithm: DSDigest) -> bool: + return algorithm not in self._deny_validate_ds + + +rfc_8624_policy = SimpleDeny( + {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1, Algorithm.ECCGOST}, + {Algorithm.RSAMD5, Algorithm.DSA, Algorithm.DSANSEC3SHA1}, + {DSDigest.NULL, DSDigest.SHA1, DSDigest.GOST}, + {DSDigest.NULL}, +) + +allow_all_policy = SimpleDeny(set(), set(), set(), set()) + + +default_policy = rfc_8624_policy + + +def make_ds( + name: Union[dns.name.Name, str], + key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, + policy: Optional[Policy] = None, + validating: bool = False, +) -> DS: + """Create a DS record for a DNSSEC key. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. + + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` or ``dns.rdtypes.ANY.DNSKEY.CDNSKEY``, + the key the DS is about. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If *key* is a relative name, + then it will be made absolute using the specified origin. + + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + *validating*, a ``bool``. If ``True``, then policy is checked in + validating mode, i.e. "Is it ok to validate using this digest algorithm?". + Otherwise the policy is checked in creating mode, i.e. "Is it ok to create a DS with + this digest algorithm?". + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Raises ``DeniedByPolicy`` if the algorithm is denied by policy. + + Returns a ``dns.rdtypes.ANY.DS.DS`` + """ + + if policy is None: + policy = default_policy + try: + if isinstance(algorithm, str): + algorithm = DSDigest[algorithm.upper()] + except Exception: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + if validating: + check = policy.ok_to_validate_ds + else: + check = policy.ok_to_create_ds + if not check(algorithm): + raise DeniedByPolicy + if not isinstance(key, (DNSKEY, CDNSKEY)): + raise ValueError("key is not a DNSKEY/CDNSKEY") + if algorithm == DSDigest.SHA1: + dshash = hashlib.sha1() + elif algorithm == DSDigest.SHA256: + dshash = hashlib.sha256() + elif algorithm == DSDigest.SHA384: + dshash = hashlib.sha384() + else: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + + if isinstance(name, str): + name = dns.name.from_text(name, origin) + wire = name.canonicalize().to_wire() + assert wire is not None + dshash.update(wire) + dshash.update(key.to_wire(origin=origin)) + digest = dshash.digest() + + dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest + ds = dns.rdata.from_wire( + dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata) + ) + return cast(DS, ds) + + +def make_cds( + name: Union[dns.name.Name, str], + key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, +) -> CDS: + """Create a CDS record for a DNSSEC key. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. + + *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` or ``dns.rdtypes.ANY.DNSKEY.CDNSKEY``, + the key the DS is about. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If *key* is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.rdtypes.ANY.DS.CDS`` + """ + + ds = make_ds(name, key, algorithm, origin) + return CDS( + rdclass=ds.rdclass, + rdtype=dns.rdatatype.CDS, + key_tag=ds.key_tag, + algorithm=ds.algorithm, + digest_type=ds.digest_type, + digest=ds.digest, + ) + + +def _find_candidate_keys( + keys: Dict[dns.name.Name, Union[dns.rdataset.Rdataset, dns.node.Node]], rrsig: RRSIG +) -> Optional[List[DNSKEY]]: + value = keys.get(rrsig.signer) + if isinstance(value, dns.node.Node): + rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY) + else: + rdataset = value + if rdataset is None: + return None + return [ + cast(DNSKEY, rd) + for rd in rdataset + if rd.algorithm == rrsig.algorithm + and key_id(rd) == rrsig.key_tag + and (rd.flags & Flag.ZONE) == Flag.ZONE # RFC 4034 2.1.1 + and rd.protocol == 3 # RFC 4034 2.1.2 + ] + + +def _get_rrname_rdataset( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], +) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: + if isinstance(rrset, tuple): + return rrset[0], rrset[1] + else: + return rrset.name, rrset + + +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: + public_cls = get_algorithm_cls_from_dnskey(key).public_cls + try: + public_key = public_cls.from_dnskey(key) + except ValueError: + raise ValidationFailure("invalid public key") + public_key.verify(sig, data) + + +def _validate_rrsig( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, + policy: Optional[Policy] = None, +) -> None: + """Validate an RRset against a single signature rdata, throwing an + exception if validation is not successful. + + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *rrsig*, a ``dns.rdata.Rdata``, the signature to validate. + + *keys*, the key dictionary, used to find the DNSKEY associated + with a given name. The dictionary is keyed by a + ``dns.name.Name``, and has ``dns.node.Node`` or + ``dns.rdataset.Rdataset`` values. + + *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative + names. + + *now*, a ``float`` or ``None``, the time, in seconds since the epoch, to + use as the current time when validating. If ``None``, the actual current + time is used. + + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + Raises ``ValidationFailure`` if the signature is expired, not yet valid, + the public key is invalid, the algorithm is unknown, the verification + fails, etc. + + Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by + dnspython but not implemented. + """ + + if policy is None: + policy = default_policy + + candidate_keys = _find_candidate_keys(keys, rrsig) + if candidate_keys is None: + raise ValidationFailure("unknown key") + + if now is None: + now = time.time() + if rrsig.expiration < now: + raise ValidationFailure("expired") + if rrsig.inception > now: + raise ValidationFailure("not yet valid") + + data = _make_rrsig_signature_data(rrset, rrsig, origin) + + for candidate_key in candidate_keys: + if not policy.ok_to_validate(candidate_key): + continue + try: + _validate_signature(rrsig.signature, data, candidate_key) + return + except (InvalidSignature, ValidationFailure): + # this happens on an individual validation failure + continue + # nothing verified -- raise failure: + raise ValidationFailure("verify failure") + + +def _validate( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name] = None, + now: Optional[float] = None, + policy: Optional[Policy] = None, +) -> None: + """Validate an RRset against a signature RRset, throwing an exception + if none of the signatures validate. + + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *rrsigset*, the signature RRset. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *keys*, the key dictionary, used to find the DNSKEY associated + with a given name. The dictionary is keyed by a + ``dns.name.Name``, and has ``dns.node.Node`` or + ``dns.rdataset.Rdataset`` values. + + *origin*, a ``dns.name.Name``, the origin to use for relative names; + defaults to None. + + *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to + use as the current time when validating. If ``None``, the actual current + time is used. + + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + Raises ``ValidationFailure`` if the signature is expired, not yet valid, + the public key is invalid, the algorithm is unknown, the verification + fails, etc. + """ + + if policy is None: + policy = default_policy + + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root) + + if isinstance(rrset, tuple): + rrname = rrset[0] + else: + rrname = rrset.name + + if isinstance(rrsigset, tuple): + rrsigname = rrsigset[0] + rrsigrdataset = rrsigset[1] + else: + rrsigname = rrsigset.name + rrsigrdataset = rrsigset + + rrname = rrname.choose_relativity(origin) + rrsigname = rrsigname.choose_relativity(origin) + if rrname != rrsigname: + raise ValidationFailure("owner names do not match") + + for rrsig in rrsigrdataset: + if not isinstance(rrsig, RRSIG): + raise ValidationFailure("expected an RRSIG") + try: + _validate_rrsig(rrset, rrsig, keys, origin, now, policy) + return + except (ValidationFailure, UnsupportedAlgorithm): + pass + raise ValidationFailure("no RRSIGs validated") + + +def _sign( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + private_key: PrivateKey, + signer: dns.name.Name, + dnskey: DNSKEY, + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + verify: bool = False, + policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, +) -> RRSIG: + """Sign RRset using private key. + + *rrset*, the RRset to validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *private_key*, the private key to use for signing, a + ``cryptography.hazmat.primitives.asymmetric`` private key class applicable + for DNSSEC. + + *signer*, a ``dns.name.Name``, the Signer's name. + + *dnskey*, a ``DNSKEY`` matching ``private_key``. + + *inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the + signature inception time. If ``None``, the current time is used. If a ``str``, the + format is "YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX + epoch in text form; this is the same the RRSIG rdata's text form. + Values of type `int` or `float` are interpreted as seconds since the UNIX epoch. + + *expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + expiration time. If ``None``, the expiration time will be the inception time plus + the value of the *lifetime* parameter. See the description of *inception* above + for how the various parameter types are interpreted. + + *lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This + parameter is only meaningful if *expiration* is ``None``. + + *verify*, a ``bool``. If set to ``True``, the signer will verify signatures + after they are created; the default is ``False``. + + *policy*, a ``dns.dnssec.Policy`` or ``None``. If ``None``, the default policy, + ``dns.dnssec.default_policy`` is used; this policy defaults to that of RFC 8624. + + *origin*, a ``dns.name.Name`` or ``None``. If ``None``, the default, then all + names in the rrset (including its owner name) must be absolute; otherwise the + specified origin will be used to make names absolute when signing. + + Raises ``DeniedByPolicy`` if the signature is denied by policy. + """ + + if policy is None: + policy = default_policy + if not policy.ok_to_sign(dnskey): + raise DeniedByPolicy + + if isinstance(rrset, tuple): + rdclass = rrset[1].rdclass + rdtype = rrset[1].rdtype + rrname = rrset[0] + original_ttl = rrset[1].ttl + else: + rdclass = rrset.rdclass + rdtype = rrset.rdtype + rrname = rrset.name + original_ttl = rrset.ttl + + if inception is not None: + rrsig_inception = to_timestamp(inception) + else: + rrsig_inception = int(time.time()) + + if expiration is not None: + rrsig_expiration = to_timestamp(expiration) + elif lifetime is not None: + rrsig_expiration = rrsig_inception + lifetime + else: + raise ValueError("expiration or lifetime must be specified") + + # Derelativize now because we need a correct labels length for the + # rrsig_template. + if origin is not None: + rrname = rrname.derelativize(origin) + labels = len(rrname) - 1 + + # Adjust labels appropriately for wildcards. + if rrname.is_wild(): + labels -= 1 + + rrsig_template = RRSIG( + rdclass=rdclass, + rdtype=dns.rdatatype.RRSIG, + type_covered=rdtype, + algorithm=dnskey.algorithm, + labels=labels, + original_ttl=original_ttl, + expiration=rrsig_expiration, + inception=rrsig_inception, + key_tag=key_id(dnskey), + signer=signer, + signature=b"", + ) + + data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) + + if isinstance(private_key, GenericPrivateKey): + signing_key = private_key + else: + try: + private_cls = get_algorithm_cls_from_dnskey(dnskey) + signing_key = private_cls(key=private_key) + except UnsupportedAlgorithm: + raise TypeError("Unsupported key algorithm") + + signature = signing_key.sign(data, verify) + + return cast(RRSIG, rrsig_template.replace(signature=signature)) + + +def _make_rrsig_signature_data( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + origin: Optional[dns.name.Name] = None, +) -> bytes: + """Create signature rdata. + + *rrset*, the RRset to sign/validate. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *rrsig*, a ``dns.rdata.Rdata``, the signature to validate, or the + signature template used when signing. + + *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative + names. + + Raises ``UnsupportedAlgorithm`` if the algorithm is recognized by + dnspython but not implemented. + """ + + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root) + + signer = rrsig.signer + if not signer.is_absolute(): + if origin is None: + raise ValidationFailure("relative RR name without an origin specified") + signer = signer.derelativize(origin) + + # For convenience, allow the rrset to be specified as a (name, + # rdataset) tuple as well as a proper rrset + rrname, rdataset = _get_rrname_rdataset(rrset) + + data = b"" + data += rrsig.to_wire(origin=signer)[:18] + data += rrsig.signer.to_digestable(signer) + + # Derelativize the name before considering labels. + if not rrname.is_absolute(): + if origin is None: + raise ValidationFailure("relative RR name without an origin specified") + rrname = rrname.derelativize(origin) + + name_len = len(rrname) + if rrname.is_wild() and rrsig.labels != name_len - 2: + raise ValidationFailure("wild owner name has wrong label length") + if name_len - 1 < rrsig.labels: + raise ValidationFailure("owner name longer than RRSIG labels") + elif rrsig.labels < name_len - 1: + suffix = rrname.split(rrsig.labels + 1)[1] + rrname = dns.name.from_text("*", suffix) + rrnamebuf = rrname.to_digestable() + rrfixed = struct.pack("!HHI", rdataset.rdtype, rdataset.rdclass, rrsig.original_ttl) + rdatas = [rdata.to_digestable(origin) for rdata in rdataset] + for rdata in sorted(rdatas): + data += rrnamebuf + data += rrfixed + rrlen = struct.pack("!H", len(rdata)) + data += rrlen + data += rdata + + return data + + +def _make_dnskey( + public_key: PublicKey, + algorithm: Union[int, str], + flags: int = Flag.ZONE, + protocol: int = 3, +) -> DNSKEY: + """Convert a public key to DNSKEY Rdata + + *public_key*, a ``PublicKey`` (``GenericPublicKey`` or + ``cryptography.hazmat.primitives.asymmetric``) to convert. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *flags*: DNSKEY flags field as an integer. + + *protocol*: DNSKEY protocol field as an integer. + + Raises ``ValueError`` if the specified key algorithm parameters are not + unsupported, ``TypeError`` if the key type is unsupported, + `UnsupportedAlgorithm` if the algorithm is unknown and + `AlgorithmKeyMismatch` if the algorithm does not match the key type. + + Return DNSKEY ``Rdata``. + """ + + algorithm = Algorithm.make(algorithm) + + if isinstance(public_key, GenericPublicKey): + return public_key.to_dnskey(flags=flags, protocol=protocol) + else: + public_cls = get_algorithm_cls(algorithm).public_cls + return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol) + + +def _make_cdnskey( + public_key: PublicKey, + algorithm: Union[int, str], + flags: int = Flag.ZONE, + protocol: int = 3, +) -> CDNSKEY: + """Convert a public key to CDNSKEY Rdata + + *public_key*, the public key to convert, a + ``cryptography.hazmat.primitives.asymmetric`` public key class applicable + for DNSSEC. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *flags*: DNSKEY flags field as an integer. + + *protocol*: DNSKEY protocol field as an integer. + + Raises ``ValueError`` if the specified key algorithm parameters are not + unsupported, ``TypeError`` if the key type is unsupported, + `UnsupportedAlgorithm` if the algorithm is unknown and + `AlgorithmKeyMismatch` if the algorithm does not match the key type. + + Return CDNSKEY ``Rdata``. + """ + + dnskey = _make_dnskey(public_key, algorithm, flags, protocol) + + return CDNSKEY( + rdclass=dnskey.rdclass, + rdtype=dns.rdatatype.CDNSKEY, + flags=dnskey.flags, + protocol=dnskey.protocol, + algorithm=dnskey.algorithm, + key=dnskey.key, + ) + + +def nsec3_hash( + domain: Union[dns.name.Name, str], + salt: Optional[Union[str, bytes]], + iterations: int, + algorithm: Union[int, str], +) -> str: + """ + Calculate the NSEC3 hash, according to + https://tools.ietf.org/html/rfc5155#section-5 + + *domain*, a ``dns.name.Name`` or ``str``, the name to hash. + + *salt*, a ``str``, ``bytes``, or ``None``, the hash salt. If a + string, it is decoded as a hex string. + + *iterations*, an ``int``, the number of iterations. + + *algorithm*, a ``str`` or ``int``, the hash algorithm. + The only defined algorithm is SHA1. + + Returns a ``str``, the encoded NSEC3 hash. + """ + + b32_conversion = str.maketrans( + "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV" + ) + + try: + if isinstance(algorithm, str): + algorithm = NSEC3Hash[algorithm.upper()] + except Exception: + raise ValueError("Wrong hash algorithm (only SHA1 is supported)") + + if algorithm != NSEC3Hash.SHA1: + raise ValueError("Wrong hash algorithm (only SHA1 is supported)") + + if salt is None: + salt_encoded = b"" + elif isinstance(salt, str): + if len(salt) % 2 == 0: + salt_encoded = bytes.fromhex(salt) + else: + raise ValueError("Invalid salt length") + else: + salt_encoded = salt + + if not isinstance(domain, dns.name.Name): + domain = dns.name.from_text(domain) + domain_encoded = domain.canonicalize().to_wire() + assert domain_encoded is not None + + digest = hashlib.sha1(domain_encoded + salt_encoded).digest() + for _ in range(iterations): + digest = hashlib.sha1(digest + salt_encoded).digest() + + output = base64.b32encode(digest).decode("utf-8") + output = output.translate(b32_conversion) + + return output + + +def make_ds_rdataset( + rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + algorithms: Set[Union[DSDigest, str]], + origin: Optional[dns.name.Name] = None, +) -> dns.rdataset.Rdataset: + """Create a DS record from DNSKEY/CDNSKEY/CDS. + + *rrset*, the RRset to create DS Rdataset for. This can be a + ``dns.rrset.RRset`` or a (``dns.name.Name``, ``dns.rdataset.Rdataset``) + tuple. + + *algorithms*, a set of ``str`` or ``int`` specifying the hash algorithms. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. If the RRset is a CDS, only digest + algorithms matching algorithms are accepted. + + *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if any of the algorithms are unknown and + ``ValueError`` if the given RRset is not usable. + + Returns a ``dns.rdataset.Rdataset`` + """ + + rrname, rdataset = _get_rrname_rdataset(rrset) + + if rdataset.rdtype not in ( + dns.rdatatype.DNSKEY, + dns.rdatatype.CDNSKEY, + dns.rdatatype.CDS, + ): + raise ValueError("rrset not a DNSKEY/CDNSKEY/CDS") + + _algorithms = set() + for algorithm in algorithms: + try: + if isinstance(algorithm, str): + algorithm = DSDigest[algorithm.upper()] + except Exception: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + _algorithms.add(algorithm) + + if rdataset.rdtype == dns.rdatatype.CDS: + res = [] + for rdata in cds_rdataset_to_ds_rdataset(rdataset): + if rdata.digest_type in _algorithms: + res.append(rdata) + if len(res) == 0: + raise ValueError("no acceptable CDS rdata found") + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + res = [] + for algorithm in _algorithms: + res.extend(dnskey_rdataset_to_cds_rdataset(rrname, rdataset, algorithm, origin)) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def cds_rdataset_to_ds_rdataset( + rdataset: dns.rdataset.Rdataset, +) -> dns.rdataset.Rdataset: + """Create a CDS record from DS. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create DS Rdataset for. + + Raises ``ValueError`` if the rdataset is not CDS. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype != dns.rdatatype.CDS: + raise ValueError("rdataset not a CDS") + res = [] + for rdata in rdataset: + res.append( + CDS( + rdclass=rdata.rdclass, + rdtype=dns.rdatatype.DS, + key_tag=rdata.key_tag, + algorithm=rdata.algorithm, + digest_type=rdata.digest_type, + digest=rdata.digest, + ) + ) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def dnskey_rdataset_to_cds_rdataset( + name: Union[dns.name.Name, str], + rdataset: dns.rdataset.Rdataset, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name] = None, +) -> dns.rdataset.Rdataset: + """Create a CDS record from DNSKEY/CDNSKEY. + + *name*, a ``dns.name.Name`` or ``str``, the owner name of the CDS record. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create DS Rdataset for. + + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. + The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case + does not matter for these strings. + + *origin*, a ``dns.name.Name`` or ``None``. If `key` is a relative name, + then it will be made absolute using the specified origin. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown or + ``ValueError`` if the rdataset is not DNSKEY/CDNSKEY. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype not in (dns.rdatatype.DNSKEY, dns.rdatatype.CDNSKEY): + raise ValueError("rdataset not a DNSKEY/CDNSKEY") + res = [] + for rdata in rdataset: + res.append(make_cds(name, rdata, algorithm, origin)) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def dnskey_rdataset_to_cdnskey_rdataset( + rdataset: dns.rdataset.Rdataset, +) -> dns.rdataset.Rdataset: + """Create a CDNSKEY record from DNSKEY. + + *rdataset*, a ``dns.rdataset.Rdataset``, to create CDNSKEY Rdataset for. + + Returns a ``dns.rdataset.Rdataset`` + """ + + if rdataset.rdtype != dns.rdatatype.DNSKEY: + raise ValueError("rdataset not a DNSKEY") + res = [] + for rdata in rdataset: + res.append( + CDNSKEY( + rdclass=rdataset.rdclass, + rdtype=rdataset.rdtype, + flags=rdata.flags, + protocol=rdata.protocol, + algorithm=rdata.algorithm, + key=rdata.key, + ) + ) + return dns.rdataset.from_rdata_list(rdataset.ttl, res) + + +def default_rrset_signer( + txn: dns.transaction.Transaction, + rrset: dns.rrset.RRset, + signer: dns.name.Name, + ksks: List[Tuple[PrivateKey, DNSKEY]], + zsks: List[Tuple[PrivateKey, DNSKEY]], + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + policy: Optional[Policy] = None, + origin: Optional[dns.name.Name] = None, +) -> None: + """Default RRset signer""" + + if rrset.rdtype in set( + [ + dns.rdatatype.RdataType.DNSKEY, + dns.rdatatype.RdataType.CDS, + dns.rdatatype.RdataType.CDNSKEY, + ] + ): + keys = ksks + else: + keys = zsks + + for private_key, dnskey in keys: + rrsig = dns.dnssec.sign( + rrset=rrset, + private_key=private_key, + dnskey=dnskey, + inception=inception, + expiration=expiration, + lifetime=lifetime, + signer=signer, + policy=policy, + origin=origin, + ) + txn.add(rrset.name, rrset.ttl, rrsig) + + +def sign_zone( + zone: dns.zone.Zone, + txn: Optional[dns.transaction.Transaction] = None, + keys: Optional[List[Tuple[PrivateKey, DNSKEY]]] = None, + add_dnskey: bool = True, + dnskey_ttl: Optional[int] = None, + inception: Optional[Union[datetime, str, int, float]] = None, + expiration: Optional[Union[datetime, str, int, float]] = None, + lifetime: Optional[int] = None, + nsec3: Optional[NSEC3PARAM] = None, + rrset_signer: Optional[RRsetSigner] = None, + policy: Optional[Policy] = None, +) -> None: + """Sign zone. + + *zone*, a ``dns.zone.Zone``, the zone to sign. + + *txn*, a ``dns.transaction.Transaction``, an optional transaction to use for + signing. + + *keys*, a list of (``PrivateKey``, ``DNSKEY``) tuples, to use for signing. KSK/ZSK + roles are assigned automatically if the SEP flag is used, otherwise all RRsets are + signed by all keys. + + *add_dnskey*, a ``bool``. If ``True``, the default, all specified DNSKEYs are + automatically added to the zone on signing. + + *dnskey_ttl*, a``int``, specifies the TTL for DNSKEY RRs. If not specified the TTL + of the existing DNSKEY RRset used or the TTL of the SOA RRset. + + *inception*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + inception time. If ``None``, the current time is used. If a ``str``, the format is + "YYYYMMDDHHMMSS" or alternatively the number of seconds since the UNIX epoch in text + form; this is the same the RRSIG rdata's text form. Values of type `int` or `float` + are interpreted as seconds since the UNIX epoch. + + *expiration*, a ``datetime``, ``str``, ``int``, ``float`` or ``None``, the signature + expiration time. If ``None``, the expiration time will be the inception time plus + the value of the *lifetime* parameter. See the description of *inception* above for + how the various parameter types are interpreted. + + *lifetime*, an ``int`` or ``None``, the signature lifetime in seconds. This + parameter is only meaningful if *expiration* is ``None``. + + *nsec3*, a ``NSEC3PARAM`` Rdata, configures signing using NSEC3. Not yet + implemented. + + *rrset_signer*, a ``Callable``, an optional function for signing RRsets. The + function requires two arguments: transaction and RRset. If the not specified, + ``dns.dnssec.default_rrset_signer`` will be used. + + Returns ``None``. + """ + + ksks = [] + zsks = [] + + # if we have both KSKs and ZSKs, split by SEP flag. if not, sign all + # records with all keys + if keys: + for key in keys: + if key[1].flags & Flag.SEP: + ksks.append(key) + else: + zsks.append(key) + if not ksks: + ksks = keys + if not zsks: + zsks = keys + else: + keys = [] + + if txn: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(txn) + else: + cm = zone.writer() + + with cm as _txn: + if add_dnskey: + if dnskey_ttl is None: + dnskey = _txn.get(zone.origin, dns.rdatatype.DNSKEY) + if dnskey: + dnskey_ttl = dnskey.ttl + else: + soa = _txn.get(zone.origin, dns.rdatatype.SOA) + dnskey_ttl = soa.ttl + for _, dnskey in keys: + _txn.add(zone.origin, dnskey_ttl, dnskey) + + if nsec3: + raise NotImplementedError("Signing with NSEC3 not yet implemented") + else: + _rrset_signer = rrset_signer or functools.partial( + default_rrset_signer, + signer=zone.origin, + ksks=ksks, + zsks=zsks, + inception=inception, + expiration=expiration, + lifetime=lifetime, + policy=policy, + origin=zone.origin, + ) + return _sign_zone_nsec(zone, _txn, _rrset_signer) + + +def _sign_zone_nsec( + zone: dns.zone.Zone, + txn: dns.transaction.Transaction, + rrset_signer: Optional[RRsetSigner] = None, +) -> None: + """NSEC zone signer""" + + def _txn_add_nsec( + txn: dns.transaction.Transaction, + name: dns.name.Name, + next_secure: Optional[dns.name.Name], + rdclass: dns.rdataclass.RdataClass, + ttl: int, + rrset_signer: Optional[RRsetSigner] = None, + ) -> None: + """NSEC zone signer helper""" + mandatory_types = set( + [dns.rdatatype.RdataType.RRSIG, dns.rdatatype.RdataType.NSEC] + ) + node = txn.get_node(name) + if node and next_secure: + types = ( + set([rdataset.rdtype for rdataset in node.rdatasets]) | mandatory_types + ) + windows = Bitmap.from_rdtypes(list(types)) + rrset = dns.rrset.from_rdata( + name, + ttl, + NSEC( + rdclass=rdclass, + rdtype=dns.rdatatype.RdataType.NSEC, + next=next_secure, + windows=windows, + ), + ) + txn.add(rrset) + if rrset_signer: + rrset_signer(txn, rrset) + + rrsig_ttl = zone.get_soa().minimum + delegation = None + last_secure = None + + for name in sorted(txn.iterate_names()): + if delegation and name.is_subdomain(delegation): + # names below delegations are not secure + continue + elif txn.get(name, dns.rdatatype.NS) and name != zone.origin: + # inside delegation + delegation = name + else: + # outside delegation + delegation = None + + if rrset_signer: + node = txn.get_node(name) + if node: + for rdataset in node.rdatasets: + if rdataset.rdtype == dns.rdatatype.RRSIG: + # do not sign RRSIGs + continue + elif delegation and rdataset.rdtype != dns.rdatatype.DS: + # do not sign delegations except DS records + continue + else: + rrset = dns.rrset.from_rdata(name, rdataset.ttl, *rdataset) + rrset_signer(txn, rrset) + + # We need "is not None" as the empty name is False because its length is 0. + if last_secure is not None: + _txn_add_nsec(txn, last_secure, name, zone.rdclass, rrsig_ttl, rrset_signer) + last_secure = name + + if last_secure: + _txn_add_nsec( + txn, last_secure, zone.origin, zone.rdclass, rrsig_ttl, rrset_signer + ) + + +def _need_pyca(*args, **kwargs): + raise ImportError( + "DNSSEC validation requires python cryptography" + ) # pragma: no cover + + +if dns._features.have("dnssec"): + from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611 + ed25519, + ) + + from dns.dnssecalgs import ( # pylint: disable=C0412 + get_algorithm_cls, + get_algorithm_cls_from_dnskey, + ) + from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey + + validate = _validate # type: ignore + validate_rrsig = _validate_rrsig # type: ignore + sign = _sign + make_dnskey = _make_dnskey + make_cdnskey = _make_cdnskey + _have_pyca = True +else: # pragma: no cover + validate = _need_pyca + validate_rrsig = _need_pyca + sign = _need_pyca + make_dnskey = _need_pyca + make_cdnskey = _need_pyca + _have_pyca = False + +### BEGIN generated Algorithm constants + +RSAMD5 = Algorithm.RSAMD5 +DH = Algorithm.DH +DSA = Algorithm.DSA +ECC = Algorithm.ECC +RSASHA1 = Algorithm.RSASHA1 +DSANSEC3SHA1 = Algorithm.DSANSEC3SHA1 +RSASHA1NSEC3SHA1 = Algorithm.RSASHA1NSEC3SHA1 +RSASHA256 = Algorithm.RSASHA256 +RSASHA512 = Algorithm.RSASHA512 +ECCGOST = Algorithm.ECCGOST +ECDSAP256SHA256 = Algorithm.ECDSAP256SHA256 +ECDSAP384SHA384 = Algorithm.ECDSAP384SHA384 +ED25519 = Algorithm.ED25519 +ED448 = Algorithm.ED448 +INDIRECT = Algorithm.INDIRECT +PRIVATEDNS = Algorithm.PRIVATEDNS +PRIVATEOID = Algorithm.PRIVATEOID + +### END generated Algorithm constants diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__init__.py b/venv/Lib/site-packages/dns/dnssecalgs/__init__.py new file mode 100644 index 00000000..3d9181a7 --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/__init__.py @@ -0,0 +1,120 @@ +from typing import Dict, Optional, Tuple, Type, Union + +import dns.name +from dns.dnssecalgs.base import GenericPrivateKey +from dns.dnssectypes import Algorithm +from dns.exception import UnsupportedAlgorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + +if dns._features.have("dnssec"): + from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 + from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 + from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 + from dns.dnssecalgs.rsa import ( + PrivateRSAMD5, + PrivateRSASHA1, + PrivateRSASHA1NSEC3SHA1, + PrivateRSASHA256, + PrivateRSASHA512, + ) + + _have_cryptography = True +else: + _have_cryptography = False + +AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] + +algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {} +if _have_cryptography: + algorithms.update( + { + (Algorithm.RSAMD5, None): PrivateRSAMD5, + (Algorithm.DSA, None): PrivateDSA, + (Algorithm.RSASHA1, None): PrivateRSASHA1, + (Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1, + (Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1, + (Algorithm.RSASHA256, None): PrivateRSASHA256, + (Algorithm.RSASHA512, None): PrivateRSASHA512, + (Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256, + (Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384, + (Algorithm.ED25519, None): PrivateED25519, + (Algorithm.ED448, None): PrivateED448, + } + ) + + +def get_algorithm_cls( + algorithm: Union[int, str], prefix: AlgorithmPrefix = None +) -> Type[GenericPrivateKey]: + """Get Private Key class from Algorithm. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + algorithm = Algorithm.make(algorithm) + cls = algorithms.get((algorithm, prefix)) + if cls: + return cls + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) + ) + + +def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]: + """Get Private Key class from DNSKEY. + + *dnskey*, a ``DNSKEY`` to get Algorithm class for. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + prefix: AlgorithmPrefix = None + if dnskey.algorithm == Algorithm.PRIVATEDNS: + prefix, _ = dns.name.from_wire(dnskey.key, 0) + elif dnskey.algorithm == Algorithm.PRIVATEOID: + length = int(dnskey.key[0]) + prefix = dnskey.key[0 : length + 1] + return get_algorithm_cls(dnskey.algorithm, prefix) + + +def register_algorithm_cls( + algorithm: Union[int, str], + algorithm_cls: Type[GenericPrivateKey], + name: Optional[Union[dns.name.Name, str]] = None, + oid: Optional[bytes] = None, +) -> None: + """Register Algorithm Private Key class. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *algorithm_cls*: A `GenericPrivateKey` class. + + *name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms. + + *oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms. + + Raises ``ValueError`` if a name or oid is specified incorrectly. + """ + if not issubclass(algorithm_cls, GenericPrivateKey): + raise TypeError("Invalid algorithm class") + algorithm = Algorithm.make(algorithm) + prefix: AlgorithmPrefix = None + if algorithm == Algorithm.PRIVATEDNS: + if name is None: + raise ValueError("Name required for PRIVATEDNS algorithms") + if isinstance(name, str): + name = dns.name.from_text(name) + prefix = name + elif algorithm == Algorithm.PRIVATEOID: + if oid is None: + raise ValueError("OID required for PRIVATEOID algorithms") + prefix = bytes([len(oid)]) + oid + elif name: + raise ValueError("Name only supported for PRIVATEDNS algorithm") + elif oid: + raise ValueError("OID only supported for PRIVATEOID algorithm") + algorithms[(algorithm, prefix)] = algorithm_cls diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..f9233e9f Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/base.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/base.cpython-312.pyc new file mode 100644 index 00000000..21a35e7f Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/base.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/cryptography.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/cryptography.cpython-312.pyc new file mode 100644 index 00000000..6788e3d2 Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/cryptography.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/dsa.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/dsa.cpython-312.pyc new file mode 100644 index 00000000..3d042c6b Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/dsa.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/ecdsa.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/ecdsa.cpython-312.pyc new file mode 100644 index 00000000..3f3329c7 Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/ecdsa.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/eddsa.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/eddsa.cpython-312.pyc new file mode 100644 index 00000000..252a3925 Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/eddsa.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/rsa.cpython-312.pyc b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/rsa.cpython-312.pyc new file mode 100644 index 00000000..393e8da3 Binary files /dev/null and b/venv/Lib/site-packages/dns/dnssecalgs/__pycache__/rsa.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/dnssecalgs/base.py b/venv/Lib/site-packages/dns/dnssecalgs/base.py new file mode 100644 index 00000000..e990575a --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod # pylint: disable=no-name-in-module +from typing import Any, Optional, Type + +import dns.rdataclass +import dns.rdatatype +from dns.dnssectypes import Algorithm +from dns.exception import AlgorithmKeyMismatch +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.dnskeybase import Flag + + +class GenericPublicKey(ABC): + algorithm: Algorithm + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """Verify signed DNSSEC data""" + + @abstractmethod + def encode_key_bytes(self) -> bytes: + """Encode key as bytes for DNSKEY""" + + @classmethod + def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None: + if key.algorithm != cls.algorithm: + raise AlgorithmKeyMismatch + + def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY: + """Return public key as DNSKEY""" + return DNSKEY( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.DNSKEY, + flags=flags, + protocol=protocol, + algorithm=self.algorithm, + key=self.encode_key_bytes(), + ) + + @classmethod + @abstractmethod + def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey": + """Create public key from DNSKEY""" + + @classmethod + @abstractmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + """Create public key from PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + @abstractmethod + def to_pem(self) -> bytes: + """Return public-key as PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + +class GenericPrivateKey(ABC): + public_cls: Type[GenericPublicKey] + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign DNSSEC data""" + + @abstractmethod + def public_key(self) -> "GenericPublicKey": + """Return public key instance""" + + @classmethod + @abstractmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + """Create private key from PEM-encoded PKCS#8""" + + @abstractmethod + def to_pem(self, password: Optional[bytes] = None) -> bytes: + """Return private key as PEM-encoded PKCS#8""" diff --git a/venv/Lib/site-packages/dns/dnssecalgs/cryptography.py b/venv/Lib/site-packages/dns/dnssecalgs/cryptography.py new file mode 100644 index 00000000..5a31a812 --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/cryptography.py @@ -0,0 +1,68 @@ +from typing import Any, Optional, Type + +from cryptography.hazmat.primitives import serialization + +from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey +from dns.exception import AlgorithmKeyMismatch + + +class CryptographyPublicKey(GenericPublicKey): + key: Any = None + key_cls: Any = None + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + @classmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + key = serialization.load_pem_public_key(public_pem) + return cls(key=key) + + def to_pem(self) -> bytes: + return self.key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +class CryptographyPrivateKey(GenericPrivateKey): + key: Any = None + key_cls: Any = None + public_cls: Type[CryptographyPublicKey] + + def __init__(self, key: Any) -> None: # pylint: disable=super-init-not-called + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + key, self.key_cls + ): + raise AlgorithmKeyMismatch + self.key = key + + def public_key(self) -> "CryptographyPublicKey": + return self.public_cls(key=self.key.public_key()) + + @classmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + key = serialization.load_pem_private_key(private_pem, password=password) + return cls(key=key) + + def to_pem(self, password: Optional[bytes] = None) -> bytes: + encryption_algorithm: serialization.KeySerializationEncryption + if password: + encryption_algorithm = serialization.BestAvailableEncryption(password) + else: + encryption_algorithm = serialization.NoEncryption() + return self.key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) diff --git a/venv/Lib/site-packages/dns/dnssecalgs/dsa.py b/venv/Lib/site-packages/dns/dnssecalgs/dsa.py new file mode 100644 index 00000000..0fe4690d --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/dsa.py @@ -0,0 +1,101 @@ +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import dsa, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicDSA(CryptographyPublicKey): + key: dsa.DSAPublicKey + key_cls = dsa.DSAPublicKey + algorithm = Algorithm.DSA + chosen_hash = hashes.SHA1() + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[1:21] + sig_s = signature[21:] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 2536, section 2.""" + pn = self.key.public_numbers() + dsa_t = (self.key.key_size // 8 - 64) // 8 + if dsa_t > 8: + raise ValueError("unsupported DSA key size") + octets = 64 + dsa_t * 8 + res = struct.pack("!B", dsa_t) + res += pn.parameter_numbers.q.to_bytes(20, "big") + res += pn.parameter_numbers.p.to_bytes(octets, "big") + res += pn.parameter_numbers.g.to_bytes(octets, "big") + res += pn.y.to_bytes(octets, "big") + return res + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicDSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (t,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + octets = 64 + t * 8 + dsa_q = keyptr[0:20] + keyptr = keyptr[20:] + dsa_p = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_g = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_y = keyptr[0:octets] + return cls( + key=dsa.DSAPublicNumbers( # type: ignore + int.from_bytes(dsa_y, "big"), + dsa.DSAParameterNumbers( + int.from_bytes(dsa_p, "big"), + int.from_bytes(dsa_q, "big"), + int.from_bytes(dsa_g, "big"), + ), + ).public_key(default_backend()), + ) + + +class PrivateDSA(CryptographyPrivateKey): + key: dsa.DSAPrivateKey + key_cls = dsa.DSAPrivateKey + public_cls = PublicDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 2536, section 3.""" + public_dsa_key = self.key.public_key() + if public_dsa_key.key_size > 1024: + raise ValueError("DSA key size overflow") + der_signature = self.key.sign(data, self.public_cls.chosen_hash) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 + octets = 20 + signature = ( + struct.pack("!B", dsa_t) + + int.to_bytes(dsa_r, length=octets, byteorder="big") + + int.to_bytes(dsa_s, length=octets, byteorder="big") + ) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateDSA": + return cls( + key=dsa.generate_private_key(key_size=key_size), + ) + + +class PublicDSANSEC3SHA1(PublicDSA): + algorithm = Algorithm.DSANSEC3SHA1 + + +class PrivateDSANSEC3SHA1(PrivateDSA): + public_cls = PublicDSANSEC3SHA1 diff --git a/venv/Lib/site-packages/dns/dnssecalgs/ecdsa.py b/venv/Lib/site-packages/dns/dnssecalgs/ecdsa.py new file mode 100644 index 00000000..a31d79f2 --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/ecdsa.py @@ -0,0 +1,89 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicECDSA(CryptographyPublicKey): + key: ec.EllipticCurvePublicKey + key_cls = ec.EllipticCurvePublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + curve: ec.EllipticCurve + octets: int + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[0 : self.octets] + sig_s = signature[self.octets :] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, ec.ECDSA(self.chosen_hash)) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 6605, section 4.""" + pn = self.key.public_numbers() + return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA": + cls._ensure_algorithm_key_combination(key) + ecdsa_x = key.key[0 : cls.octets] + ecdsa_y = key.key[cls.octets : cls.octets * 2] + return cls( + key=ec.EllipticCurvePublicNumbers( + curve=cls.curve, + x=int.from_bytes(ecdsa_x, "big"), + y=int.from_bytes(ecdsa_y, "big"), + ).public_key(default_backend()), + ) + + +class PrivateECDSA(CryptographyPrivateKey): + key: ec.EllipticCurvePrivateKey + key_cls = ec.EllipticCurvePrivateKey + public_cls = PublicECDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 6605, section 4.""" + der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + signature = int.to_bytes( + dsa_r, length=self.public_cls.octets, byteorder="big" + ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big") + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateECDSA": + return cls( + key=ec.generate_private_key( + curve=cls.public_cls.curve, backend=default_backend() + ), + ) + + +class PublicECDSAP256SHA256(PublicECDSA): + algorithm = Algorithm.ECDSAP256SHA256 + chosen_hash = hashes.SHA256() + curve = ec.SECP256R1() + octets = 32 + + +class PrivateECDSAP256SHA256(PrivateECDSA): + public_cls = PublicECDSAP256SHA256 + + +class PublicECDSAP384SHA384(PublicECDSA): + algorithm = Algorithm.ECDSAP384SHA384 + chosen_hash = hashes.SHA384() + curve = ec.SECP384R1() + octets = 48 + + +class PrivateECDSAP384SHA384(PrivateECDSA): + public_cls = PublicECDSAP384SHA384 diff --git a/venv/Lib/site-packages/dns/dnssecalgs/eddsa.py b/venv/Lib/site-packages/dns/dnssecalgs/eddsa.py new file mode 100644 index 00000000..70505342 --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/eddsa.py @@ -0,0 +1,65 @@ +from typing import Type + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed448, ed25519 + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicEDDSA(CryptographyPublicKey): + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 8080, section 3.""" + return self.key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA": + cls._ensure_algorithm_key_combination(key) + return cls( + key=cls.key_cls.from_public_bytes(key.key), + ) + + +class PrivateEDDSA(CryptographyPrivateKey): + public_cls: Type[PublicEDDSA] + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 8080, section 4.""" + signature = self.key.sign(data) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateEDDSA": + return cls(key=cls.key_cls.generate()) + + +class PublicED25519(PublicEDDSA): + key: ed25519.Ed25519PublicKey + key_cls = ed25519.Ed25519PublicKey + algorithm = Algorithm.ED25519 + + +class PrivateED25519(PrivateEDDSA): + key: ed25519.Ed25519PrivateKey + key_cls = ed25519.Ed25519PrivateKey + public_cls = PublicED25519 + + +class PublicED448(PublicEDDSA): + key: ed448.Ed448PublicKey + key_cls = ed448.Ed448PublicKey + algorithm = Algorithm.ED448 + + +class PrivateED448(PrivateEDDSA): + key: ed448.Ed448PrivateKey + key_cls = ed448.Ed448PrivateKey + public_cls = PublicED448 diff --git a/venv/Lib/site-packages/dns/dnssecalgs/rsa.py b/venv/Lib/site-packages/dns/dnssecalgs/rsa.py new file mode 100644 index 00000000..e95dcf1d --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssecalgs/rsa.py @@ -0,0 +1,119 @@ +import math +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicRSA(CryptographyPublicKey): + key: rsa.RSAPublicKey + key_cls = rsa.RSAPublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 3110, section 2.""" + pn = self.key.public_numbers() + _exp_len = math.ceil(int.bit_length(pn.e) / 8) + exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") + if _exp_len > 255: + exp_header = b"\0" + struct.pack("!H", _exp_len) + else: + exp_header = struct.pack("!B", _exp_len) + if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: + raise ValueError("unsupported RSA key length") + return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicRSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (bytes_,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + if bytes_ == 0: + (bytes_,) = struct.unpack("!H", keyptr[0:2]) + keyptr = keyptr[2:] + rsa_e = keyptr[0:bytes_] + rsa_n = keyptr[bytes_:] + return cls( + key=rsa.RSAPublicNumbers( + int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big") + ).public_key(default_backend()) + ) + + +class PrivateRSA(CryptographyPrivateKey): + key: rsa.RSAPrivateKey + key_cls = rsa.RSAPrivateKey + public_cls = PublicRSA + default_public_exponent = 65537 + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 3110, section 3.""" + signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateRSA": + return cls( + key=rsa.generate_private_key( + public_exponent=cls.default_public_exponent, + key_size=key_size, + backend=default_backend(), + ) + ) + + +class PublicRSAMD5(PublicRSA): + algorithm = Algorithm.RSAMD5 + chosen_hash = hashes.MD5() + + +class PrivateRSAMD5(PrivateRSA): + public_cls = PublicRSAMD5 + + +class PublicRSASHA1(PublicRSA): + algorithm = Algorithm.RSASHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1(PrivateRSA): + public_cls = PublicRSASHA1 + + +class PublicRSASHA1NSEC3SHA1(PublicRSA): + algorithm = Algorithm.RSASHA1NSEC3SHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1NSEC3SHA1(PrivateRSA): + public_cls = PublicRSASHA1NSEC3SHA1 + + +class PublicRSASHA256(PublicRSA): + algorithm = Algorithm.RSASHA256 + chosen_hash = hashes.SHA256() + + +class PrivateRSASHA256(PrivateRSA): + public_cls = PublicRSASHA256 + + +class PublicRSASHA512(PublicRSA): + algorithm = Algorithm.RSASHA512 + chosen_hash = hashes.SHA512() + + +class PrivateRSASHA512(PrivateRSA): + public_cls = PublicRSASHA512 diff --git a/venv/Lib/site-packages/dns/dnssectypes.py b/venv/Lib/site-packages/dns/dnssectypes.py new file mode 100644 index 00000000..02131e0a --- /dev/null +++ b/venv/Lib/site-packages/dns/dnssectypes.py @@ -0,0 +1,71 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Common DNSSEC-related types.""" + +# This is a separate file to avoid import circularity between dns.dnssec and +# the implementations of the DS and DNSKEY types. + +import dns.enum + + +class Algorithm(dns.enum.IntEnum): + RSAMD5 = 1 + DH = 2 + DSA = 3 + ECC = 4 + RSASHA1 = 5 + DSANSEC3SHA1 = 6 + RSASHA1NSEC3SHA1 = 7 + RSASHA256 = 8 + RSASHA512 = 10 + ECCGOST = 12 + ECDSAP256SHA256 = 13 + ECDSAP384SHA384 = 14 + ED25519 = 15 + ED448 = 16 + INDIRECT = 252 + PRIVATEDNS = 253 + PRIVATEOID = 254 + + @classmethod + def _maximum(cls): + return 255 + + +class DSDigest(dns.enum.IntEnum): + """DNSSEC Delegation Signer Digest Algorithm""" + + NULL = 0 + SHA1 = 1 + SHA256 = 2 + GOST = 3 + SHA384 = 4 + + @classmethod + def _maximum(cls): + return 255 + + +class NSEC3Hash(dns.enum.IntEnum): + """NSEC3 hash algorithm""" + + SHA1 = 1 + + @classmethod + def _maximum(cls): + return 255 diff --git a/venv/Lib/site-packages/dns/e164.py b/venv/Lib/site-packages/dns/e164.py new file mode 100644 index 00000000..453736d4 --- /dev/null +++ b/venv/Lib/site-packages/dns/e164.py @@ -0,0 +1,116 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS E.164 helpers.""" + +from typing import Iterable, Optional, Union + +import dns.exception +import dns.name +import dns.resolver + +#: The public E.164 domain. +public_enum_domain = dns.name.from_text("e164.arpa.") + + +def from_e164( + text: str, origin: Optional[dns.name.Name] = public_enum_domain +) -> dns.name.Name: + """Convert an E.164 number in textual form into a Name object whose + value is the ENUM domain name for that number. + + Non-digits in the text are ignored, i.e. "16505551212", + "+1.650.555.1212" and "1 (650) 555-1212" are all the same. + + *text*, a ``str``, is an E.164 number in textual form. + + *origin*, a ``dns.name.Name``, the domain in which the number + should be constructed. The default is ``e164.arpa.``. + + Returns a ``dns.name.Name``. + """ + + parts = [d for d in text if d.isdigit()] + parts.reverse() + return dns.name.from_text(".".join(parts), origin=origin) + + +def to_e164( + name: dns.name.Name, + origin: Optional[dns.name.Name] = public_enum_domain, + want_plus_prefix: bool = True, +) -> str: + """Convert an ENUM domain name into an E.164 number. + + Note that dnspython does not have any information about preferred + number formats within national numbering plans, so all numbers are + emitted as a simple string of digits, prefixed by a '+' (unless + *want_plus_prefix* is ``False``). + + *name* is a ``dns.name.Name``, the ENUM domain name. + + *origin* is a ``dns.name.Name``, a domain containing the ENUM + domain name. The name is relativized to this domain before being + converted to text. If ``None``, no relativization is done. + + *want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of + the returned number. + + Returns a ``str``. + + """ + if origin is not None: + name = name.relativize(origin) + dlabels = [d for d in name.labels if d.isdigit() and len(d) == 1] + if len(dlabels) != len(name.labels): + raise dns.exception.SyntaxError("non-digit labels in ENUM domain name") + dlabels.reverse() + text = b"".join(dlabels) + if want_plus_prefix: + text = b"+" + text + return text.decode() + + +def query( + number: str, + domains: Iterable[Union[dns.name.Name, str]], + resolver: Optional[dns.resolver.Resolver] = None, +) -> dns.resolver.Answer: + """Look for NAPTR RRs for the specified number in the specified domains. + + e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) + + *number*, a ``str`` is the number to look for. + + *domains* is an iterable containing ``dns.name.Name`` values. + + *resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If + ``None``, the default resolver is used. + """ + + if resolver is None: + resolver = dns.resolver.get_default_resolver() + e_nx = dns.resolver.NXDOMAIN() + for domain in domains: + if isinstance(domain, str): + domain = dns.name.from_text(domain) + qname = dns.e164.from_e164(number, domain) + try: + return resolver.resolve(qname, "NAPTR") + except dns.resolver.NXDOMAIN as e: + e_nx += e + raise e_nx diff --git a/venv/Lib/site-packages/dns/edns.py b/venv/Lib/site-packages/dns/edns.py new file mode 100644 index 00000000..776e5eeb --- /dev/null +++ b/venv/Lib/site-packages/dns/edns.py @@ -0,0 +1,516 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""EDNS Options""" + +import binascii +import math +import socket +import struct +from typing import Any, Dict, Optional, Union + +import dns.enum +import dns.inet +import dns.rdata +import dns.wire + + +class OptionType(dns.enum.IntEnum): + #: NSID + NSID = 3 + #: DAU + DAU = 5 + #: DHU + DHU = 6 + #: N3U + N3U = 7 + #: ECS (client-subnet) + ECS = 8 + #: EXPIRE + EXPIRE = 9 + #: COOKIE + COOKIE = 10 + #: KEEPALIVE + KEEPALIVE = 11 + #: PADDING + PADDING = 12 + #: CHAIN + CHAIN = 13 + #: EDE (extended-dns-error) + EDE = 15 + + @classmethod + def _maximum(cls): + return 65535 + + +class Option: + """Base class for all EDNS option types.""" + + def __init__(self, otype: Union[OptionType, str]): + """Initialize an option. + + *otype*, a ``dns.edns.OptionType``, is the option type. + """ + self.otype = OptionType.make(otype) + + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + """Convert an option to wire format. + + Returns a ``bytes`` or ``None``. + + """ + raise NotImplementedError # pragma: no cover + + def to_text(self) -> str: + raise NotImplementedError # pragma: no cover + + @classmethod + def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option": + """Build an EDNS option object from wire format. + + *otype*, a ``dns.edns.OptionType``, is the option type. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restructed to the option length. + + Returns a ``dns.edns.Option``. + """ + raise NotImplementedError # pragma: no cover + + def _cmp(self, other): + """Compare an EDNS option with another option of the same type. + + Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*. + """ + wire = self.to_wire() + owire = other.to_wire() + if wire == owire: + return 0 + if wire > owire: + return 1 + return -1 + + def __eq__(self, other): + if not isinstance(other, Option): + return False + if self.otype != other.otype: + return False + return self._cmp(other) == 0 + + def __ne__(self, other): + if not isinstance(other, Option): + return True + if self.otype != other.otype: + return True + return self._cmp(other) != 0 + + def __lt__(self, other): + if not isinstance(other, Option) or self.otype != other.otype: + return NotImplemented + return self._cmp(other) < 0 + + def __le__(self, other): + if not isinstance(other, Option) or self.otype != other.otype: + return NotImplemented + return self._cmp(other) <= 0 + + def __ge__(self, other): + if not isinstance(other, Option) or self.otype != other.otype: + return NotImplemented + return self._cmp(other) >= 0 + + def __gt__(self, other): + if not isinstance(other, Option) or self.otype != other.otype: + return NotImplemented + return self._cmp(other) > 0 + + def __str__(self): + return self.to_text() + + +class GenericOption(Option): # lgtm[py/missing-equals] + """Generic Option Class + + This class is used for EDNS option types for which we have no better + implementation. + """ + + def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]): + super().__init__(otype) + self.data = dns.rdata.Rdata._as_bytes(data, True) + + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + if file: + file.write(self.data) + return None + else: + return self.data + + def to_text(self) -> str: + return "Generic %d" % self.otype + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + return cls(otype, parser.get_remaining()) + + +class ECSOption(Option): # lgtm[py/missing-equals] + """EDNS Client Subnet (ECS, RFC7871)""" + + def __init__(self, address: str, srclen: Optional[int] = None, scopelen: int = 0): + """*address*, a ``str``, is the client address information. + + *srclen*, an ``int``, the source prefix length, which is the + leftmost number of bits of the address to be used for the + lookup. The default is 24 for IPv4 and 56 for IPv6. + + *scopelen*, an ``int``, the scope prefix length. This value + must be 0 in queries, and should be set in responses. + """ + + super().__init__(OptionType.ECS) + af = dns.inet.af_for_address(address) + + if af == socket.AF_INET6: + self.family = 2 + if srclen is None: + srclen = 56 + address = dns.rdata.Rdata._as_ipv6_address(address) + srclen = dns.rdata.Rdata._as_int(srclen, 0, 128) + scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 128) + elif af == socket.AF_INET: + self.family = 1 + if srclen is None: + srclen = 24 + address = dns.rdata.Rdata._as_ipv4_address(address) + srclen = dns.rdata.Rdata._as_int(srclen, 0, 32) + scopelen = dns.rdata.Rdata._as_int(scopelen, 0, 32) + else: # pragma: no cover (this will never happen) + raise ValueError("Bad address family") + + assert srclen is not None + self.address = address + self.srclen = srclen + self.scopelen = scopelen + + addrdata = dns.inet.inet_pton(af, address) + nbytes = int(math.ceil(srclen / 8.0)) + + # Truncate to srclen and pad to the end of the last octet needed + # See RFC section 6 + self.addrdata = addrdata[:nbytes] + nbits = srclen % 8 + if nbits != 0: + last = struct.pack("B", ord(self.addrdata[-1:]) & (0xFF << (8 - nbits))) + self.addrdata = self.addrdata[:-1] + last + + def to_text(self) -> str: + return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) + + @staticmethod + def from_text(text: str) -> Option: + """Convert a string into a `dns.edns.ECSOption` + + *text*, a `str`, the text form of the option. + + Returns a `dns.edns.ECSOption`. + + Examples: + + >>> import dns.edns + >>> + >>> # basic example + >>> dns.edns.ECSOption.from_text('1.2.3.4/24') + >>> + >>> # also understands scope + >>> dns.edns.ECSOption.from_text('1.2.3.4/24/32') + >>> + >>> # IPv6 + >>> dns.edns.ECSOption.from_text('2001:4b98::1/64/64') + >>> + >>> # it understands results from `dns.edns.ECSOption.to_text()` + >>> dns.edns.ECSOption.from_text('ECS 1.2.3.4/24/32') + """ + optional_prefix = "ECS" + tokens = text.split() + ecs_text = None + if len(tokens) == 1: + ecs_text = tokens[0] + elif len(tokens) == 2: + if tokens[0] != optional_prefix: + raise ValueError('could not parse ECS from "{}"'.format(text)) + ecs_text = tokens[1] + else: + raise ValueError('could not parse ECS from "{}"'.format(text)) + n_slashes = ecs_text.count("/") + if n_slashes == 1: + address, tsrclen = ecs_text.split("/") + tscope = "0" + elif n_slashes == 2: + address, tsrclen, tscope = ecs_text.split("/") + else: + raise ValueError('could not parse ECS from "{}"'.format(text)) + try: + scope = int(tscope) + except ValueError: + raise ValueError( + "invalid scope " + '"{}": scope must be an integer'.format(tscope) + ) + try: + srclen = int(tsrclen) + except ValueError: + raise ValueError( + "invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen) + ) + return ECSOption(address, srclen, scope) + + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = ( + struct.pack("!HBB", self.family, self.srclen, self.scopelen) + self.addrdata + ) + if file: + file.write(value) + return None + else: + return value + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + family, src, scope = parser.get_struct("!HBB") + addrlen = int(math.ceil(src / 8.0)) + prefix = parser.get_bytes(addrlen) + if family == 1: + pad = 4 - addrlen + addr = dns.ipv4.inet_ntoa(prefix + b"\x00" * pad) + elif family == 2: + pad = 16 - addrlen + addr = dns.ipv6.inet_ntoa(prefix + b"\x00" * pad) + else: + raise ValueError("unsupported family") + + return cls(addr, src, scope) + + +class EDECode(dns.enum.IntEnum): + OTHER = 0 + UNSUPPORTED_DNSKEY_ALGORITHM = 1 + UNSUPPORTED_DS_DIGEST_TYPE = 2 + STALE_ANSWER = 3 + FORGED_ANSWER = 4 + DNSSEC_INDETERMINATE = 5 + DNSSEC_BOGUS = 6 + SIGNATURE_EXPIRED = 7 + SIGNATURE_NOT_YET_VALID = 8 + DNSKEY_MISSING = 9 + RRSIGS_MISSING = 10 + NO_ZONE_KEY_BIT_SET = 11 + NSEC_MISSING = 12 + CACHED_ERROR = 13 + NOT_READY = 14 + BLOCKED = 15 + CENSORED = 16 + FILTERED = 17 + PROHIBITED = 18 + STALE_NXDOMAIN_ANSWER = 19 + NOT_AUTHORITATIVE = 20 + NOT_SUPPORTED = 21 + NO_REACHABLE_AUTHORITY = 22 + NETWORK_ERROR = 23 + INVALID_DATA = 24 + + @classmethod + def _maximum(cls): + return 65535 + + +class EDEOption(Option): # lgtm[py/missing-equals] + """Extended DNS Error (EDE, RFC8914)""" + + _preserve_case = {"DNSKEY", "DS", "DNSSEC", "RRSIGs", "NSEC", "NXDOMAIN"} + + def __init__(self, code: Union[EDECode, str], text: Optional[str] = None): + """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the + extended error. + + *text*, a ``str`` or ``None``, specifying additional information about + the error. + """ + + super().__init__(OptionType.EDE) + + self.code = EDECode.make(code) + if text is not None and not isinstance(text, str): + raise ValueError("text must be string or None") + self.text = text + + def to_text(self) -> str: + output = f"EDE {self.code}" + if self.code in EDECode: + desc = EDECode.to_text(self.code) + desc = " ".join( + word if word in self._preserve_case else word.title() + for word in desc.split("_") + ) + output += f" ({desc})" + if self.text is not None: + output += f": {self.text}" + return output + + def to_wire(self, file: Optional[Any] = None) -> Optional[bytes]: + value = struct.pack("!H", self.code) + if self.text is not None: + value += self.text.encode("utf8") + + if file: + file.write(value) + return None + else: + return value + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: "dns.wire.Parser" + ) -> Option: + code = EDECode.make(parser.get_uint16()) + text = parser.get_remaining() + + if text: + if text[-1] == 0: # text MAY be null-terminated + text = text[:-1] + btext = text.decode("utf8") + else: + btext = None + + return cls(code, btext) + + +class NSIDOption(Option): + def __init__(self, nsid: bytes): + super().__init__(OptionType.NSID) + self.nsid = nsid + + def to_wire(self, file: Any = None) -> Optional[bytes]: + if file: + file.write(self.nsid) + return None + else: + return self.nsid + + def to_text(self) -> str: + if all(c >= 0x20 and c <= 0x7E for c in self.nsid): + # All ASCII printable, so it's probably a string. + value = self.nsid.decode() + else: + value = binascii.hexlify(self.nsid).decode() + return f"NSID {value}" + + @classmethod + def from_wire_parser( + cls, otype: Union[OptionType, str], parser: dns.wire.Parser + ) -> Option: + return cls(parser.get_remaining()) + + +_type_to_class: Dict[OptionType, Any] = { + OptionType.ECS: ECSOption, + OptionType.EDE: EDEOption, + OptionType.NSID: NSIDOption, +} + + +def get_option_class(otype: OptionType) -> Any: + """Return the class for the specified option type. + + The GenericOption class is used if a more specific class is not + known. + """ + + cls = _type_to_class.get(otype) + if cls is None: + cls = GenericOption + return cls + + +def option_from_wire_parser( + otype: Union[OptionType, str], parser: "dns.wire.Parser" +) -> Option: + """Build an EDNS option object from wire format. + + *otype*, an ``int``, is the option type. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the option length. + + Returns an instance of a subclass of ``dns.edns.Option``. + """ + otype = OptionType.make(otype) + cls = get_option_class(otype) + return cls.from_wire_parser(otype, parser) + + +def option_from_wire( + otype: Union[OptionType, str], wire: bytes, current: int, olen: int +) -> Option: + """Build an EDNS option object from wire format. + + *otype*, an ``int``, is the option type. + + *wire*, a ``bytes``, is the wire-format message. + + *current*, an ``int``, is the offset in *wire* of the beginning + of the rdata. + + *olen*, an ``int``, is the length of the wire-format option data + + Returns an instance of a subclass of ``dns.edns.Option``. + """ + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(olen): + return option_from_wire_parser(otype, parser) + + +def register_type(implementation: Any, otype: OptionType) -> None: + """Register the implementation of an option type. + + *implementation*, a ``class``, is a subclass of ``dns.edns.Option``. + + *otype*, an ``int``, is the option type. + """ + + _type_to_class[otype] = implementation + + +### BEGIN generated OptionType constants + +NSID = OptionType.NSID +DAU = OptionType.DAU +DHU = OptionType.DHU +N3U = OptionType.N3U +ECS = OptionType.ECS +EXPIRE = OptionType.EXPIRE +COOKIE = OptionType.COOKIE +KEEPALIVE = OptionType.KEEPALIVE +PADDING = OptionType.PADDING +CHAIN = OptionType.CHAIN +EDE = OptionType.EDE + +### END generated OptionType constants diff --git a/venv/Lib/site-packages/dns/entropy.py b/venv/Lib/site-packages/dns/entropy.py new file mode 100644 index 00000000..4dcdc627 --- /dev/null +++ b/venv/Lib/site-packages/dns/entropy.py @@ -0,0 +1,130 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import hashlib +import os +import random +import threading +import time +from typing import Any, Optional + + +class EntropyPool: + # This is an entropy pool for Python implementations that do not + # have a working SystemRandom. I'm not sure there are any, but + # leaving this code doesn't hurt anything as the library code + # is used if present. + + def __init__(self, seed: Optional[bytes] = None): + self.pool_index = 0 + self.digest: Optional[bytearray] = None + self.next_byte = 0 + self.lock = threading.Lock() + self.hash = hashlib.sha1() + self.hash_len = 20 + self.pool = bytearray(b"\0" * self.hash_len) + if seed is not None: + self._stir(seed) + self.seeded = True + self.seed_pid = os.getpid() + else: + self.seeded = False + self.seed_pid = 0 + + def _stir(self, entropy: bytes) -> None: + for c in entropy: + if self.pool_index == self.hash_len: + self.pool_index = 0 + b = c & 0xFF + self.pool[self.pool_index] ^= b + self.pool_index += 1 + + def stir(self, entropy: bytes) -> None: + with self.lock: + self._stir(entropy) + + def _maybe_seed(self) -> None: + if not self.seeded or self.seed_pid != os.getpid(): + try: + seed = os.urandom(16) + except Exception: # pragma: no cover + try: + with open("/dev/urandom", "rb", 0) as r: + seed = r.read(16) + except Exception: + seed = str(time.time()).encode() + self.seeded = True + self.seed_pid = os.getpid() + self.digest = None + seed = bytearray(seed) + self._stir(seed) + + def random_8(self) -> int: + with self.lock: + self._maybe_seed() + if self.digest is None or self.next_byte == self.hash_len: + self.hash.update(bytes(self.pool)) + self.digest = bytearray(self.hash.digest()) + self._stir(self.digest) + self.next_byte = 0 + value = self.digest[self.next_byte] + self.next_byte += 1 + return value + + def random_16(self) -> int: + return self.random_8() * 256 + self.random_8() + + def random_32(self) -> int: + return self.random_16() * 65536 + self.random_16() + + def random_between(self, first: int, last: int) -> int: + size = last - first + 1 + if size > 4294967296: + raise ValueError("too big") + if size > 65536: + rand = self.random_32 + max = 4294967295 + elif size > 256: + rand = self.random_16 + max = 65535 + else: + rand = self.random_8 + max = 255 + return first + size * rand() // (max + 1) + + +pool = EntropyPool() + +system_random: Optional[Any] +try: + system_random = random.SystemRandom() +except Exception: # pragma: no cover + system_random = None + + +def random_16() -> int: + if system_random is not None: + return system_random.randrange(0, 65536) + else: + return pool.random_16() + + +def between(first: int, last: int) -> int: + if system_random is not None: + return system_random.randrange(first, last + 1) + else: + return pool.random_between(first, last) diff --git a/venv/Lib/site-packages/dns/enum.py b/venv/Lib/site-packages/dns/enum.py new file mode 100644 index 00000000..71461f17 --- /dev/null +++ b/venv/Lib/site-packages/dns/enum.py @@ -0,0 +1,116 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import enum +from typing import Type, TypeVar, Union + +TIntEnum = TypeVar("TIntEnum", bound="IntEnum") + + +class IntEnum(enum.IntEnum): + @classmethod + def _missing_(cls, value): + cls._check_value(value) + val = int.__new__(cls, value) + val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}" + val._value_ = value + return val + + @classmethod + def _check_value(cls, value): + max = cls._maximum() + if not isinstance(value, int): + raise TypeError + if value < 0 or value > max: + name = cls._short_name() + raise ValueError(f"{name} must be an int between >= 0 and <= {max}") + + @classmethod + def from_text(cls: Type[TIntEnum], text: str) -> TIntEnum: + text = text.upper() + try: + return cls[text] + except KeyError: + pass + value = cls._extra_from_text(text) + if value: + return value + prefix = cls._prefix() + if text.startswith(prefix) and text[len(prefix) :].isdigit(): + value = int(text[len(prefix) :]) + cls._check_value(value) + try: + return cls(value) + except ValueError: + return value + raise cls._unknown_exception_class() + + @classmethod + def to_text(cls: Type[TIntEnum], value: int) -> str: + cls._check_value(value) + try: + text = cls(value).name + except ValueError: + text = None + text = cls._extra_to_text(value, text) + if text is None: + text = f"{cls._prefix()}{value}" + return text + + @classmethod + def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum: + """Convert text or a value into an enumerated type, if possible. + + *value*, the ``int`` or ``str`` to convert. + + Raises a class-specific exception if a ``str`` is provided that + cannot be converted. + + Raises ``ValueError`` if the value is out of range. + + Returns an enumeration from the calling class corresponding to the + value, if one is defined, or an ``int`` otherwise. + """ + + if isinstance(value, str): + return cls.from_text(value) + cls._check_value(value) + return cls(value) + + @classmethod + def _maximum(cls): + raise NotImplementedError # pragma: no cover + + @classmethod + def _short_name(cls): + return cls.__name__.lower() + + @classmethod + def _prefix(cls): + return "" + + @classmethod + def _extra_from_text(cls, text): # pylint: disable=W0613 + return None + + @classmethod + def _extra_to_text(cls, value, current_text): # pylint: disable=W0613 + return current_text + + @classmethod + def _unknown_exception_class(cls): + return ValueError diff --git a/venv/Lib/site-packages/dns/exception.py b/venv/Lib/site-packages/dns/exception.py new file mode 100644 index 00000000..6982373d --- /dev/null +++ b/venv/Lib/site-packages/dns/exception.py @@ -0,0 +1,169 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Common DNS Exceptions. + +Dnspython modules may also define their own exceptions, which will +always be subclasses of ``DNSException``. +""" + + +from typing import Optional, Set + + +class DNSException(Exception): + """Abstract base class shared by all dnspython exceptions. + + It supports two basic modes of operation: + + a) Old/compatible mode is used if ``__init__`` was called with + empty *kwargs*. In compatible mode all *args* are passed + to the standard Python Exception class as before and all *args* are + printed by the standard ``__str__`` implementation. Class variable + ``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()`` + if *args* is empty. + + b) New/parametrized mode is used if ``__init__`` was called with + non-empty *kwargs*. + In the new mode *args* must be empty and all kwargs must match + those set in class variable ``supp_kwargs``. All kwargs are stored inside + ``self.kwargs`` and used in a new ``__str__`` implementation to construct + a formatted message based on the ``fmt`` class variable, a ``string``. + + In the simplest case it is enough to override the ``supp_kwargs`` + and ``fmt`` class variables to get nice parametrized messages. + """ + + msg: Optional[str] = None # non-parametrized message + supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check) + fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs + + def __init__(self, *args, **kwargs): + self._check_params(*args, **kwargs) + if kwargs: + # This call to a virtual method from __init__ is ok in our usage + self.kwargs = self._check_kwargs(**kwargs) # lgtm[py/init-calls-subclass] + self.msg = str(self) + else: + self.kwargs = dict() # defined but empty for old mode exceptions + if self.msg is None: + # doc string is better implicit message than empty string + self.msg = self.__doc__ + if args: + super().__init__(*args) + else: + super().__init__(self.msg) + + def _check_params(self, *args, **kwargs): + """Old exceptions supported only args and not kwargs. + + For sanity we do not allow to mix old and new behavior.""" + if args or kwargs: + assert bool(args) != bool( + kwargs + ), "keyword arguments are mutually exclusive with positional args" + + def _check_kwargs(self, **kwargs): + if kwargs: + assert ( + set(kwargs.keys()) == self.supp_kwargs + ), "following set of keyword args is required: %s" % (self.supp_kwargs) + return kwargs + + def _fmt_kwargs(self, **kwargs): + """Format kwargs before printing them. + + Resulting dictionary has to have keys necessary for str.format call + on fmt class variable. + """ + fmtargs = {} + for kw, data in kwargs.items(): + if isinstance(data, (list, set)): + # convert list of to list of str() + fmtargs[kw] = list(map(str, data)) + if len(fmtargs[kw]) == 1: + # remove list brackets [] from single-item lists + fmtargs[kw] = fmtargs[kw].pop() + else: + fmtargs[kw] = data + return fmtargs + + def __str__(self): + if self.kwargs and self.fmt: + # provide custom message constructed from keyword arguments + fmtargs = self._fmt_kwargs(**self.kwargs) + return self.fmt.format(**fmtargs) + else: + # print *args directly in the same way as old DNSException + return super().__str__() + + +class FormError(DNSException): + """DNS message is malformed.""" + + +class SyntaxError(DNSException): + """Text input is malformed.""" + + +class UnexpectedEnd(SyntaxError): + """Text input ended unexpectedly.""" + + +class TooBig(DNSException): + """The DNS message is too big.""" + + +class Timeout(DNSException): + """The DNS operation timed out.""" + + supp_kwargs = {"timeout"} + fmt = "The DNS operation timed out after {timeout:.3f} seconds" + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class UnsupportedAlgorithm(DNSException): + """The DNSSEC algorithm is not supported.""" + + +class AlgorithmKeyMismatch(UnsupportedAlgorithm): + """The DNSSEC algorithm is not supported for the given key type.""" + + +class ValidationFailure(DNSException): + """The DNSSEC signature is invalid.""" + + +class DeniedByPolicy(DNSException): + """Denied by DNSSEC policy.""" + + +class ExceptionWrapper: + def __init__(self, exception_class): + self.exception_class = exception_class + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and not isinstance(exc_val, self.exception_class): + raise self.exception_class(str(exc_val)) from exc_val + return False diff --git a/venv/Lib/site-packages/dns/flags.py b/venv/Lib/site-packages/dns/flags.py new file mode 100644 index 00000000..4c60be13 --- /dev/null +++ b/venv/Lib/site-packages/dns/flags.py @@ -0,0 +1,123 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Message Flags.""" + +import enum +from typing import Any + +# Standard DNS flags + + +class Flag(enum.IntFlag): + #: Query Response + QR = 0x8000 + #: Authoritative Answer + AA = 0x0400 + #: Truncated Response + TC = 0x0200 + #: Recursion Desired + RD = 0x0100 + #: Recursion Available + RA = 0x0080 + #: Authentic Data + AD = 0x0020 + #: Checking Disabled + CD = 0x0010 + + +# EDNS flags + + +class EDNSFlag(enum.IntFlag): + #: DNSSEC answer OK + DO = 0x8000 + + +def _from_text(text: str, enum_class: Any) -> int: + flags = 0 + tokens = text.split() + for t in tokens: + flags |= enum_class[t.upper()] + return flags + + +def _to_text(flags: int, enum_class: Any) -> str: + text_flags = [] + for k, v in enum_class.__members__.items(): + if flags & v != 0: + text_flags.append(k) + return " ".join(text_flags) + + +def from_text(text: str) -> int: + """Convert a space-separated list of flag text values into a flags + value. + + Returns an ``int`` + """ + + return _from_text(text, Flag) + + +def to_text(flags: int) -> str: + """Convert a flags value into a space-separated list of flag text + values. + + Returns a ``str``. + """ + + return _to_text(flags, Flag) + + +def edns_from_text(text: str) -> int: + """Convert a space-separated list of EDNS flag text values into a EDNS + flags value. + + Returns an ``int`` + """ + + return _from_text(text, EDNSFlag) + + +def edns_to_text(flags: int) -> str: + """Convert an EDNS flags value into a space-separated list of EDNS flag + text values. + + Returns a ``str``. + """ + + return _to_text(flags, EDNSFlag) + + +### BEGIN generated Flag constants + +QR = Flag.QR +AA = Flag.AA +TC = Flag.TC +RD = Flag.RD +RA = Flag.RA +AD = Flag.AD +CD = Flag.CD + +### END generated Flag constants + +### BEGIN generated EDNSFlag constants + +DO = EDNSFlag.DO + +### END generated EDNSFlag constants diff --git a/venv/Lib/site-packages/dns/grange.py b/venv/Lib/site-packages/dns/grange.py new file mode 100644 index 00000000..3a52278f --- /dev/null +++ b/venv/Lib/site-packages/dns/grange.py @@ -0,0 +1,72 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2012-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS GENERATE range conversion.""" + +from typing import Tuple + +import dns + + +def from_text(text: str) -> Tuple[int, int, int]: + """Convert the text form of a range in a ``$GENERATE`` statement to an + integer. + + *text*, a ``str``, the textual range in ``$GENERATE`` form. + + Returns a tuple of three ``int`` values ``(start, stop, step)``. + """ + + start = -1 + stop = -1 + step = 1 + cur = "" + state = 0 + # state 0 1 2 + # x - y / z + + if text and text[0] == "-": + raise dns.exception.SyntaxError("Start cannot be a negative number") + + for c in text: + if c == "-" and state == 0: + start = int(cur) + cur = "" + state = 1 + elif c == "/": + stop = int(cur) + cur = "" + state = 2 + elif c.isdigit(): + cur += c + else: + raise dns.exception.SyntaxError("Could not parse %s" % (c)) + + if state == 0: + raise dns.exception.SyntaxError("no stop value specified") + elif state == 1: + stop = int(cur) + else: + assert state == 2 + step = int(cur) + + assert step >= 1 + assert start >= 0 + if start > stop: + raise dns.exception.SyntaxError("start must be <= stop") + + return (start, stop, step) diff --git a/venv/Lib/site-packages/dns/immutable.py b/venv/Lib/site-packages/dns/immutable.py new file mode 100644 index 00000000..36b0362c --- /dev/null +++ b/venv/Lib/site-packages/dns/immutable.py @@ -0,0 +1,68 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections.abc +from typing import Any, Callable + +from dns._immutable_ctx import immutable + + +@immutable +class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] + def __init__( + self, + dictionary: Any, + no_copy: bool = False, + map_factory: Callable[[], collections.abc.MutableMapping] = dict, + ): + """Make an immutable dictionary from the specified dictionary. + + If *no_copy* is `True`, then *dictionary* will be wrapped instead + of copied. Only set this if you are sure there will be no external + references to the dictionary. + """ + if no_copy and isinstance(dictionary, collections.abc.MutableMapping): + self._odict = dictionary + else: + self._odict = map_factory() + self._odict.update(dictionary) + self._hash = None + + def __getitem__(self, key): + return self._odict.__getitem__(key) + + def __hash__(self): # pylint: disable=invalid-hash-returned + if self._hash is None: + h = 0 + for key in sorted(self._odict.keys()): + h ^= hash(key) + object.__setattr__(self, "_hash", h) + # this does return an int, but pylint doesn't figure that out + return self._hash + + def __len__(self): + return len(self._odict) + + def __iter__(self): + return iter(self._odict) + + +def constify(o: Any) -> Any: + """ + Convert mutable types to immutable types. + """ + if isinstance(o, bytearray): + return bytes(o) + if isinstance(o, tuple): + try: + hash(o) + return o + except Exception: + return tuple(constify(elt) for elt in o) + if isinstance(o, list): + return tuple(constify(elt) for elt in o) + if isinstance(o, dict): + cdict = dict() + for k, v in o.items(): + cdict[k] = constify(v) + return Dict(cdict, True) + return o diff --git a/venv/Lib/site-packages/dns/inet.py b/venv/Lib/site-packages/dns/inet.py new file mode 100644 index 00000000..4a03f996 --- /dev/null +++ b/venv/Lib/site-packages/dns/inet.py @@ -0,0 +1,197 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Generic Internet address helper functions.""" + +import socket +from typing import Any, Optional, Tuple + +import dns.ipv4 +import dns.ipv6 + +# We assume that AF_INET and AF_INET6 are always defined. We keep +# these here for the benefit of any old code (unlikely though that +# is!). +AF_INET = socket.AF_INET +AF_INET6 = socket.AF_INET6 + + +def inet_pton(family: int, text: str) -> bytes: + """Convert the textual form of a network address into its binary form. + + *family* is an ``int``, the address family. + + *text* is a ``str``, the textual address. + + Raises ``NotImplementedError`` if the address family specified is not + implemented. + + Returns a ``bytes``. + """ + + if family == AF_INET: + return dns.ipv4.inet_aton(text) + elif family == AF_INET6: + return dns.ipv6.inet_aton(text, True) + else: + raise NotImplementedError + + +def inet_ntop(family: int, address: bytes) -> str: + """Convert the binary form of a network address into its textual form. + + *family* is an ``int``, the address family. + + *address* is a ``bytes``, the network address in binary form. + + Raises ``NotImplementedError`` if the address family specified is not + implemented. + + Returns a ``str``. + """ + + if family == AF_INET: + return dns.ipv4.inet_ntoa(address) + elif family == AF_INET6: + return dns.ipv6.inet_ntoa(address) + else: + raise NotImplementedError + + +def af_for_address(text: str) -> int: + """Determine the address family of a textual-form network address. + + *text*, a ``str``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns an ``int``. + """ + + try: + dns.ipv4.inet_aton(text) + return AF_INET + except Exception: + try: + dns.ipv6.inet_aton(text, True) + return AF_INET6 + except Exception: + raise ValueError + + +def is_multicast(text: str) -> bool: + """Is the textual-form network address a multicast address? + + *text*, a ``str``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns a ``bool``. + """ + + try: + first = dns.ipv4.inet_aton(text)[0] + return first >= 224 and first <= 239 + except Exception: + try: + first = dns.ipv6.inet_aton(text, True)[0] + return first == 255 + except Exception: + raise ValueError + + +def is_address(text: str) -> bool: + """Is the specified string an IPv4 or IPv6 address? + + *text*, a ``str``, the textual address. + + Returns a ``bool``. + """ + + try: + dns.ipv4.inet_aton(text) + return True + except Exception: + try: + dns.ipv6.inet_aton(text, True) + return True + except Exception: + return False + + +def low_level_address_tuple( + high_tuple: Tuple[str, int], af: Optional[int] = None +) -> Any: + """Given a "high-level" address tuple, i.e. + an (address, port) return the appropriate "low-level" address tuple + suitable for use in socket calls. + + If an *af* other than ``None`` is provided, it is assumed the + address in the high-level tuple is valid and has that af. If af + is ``None``, then af_for_address will be called. + """ + address, port = high_tuple + if af is None: + af = af_for_address(address) + if af == AF_INET: + return (address, port) + elif af == AF_INET6: + i = address.find("%") + if i < 0: + # no scope, shortcut! + return (address, port, 0, 0) + # try to avoid getaddrinfo() + addrpart = address[:i] + scope = address[i + 1 :] + if scope.isdigit(): + return (addrpart, port, 0, int(scope)) + try: + return (addrpart, port, 0, socket.if_nametoindex(scope)) + except AttributeError: # pragma: no cover (we can't really test this) + ai_flags = socket.AI_NUMERICHOST + ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) + return tup + else: + raise NotImplementedError(f"unknown address family {af}") + + +def any_for_af(af): + """Return the 'any' address for the specified address family.""" + if af == socket.AF_INET: + return "0.0.0.0" + elif af == socket.AF_INET6: + return "::" + raise NotImplementedError(f"unknown address family {af}") + + +def canonicalize(text: str) -> str: + """Verify that *address* is a valid text form IPv4 or IPv6 address and return its + canonical text form. IPv6 addresses with scopes are rejected. + + *text*, a ``str``, the address in textual form. + + Raises ``ValueError`` if the text is not valid. + """ + try: + return dns.ipv6.canonicalize(text) + except Exception: + try: + return dns.ipv4.canonicalize(text) + except Exception: + raise ValueError diff --git a/venv/Lib/site-packages/dns/ipv4.py b/venv/Lib/site-packages/dns/ipv4.py new file mode 100644 index 00000000..65ee69c0 --- /dev/null +++ b/venv/Lib/site-packages/dns/ipv4.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""IPv4 helper functions.""" + +import struct +from typing import Union + +import dns.exception + + +def inet_ntoa(address: bytes) -> str: + """Convert an IPv4 address in binary form to text form. + + *address*, a ``bytes``, the IPv4 address in binary form. + + Returns a ``str``. + """ + + if len(address) != 4: + raise dns.exception.SyntaxError + return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3]) + + +def inet_aton(text: Union[str, bytes]) -> bytes: + """Convert an IPv4 address in text form to binary form. + + *text*, a ``str`` or ``bytes``, the IPv4 address in textual form. + + Returns a ``bytes``. + """ + + if not isinstance(text, bytes): + btext = text.encode() + else: + btext = text + parts = btext.split(b".") + if len(parts) != 4: + raise dns.exception.SyntaxError + for part in parts: + if not part.isdigit(): + raise dns.exception.SyntaxError + if len(part) > 1 and part[0] == ord("0"): + # No leading zeros + raise dns.exception.SyntaxError + try: + b = [int(part) for part in parts] + return struct.pack("BBBB", *b) + except Exception: + raise dns.exception.SyntaxError + + +def canonicalize(text: Union[str, bytes]) -> str: + """Verify that *address* is a valid text form IPv4 address and return its + canonical text form. + + *text*, a ``str`` or ``bytes``, the IPv4 address in textual form. + + Raises ``dns.exception.SyntaxError`` if the text is not valid. + """ + # Note that inet_aton() only accepts canonial form, but we still run through + # inet_ntoa() to ensure the output is a str. + return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) diff --git a/venv/Lib/site-packages/dns/ipv6.py b/venv/Lib/site-packages/dns/ipv6.py new file mode 100644 index 00000000..44a10639 --- /dev/null +++ b/venv/Lib/site-packages/dns/ipv6.py @@ -0,0 +1,219 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""IPv6 helper functions.""" + +import binascii +import re +from typing import List, Union + +import dns.exception +import dns.ipv4 + +_leading_zero = re.compile(r"0+([0-9a-f]+)") + + +def inet_ntoa(address: bytes) -> str: + """Convert an IPv6 address in binary form to text form. + + *address*, a ``bytes``, the IPv6 address in binary form. + + Raises ``ValueError`` if the address isn't 16 bytes long. + Returns a ``str``. + """ + + if len(address) != 16: + raise ValueError("IPv6 addresses are 16 bytes long") + hex = binascii.hexlify(address) + chunks = [] + i = 0 + l = len(hex) + while i < l: + chunk = hex[i : i + 4].decode() + # strip leading zeros. we do this with an re instead of + # with lstrip() because lstrip() didn't support chars until + # python 2.2.2 + m = _leading_zero.match(chunk) + if m is not None: + chunk = m.group(1) + chunks.append(chunk) + i += 4 + # + # Compress the longest subsequence of 0-value chunks to :: + # + best_start = 0 + best_len = 0 + start = -1 + last_was_zero = False + for i in range(8): + if chunks[i] != "0": + if last_was_zero: + end = i + current_len = end - start + if current_len > best_len: + best_start = start + best_len = current_len + last_was_zero = False + elif not last_was_zero: + start = i + last_was_zero = True + if last_was_zero: + end = 8 + current_len = end - start + if current_len > best_len: + best_start = start + best_len = current_len + if best_len > 1: + if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"): + # We have an embedded IPv4 address + if best_len == 6: + prefix = "::" + else: + prefix = "::ffff:" + thex = prefix + dns.ipv4.inet_ntoa(address[12:]) + else: + thex = ( + ":".join(chunks[:best_start]) + + "::" + + ":".join(chunks[best_start + best_len :]) + ) + else: + thex = ":".join(chunks) + return thex + + +_v4_ending = re.compile(rb"(.*):(\d+\.\d+\.\d+\.\d+)$") +_colon_colon_start = re.compile(rb"::.*") +_colon_colon_end = re.compile(rb".*::$") + + +def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes: + """Convert an IPv6 address in text form to binary form. + + *text*, a ``str`` or ``bytes``, the IPv6 address in textual form. + + *ignore_scope*, a ``bool``. If ``True``, a scope will be ignored. + If ``False``, the default, it is an error for a scope to be present. + + Returns a ``bytes``. + """ + + # + # Our aim here is not something fast; we just want something that works. + # + if not isinstance(text, bytes): + btext = text.encode() + else: + btext = text + + if ignore_scope: + parts = btext.split(b"%") + l = len(parts) + if l == 2: + btext = parts[0] + elif l > 2: + raise dns.exception.SyntaxError + + if btext == b"": + raise dns.exception.SyntaxError + elif btext.endswith(b":") and not btext.endswith(b"::"): + raise dns.exception.SyntaxError + elif btext.startswith(b":") and not btext.startswith(b"::"): + raise dns.exception.SyntaxError + elif btext == b"::": + btext = b"0::" + # + # Get rid of the icky dot-quad syntax if we have it. + # + m = _v4_ending.match(btext) + if m is not None: + b = dns.ipv4.inet_aton(m.group(2)) + btext = ( + "{}:{:02x}{:02x}:{:02x}{:02x}".format( + m.group(1).decode(), b[0], b[1], b[2], b[3] + ) + ).encode() + # + # Try to turn '::' into ':'; if no match try to + # turn '::' into ':' + # + m = _colon_colon_start.match(btext) + if m is not None: + btext = btext[1:] + else: + m = _colon_colon_end.match(btext) + if m is not None: + btext = btext[:-1] + # + # Now canonicalize into 8 chunks of 4 hex digits each + # + chunks = btext.split(b":") + l = len(chunks) + if l > 8: + raise dns.exception.SyntaxError + seen_empty = False + canonical: List[bytes] = [] + for c in chunks: + if c == b"": + if seen_empty: + raise dns.exception.SyntaxError + seen_empty = True + for _ in range(0, 8 - l + 1): + canonical.append(b"0000") + else: + lc = len(c) + if lc > 4: + raise dns.exception.SyntaxError + if lc != 4: + c = (b"0" * (4 - lc)) + c + canonical.append(c) + if l < 8 and not seen_empty: + raise dns.exception.SyntaxError + btext = b"".join(canonical) + + # + # Finally we can go to binary. + # + try: + return binascii.unhexlify(btext) + except (binascii.Error, TypeError): + raise dns.exception.SyntaxError + + +_mapped_prefix = b"\x00" * 10 + b"\xff\xff" + + +def is_mapped(address: bytes) -> bool: + """Is the specified address a mapped IPv4 address? + + *address*, a ``bytes`` is an IPv6 address in binary form. + + Returns a ``bool``. + """ + + return address.startswith(_mapped_prefix) + + +def canonicalize(text: Union[str, bytes]) -> str: + """Verify that *address* is a valid text form IPv6 address and return its + canonical text form. Addresses with scopes are rejected. + + *text*, a ``str`` or ``bytes``, the IPv6 address in textual form. + + Raises ``dns.exception.SyntaxError`` if the text is not valid. + """ + return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) diff --git a/venv/Lib/site-packages/dns/message.py b/venv/Lib/site-packages/dns/message.py new file mode 100644 index 00000000..44cacbd9 --- /dev/null +++ b/venv/Lib/site-packages/dns/message.py @@ -0,0 +1,1888 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Messages""" + +import contextlib +import io +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +import dns.edns +import dns.entropy +import dns.enum +import dns.exception +import dns.flags +import dns.name +import dns.opcode +import dns.rcode +import dns.rdata +import dns.rdataclass +import dns.rdatatype +import dns.rdtypes.ANY.OPT +import dns.rdtypes.ANY.TSIG +import dns.renderer +import dns.rrset +import dns.tsig +import dns.ttl +import dns.wire + + +class ShortHeader(dns.exception.FormError): + """The DNS packet passed to from_wire() is too short.""" + + +class TrailingJunk(dns.exception.FormError): + """The DNS packet passed to from_wire() has extra junk at the end of it.""" + + +class UnknownHeaderField(dns.exception.DNSException): + """The header field name was not recognized when converting from text + into a message.""" + + +class BadEDNS(dns.exception.FormError): + """An OPT record occurred somewhere other than + the additional data section.""" + + +class BadTSIG(dns.exception.FormError): + """A TSIG record occurred somewhere other than the end of + the additional data section.""" + + +class UnknownTSIGKey(dns.exception.DNSException): + """A TSIG with an unknown key was received.""" + + +class Truncated(dns.exception.DNSException): + """The truncated flag is set.""" + + supp_kwargs = {"message"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def message(self): + """As much of the message as could be processed. + + Returns a ``dns.message.Message``. + """ + return self.kwargs["message"] + + +class NotQueryResponse(dns.exception.DNSException): + """Message is not a response to a query.""" + + +class ChainTooLong(dns.exception.DNSException): + """The CNAME chain is too long.""" + + +class AnswerForNXDOMAIN(dns.exception.DNSException): + """The rcode is NXDOMAIN but an answer was found.""" + + +class NoPreviousName(dns.exception.SyntaxError): + """No previous name was known.""" + + +class MessageSection(dns.enum.IntEnum): + """Message sections""" + + QUESTION = 0 + ANSWER = 1 + AUTHORITY = 2 + ADDITIONAL = 3 + + @classmethod + def _maximum(cls): + return 3 + + +class MessageError: + def __init__(self, exception: Exception, offset: int): + self.exception = exception + self.offset = offset + + +DEFAULT_EDNS_PAYLOAD = 1232 +MAX_CHAIN = 16 + +IndexKeyType = Tuple[ + int, + dns.name.Name, + dns.rdataclass.RdataClass, + dns.rdatatype.RdataType, + Optional[dns.rdatatype.RdataType], + Optional[dns.rdataclass.RdataClass], +] +IndexType = Dict[IndexKeyType, dns.rrset.RRset] +SectionType = Union[int, str, List[dns.rrset.RRset]] + + +class Message: + """A DNS message.""" + + _section_enum = MessageSection + + def __init__(self, id: Optional[int] = None): + if id is None: + self.id = dns.entropy.random_16() + else: + self.id = id + self.flags = 0 + self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []] + self.opt: Optional[dns.rrset.RRset] = None + self.request_payload = 0 + self.pad = 0 + self.keyring: Any = None + self.tsig: Optional[dns.rrset.RRset] = None + self.request_mac = b"" + self.xfr = False + self.origin: Optional[dns.name.Name] = None + self.tsig_ctx: Optional[Any] = None + self.index: IndexType = {} + self.errors: List[MessageError] = [] + self.time = 0.0 + + @property + def question(self) -> List[dns.rrset.RRset]: + """The question section.""" + return self.sections[0] + + @question.setter + def question(self, v): + self.sections[0] = v + + @property + def answer(self) -> List[dns.rrset.RRset]: + """The answer section.""" + return self.sections[1] + + @answer.setter + def answer(self, v): + self.sections[1] = v + + @property + def authority(self) -> List[dns.rrset.RRset]: + """The authority section.""" + return self.sections[2] + + @authority.setter + def authority(self, v): + self.sections[2] = v + + @property + def additional(self) -> List[dns.rrset.RRset]: + """The additional data section.""" + return self.sections[3] + + @additional.setter + def additional(self, v): + self.sections[3] = v + + def __repr__(self): + return "" + + def __str__(self): + return self.to_text() + + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: + """Convert the message to text. + + The *origin*, *relativize*, and any other keyword + arguments are passed to the RRset ``to_wire()`` method. + + Returns a ``str``. + """ + + s = io.StringIO() + s.write("id %d\n" % self.id) + s.write("opcode %s\n" % dns.opcode.to_text(self.opcode())) + s.write("rcode %s\n" % dns.rcode.to_text(self.rcode())) + s.write("flags %s\n" % dns.flags.to_text(self.flags)) + if self.edns >= 0: + s.write("edns %s\n" % self.edns) + if self.ednsflags != 0: + s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags)) + s.write("payload %d\n" % self.payload) + for opt in self.options: + s.write("option %s\n" % opt.to_text()) + for name, which in self._section_enum.__members__.items(): + s.write(f";{name}\n") + for rrset in self.section_from_number(which): + s.write(rrset.to_text(origin, relativize, **kw)) + s.write("\n") + # + # We strip off the final \n so the caller can print the result without + # doing weird things to get around eccentricities in Python print + # formatting + # + return s.getvalue()[:-1] + + def __eq__(self, other): + """Two messages are equal if they have the same content in the + header, question, answer, and authority sections. + + Returns a ``bool``. + """ + + if not isinstance(other, Message): + return False + if self.id != other.id: + return False + if self.flags != other.flags: + return False + for i, section in enumerate(self.sections): + other_section = other.sections[i] + for n in section: + if n not in other_section: + return False + for n in other_section: + if n not in section: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def is_response(self, other: "Message") -> bool: + """Is *other*, also a ``dns.message.Message``, a response to this + message? + + Returns a ``bool``. + """ + + if ( + other.flags & dns.flags.QR == 0 + or self.id != other.id + or dns.opcode.from_flags(self.flags) != dns.opcode.from_flags(other.flags) + ): + return False + if other.rcode() in { + dns.rcode.FORMERR, + dns.rcode.SERVFAIL, + dns.rcode.NOTIMP, + dns.rcode.REFUSED, + }: + # We don't check the question section in these cases if + # the other question section is empty, even though they + # still really ought to have a question section. + if len(other.question) == 0: + return True + if dns.opcode.is_update(self.flags): + # This is assuming the "sender doesn't include anything + # from the update", but we don't care to check the other + # case, which is that all the sections are returned and + # identical. + return True + for n in self.question: + if n not in other.question: + return False + for n in other.question: + if n not in self.question: + return False + return True + + def section_number(self, section: List[dns.rrset.RRset]) -> int: + """Return the "section number" of the specified section for use + in indexing. + + *section* is one of the section attributes of this message. + + Raises ``ValueError`` if the section isn't known. + + Returns an ``int``. + """ + + for i, our_section in enumerate(self.sections): + if section is our_section: + return self._section_enum(i) + raise ValueError("unknown section") + + def section_from_number(self, number: int) -> List[dns.rrset.RRset]: + """Return the section list associated with the specified section + number. + + *number* is a section number `int` or the text form of a section + name. + + Raises ``ValueError`` if the section isn't known. + + Returns a ``list``. + """ + + section = self._section_enum.make(number) + return self.sections[section] + + def find_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, + ) -> dns.rrset.RRset: + """Find the RRset with the given attributes in the specified section. + + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the + the section of the message to search. For example:: + + my_message.find_rrset(my_message.answer, name, rdclass, rdtype) + my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.find_rrset("ANSWER", name, rdclass, rdtype) + + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. + + *rdclass*, an ``int`` or ``str``, the class of the RRset. + + *rdtype*, an ``int`` or ``str``, the type of the RRset. + + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. + + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Raises ``KeyError`` if the RRset was not found and create was + ``False``. + + Returns a ``dns.rrset.RRset object``. + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = self._section_enum.from_text(section) + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + if isinstance(name, str): + name = dns.name.from_text(name, idna_codec=idna_codec) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + covers = dns.rdatatype.RdataType.make(covers) + if deleting is not None: + deleting = dns.rdataclass.RdataClass.make(deleting) + key = (section_number, name, rdclass, rdtype, covers, deleting) + if not force_unique: + if self.index is not None: + rrset = self.index.get(key) + if rrset is not None: + return rrset + else: + for rrset in section: + if rrset.full_match(name, rdclass, rdtype, covers, deleting): + return rrset + if not create: + raise KeyError + rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) + section.append(rrset) + if self.index is not None: + self.index[key] = rrset + return rrset + + def get_rrset( + self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + create: bool = False, + force_unique: bool = False, + idna_codec: Optional[dns.name.IDNACodec] = None, + ) -> Optional[dns.rrset.RRset]: + """Get the RRset with the given attributes in the specified section. + + If the RRset is not found, None is returned. + + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the + the section of the message to search. For example:: + + my_message.get_rrset(my_message.answer, name, rdclass, rdtype) + my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) + my_message.get_rrset("ANSWER", name, rdclass, rdtype) + + *name*, a ``dns.name.Name`` or ``str``, the name of the RRset. + + *rdclass*, an ``int`` or ``str``, the class of the RRset. + + *rdtype*, an ``int`` or ``str``, the type of the RRset. + + *covers*, an ``int`` or ``str``, the covers value of the RRset. + The default is ``dns.rdatatype.NONE``. + + *deleting*, an ``int``, ``str``, or ``None``, the deleting value of the + RRset. The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.rrset.RRset object`` or ``None``. + """ + + try: + rrset = self.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + create, + force_unique, + idna_codec, + ) + except KeyError: + rrset = None + return rrset + + def section_count(self, section: SectionType) -> int: + """Returns the number of records in the specified section. + + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the + the section of the message to count. For example:: + + my_message.section_count(my_message.answer) + my_message.section_count(dns.message.ANSWER) + my_message.section_count("ANSWER") + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = self._section_enum.from_text(section) + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + count = sum(max(1, len(rrs)) for rrs in section) + if section_number == MessageSection.ADDITIONAL: + if self.opt is not None: + count += 1 + if self.tsig is not None: + count += 1 + return count + + def _compute_opt_reserve(self) -> int: + """Compute the size required for the OPT RR, padding excluded""" + if not self.opt: + return 0 + # 1 byte for the root name, 10 for the standard RR fields + size = 11 + # This would be more efficient if options had a size() method, but we won't + # worry about that for now. We also don't worry if there is an existing padding + # option, as it is unlikely and probably harmless, as the worst case is that we + # may add another, and this seems to be legal. + for option in self.opt[0].options: + wire = option.to_wire() + # We add 4 here to account for the option type and length + size += len(wire) + 4 + if self.pad: + # Padding will be added, so again add the option type and length. + size += 4 + return size + + def _compute_tsig_reserve(self) -> int: + """Compute the size required for the TSIG RR""" + # This would be more efficient if TSIGs had a size method, but we won't + # worry about for now. Also, we can't really cope with the potential + # compressibility of the TSIG owner name, so we estimate with the uncompressed + # size. We will disable compression when TSIG and padding are both is active + # so that the padding comes out right. + if not self.tsig: + return 0 + f = io.BytesIO() + self.tsig.to_wire(f) + return len(f.getvalue()) + + def to_wire( + self, + origin: Optional[dns.name.Name] = None, + max_size: int = 0, + multi: bool = False, + tsig_ctx: Optional[Any] = None, + prepend_length: bool = False, + prefer_truncation: bool = False, + **kw: Dict[str, Any], + ) -> bytes: + """Return a string containing the message in DNS compressed wire + format. + + Additional keyword arguments are passed to the RRset ``to_wire()`` + method. + + *origin*, a ``dns.name.Name`` or ``None``, the origin to be appended + to any relative names. If ``None``, and the message has an origin + attribute that is not ``None``, then it will be used. + + *max_size*, an ``int``, the maximum size of the wire format + output; default is 0, which means "the message's request + payload, if nonzero, or 65535". + + *multi*, a ``bool``, should be set to ``True`` if this message is + part of a multiple message sequence. + + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the + ongoing TSIG context, used when signing zone transfers. + + *prepend_length*, a ``bool``, should be set to ``True`` if the caller + wants the message length prepended to the message itself. This is + useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ). + + *prefer_truncation*, a ``bool``, should be set to ``True`` if the caller + wants the message to be truncated if it would otherwise exceed the + maximum length. If the truncation occurs before the additional section, + the TC bit will be set. + + Raises ``dns.exception.TooBig`` if *max_size* was exceeded. + + Returns a ``bytes``. + """ + + if origin is None and self.origin is not None: + origin = self.origin + if max_size == 0: + if self.request_payload != 0: + max_size = self.request_payload + else: + max_size = 65535 + if max_size < 512: + max_size = 512 + elif max_size > 65535: + max_size = 65535 + r = dns.renderer.Renderer(self.id, self.flags, max_size, origin) + opt_reserve = self._compute_opt_reserve() + r.reserve(opt_reserve) + tsig_reserve = self._compute_tsig_reserve() + r.reserve(tsig_reserve) + try: + for rrset in self.question: + r.add_question(rrset.name, rrset.rdtype, rrset.rdclass) + for rrset in self.answer: + r.add_rrset(dns.renderer.ANSWER, rrset, **kw) + for rrset in self.authority: + r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw) + for rrset in self.additional: + r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) + except dns.exception.TooBig: + if prefer_truncation: + if r.section < dns.renderer.ADDITIONAL: + r.flags |= dns.flags.TC + else: + raise + r.release_reserved() + if self.opt is not None: + r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve) + r.write_header() + if self.tsig is not None: + (new_tsig, ctx) = dns.tsig.sign( + r.get_wire(), + self.keyring, + self.tsig[0], + int(time.time()), + self.request_mac, + tsig_ctx, + multi, + ) + self.tsig.clear() + self.tsig.add(new_tsig) + r.add_rrset(dns.renderer.ADDITIONAL, self.tsig) + r.write_header() + if multi: + self.tsig_ctx = ctx + wire = r.get_wire() + if prepend_length: + wire = len(wire).to_bytes(2, "big") + wire + return wire + + @staticmethod + def _make_tsig( + keyname, algorithm, time_signed, fudge, mac, original_id, error, other + ): + tsig = dns.rdtypes.ANY.TSIG.TSIG( + dns.rdataclass.ANY, + dns.rdatatype.TSIG, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) + return dns.rrset.from_rdata(keyname, 0, tsig) + + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + fudge: int = 300, + original_id: Optional[int] = None, + tsig_error: int = 0, + other_data: bytes = b"", + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: + """When sending, a TSIG signature using the specified key + should be added. + + *key*, a ``dns.tsig.Key`` is the key to use. If a key is specified, + the *keyring* and *algorithm* fields are not used. + + *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either + the TSIG keyring or key to use. + + The format of a keyring dict is a mapping from TSIG key name, as + ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``. + If a ``dict`` *keyring* is specified but a *keyname* is not, the key + used will be the first key in the *keyring*. Note that the order of + keys in a dictionary is not defined, so applications should supply a + keyname when a ``dict`` keyring is used, unless they know the keyring + contains only one key. If a ``callable`` keyring is specified, the + callable will be called with the message and the keyname, and is + expected to return a key. + + *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of + this TSIG key to use; defaults to ``None``. If *keyring* is a + ``dict``, the key must be defined in it. If *keyring* is a + ``dns.tsig.Key``, this is ignored. + + *fudge*, an ``int``, the TSIG time fudge. + + *original_id*, an ``int``, the TSIG original id. If ``None``, + the message's id is used. + + *tsig_error*, an ``int``, the TSIG error code. + + *other_data*, a ``bytes``, the TSIG other data. + + *algorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. This is + only used if *keyring* is a ``dict``, and the key entry is a ``bytes``. + """ + + if isinstance(keyring, dns.tsig.Key): + key = keyring + keyname = key.name + elif callable(keyring): + key = keyring(self, keyname) + else: + if isinstance(keyname, str): + keyname = dns.name.from_text(keyname) + if keyname is None: + keyname = next(iter(keyring)) + key = keyring[keyname] + if isinstance(key, bytes): + key = dns.tsig.Key(keyname, key, algorithm) + self.keyring = key + if original_id is None: + original_id = self.id + self.tsig = self._make_tsig( + keyname, + self.keyring.algorithm, + 0, + fudge, + b"\x00" * dns.tsig.mac_sizes[self.keyring.algorithm], + original_id, + tsig_error, + other_data, + ) + + @property + def keyname(self) -> Optional[dns.name.Name]: + if self.tsig: + return self.tsig.name + else: + return None + + @property + def keyalgorithm(self) -> Optional[dns.name.Name]: + if self.tsig: + return self.tsig[0].algorithm + else: + return None + + @property + def mac(self) -> Optional[bytes]: + if self.tsig: + return self.tsig[0].mac + else: + return None + + @property + def tsig_error(self) -> Optional[int]: + if self.tsig: + return self.tsig[0].error + else: + return None + + @property + def had_tsig(self) -> bool: + return bool(self.tsig) + + @staticmethod + def _make_opt(flags=0, payload=DEFAULT_EDNS_PAYLOAD, options=None): + opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT, options or ()) + return dns.rrset.from_rdata(dns.name.root, int(flags), opt) + + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = DEFAULT_EDNS_PAYLOAD, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + pad: int = 0, + ) -> None: + """Configure EDNS behavior. + + *edns*, an ``int``, is the EDNS level to use. Specifying ``None``, ``False``, + or ``-1`` means "do not use EDNS", and in this case the other parameters are + ignored. Specifying ``True`` is equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the maximum + size of UDP datagram the sender can handle. I.e. how big a response to this + message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when sending this + message. If not specified, defaults to the value of *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS options. + + *pad*, a non-negative ``int``. If 0, the default, do not pad; otherwise add + padding bytes to make the message size a multiple of *pad*. Note that if + padding is non-zero, an EDNS PADDING option will always be added to the + message. + """ + + if edns is None or edns is False: + edns = -1 + elif edns is True: + edns = 0 + if edns < 0: + self.opt = None + self.request_payload = 0 + else: + # make sure the EDNS version in ednsflags agrees with edns + ednsflags &= 0xFF00FFFF + ednsflags |= edns << 16 + if options is None: + options = [] + self.opt = self._make_opt(ednsflags, payload, options) + if request_payload is None: + request_payload = payload + self.request_payload = request_payload + if pad < 0: + raise ValueError("pad must be non-negative") + self.pad = pad + + @property + def edns(self) -> int: + if self.opt: + return (self.ednsflags & 0xFF0000) >> 16 + else: + return -1 + + @property + def ednsflags(self) -> int: + if self.opt: + return self.opt.ttl + else: + return 0 + + @ednsflags.setter + def ednsflags(self, v): + if self.opt: + self.opt.ttl = v + elif v: + self.opt = self._make_opt(v) + + @property + def payload(self) -> int: + if self.opt: + return self.opt[0].payload + else: + return 0 + + @property + def options(self) -> Tuple: + if self.opt: + return self.opt[0].options + else: + return () + + def want_dnssec(self, wanted: bool = True) -> None: + """Enable or disable 'DNSSEC desired' flag in requests. + + *wanted*, a ``bool``. If ``True``, then DNSSEC data is + desired in the response, EDNS is enabled if required, and then + the DO bit is set. If ``False``, the DO bit is cleared if + EDNS is enabled. + """ + + if wanted: + self.ednsflags |= dns.flags.DO + elif self.opt: + self.ednsflags &= ~int(dns.flags.DO) + + def rcode(self) -> dns.rcode.Rcode: + """Return the rcode. + + Returns a ``dns.rcode.Rcode``. + """ + return dns.rcode.from_flags(int(self.flags), int(self.ednsflags)) + + def set_rcode(self, rcode: dns.rcode.Rcode) -> None: + """Set the rcode. + + *rcode*, a ``dns.rcode.Rcode``, is the rcode to set. + """ + (value, evalue) = dns.rcode.to_flags(rcode) + self.flags &= 0xFFF0 + self.flags |= value + self.ednsflags &= 0x00FFFFFF + self.ednsflags |= evalue + + def opcode(self) -> dns.opcode.Opcode: + """Return the opcode. + + Returns a ``dns.opcode.Opcode``. + """ + return dns.opcode.from_flags(int(self.flags)) + + def set_opcode(self, opcode: dns.opcode.Opcode) -> None: + """Set the opcode. + + *opcode*, a ``dns.opcode.Opcode``, is the opcode to set. + """ + self.flags &= 0x87FF + self.flags |= dns.opcode.to_flags(opcode) + + def _get_one_rr_per_rrset(self, value): + # What the caller picked is fine. + return value + + # pylint: disable=unused-argument + + def _parse_rr_header(self, section, name, rdclass, rdtype): + return (rdclass, rdtype, None, False) + + # pylint: enable=unused-argument + + def _parse_special_rr_header(self, section, count, position, name, rdclass, rdtype): + if rdtype == dns.rdatatype.OPT: + if ( + section != MessageSection.ADDITIONAL + or self.opt + or name != dns.name.root + ): + raise BadEDNS + elif rdtype == dns.rdatatype.TSIG: + if ( + section != MessageSection.ADDITIONAL + or rdclass != dns.rdatatype.ANY + or position != count - 1 + ): + raise BadTSIG + return (rdclass, rdtype, None, False) + + +class ChainingResult: + """The result of a call to dns.message.QueryMessage.resolve_chaining(). + + The ``answer`` attribute is the answer RRSet, or ``None`` if it doesn't + exist. + + The ``canonical_name`` attribute is the canonical name after all + chaining has been applied (this is the same name as ``rrset.name`` in cases + where rrset is not ``None``). + + The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to + use if caching the data. It is the smallest of all the CNAME TTLs + and either the answer TTL if it exists or the SOA TTL and SOA + minimum values for negative answers. + + The ``cnames`` attribute is a list of all the CNAME RRSets followed to + get to the canonical name. + """ + + def __init__( + self, + canonical_name: dns.name.Name, + answer: Optional[dns.rrset.RRset], + minimum_ttl: int, + cnames: List[dns.rrset.RRset], + ): + self.canonical_name = canonical_name + self.answer = answer + self.minimum_ttl = minimum_ttl + self.cnames = cnames + + +class QueryMessage(Message): + def resolve_chaining(self) -> ChainingResult: + """Follow the CNAME chain in the response to determine the answer + RRset. + + Raises ``dns.message.NotQueryResponse`` if the message is not + a response. + + Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long. + + Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN + but an answer was found. + + Raises ``dns.exception.FormError`` if the question count is not 1. + + Returns a ChainingResult object. + """ + if self.flags & dns.flags.QR == 0: + raise NotQueryResponse + if len(self.question) != 1: + raise dns.exception.FormError + question = self.question[0] + qname = question.name + min_ttl = dns.ttl.MAX_TTL + answer = None + count = 0 + cnames = [] + while count < MAX_CHAIN: + try: + answer = self.find_rrset( + self.answer, qname, question.rdclass, question.rdtype + ) + min_ttl = min(min_ttl, answer.ttl) + break + except KeyError: + if question.rdtype != dns.rdatatype.CNAME: + try: + crrset = self.find_rrset( + self.answer, qname, question.rdclass, dns.rdatatype.CNAME + ) + cnames.append(crrset) + min_ttl = min(min_ttl, crrset.ttl) + for rd in crrset: + qname = rd.target + break + count += 1 + continue + except KeyError: + # Exit the chaining loop + break + else: + # Exit the chaining loop + break + if count >= MAX_CHAIN: + raise ChainTooLong + if self.rcode() == dns.rcode.NXDOMAIN and answer is not None: + raise AnswerForNXDOMAIN + if answer is None: + # Further minimize the TTL with NCACHE. + auname = qname + while True: + # Look for an SOA RR whose owner name is a superdomain + # of qname. + try: + srrset = self.find_rrset( + self.authority, auname, question.rdclass, dns.rdatatype.SOA + ) + min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum) + break + except KeyError: + try: + auname = auname.parent() + except dns.name.NoParent: + break + return ChainingResult(qname, answer, min_ttl, cnames) + + def canonical_name(self) -> dns.name.Name: + """Return the canonical name of the first name in the question + section. + + Raises ``dns.message.NotQueryResponse`` if the message is not + a response. + + Raises ``dns.message.ChainTooLong`` if the CNAME chain is too long. + + Raises ``dns.message.AnswerForNXDOMAIN`` if the rcode is NXDOMAIN + but an answer was found. + + Raises ``dns.exception.FormError`` if the question count is not 1. + """ + return self.resolve_chaining().canonical_name + + +def _maybe_import_update(): + # We avoid circular imports by doing this here. We do it in another + # function as doing it in _message_factory_from_opcode() makes "dns" + # a local symbol, and the first line fails :) + + # pylint: disable=redefined-outer-name,import-outside-toplevel,unused-import + import dns.update # noqa: F401 + + +def _message_factory_from_opcode(opcode): + if opcode == dns.opcode.QUERY: + return QueryMessage + elif opcode == dns.opcode.UPDATE: + _maybe_import_update() + return dns.update.UpdateMessage + else: + return Message + + +class _WireReader: + """Wire format reader. + + parser: the binary parser + message: The message object being built + initialize_message: Callback to set message parsing options + question_only: Are we only reading the question? + one_rr_per_rrset: Put each RR into its own RRset? + keyring: TSIG keyring + ignore_trailing: Ignore trailing junk at end of request? + multi: Is this message part of a multi-message sequence? + DNS dynamic updates. + continue_on_error: try to extract as much information as possible from + the message, accumulating MessageErrors in the *errors* attribute instead of + raising them. + """ + + def __init__( + self, + wire, + initialize_message, + question_only=False, + one_rr_per_rrset=False, + ignore_trailing=False, + keyring=None, + multi=False, + continue_on_error=False, + ): + self.parser = dns.wire.Parser(wire) + self.message = None + self.initialize_message = initialize_message + self.question_only = question_only + self.one_rr_per_rrset = one_rr_per_rrset + self.ignore_trailing = ignore_trailing + self.keyring = keyring + self.multi = multi + self.continue_on_error = continue_on_error + self.errors = [] + + def _get_question(self, section_number, qcount): + """Read the next *qcount* records from the wire data and add them to + the question section. + """ + assert self.message is not None + section = self.message.sections[section_number] + for _ in range(qcount): + qname = self.parser.get_name(self.message.origin) + (rdtype, rdclass) = self.parser.get_struct("!HH") + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, qname, rdclass, rdtype + ) + self.message.find_rrset( + section, qname, rdclass, rdtype, create=True, force_unique=True + ) + + def _add_error(self, e): + self.errors.append(MessageError(e, self.parser.current)) + + def _get_section(self, section_number, count): + """Read the next I{count} records from the wire data and add them to + the specified section. + + section_number: the section of the message to which to add records + count: the number of records to read + """ + assert self.message is not None + section = self.message.sections[section_number] + force_unique = self.one_rr_per_rrset + for i in range(count): + rr_start = self.parser.current + absolute_name = self.parser.get_name() + if self.message.origin is not None: + name = absolute_name.relativize(self.message.origin) + else: + name = absolute_name + (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct("!HHIH") + if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): + ( + rdclass, + rdtype, + deleting, + empty, + ) = self.message._parse_special_rr_header( + section_number, count, i, name, rdclass, rdtype + ) + else: + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + rdata_start = self.parser.current + try: + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE + else: + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser( + rdclass, rdtype, self.parser, self.message.origin + ) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: + if self.keyring is None: + raise UnknownTSIGKey("got signed message without keyring") + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) + else: + key = self.keyring + if key is None: + raise UnknownTSIGKey("key '%s' unknown" % name) + self.message.keyring = key + self.message.tsig_ctx = dns.tsig.validate( + self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi, + ) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) + else: + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + force_unique, + ) + if rd is not None: + if ttl > 0x7FFFFFFF: + ttl = 0 + rrset.add(rd, ttl) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + self.parser.seek(rdata_start + rdlen) + else: + raise + + def read(self): + """Read a wire format DNS message and build a dns.message.Message + object.""" + + if self.parser.remaining() < 12: + raise ShortHeader + (id, flags, qcount, ancount, aucount, adcount) = self.parser.get_struct( + "!HHHHHH" + ) + factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) + self.message = factory(id=id) + self.message.flags = dns.flags.Flag(flags) + self.initialize_message(self.message) + self.one_rr_per_rrset = self.message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) + try: + self._get_question(MessageSection.QUESTION, qcount) + if self.question_only: + return self.message + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: + raise TrailingJunk + if self.multi and self.message.tsig_ctx and not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + else: + raise + return self.message + + +def from_wire( + wire: bytes, + keyring: Optional[Any] = None, + request_mac: Optional[bytes] = b"", + xfr: bool = False, + origin: Optional[dns.name.Name] = None, + tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, + multi: bool = False, + question_only: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + continue_on_error: bool = False, +) -> Message: + """Convert a DNS wire format message into a message object. + + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message + is signed. + + *request_mac*, a ``bytes`` or ``None``. If the message is a response to a + TSIG-signed request, *request_mac* should be set to the MAC of that request. + + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a zone + transfer. + + *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone + transfer, *origin* should be the origin name of the zone. If not ``None``, names + will be relativized to the origin. + + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG + context, used when validating zone transfers. + + *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple + message sequence. + + *question_only*, a ``bool``. If ``True``, read only up to the end of the question + section. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is + set. + + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even if + errors occur. Erroneous rdata will be ignored. Errors will be accumulated as a + list of MessageError objects in the message's ``errors`` attribute. This option is + recommended only for DNS analysis tools, or for use in a server as part of an error + handling path. The default is ``False``. + + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. + + Raises ``dns.message.TrailingJunk`` if there were octets in the message past the end + of the proper DNS message, and *ignore_trailing* is ``False``. + + Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or + occurred more than once. + + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of the + additional data section. + + Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is + ``True``. + + Returns a ``dns.message.Message``. + """ + + # We permit None for request_mac solely for backwards compatibility + if request_mac is None: + request_mac = b"" + + def initialize_message(message): + message.request_mac = request_mac + message.xfr = xfr + message.origin = origin + message.tsig_ctx = tsig_ctx + + reader = _WireReader( + wire, + initialize_message, + question_only, + one_rr_per_rrset, + ignore_trailing, + keyring, + multi, + continue_on_error, + ) + try: + m = reader.read() + except dns.exception.FormError: + if ( + reader.message + and (reader.message.flags & dns.flags.TC) + and raise_on_truncation + ): + raise Truncated(message=reader.message) + else: + raise + # Reading a truncated message might not have any errors, so we + # have to do this check here too. + if m.flags & dns.flags.TC and raise_on_truncation: + raise Truncated(message=m) + if continue_on_error: + m.errors = reader.errors + + return m + + +class _TextReader: + """Text format reader. + + tok: the tokenizer. + message: The message object being built. + DNS dynamic updates. + last_name: The most recently read name when building a message object. + one_rr_per_rrset: Put each RR into its own RRset? + origin: The origin for relative names + relativize: relativize names? + relativize_to: the origin to relativize to. + """ + + def __init__( + self, + text, + idna_codec, + one_rr_per_rrset=False, + origin=None, + relativize=True, + relativize_to=None, + ): + self.message = None + self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) + self.last_name = None + self.one_rr_per_rrset = one_rr_per_rrset + self.origin = origin + self.relativize = relativize + self.relativize_to = relativize_to + self.id = None + self.edns = -1 + self.ednsflags = 0 + self.payload = DEFAULT_EDNS_PAYLOAD + self.rcode = None + self.opcode = dns.opcode.QUERY + self.flags = 0 + + def _header_line(self, _): + """Process one line from the text format header section.""" + + token = self.tok.get() + what = token.value + if what == "id": + self.id = self.tok.get_int() + elif what == "flags": + while True: + token = self.tok.get() + if not token.is_identifier(): + self.tok.unget(token) + break + self.flags = self.flags | dns.flags.from_text(token.value) + elif what == "edns": + self.edns = self.tok.get_int() + self.ednsflags = self.ednsflags | (self.edns << 16) + elif what == "eflags": + if self.edns < 0: + self.edns = 0 + while True: + token = self.tok.get() + if not token.is_identifier(): + self.tok.unget(token) + break + self.ednsflags = self.ednsflags | dns.flags.edns_from_text(token.value) + elif what == "payload": + self.payload = self.tok.get_int() + if self.edns < 0: + self.edns = 0 + elif what == "opcode": + text = self.tok.get_string() + self.opcode = dns.opcode.from_text(text) + self.flags = self.flags | dns.opcode.to_flags(self.opcode) + elif what == "rcode": + text = self.tok.get_string() + self.rcode = dns.rcode.from_text(text) + else: + raise UnknownHeaderField + self.tok.get_eol() + + def _question_line(self, section_number): + """Process one line from the text format question section.""" + + section = self.message.sections[section_number] + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) + name = self.last_name + if name is None: + raise NoPreviousName + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + rdclass = dns.rdataclass.IN + # Type + rdtype = dns.rdatatype.from_text(token.value) + (rdclass, rdtype, _, _) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + self.message.find_rrset( + section, name, rdclass, rdtype, create=True, force_unique=True + ) + self.tok.get_eol() + + def _rr_line(self, section_number): + """Process one line from the text format answer, authority, or + additional data sections. + """ + + section = self.message.sections[section_number] + # Name + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name( + token, self.message.origin, self.relativize, self.relativize_to + ) + name = self.last_name + if name is None: + raise NoPreviousName + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + # TTL + try: + ttl = int(token.value, 0) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + ttl = 0 + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + rdclass = dns.rdataclass.IN + # Type + rdtype = dns.rdatatype.from_text(token.value) + (rdclass, rdtype, deleting, empty) = self.message._parse_rr_header( + section_number, name, rdclass, rdtype + ) + token = self.tok.get() + if empty and not token.is_eol_or_eof(): + raise dns.exception.SyntaxError + if not empty and token.is_eol_or_eof(): + raise dns.exception.UnexpectedEnd + if not token.is_eol_or_eof(): + self.tok.unget(token) + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.message.origin, + self.relativize, + self.relativize_to, + ) + covers = rd.covers() + else: + rd = None + covers = dns.rdatatype.NONE + rrset = self.message.find_rrset( + section, + name, + rdclass, + rdtype, + covers, + deleting, + True, + self.one_rr_per_rrset, + ) + if rd is not None: + rrset.add(rd, ttl) + + def _make_message(self): + factory = _message_factory_from_opcode(self.opcode) + message = factory(id=self.id) + message.flags = self.flags + if self.edns >= 0: + message.use_edns(self.edns, self.ednsflags, self.payload) + if self.rcode: + message.set_rcode(self.rcode) + if self.origin: + message.origin = self.origin + return message + + def read(self): + """Read a text format DNS message and build a dns.message.Message + object.""" + + line_method = self._header_line + section_number = None + while 1: + token = self.tok.get(True, True) + if token.is_eol_or_eof(): + break + if token.is_comment(): + u = token.value.upper() + if u == "HEADER": + line_method = self._header_line + + if self.message: + message = self.message + else: + # If we don't have a message, create one with the current + # opcode, so that we know which section names to parse. + message = self._make_message() + try: + section_number = message._section_enum.from_text(u) + # We found a section name. If we don't have a message, + # use the one we just created. + if not self.message: + self.message = message + self.one_rr_per_rrset = message._get_one_rr_per_rrset( + self.one_rr_per_rrset + ) + if section_number == MessageSection.QUESTION: + line_method = self._question_line + else: + line_method = self._rr_line + except Exception: + # It's just a comment. + pass + self.tok.get_eol() + continue + self.tok.unget(token) + line_method(section_number) + if not self.message: + self.message = self._make_message() + return self.message + + +def from_text( + text: str, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Message: + """Convert the text format message into a message object. + + The reader stops after reading the first blank line in the input to + facilitate reading multiple messages from a single file with + ``dns.message.from_file()``. + + *text*, a ``str``, the text format message. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *one_rr_per_rrset*, a ``bool``. If ``True``, then each RR is put + into its own rrset. The default is ``False``. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ + + # 'text' can also be a file, but we don't publish that fact + # since it's an implementation detail. The official file + # interface is from_file(). + + reader = _TextReader( + text, idna_codec, one_rr_per_rrset, origin, relativize, relativize_to + ) + return reader.read() + + +def from_file( + f: Any, + idna_codec: Optional[dns.name.IDNACodec] = None, + one_rr_per_rrset: bool = False, +) -> Message: + """Read the next text format message from the specified file. + + Message blocks are separated by a single blank line. + + *f*, a ``file`` or ``str``. If *f* is text, it is treated as the + pathname of a file to open. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *one_rr_per_rrset*, a ``bool``. If ``True``, then each RR is put + into its own rrset. The default is ``False``. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ + + if isinstance(f, str): + cm: contextlib.AbstractContextManager = open(f) + else: + cm = contextlib.nullcontext(f) + with cm as f: + return from_text(f, idna_codec, one_rr_per_rrset) + assert False # for mypy lgtm[py/unreachable-statement] + + +def make_query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + use_edns: Optional[Union[int, bool]] = None, + want_dnssec: bool = False, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + id: Optional[int] = None, + flags: int = dns.flags.RD, + pad: int = 0, +) -> QueryMessage: + """Make a query message. + + The query name, type, and class may all be specified either + as objects of the appropriate type, or as strings. + + The query will have a randomly chosen query id, and its DNS flags + will be set to dns.flags.RD. + + qname, a ``dns.name.Name`` or ``str``, the query name. + + *rdtype*, an ``int`` or ``str``, the desired rdata type. + + *rdclass*, an ``int`` or ``str``, the desired rdata class; the default + is class IN. + + *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the + default is ``None``. If ``None``, EDNS will be enabled only if other + parameters (*ednsflags*, *payload*, *request_payload*, or *options*) are + set. + See the description of dns.message.Message.use_edns() for the possible + values for use_edns and their meanings. + + *want_dnssec*, a ``bool``. If ``True``, DNSSEC data is desired. + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when + sending this message. If not specified, defaults to the value of + *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *id*, an ``int`` or ``None``, the desired query id. The default is + ``None``, which generates a random query id. + + *flags*, an ``int``, the desired query flags. The default is + ``dns.flags.RD``. + + *pad*, a non-negative ``int``. If 0, the default, do not pad; otherwise add + padding bytes to make the message size a multiple of *pad*. Note that if + padding is non-zero, an EDNS PADDING option will always be added to the + message. + + Returns a ``dns.message.QueryMessage`` + """ + + if isinstance(qname, str): + qname = dns.name.from_text(qname, idna_codec=idna_codec) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + m = QueryMessage(id=id) + m.flags = dns.flags.Flag(flags) + m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) + # only pass keywords on to use_edns if they have been set to a + # non-None value. Setting a field will turn EDNS on if it hasn't + # been configured. + kwargs: Dict[str, Any] = {} + if ednsflags is not None: + kwargs["ednsflags"] = ednsflags + if payload is not None: + kwargs["payload"] = payload + if request_payload is not None: + kwargs["request_payload"] = request_payload + if options is not None: + kwargs["options"] = options + if kwargs and use_edns is None: + use_edns = 0 + kwargs["edns"] = use_edns + kwargs["pad"] = pad + m.use_edns(**kwargs) + m.want_dnssec(want_dnssec) + return m + + +def make_response( + query: Message, + recursion_available: bool = False, + our_payload: int = 8192, + fudge: int = 300, + tsig_error: int = 0, + pad: Optional[int] = None, +) -> Message: + """Make a message which is a response for the specified query. + The message returned is really a response skeleton; it has all of the infrastructure + required of a response, but none of the content. + + The response's question section is a shallow copy of the query's question section, + so the query's question RRsets should not be changed. + + *query*, a ``dns.message.Message``, the query to respond to. + + *recursion_available*, a ``bool``, should RA be set in the response? + + *our_payload*, an ``int``, the payload size to advertise in EDNS responses. + + *fudge*, an ``int``, the TSIG time fudge. + + *tsig_error*, an ``int``, the TSIG error. + + *pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise + if not ``None`` add padding bytes to make the message size a multiple of *pad*. + Note that if padding is non-zero, an EDNS PADDING option will always be added to the + message. If ``None``, add padding following RFC 8467, namely if the request is + padded, pad the response to 468 otherwise do not pad. + + Returns a ``dns.message.Message`` object whose specific class is appropriate for the + query. For example, if query is a ``dns.update.UpdateMessage``, response will be + too. + """ + + if query.flags & dns.flags.QR: + raise dns.exception.FormError("specified query message is not a query") + factory = _message_factory_from_opcode(query.opcode()) + response = factory(id=query.id) + response.flags = dns.flags.QR | (query.flags & dns.flags.RD) + if recursion_available: + response.flags |= dns.flags.RA + response.set_opcode(query.opcode()) + response.question = list(query.question) + if query.edns >= 0: + if pad is None: + # Set response padding per RFC 8467 + pad = 0 + for option in query.options: + if option.otype == dns.edns.OptionType.PADDING: + pad = 468 + response.use_edns(0, 0, our_payload, query.payload, pad=pad) + if query.had_tsig: + response.use_tsig( + query.keyring, + query.keyname, + fudge, + None, + tsig_error, + b"", + query.keyalgorithm, + ) + response.request_mac = query.mac + return response + + +### BEGIN generated MessageSection constants + +QUESTION = MessageSection.QUESTION +ANSWER = MessageSection.ANSWER +AUTHORITY = MessageSection.AUTHORITY +ADDITIONAL = MessageSection.ADDITIONAL + +### END generated MessageSection constants diff --git a/venv/Lib/site-packages/dns/name.py b/venv/Lib/site-packages/dns/name.py new file mode 100644 index 00000000..22ccb392 --- /dev/null +++ b/venv/Lib/site-packages/dns/name.py @@ -0,0 +1,1283 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Names. +""" + +import copy +import encodings.idna # type: ignore +import functools +import struct +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union + +import dns._features +import dns.enum +import dns.exception +import dns.immutable +import dns.wire + +if dns._features.have("idna"): + import idna # type: ignore + + have_idna_2008 = True +else: # pragma: no cover + have_idna_2008 = False + +CompressType = Dict["Name", int] + + +class NameRelation(dns.enum.IntEnum): + """Name relation result from fullcompare().""" + + # This is an IntEnum for backwards compatibility in case anyone + # has hardwired the constants. + + #: The compared names have no relationship to each other. + NONE = 0 + #: the first name is a superdomain of the second. + SUPERDOMAIN = 1 + #: The first name is a subdomain of the second. + SUBDOMAIN = 2 + #: The compared names are equal. + EQUAL = 3 + #: The compared names have a common ancestor. + COMMONANCESTOR = 4 + + @classmethod + def _maximum(cls): + return cls.COMMONANCESTOR + + @classmethod + def _short_name(cls): + return cls.__name__ + + +# Backwards compatibility +NAMERELN_NONE = NameRelation.NONE +NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN +NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN +NAMERELN_EQUAL = NameRelation.EQUAL +NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR + + +class EmptyLabel(dns.exception.SyntaxError): + """A DNS label is empty.""" + + +class BadEscape(dns.exception.SyntaxError): + """An escaped code in a text format of DNS name is invalid.""" + + +class BadPointer(dns.exception.FormError): + """A DNS compression pointer points forward instead of backward.""" + + +class BadLabelType(dns.exception.FormError): + """The label type in DNS name wire format is unknown.""" + + +class NeedAbsoluteNameOrOrigin(dns.exception.DNSException): + """An attempt was made to convert a non-absolute name to + wire when there was also a non-absolute (or missing) origin.""" + + +class NameTooLong(dns.exception.FormError): + """A DNS name is > 255 octets long.""" + + +class LabelTooLong(dns.exception.SyntaxError): + """A DNS label is > 63 octets long.""" + + +class AbsoluteConcatenation(dns.exception.DNSException): + """An attempt was made to append anything other than the + empty name to an absolute DNS name.""" + + +class NoParent(dns.exception.DNSException): + """An attempt was made to get the parent of the root name + or the empty name.""" + + +class NoIDNA2008(dns.exception.DNSException): + """IDNA 2008 processing was requested but the idna module is not + available.""" + + +class IDNAException(dns.exception.DNSException): + """IDNA processing raised an exception.""" + + supp_kwargs = {"idna_exception"} + fmt = "IDNA processing exception: {idna_exception}" + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class NeedSubdomainOfOrigin(dns.exception.DNSException): + """An absolute name was provided that is not a subdomain of the specified origin.""" + + +_escaped = b'"().;\\@$' +_escaped_text = '"().;\\@$' + + +def _escapify(label: Union[bytes, str]) -> str: + """Escape the characters in label which need it. + @returns: the escaped string + @rtype: string""" + if isinstance(label, bytes): + # Ordinary DNS label mode. Escape special characters and values + # < 0x20 or > 0x7f. + text = "" + for c in label: + if c in _escaped: + text += "\\" + chr(c) + elif c > 0x20 and c < 0x7F: + text += chr(c) + else: + text += "\\%03d" % c + return text + + # Unicode label mode. Escape only special characters and values < 0x20 + text = "" + for uc in label: + if uc in _escaped_text: + text += "\\" + uc + elif uc <= "\x20": + text += "\\%03d" % ord(uc) + else: + text += uc + return text + + +class IDNACodec: + """Abstract base class for IDNA encoder/decoders.""" + + def __init__(self): + pass + + def is_idna(self, label: bytes) -> bool: + return label.lower().startswith(b"xn--") + + def encode(self, label: str) -> bytes: + raise NotImplementedError # pragma: no cover + + def decode(self, label: bytes) -> str: + # We do not apply any IDNA policy on decode. + if self.is_idna(label): + try: + slabel = label[4:].decode("punycode") + return _escapify(slabel) + except Exception as e: + raise IDNAException(idna_exception=e) + else: + return _escapify(label) + + +class IDNA2003Codec(IDNACodec): + """IDNA 2003 encoder/decoder.""" + + def __init__(self, strict_decode: bool = False): + """Initialize the IDNA 2003 encoder/decoder. + + *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2008. The default is `False`. + """ + + super().__init__() + self.strict_decode = strict_decode + + def encode(self, label: str) -> bytes: + """Encode *label*.""" + + if label == "": + return b"" + try: + return encodings.idna.ToASCII(label) + except UnicodeError: + raise LabelTooLong + + def decode(self, label: bytes) -> str: + """Decode *label*.""" + if not self.strict_decode: + return super().decode(label) + if label == b"": + return "" + try: + return _escapify(encodings.idna.ToUnicode(label)) + except Exception as e: + raise IDNAException(idna_exception=e) + + +class IDNA2008Codec(IDNACodec): + """IDNA 2008 encoder/decoder.""" + + def __init__( + self, + uts_46: bool = False, + transitional: bool = False, + allow_pure_ascii: bool = False, + strict_decode: bool = False, + ): + """Initialize the IDNA 2008 encoder/decoder. + + *uts_46* is a ``bool``. If True, apply Unicode IDNA + compatibility processing as described in Unicode Technical + Standard #46 (https://unicode.org/reports/tr46/). + If False, do not apply the mapping. The default is False. + + *transitional* is a ``bool``: If True, use the + "transitional" mode described in Unicode Technical Standard + #46. The default is False. + + *allow_pure_ascii* is a ``bool``. If True, then a label which + consists of only ASCII characters is allowed. This is less + strict than regular IDNA 2008, but is also necessary for mixed + names, e.g. a name with starting with "_sip._tcp." and ending + in an IDN suffix which would otherwise be disallowed. The + default is False. + + *strict_decode* is a ``bool``: If True, then IDNA2008 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2003. The default is False. + """ + super().__init__() + self.uts_46 = uts_46 + self.transitional = transitional + self.allow_pure_ascii = allow_pure_ascii + self.strict_decode = strict_decode + + def encode(self, label: str) -> bytes: + if label == "": + return b"" + if self.allow_pure_ascii and is_all_ascii(label): + encoded = label.encode("ascii") + if len(encoded) > 63: + raise LabelTooLong + return encoded + if not have_idna_2008: + raise NoIDNA2008 + try: + if self.uts_46: + label = idna.uts46_remap(label, False, self.transitional) + return idna.alabel(label) + except idna.IDNAError as e: + if e.args[0] == "Label too long": + raise LabelTooLong + else: + raise IDNAException(idna_exception=e) + + def decode(self, label: bytes) -> str: + if not self.strict_decode: + return super().decode(label) + if label == b"": + return "" + if not have_idna_2008: + raise NoIDNA2008 + try: + ulabel = idna.ulabel(label) + if self.uts_46: + ulabel = idna.uts46_remap(ulabel, False, self.transitional) + return _escapify(ulabel) + except (idna.IDNAError, UnicodeError) as e: + raise IDNAException(idna_exception=e) + + +IDNA_2003_Practical = IDNA2003Codec(False) +IDNA_2003_Strict = IDNA2003Codec(True) +IDNA_2003 = IDNA_2003_Practical +IDNA_2008_Practical = IDNA2008Codec(True, False, True, False) +IDNA_2008_UTS_46 = IDNA2008Codec(True, False, False, False) +IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) +IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) +IDNA_2008 = IDNA_2008_Practical + + +def _validate_labels(labels: Tuple[bytes, ...]) -> None: + """Check for empty labels in the middle of a label sequence, + labels that are too long, and for too many labels. + + Raises ``dns.name.NameTooLong`` if the name as a whole is too long. + + Raises ``dns.name.EmptyLabel`` if a label is empty (i.e. the root + label) and appears in a position other than the end of the label + sequence + + """ + + l = len(labels) + total = 0 + i = -1 + j = 0 + for label in labels: + ll = len(label) + total += ll + 1 + if ll > 63: + raise LabelTooLong + if i < 0 and label == b"": + i = j + j += 1 + if total > 255: + raise NameTooLong + if i >= 0 and i != l - 1: + raise EmptyLabel + + +def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes: + """If label is ``str``, convert it to ``bytes``. If it is already + ``bytes`` just return it. + + """ + + if isinstance(label, bytes): + return label + if isinstance(label, str): + return label.encode() + raise ValueError # pragma: no cover + + +@dns.immutable.immutable +class Name: + """A DNS name. + + The dns.name.Name class represents a DNS name as a tuple of + labels. Each label is a ``bytes`` in DNS wire format. Instances + of the class are immutable. + """ + + __slots__ = ["labels"] + + def __init__(self, labels: Iterable[Union[bytes, str]]): + """*labels* is any iterable whose values are ``str`` or ``bytes``.""" + + blabels = [_maybe_convert_to_binary(x) for x in labels] + self.labels = tuple(blabels) + _validate_labels(self.labels) + + def __copy__(self): + return Name(self.labels) + + def __deepcopy__(self, memo): + return Name(copy.deepcopy(self.labels, memo)) + + def __getstate__(self): + # Names can be pickled + return {"labels": self.labels} + + def __setstate__(self, state): + super().__setattr__("labels", state["labels"]) + _validate_labels(self.labels) + + def is_absolute(self) -> bool: + """Is the most significant label of this name the root label? + + Returns a ``bool``. + """ + + return len(self.labels) > 0 and self.labels[-1] == b"" + + def is_wild(self) -> bool: + """Is this name wild? (I.e. Is the least significant label '*'?) + + Returns a ``bool``. + """ + + return len(self.labels) > 0 and self.labels[0] == b"*" + + def __hash__(self) -> int: + """Return a case-insensitive hash of the name. + + Returns an ``int``. + """ + + h = 0 + for label in self.labels: + for c in label.lower(): + h += (h << 3) + c + return h + + def fullcompare(self, other: "Name") -> Tuple[NameRelation, int, int]: + """Compare two names, returning a 3-tuple + ``(relation, order, nlabels)``. + + *relation* describes the relation ship between the names, + and is one of: ``dns.name.NameRelation.NONE``, + ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``, + ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``. + + *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == + 0 if *self* == *other*. A relative name is always less than an + absolute name. If both names have the same relativity, then + the DNSSEC order relation is used to order them. + + *nlabels* is the number of significant labels that the two names + have in common. + + Here are some examples. Names ending in "." are absolute names, + those not ending in "." are relative names. + + ============= ============= =========== ===== ======= + self other relation order nlabels + ============= ============= =========== ===== ======= + www.example. www.example. equal 0 3 + www.example. example. subdomain > 0 2 + example. www.example. superdomain < 0 2 + example1.com. example2.com. common anc. < 0 2 + example1 example2. none < 0 0 + example1. example2 none > 0 0 + ============= ============= =========== ===== ======= + """ + + sabs = self.is_absolute() + oabs = other.is_absolute() + if sabs != oabs: + if sabs: + return (NameRelation.NONE, 1, 0) + else: + return (NameRelation.NONE, -1, 0) + l1 = len(self.labels) + l2 = len(other.labels) + ldiff = l1 - l2 + if ldiff < 0: + l = l1 + else: + l = l2 + + order = 0 + nlabels = 0 + namereln = NameRelation.NONE + while l > 0: + l -= 1 + l1 -= 1 + l2 -= 1 + label1 = self.labels[l1].lower() + label2 = other.labels[l2].lower() + if label1 < label2: + order = -1 + if nlabels > 0: + namereln = NameRelation.COMMONANCESTOR + return (namereln, order, nlabels) + elif label1 > label2: + order = 1 + if nlabels > 0: + namereln = NameRelation.COMMONANCESTOR + return (namereln, order, nlabels) + nlabels += 1 + order = ldiff + if ldiff < 0: + namereln = NameRelation.SUPERDOMAIN + elif ldiff > 0: + namereln = NameRelation.SUBDOMAIN + else: + namereln = NameRelation.EQUAL + return (namereln, order, nlabels) + + def is_subdomain(self, other: "Name") -> bool: + """Is self a subdomain of other? + + Note that the notion of subdomain includes equality, e.g. + "dnspython.org" is a subdomain of itself. + + Returns a ``bool``. + """ + + (nr, _, _) = self.fullcompare(other) + if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL: + return True + return False + + def is_superdomain(self, other: "Name") -> bool: + """Is self a superdomain of other? + + Note that the notion of superdomain includes equality, e.g. + "dnspython.org" is a superdomain of itself. + + Returns a ``bool``. + """ + + (nr, _, _) = self.fullcompare(other) + if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL: + return True + return False + + def canonicalize(self) -> "Name": + """Return a name which is equal to the current name, but is in + DNSSEC canonical form. + """ + + return Name([x.lower() for x in self.labels]) + + def __eq__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] == 0 + else: + return False + + def __ne__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] != 0 + else: + return True + + def __lt__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] < 0 + else: + return NotImplemented + + def __le__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] <= 0 + else: + return NotImplemented + + def __ge__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] >= 0 + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, Name): + return self.fullcompare(other)[1] > 0 + else: + return NotImplemented + + def __repr__(self): + return "" + + def __str__(self): + return self.to_text(False) + + def to_text(self, omit_final_dot: bool = False) -> str: + """Convert name to DNS text format. + + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + + Returns a ``str``. + """ + + if len(self.labels) == 0: + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." + if omit_final_dot and self.is_absolute(): + l = self.labels[:-1] + else: + l = self.labels + s = ".".join(map(_escapify, l)) + return s + + def to_unicode( + self, omit_final_dot: bool = False, idna_codec: Optional[IDNACodec] = None + ) -> str: + """Convert name to Unicode text format. + + IDN ACE labels are converted to Unicode. + + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + *idna_codec* specifies the IDNA encoder/decoder. If None, the + dns.name.IDNA_2003_Practical encoder/decoder is used. + The IDNA_2003_Practical decoder does + not impose any policy, it just decodes punycode, so if you + don't want checking for compliance, you can use this decoder + for IDNA2008 as well. + + Returns a ``str``. + """ + + if len(self.labels) == 0: + return "@" + if len(self.labels) == 1 and self.labels[0] == b"": + return "." + if omit_final_dot and self.is_absolute(): + l = self.labels[:-1] + else: + l = self.labels + if idna_codec is None: + idna_codec = IDNA_2003_Practical + return ".".join([idna_codec.decode(x) for x in l]) + + def to_digestable(self, origin: Optional["Name"] = None) -> bytes: + """Convert name to a format suitable for digesting in hashes. + + The name is canonicalized and converted to uncompressed wire + format. All names in wire format are absolute. If the name + is a relative name, then an origin must be supplied. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then origin will be appended + to the name. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``bytes``. + """ + + digest = self.to_wire(origin=origin, canonicalize=True) + assert digest is not None + return digest + + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[CompressType] = None, + origin: Optional["Name"] = None, + canonicalize: bool = False, + ) -> Optional[bytes]: + """Convert name to wire format, possibly compressing it. + + *file* is the file where the name is emitted (typically an + io.BytesIO file). If ``None`` (the default), a ``bytes`` + containing the wire name will be returned. + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. Note that + the compression code assumes that compression offset 0 is the + start of *file*, and thus compression will not be correct + if this is not the case. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + *canonicalize*, a ``bool``, indicates whether the name should + be canonicalized; that is, converted to a format suitable for + digesting in hashes. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``bytes`` or ``None``. + """ + + if file is None: + out = bytearray() + for label in self.labels: + out.append(len(label)) + if canonicalize: + out += label.lower() + else: + out += label + if not self.is_absolute(): + if origin is None or not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + for label in origin.labels: + out.append(len(label)) + if canonicalize: + out += label.lower() + else: + out += label + return bytes(out) + + labels: Iterable[bytes] + if not self.is_absolute(): + if origin is None or not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + labels = list(self.labels) + labels.extend(list(origin.labels)) + else: + labels = self.labels + i = 0 + for label in labels: + n = Name(labels[i:]) + i += 1 + if compress is not None: + pos = compress.get(n) + else: + pos = None + if pos is not None: + value = 0xC000 + pos + s = struct.pack("!H", value) + file.write(s) + break + else: + if compress is not None and len(n) > 1: + pos = file.tell() + if pos <= 0x3FFF: + compress[n] = pos + l = len(label) + file.write(struct.pack("!B", l)) + if l > 0: + if canonicalize: + file.write(label.lower()) + else: + file.write(label) + return None + + def __len__(self) -> int: + """The length of the name (in labels). + + Returns an ``int``. + """ + + return len(self.labels) + + def __getitem__(self, index): + return self.labels[index] + + def __add__(self, other): + return self.concatenate(other) + + def __sub__(self, other): + return self.relativize(other) + + def split(self, depth: int) -> Tuple["Name", "Name"]: + """Split a name into a prefix and suffix names at the specified depth. + + *depth* is an ``int`` specifying the number of labels in the suffix + + Raises ``ValueError`` if *depth* was not >= 0 and <= the length of the + name. + + Returns the tuple ``(prefix, suffix)``. + """ + + l = len(self.labels) + if depth == 0: + return (self, dns.name.empty) + elif depth == l: + return (dns.name.empty, self) + elif depth < 0 or depth > l: + raise ValueError("depth must be >= 0 and <= the length of the name") + return (Name(self[:-depth]), Name(self[-depth:])) + + def concatenate(self, other: "Name") -> "Name": + """Return a new name which is the concatenation of self and other. + + Raises ``dns.name.AbsoluteConcatenation`` if the name is + absolute and *other* is not the empty name. + + Returns a ``dns.name.Name``. + """ + + if self.is_absolute() and len(other) > 0: + raise AbsoluteConcatenation + labels = list(self.labels) + labels.extend(list(other.labels)) + return Name(labels) + + def relativize(self, origin: "Name") -> "Name": + """If the name is a subdomain of *origin*, return a new name which is + the name relative to origin. Otherwise return the name. + + For example, relativizing ``www.dnspython.org.`` to origin + ``dnspython.org.`` returns the name ``www``. Relativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. + """ + + if origin is not None and self.is_subdomain(origin): + return Name(self[: -len(origin)]) + else: + return self + + def derelativize(self, origin: "Name") -> "Name": + """If the name is a relative name, return a new name which is the + concatenation of the name and origin. Otherwise return the name. + + For example, derelativizing ``www`` to origin ``dnspython.org.`` + returns the name ``www.dnspython.org.``. Derelativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. + """ + + if not self.is_absolute(): + return self.concatenate(origin) + else: + return self + + def choose_relativity( + self, origin: Optional["Name"] = None, relativize: bool = True + ) -> "Name": + """Return a name with the relativity desired by the caller. + + If *origin* is ``None``, then the name is returned. + Otherwise, if *relativize* is ``True`` the name is + relativized, and if *relativize* is ``False`` the name is + derelativized. + + Returns a ``dns.name.Name``. + """ + + if origin: + if relativize: + return self.relativize(origin) + else: + return self.derelativize(origin) + else: + return self + + def parent(self) -> "Name": + """Return the parent of the name. + + For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. + + Raises ``dns.name.NoParent`` if the name is either the root name or the + empty name, and thus has no parent. + + Returns a ``dns.name.Name``. + """ + + if self == root or self == empty: + raise NoParent + return Name(self.labels[1:]) + + def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the maximal predecessor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return the longest name under *origin* if the + name is origin (i.e. wrap around to the longest name, which may still be + *origin* due to length considerations. + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the predecessor will be absolute. + + *prefix_ok* indicates if prefixing labels is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a maximal predecessor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call( + _absolute_predecessor, self, origin, prefix_ok + ) + + def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the minimal successor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return *origin* if the successor cannot be + computed due to name length limitations. + + Note that *origin* is returned in the "too long" cases because wrapping + around to the origin is how NSEC records express "end of the zone". + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the successor will be absolute. + + *prefix_ok* indicates if prefixing a new minimal label is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a minimal successor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok) + + +#: The root name, '.' +root = Name([b""]) + +#: The empty name. +empty = Name([]) + + +def from_unicode( + text: str, origin: Optional[Name] = root, idna_codec: Optional[IDNACodec] = None +) -> Name: + """Convert unicode text into a Name object. + + Labels are encoded in IDN ACE form according to rules specified by + the IDNA codec. + + *text*, a ``str``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. + """ + + if not isinstance(text, str): + raise ValueError("input to from_unicode() must be a unicode string") + if not (origin is None or isinstance(origin, Name)): + raise ValueError("origin must be a Name or None") + labels = [] + label = "" + escaping = False + edigits = 0 + total = 0 + if idna_codec is None: + idna_codec = IDNA_2003 + if text == "@": + text = "" + if text: + if text in [".", "\u3002", "\uff0e", "\uff61"]: + return Name([b""]) # no Unicode "u" on this constant! + for c in text: + if escaping: + if edigits == 0: + if c.isdigit(): + total = int(c) + edigits += 1 + else: + label += c + escaping = False + else: + if not c.isdigit(): + raise BadEscape + total *= 10 + total += int(c) + edigits += 1 + if edigits == 3: + escaping = False + label += chr(total) + elif c in [".", "\u3002", "\uff0e", "\uff61"]: + if len(label) == 0: + raise EmptyLabel + labels.append(idna_codec.encode(label)) + label = "" + elif c == "\\": + escaping = True + edigits = 0 + total = 0 + else: + label += c + if escaping: + raise BadEscape + if len(label) > 0: + labels.append(idna_codec.encode(label)) + else: + labels.append(b"") + + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: + labels.extend(list(origin.labels)) + return Name(labels) + + +def is_all_ascii(text: str) -> bool: + for c in text: + if ord(c) > 0x7F: + return False + return True + + +def from_text( + text: Union[bytes, str], + origin: Optional[Name] = root, + idna_codec: Optional[IDNACodec] = None, +) -> Name: + """Convert text into a Name object. + + *text*, a ``bytes`` or ``str``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. + """ + + if isinstance(text, str): + if not is_all_ascii(text): + # Some codepoint in the input text is > 127, so IDNA applies. + return from_unicode(text, origin, idna_codec) + # The input is all ASCII, so treat this like an ordinary non-IDNA + # domain name. Note that "all ASCII" is about the input text, + # not the codepoints in the domain name. E.g. if text has value + # + # r'\150\151\152\153\154\155\156\157\158\159' + # + # then it's still "all ASCII" even though the domain name has + # codepoints > 127. + text = text.encode("ascii") + if not isinstance(text, bytes): + raise ValueError("input to from_text() must be a string") + if not (origin is None or isinstance(origin, Name)): + raise ValueError("origin must be a Name or None") + labels = [] + label = b"" + escaping = False + edigits = 0 + total = 0 + if text == b"@": + text = b"" + if text: + if text == b".": + return Name([b""]) + for c in text: + byte_ = struct.pack("!B", c) + if escaping: + if edigits == 0: + if byte_.isdigit(): + total = int(byte_) + edigits += 1 + else: + label += byte_ + escaping = False + else: + if not byte_.isdigit(): + raise BadEscape + total *= 10 + total += int(byte_) + edigits += 1 + if edigits == 3: + escaping = False + label += struct.pack("!B", total) + elif byte_ == b".": + if len(label) == 0: + raise EmptyLabel + labels.append(label) + label = b"" + elif byte_ == b"\\": + escaping = True + edigits = 0 + total = 0 + else: + label += byte_ + if escaping: + raise BadEscape + if len(label) > 0: + labels.append(label) + else: + labels.append(b"") + if (len(labels) == 0 or labels[-1] != b"") and origin is not None: + labels.extend(list(origin.labels)) + return Name(labels) + + +# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other. + + +def from_wire_parser(parser: "dns.wire.Parser") -> Name: + """Convert possibly compressed wire format into a Name. + + *parser* is a dns.wire.Parser. + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``dns.name.Name`` + """ + + labels = [] + biggest_pointer = parser.current + with parser.restore_furthest(): + count = parser.get_uint8() + while count != 0: + if count < 64: + labels.append(parser.get_bytes(count)) + elif count >= 192: + current = (count & 0x3F) * 256 + parser.get_uint8() + if current >= biggest_pointer: + raise BadPointer + biggest_pointer = current + parser.seek(current) + else: + raise BadLabelType + count = parser.get_uint8() + labels.append(b"") + return Name(labels) + + +def from_wire(message: bytes, current: int) -> Tuple[Name, int]: + """Convert possibly compressed wire format into a Name. + + *message* is a ``bytes`` containing an entire DNS message in DNS + wire form. + + *current*, an ``int``, is the offset of the beginning of the name + from the start of the message + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``(dns.name.Name, int)`` tuple consisting of the name + that was read and the number of bytes of the wire format message + which were consumed reading it. + """ + + if not isinstance(message, bytes): + raise ValueError("input to from_wire() must be a byte string") + parser = dns.wire.Parser(message, current) + name = from_wire_parser(parser) + return (name, parser.current - current) + + +# RFC 4471 Support + +_MINIMAL_OCTET = b"\x00" +_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET) +_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET]) +_MAXIMAL_OCTET = b"\xff" +_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET) +_AT_SIGN_VALUE = ord("@") +_LEFT_SQUARE_BRACKET_VALUE = ord("[") + + +def _wire_length(labels): + return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0) + + +def _pad_to_max_name(name): + needed = 255 - _wire_length(name.labels) + new_labels = [] + while needed > 64: + new_labels.append(_MAXIMAL_OCTET * 63) + needed -= 64 + if needed >= 2: + new_labels.append(_MAXIMAL_OCTET * (needed - 1)) + # Note we're already maximal in the needed == 1 case as while we'd like + # to add one more byte as a new label, we can't, as adding a new non-empty + # label requires at least 2 bytes. + new_labels = list(reversed(new_labels)) + new_labels.extend(name.labels) + return Name(new_labels) + + +def _pad_to_max_label(label, suffix_labels): + length = len(label) + # We have to subtract one here to account for the length byte of label. + remaining = 255 - _wire_length(suffix_labels) - length - 1 + if remaining <= 0: + # Shouldn't happen! + return label + needed = min(63 - length, remaining) + return label + _MAXIMAL_OCTET * needed + + +def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 predecessor algorithm using the "absolute method" of section + # 3.1.1. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if name == origin: + return _pad_to_max_name(name) + least_significant_label = name[0] + if least_significant_label == _MINIMAL_OCTET: + return name.parent() + least_octet = least_significant_label[-1] + suffix_labels = name.labels[1:] + if least_octet == _MINIMAL_OCTET_VALUE: + new_labels = [least_significant_label[:-1]] + else: + octets = bytearray(least_significant_label) + octet = octets[-1] + if octet == _LEFT_SQUARE_BRACKET_VALUE: + octet = _AT_SIGN_VALUE + else: + octet -= 1 + octets[-1] = octet + least_significant_label = bytes(octets) + new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)] + new_labels.extend(suffix_labels) + name = Name(new_labels) + if prefix_ok: + return _pad_to_max_name(name) + else: + return name + + +def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 successor algorithm using the "absolute method" of section + # 3.1.2. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if prefix_ok: + # Try prefixing \000 as new label + try: + return _SUCCESSOR_PREFIX.concatenate(name) + except NameTooLong: + pass + while name != origin: + # Try extending the least significant label. + least_significant_label = name[0] + if len(least_significant_label) < 63: + # We may be able to extend the least label with a minimal additional byte. + # This is only "may" because we could have a maximal length name even though + # the least significant label isn't maximally long. + new_labels = [least_significant_label + _MINIMAL_OCTET] + new_labels.extend(name.labels[1:]) + try: + return dns.name.Name(new_labels) + except dns.name.NameTooLong: + pass + # We can't extend the label either, so we'll try to increment the least + # signficant non-maximal byte in it. + octets = bytearray(least_significant_label) + # We do this reversed iteration with an explicit indexing variable because + # if we find something to increment, we're going to want to truncate everything + # to the right of it. + for i in range(len(octets) - 1, -1, -1): + octet = octets[i] + if octet == _MAXIMAL_OCTET_VALUE: + # We can't increment this, so keep looking. + continue + # Finally, something we can increment. We have to apply a special rule for + # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when + # comparing names, uppercase letters compare as if they were their + # lower-case equivalents. If we increment "@" to "A", then it would compare + # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have + # skipped the most minimal successor, namely "[". + if octet == _AT_SIGN_VALUE: + octet = _LEFT_SQUARE_BRACKET_VALUE + else: + octet += 1 + octets[i] = octet + # We can now truncate all of the maximal values we skipped (if any) + new_labels = [bytes(octets[: i + 1])] + new_labels.extend(name.labels[1:]) + # We haven't changed the length of the name, so the Name constructor will + # always work. + return Name(new_labels) + # We couldn't increment, so chop off the least significant label and try + # again. + name = name.parent() + + # We couldn't increment at all, so return the origin, as wrapping around is the + # DNSSEC way. + return origin + + +def _handle_relativity_and_call( + function: Callable[[Name, Name, bool], Name], + name: Name, + origin: Name, + prefix_ok: bool, +) -> Name: + # Make "name" absolute if needed, ensure that the origin is absolute, + # call function(), and then relativize the result if needed. + if not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + relative = not name.is_absolute() + if relative: + name = name.derelativize(origin) + elif not name.is_subdomain(origin): + raise NeedSubdomainOfOrigin + result_name = function(name, origin, prefix_ok) + if relative: + result_name = result_name.relativize(origin) + return result_name diff --git a/venv/Lib/site-packages/dns/namedict.py b/venv/Lib/site-packages/dns/namedict.py new file mode 100644 index 00000000..ca8b1978 --- /dev/null +++ b/venv/Lib/site-packages/dns/namedict.py @@ -0,0 +1,109 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# Copyright (C) 2016 Coresec Systems AB +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND CORESEC SYSTEMS AB DISCLAIMS ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL CORESEC +# SYSTEMS AB BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION +# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS name dictionary""" + +# pylint seems to be confused about this one! +from collections.abc import MutableMapping # pylint: disable=no-name-in-module + +import dns.name + + +class NameDict(MutableMapping): + """A dictionary whose keys are dns.name.Name objects. + + In addition to being like a regular Python dictionary, this + dictionary can also get the deepest match for a given key. + """ + + __slots__ = ["max_depth", "max_depth_items", "__store"] + + def __init__(self, *args, **kwargs): + super().__init__() + self.__store = dict() + #: the maximum depth of the keys that have ever been added + self.max_depth = 0 + #: the number of items of maximum depth + self.max_depth_items = 0 + self.update(dict(*args, **kwargs)) + + def __update_max_depth(self, key): + if len(key) == self.max_depth: + self.max_depth_items = self.max_depth_items + 1 + elif len(key) > self.max_depth: + self.max_depth = len(key) + self.max_depth_items = 1 + + def __getitem__(self, key): + return self.__store[key] + + def __setitem__(self, key, value): + if not isinstance(key, dns.name.Name): + raise ValueError("NameDict key must be a name") + self.__store[key] = value + self.__update_max_depth(key) + + def __delitem__(self, key): + self.__store.pop(key) + if len(key) == self.max_depth: + self.max_depth_items = self.max_depth_items - 1 + if self.max_depth_items == 0: + self.max_depth = 0 + for k in self.__store: + self.__update_max_depth(k) + + def __iter__(self): + return iter(self.__store) + + def __len__(self): + return len(self.__store) + + def has_key(self, key): + return key in self.__store + + def get_deepest_match(self, name): + """Find the deepest match to *name* in the dictionary. + + The deepest match is the longest name in the dictionary which is + a superdomain of *name*. Note that *superdomain* includes matching + *name* itself. + + *name*, a ``dns.name.Name``, the name to find. + + Returns a ``(key, value)`` where *key* is the deepest + ``dns.name.Name``, and *value* is the value associated with *key*. + """ + + depth = len(name) + if depth > self.max_depth: + depth = self.max_depth + for i in range(-depth, 0): + n = dns.name.Name(name[i:]) + if n in self: + return (n, self[n]) + v = self[dns.name.empty] + return (dns.name.empty, v) diff --git a/venv/Lib/site-packages/dns/nameserver.py b/venv/Lib/site-packages/dns/nameserver.py new file mode 100644 index 00000000..5dbb4e8b --- /dev/null +++ b/venv/Lib/site-packages/dns/nameserver.py @@ -0,0 +1,359 @@ +from typing import Optional, Union +from urllib.parse import urlparse + +import dns.asyncbackend +import dns.asyncquery +import dns.inet +import dns.message +import dns.query + + +class Nameserver: + def __init__(self): + pass + + def __str__(self): + raise NotImplementedError + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + raise NotImplementedError + + def answer_nameserver(self) -> str: + raise NotImplementedError + + def answer_port(self) -> int: + raise NotImplementedError + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + raise NotImplementedError + + +class AddressAndPortNameserver(Nameserver): + def __init__(self, address: str, port: int): + super().__init__() + self.address = address + self.port = port + + def kind(self) -> str: + raise NotImplementedError + + def is_always_max_size(self) -> bool: + return False + + def __str__(self): + ns_kind = self.kind() + return f"{ns_kind}:{self.address}@{self.port}" + + def answer_nameserver(self) -> str: + return self.address + + def answer_port(self) -> int: + return self.port + + +class Do53Nameserver(AddressAndPortNameserver): + def __init__(self, address: str, port: int = 53): + super().__init__(address, port) + + def kind(self): + return "Do53" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = dns.query.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = dns.query.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ignore_errors=True, + ignore_unexpected=True, + ) + return response + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + if max_size: + response = await dns.asyncquery.tcp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + else: + response = await dns.asyncquery.udp( + request, + self.address, + timeout=timeout, + port=self.port, + source=source, + source_port=source_port, + raise_on_truncation=True, + backend=backend, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ignore_errors=True, + ignore_unexpected=True, + ) + return response + + +class DoHNameserver(Nameserver): + def __init__( + self, + url: str, + bootstrap_address: Optional[str] = None, + verify: Union[bool, str] = True, + want_get: bool = False, + ): + super().__init__() + self.url = url + self.bootstrap_address = bootstrap_address + self.verify = verify + self.want_get = want_get + + def kind(self): + return "DoH" + + def is_always_max_size(self) -> bool: + return True + + def __str__(self): + return self.url + + def answer_nameserver(self) -> str: + return self.url + + def answer_port(self) -> int: + port = urlparse(self.url).port + if port is None: + port = 443 + return port + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.https( + request, + self.url, + timeout=timeout, + source=source, + source_port=source_port, + bootstrap_address=self.bootstrap_address, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + post=(not self.want_get), + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.https( + request, + self.url, + timeout=timeout, + source=source, + source_port=source_port, + bootstrap_address=self.bootstrap_address, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + post=(not self.want_get), + ) + + +class DoTNameserver(AddressAndPortNameserver): + def __init__( + self, + address: str, + port: int = 853, + hostname: Optional[str] = None, + verify: Union[bool, str] = True, + ): + super().__init__(address, port) + self.hostname = hostname + self.verify = verify + + def kind(self): + return "DoT" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + verify=self.verify, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.tls( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + server_hostname=self.hostname, + verify=self.verify, + ) + + +class DoQNameserver(AddressAndPortNameserver): + def __init__( + self, + address: str, + port: int = 853, + verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, + ): + super().__init__(address, port) + self.verify = verify + self.server_hostname = server_hostname + + def kind(self): + return "DoQ" + + def query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return dns.query.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) + + async def async_query( + self, + request: dns.message.QueryMessage, + timeout: float, + source: Optional[str], + source_port: int, + max_size: bool, + backend: dns.asyncbackend.Backend, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + ) -> dns.message.Message: + return await dns.asyncquery.quic( + request, + self.address, + port=self.port, + timeout=timeout, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + verify=self.verify, + server_hostname=self.server_hostname, + ) diff --git a/venv/Lib/site-packages/dns/node.py b/venv/Lib/site-packages/dns/node.py new file mode 100644 index 00000000..de85a82d --- /dev/null +++ b/venv/Lib/site-packages/dns/node.py @@ -0,0 +1,359 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS nodes. A node is a set of rdatasets.""" + +import enum +import io +from typing import Any, Dict, Optional + +import dns.immutable +import dns.name +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.renderer +import dns.rrset + +_cname_types = { + dns.rdatatype.CNAME, +} + +# "neutral" types can coexist with a CNAME and thus are not "other data" +_neutral_types = { + dns.rdatatype.NSEC, # RFC 4035 section 2.5 + dns.rdatatype.NSEC3, # This is not likely to happen, but not impossible! + dns.rdatatype.KEY, # RFC 4035 section 2.5, RFC 3007 +} + + +def _matches_type_or_its_signature(rdtypes, rdtype, covers): + return rdtype in rdtypes or (rdtype == dns.rdatatype.RRSIG and covers in rdtypes) + + +@enum.unique +class NodeKind(enum.Enum): + """Rdatasets in nodes""" + + REGULAR = 0 # a.k.a "other data" + NEUTRAL = 1 + CNAME = 2 + + @classmethod + def classify( + cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType + ) -> "NodeKind": + if _matches_type_or_its_signature(_cname_types, rdtype, covers): + return NodeKind.CNAME + elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): + return NodeKind.NEUTRAL + else: + return NodeKind.REGULAR + + @classmethod + def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> "NodeKind": + return cls.classify(rdataset.rdtype, rdataset.covers) + + +class Node: + """A Node is a set of rdatasets. + + A node is either a CNAME node or an "other data" node. A CNAME + node contains only CNAME, KEY, NSEC, and NSEC3 rdatasets along with their + covering RRSIG rdatasets. An "other data" node contains any + rdataset other than a CNAME or RRSIG(CNAME) rdataset. When + changes are made to a node, the CNAME or "other data" state is + always consistent with the update, i.e. the most recent change + wins. For example, if you have a node which contains a CNAME + rdataset, and then add an MX rdataset to it, then the CNAME + rdataset will be deleted. Likewise if you have a node containing + an MX rdataset and add a CNAME rdataset, the MX rdataset will be + deleted. + """ + + __slots__ = ["rdatasets"] + + def __init__(self): + # the set of rdatasets, represented as a list. + self.rdatasets = [] + + def to_text(self, name: dns.name.Name, **kw: Dict[str, Any]) -> str: + """Convert a node to text format. + + Each rdataset at the node is printed. Any keyword arguments + to this method are passed on to the rdataset's to_text() method. + + *name*, a ``dns.name.Name``, the owner name of the + rdatasets. + + Returns a ``str``. + + """ + + s = io.StringIO() + for rds in self.rdatasets: + if len(rds) > 0: + s.write(rds.to_text(name, **kw)) # type: ignore[arg-type] + s.write("\n") + return s.getvalue()[:-1] + + def __repr__(self): + return "" + + def __eq__(self, other): + # + # This is inefficient. Good thing we don't need to do it much. + # + for rd in self.rdatasets: + if rd not in other.rdatasets: + return False + for rd in other.rdatasets: + if rd not in self.rdatasets: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __len__(self): + return len(self.rdatasets) + + def __iter__(self): + return iter(self.rdatasets) + + def _append_rdataset(self, rdataset): + """Append rdataset to the node with special handling for CNAME and + other data conditions. + + Specifically, if the rdataset being appended has ``NodeKind.CNAME``, + then all rdatasets other than KEY, NSEC, NSEC3, and their covering + RRSIGs are deleted. If the rdataset being appended has + ``NodeKind.REGULAR`` then CNAME and RRSIG(CNAME) are deleted. + """ + # Make having just one rdataset at the node fast. + if len(self.rdatasets) > 0: + kind = NodeKind.classify_rdataset(rdataset) + if kind == NodeKind.CNAME: + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.REGULAR + ] + elif kind == NodeKind.REGULAR: + self.rdatasets = [ + rds + for rds in self.rdatasets + if NodeKind.classify_rdataset(rds) != NodeKind.CNAME + ] + # Otherwise the rdataset is NodeKind.NEUTRAL and we do not need to + # edit self.rdatasets. + self.rdatasets.append(rdataset) + + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + """Find an rdataset matching the specified properties in the + current node. + + *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset. + + *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset. + + *covers*, a ``dns.rdatatype.RdataType``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Raises ``KeyError`` if an rdataset of the desired type and class does + not exist and *create* is not ``True``. + + Returns a ``dns.rdataset.Rdataset``. + """ + + for rds in self.rdatasets: + if rds.match(rdclass, rdtype, covers): + return rds + if not create: + raise KeyError + rds = dns.rdataset.Rdataset(rdclass, rdtype, covers) + self._append_rdataset(rds) + return rds + + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + """Get an rdataset matching the specified properties in the + current node. + + None is returned if an rdataset of the specified type and + class does not exist and *create* is not ``True``. + + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. Usually this value is + dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or + dns.rdatatype.RRSIG, then the covers value will be the rdata + type the SIG/RRSIG covers. The library treats the SIG and RRSIG + types as if they were a family of + types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much + easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Returns a ``dns.rdataset.Rdataset`` or ``None``. + """ + + try: + rds = self.find_rdataset(rdclass, rdtype, covers, create) + except KeyError: + rds = None + return rds + + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: + """Delete the rdataset matching the specified properties in the + current node. + + If a matching rdataset does not exist, it is not an error. + + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. + """ + + rds = self.get_rdataset(rdclass, rdtype, covers) + if rds is not None: + self.rdatasets.remove(rds) + + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: + """Replace an rdataset. + + It is not an error if there is no rdataset matching *replacement*. + + Ownership of the *replacement* object is transferred to the node; + in other words, this method does not store a copy of *replacement* + at the node, it stores *replacement* itself. + + *replacement*, a ``dns.rdataset.Rdataset``. + + Raises ``ValueError`` if *replacement* is not a + ``dns.rdataset.Rdataset``. + """ + + if not isinstance(replacement, dns.rdataset.Rdataset): + raise ValueError("replacement is not an rdataset") + if isinstance(replacement, dns.rrset.RRset): + # RRsets are not good replacements as the match() method + # is not compatible. + replacement = replacement.to_rdataset() + self.delete_rdataset( + replacement.rdclass, replacement.rdtype, replacement.covers + ) + self._append_rdataset(replacement) + + def classify(self) -> NodeKind: + """Classify a node. + + A node which contains a CNAME or RRSIG(CNAME) is a + ``NodeKind.CNAME`` node. + + A node which contains only "neutral" types, i.e. types allowed to + co-exist with a CNAME, is a ``NodeKind.NEUTRAL`` node. The neutral + types are NSEC, NSEC3, KEY, and their associated RRSIGS. An empty node + is also considered neutral. + + A node which contains some rdataset which is not a CNAME, RRSIG(CNAME), + or a neutral type is a a ``NodeKind.REGULAR`` node. Regular nodes are + also commonly referred to as "other data". + """ + for rdataset in self.rdatasets: + kind = NodeKind.classify(rdataset.rdtype, rdataset.covers) + if kind != NodeKind.NEUTRAL: + return kind + return NodeKind.NEUTRAL + + def is_immutable(self) -> bool: + return False + + +@dns.immutable.immutable +class ImmutableNode(Node): + def __init__(self, node): + super().__init__() + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: + raise TypeError("immutable") + + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: + raise TypeError("immutable") + + def is_immutable(self) -> bool: + return True diff --git a/venv/Lib/site-packages/dns/opcode.py b/venv/Lib/site-packages/dns/opcode.py new file mode 100644 index 00000000..78b43d2c --- /dev/null +++ b/venv/Lib/site-packages/dns/opcode.py @@ -0,0 +1,117 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Opcodes.""" + +import dns.enum +import dns.exception + + +class Opcode(dns.enum.IntEnum): + #: Query + QUERY = 0 + #: Inverse Query (historical) + IQUERY = 1 + #: Server Status (unspecified and unimplemented anywhere) + STATUS = 2 + #: Notify + NOTIFY = 4 + #: Dynamic Update + UPDATE = 5 + + @classmethod + def _maximum(cls): + return 15 + + @classmethod + def _unknown_exception_class(cls): + return UnknownOpcode + + +class UnknownOpcode(dns.exception.DNSException): + """An DNS opcode is unknown.""" + + +def from_text(text: str) -> Opcode: + """Convert text into an opcode. + + *text*, a ``str``, the textual opcode + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns an ``int``. + """ + + return Opcode.from_text(text) + + +def from_flags(flags: int) -> Opcode: + """Extract an opcode from DNS message flags. + + *flags*, an ``int``, the DNS flags. + + Returns an ``int``. + """ + + return Opcode((flags & 0x7800) >> 11) + + +def to_flags(value: Opcode) -> int: + """Convert an opcode to a value suitable for ORing into DNS message + flags. + + *value*, an ``int``, the DNS opcode value. + + Returns an ``int``. + """ + + return (value << 11) & 0x7800 + + +def to_text(value: Opcode) -> str: + """Convert an opcode to text. + + *value*, an ``int`` the opcode value, + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns a ``str``. + """ + + return Opcode.to_text(value) + + +def is_update(flags: int) -> bool: + """Is the opcode in flags UPDATE? + + *flags*, an ``int``, the DNS message flags. + + Returns a ``bool``. + """ + + return from_flags(flags) == Opcode.UPDATE + + +### BEGIN generated Opcode constants + +QUERY = Opcode.QUERY +IQUERY = Opcode.IQUERY +STATUS = Opcode.STATUS +NOTIFY = Opcode.NOTIFY +UPDATE = Opcode.UPDATE + +### END generated Opcode constants diff --git a/venv/Lib/site-packages/dns/py.typed b/venv/Lib/site-packages/dns/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/venv/Lib/site-packages/dns/query.py b/venv/Lib/site-packages/dns/query.py new file mode 100644 index 00000000..f0ee9161 --- /dev/null +++ b/venv/Lib/site-packages/dns/query.py @@ -0,0 +1,1578 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Talk to a DNS server.""" + +import base64 +import contextlib +import enum +import errno +import os +import os.path +import selectors +import socket +import struct +import time +from typing import Any, Dict, Optional, Tuple, Union + +import dns._features +import dns.exception +import dns.inet +import dns.message +import dns.name +import dns.quic +import dns.rcode +import dns.rdataclass +import dns.rdatatype +import dns.serial +import dns.transaction +import dns.tsig +import dns.xfr + + +def _remaining(expiration): + if expiration is None: + return None + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + return timeout + + +def _expiration_for_this_attempt(timeout, expiration): + if expiration is None: + return None + return min(time.time() + timeout, expiration) + + +_have_httpx = dns._features.have("doh") +if _have_httpx: + import httpcore._backends.sync + import httpx + + _CoreNetworkBackend = httpcore.NetworkBackend + _CoreSyncStream = httpcore._backends.sync.SyncStream + + class _NetworkBackend(_CoreNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + def connect_tcp( + self, host, port, timeout, local_address, socket_options=None + ): # pylint: disable=signature-differs + addresses = [] + _, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = dns.inet.low_level_address_tuple( + (local_address, self._local_port), af + ) + else: + source = None + sock = _make_socket(af, socket.SOCK_STREAM, source) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + try: + _connect( + sock, + dns.inet.low_level_address_tuple((address, port), af), + attempt_expiration, + ) + return _CoreSyncStream(sock) + except Exception: + pass + raise httpcore.ConnectError + + def connect_unix_socket( + self, path, timeout, socket_options=None + ): # pylint: disable=signature-differs + raise NotImplementedError + + class _HTTPTransport(httpx.HTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.resolver + + resolver = dns.resolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +else: + + class _HTTPTransport: # type: ignore + def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + +have_doh = _have_httpx + +try: + import ssl +except ImportError: # pragma: no cover + + class ssl: # type: ignore + CERT_NONE = 0 + + class WantReadException(Exception): + pass + + class WantWriteException(Exception): + pass + + class SSLContext: + pass + + class SSLSocket: + pass + + @classmethod + def create_default_context(cls, *args, **kwargs): + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + +# Function used to create a socket. Can be overridden if needed in special +# situations. +socket_factory = socket.socket + + +class UnexpectedSource(dns.exception.DNSException): + """A DNS query response came from an unexpected address or port.""" + + +class BadResponse(dns.exception.FormError): + """A DNS query response does not respond to the question asked.""" + + +class NoDOH(dns.exception.DNSException): + """DNS over HTTPS (DOH) was requested but the httpx module is not + available.""" + + +class NoDOQ(dns.exception.DNSException): + """DNS over QUIC (DOQ) was requested but the aioquic module is not + available.""" + + +# for backwards compatibility +TransferError = dns.xfr.TransferError + + +def _compute_times(timeout): + now = time.time() + if timeout is None: + return (now, None) + else: + return (now, now + timeout) + + +def _wait_for(fd, readable, writable, _, expiration): + # Use the selected selector class to wait for any of the specified + # events. An "expiration" absolute time is converted into a relative + # timeout. + # + # The unused parameter is 'error', which is always set when + # selecting for read or write, and we have no error-only selects. + + if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: + return True + sel = _selector_class() + events = 0 + if readable: + events |= selectors.EVENT_READ + if writable: + events |= selectors.EVENT_WRITE + if events: + sel.register(fd, events) + if expiration is None: + timeout = None + else: + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + if not sel.select(timeout): + raise dns.exception.Timeout + + +def _set_selector_class(selector_class): + # Internal API. Do not use. + + global _selector_class + + _selector_class = selector_class + + +if hasattr(selectors, "PollSelector"): + # Prefer poll() on platforms that support it because it has no + # limits on the maximum value of a file descriptor (plus it will + # be more efficient for high values). + # + # We ignore typing here as we can't say _selector_class is Any + # on python < 3.8 due to a bug. + _selector_class = selectors.PollSelector # type: ignore +else: + _selector_class = selectors.SelectSelector # type: ignore + + +def _wait_for_readable(s, expiration): + _wait_for(s, True, False, True, expiration) + + +def _wait_for_writable(s, expiration): + _wait_for(s, False, True, True, expiration) + + +def _addresses_equal(af, a1, a2): + # Convert the first value of the tuple, which is a textual format + # address into binary form, so that we are not confused by different + # textual representations of the same address + try: + n1 = dns.inet.inet_pton(af, a1[0]) + n2 = dns.inet.inet_pton(af, a2[0]) + except dns.exception.SyntaxError: + return False + return n1 == n2 and a1[1:] == a2[1:] + + +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or ( + dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:] + ): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource( + f"got a response from {from_address} instead of " f"{destination}" + ) + + +def _destination_and_source( + where, port, source, source_port, where_must_be_address=True +): + # Apply defaults and compute destination and source tuples + # suitable for use in connect(), sendto(), or bind(). + af = None + destination = None + try: + af = dns.inet.af_for_address(where) + destination = where + except Exception: + if where_must_be_address: + raise + # URLs are ok so eat the exception + if source: + saf = dns.inet.af_for_address(source) + if af: + # We know the destination af, so source had better agree! + if saf != af: + raise ValueError( + "different address families for source and destination" + ) + else: + # We didn't know the destination af, but we know the source, + # so that's our af. + af = saf + if source_port and not source: + # Caller has specified a source_port but not an address, so we + # need to return a source, and we need to use the appropriate + # wildcard address as the address. + try: + source = dns.inet.any_for_af(af) + except Exception: + # we catch this and raise ValueError for backwards compatibility + raise ValueError("source_port specified but address family is unknown") + # Convert high-level (address, port) tuples into low-level address + # tuples. + if destination: + destination = dns.inet.low_level_address_tuple((destination, port), af) + if source: + source = dns.inet.low_level_address_tuple((source, source_port), af) + return (af, destination, source) + + +def _make_socket(af, type, source, ssl_context=None, server_hostname=None): + s = socket_factory(af, type) + try: + s.setblocking(False) + if source is not None: + s.bind(source) + if ssl_context: + # LGTM gets a false positive here, as our default context is OK + return ssl_context.wrap_socket( + s, + do_handshake_on_connect=False, # lgtm[py/insecure-protocol] + server_hostname=server_hostname, + ) + else: + return s + except Exception: + s.close() + raise + + +def https( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 443, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + session: Optional[Any] = None, + path: str = "/dns-query", + post: bool = True, + bootstrap_address: Optional[str] = None, + verify: Union[bool, str] = True, + resolver: Optional["dns.resolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, +) -> dns.message.Message: + """Return the response obtained after sending a query via DNS-over-HTTPS. + + *q*, a ``dns.message.Message``, the query to send. + + *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is + given, the URL will be constructed using the following schema: + https://:/. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. + + *port*, a ``int``, the port to send the query to. The default is 443. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is + 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. + + *session*, an ``httpx.Client``. If provided, the client session to use to send the + queries. + + *path*, a ``str``. If *where* is an IP address, then *path* will be used to + construct the URL to send the DNS query to. + + *post*, a ``bool``. If ``True``, the default, POST method will be used. + + *bootstrap_address*, a ``str``, the IP address to use to bypass resolution. + + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames in URLs. If not specified, a new resolver with a default + configuration will be used; note this is *not* the default resolver as that resolver + might have been configured to use DoH causing a chicken-and-egg problem. This + parameter only has an effect if the HTTP library is httpx. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A + and AAAA records will be retrieved. + + Returns a ``dns.message.Message``. + """ + + if not have_doh: + raise NoDOH # pragma: no cover + if session and not isinstance(session, httpx.Client): + raise ValueError("session parameter must be an httpx.Client") + + wire = q.to_wire() + (af, _, the_source) = _destination_and_source( + where, port, source, source_port, False + ) + transport = None + headers = {"accept": "application/dns-message"} + if af is not None and dns.inet.is_address(where): + if af == socket.AF_INET: + url = "https://{}:{}{}".format(where, port, path) + elif af == socket.AF_INET6: + url = "https://[{}]:{}{}".format(where, port, path) + else: + url = where + + # set source port and source address + + if the_source is None: + local_address = None + local_port = 0 + else: + local_address = the_source[0] + local_port = the_source[1] + transport = _HTTPTransport( + local_address=local_address, + http1=True, + http2=True, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) + + if session: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) + else: + cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) + with cm as session: + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update( + { + "content-type": "application/dns-message", + "content-length": str(len(wire)), + } + ) + response = session.post(url, headers=headers, content=wire, timeout=timeout) + else: + wire = base64.urlsafe_b64encode(wire).rstrip(b"=") + twire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) + + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes + if response.status_code < 200 or response.status_code > 299: + raise ValueError( + "{} responded with status code {}" + "\nResponse body: {}".format(where, response.status_code, response.content) + ) + r = dns.message.from_wire( + response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = response.elapsed.total_seconds() + if not q.is_response(r): + raise BadResponse + return r + + +def _udp_recv(sock, max_size, expiration): + """Reads a datagram from the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + return sock.recvfrom(max_size) + except BlockingIOError: + _wait_for_readable(sock, expiration) + + +def _udp_send(sock, data, destination, expiration): + """Sends the specified datagram to destination over the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + if destination: + return sock.sendto(data, destination) + else: + return sock.send(data) + except BlockingIOError: # pragma: no cover + _wait_for_writable(sock, expiration) + + +def send_udp( + sock: Any, + what: Union[dns.message.Message, bytes], + destination: Any, + expiration: Optional[float] = None, +) -> Tuple[int, float]: + """Send a DNS message to the specified UDP socket. + + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + sent_time = time.time() + n = _udp_send(sock, what, destination, expiration) + return (n, sent_time) + + +def receive_udp( + sock: Any, + destination: Optional[Any] = None, + expiration: Optional[float] = None, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + ignore_errors: bool = False, + query: Optional[dns.message.Message] = None, +) -> Any: + """Read a DNS message from a UDP socket. + + *sock*, a ``socket``. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. + + *ignore_errors*, a ``bool``. If various format errors or response + mismatches occur, ignore them and keep listening for a valid response. + The default is ``False``. + + *query*, a ``dns.message.Message`` or ``None``. If not ``None`` and + *ignore_errors* is ``True``, check that the received message is a response + to this query, and if not keep listening for a valid response. + """ + + wire = b"" + while True: + (wire, from_address) = _udp_recv(sock, 65535, expiration) + if not _matches_destination( + sock.family, from_address, destination, ignore_unexpected + ): + continue + received_time = time.time() + try: + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation, + ) + except dns.message.Truncated as e: + # If we got Truncated and not FORMERR, we at least got the header with TC + # set, and very likely the question section, so we'll re-raise if the + # message seems to be a response as we need to know when truncation happens. + # We need to check that it seems to be a response as we don't want a random + # injected message with TC set to cause us to bail out. + if ( + ignore_errors + and query is not None + and not query.is_response(e.message()) + ): + continue + else: + raise + except Exception: + if ignore_errors: + continue + else: + raise + if ignore_errors and query is not None and not query.is_response(r): + continue + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) + + +def udp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + raise_on_truncation: bool = False, + sock: Optional[Any] = None, + ignore_errors: bool = False, +) -> dns.message.Message: + """Return the response obtained after sending a query via UDP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + *sock*, a ``socket.socket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking datagram socket, + and the *source* and *source_port* are ignored. + + *ignore_errors*, a ``bool``. If various format errors or response + mismatches occur, ignore them and keep listening for a valid response. + The default is ``False``. + + Returns a ``dns.message.Message``. + """ + + wire = q.to_wire() + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + (begin_time, expiration) = _compute_times(timeout) + if sock: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) + else: + cm = _make_socket(af, socket.SOCK_DGRAM, source) + with cm as s: + send_udp(s, wire, destination, expiration) + (r, received_time) = receive_udp( + s, + destination, + expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, + q.mac, + ignore_trailing, + raise_on_truncation, + ignore_errors, + q, + ) + r.time = received_time - begin_time + # We don't need to check q.is_response() if we are in ignore_errors mode + # as receive_udp() will have checked it. + if not (ignore_errors or q.is_response(r)): + raise BadResponse + return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) + + +def udp_with_fallback( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + ignore_unexpected: bool = False, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + udp_sock: Optional[Any] = None, + tcp_sock: Optional[Any] = None, + ignore_errors: bool = False, +) -> Tuple[dns.message.Message, bool]: + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is + 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected + sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. + + *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query. + If ``None``, the default, a socket is created. Note that if a socket is provided, + it must be a nonblocking datagram socket, and the *source* and *source_port* are + ignored for the UDP query. + + *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the + TCP query. If ``None``, the default, a socket is created. Note that if a socket is + provided, it must be a nonblocking connected stream socket, and *where*, *source* + and *source_port* are ignored for the TCP query. + + *ignore_errors*, a ``bool``. If various format errors or response mismatches occur + while listening for UDP, ignore them and keep listening for a valid response. The + default is ``False``. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if + TCP was used. + """ + try: + response = udp( + q, + where, + timeout, + port, + source, + source_port, + ignore_unexpected, + one_rr_per_rrset, + ignore_trailing, + True, + udp_sock, + ignore_errors, + ) + return (response, False) + except dns.message.Truncated: + response = tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + tcp_sock, + ) + return (response, True) + + +def _net_read(sock, count, expiration): + """Read the specified number of bytes from sock. Keep trying until we + either get the desired amount, or we hit EOF. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + s = b"" + while count > 0: + try: + n = sock.recv(count) + if n == b"": + raise EOFError + count -= len(n) + s += n + except (BlockingIOError, ssl.SSLWantReadError): + _wait_for_readable(sock, expiration) + except ssl.SSLWantWriteError: # pragma: no cover + _wait_for_writable(sock, expiration) + return s + + +def _net_write(sock, data, expiration): + """Write the specified data to the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + current = 0 + l = len(data) + while current < l: + try: + current += sock.send(data[current:]) + except (BlockingIOError, ssl.SSLWantWriteError): + _wait_for_writable(sock, expiration) + except ssl.SSLWantReadError: # pragma: no cover + _wait_for_readable(sock, expiration) + + +def send_tcp( + sock: Any, + what: Union[dns.message.Message, bytes], + expiration: Optional[float] = None, +) -> Tuple[int, float]: + """Send a DNS message to the specified TCP socket. + + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + tcpmsg = what.to_wire(prepend_length=True) + else: + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = len(what).to_bytes(2, "big") + what + sent_time = time.time() + _net_write(sock, tcpmsg, expiration) + return (len(tcpmsg), sent_time) + + +def receive_tcp( + sock: Any, + expiration: Optional[float] = None, + one_rr_per_rrset: bool = False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + request_mac: Optional[bytes] = b"", + ignore_trailing: bool = False, +) -> Tuple[dns.message.Message, float]: + """Read a DNS message from a TCP socket. + + *sock*, a ``socket``. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``(dns.message.Message, float)`` tuple of the received message + and the received time. + """ + + ldata = _net_read(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire( + wire, + keyring=keyring, + request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + return (r, received_time) + + +def _connect(s, address, expiration): + err = s.connect_ex(address) + if err == 0: + return + if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY): + _wait_for_writable(s, expiration) + err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, os.strerror(err)) + + +def tcp( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 53, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[Any] = None, +) -> dns.message.Message: + """Return the response obtained after sending a query via TCP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected stream + socket, and *where*, *port*, *source* and *source_port* are ignored. + + Returns a ``dns.message.Message``. + """ + + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + if sock: + cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) + else: + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + cm = _make_socket(af, socket.SOCK_STREAM, source) + with cm as s: + if not sock: + _connect(s, destination, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) + + +def _tls_handshake(s, expiration): + while True: + try: + s.do_handshake() + return + except ssl.SSLWantReadError: + _wait_for_readable(s, expiration) + except ssl.SSLWantWriteError: # pragma: no cover + _wait_for_writable(s, expiration) + + +def _make_dot_ssl_context( + server_hostname: Optional[str], verify: Union[bool, str] +) -> ssl.SSLContext: + cafile: Optional[str] = None + capath: Optional[str] = None + if isinstance(verify, str): + if os.path.isfile(verify): + cafile = verify + elif os.path.isdir(verify): + capath = verify + else: + raise ValueError("invalid verify string") + ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + if server_hostname is None: + ssl_context.check_hostname = False + ssl_context.set_alpn_protocols(["dot"]) + if verify is False: + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + +def tls( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + sock: Optional[ssl.SSLSocket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, +) -> dns.message.Message: + """Return the response obtained after sending a query via TLS. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for + the query. If ``None``, the default, a socket is created. Note + that if a socket is provided, it must be a nonblocking connected + SSL stream socket, and *where*, *port*, *source*, *source_port*, + and *ssl_context* are ignored. + + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing + a TLS connection. If ``None``, the default, creates one with the default + configuration. + + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + + Returns a ``dns.message.Message``. + + """ + + if sock: + # + # If a socket was provided, there's no special TLS handling needed. + # + return tcp( + q, + where, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + sock, + ) + + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + if ssl_context is None and not sock: + ssl_context = _make_dot_ssl_context(server_hostname, verify) + + with _make_socket( + af, + socket.SOCK_STREAM, + source, + ssl_context=ssl_context, + server_hostname=server_hostname, + ) as s: + _connect(s, destination, expiration) + _tls_handshake(s, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp( + s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing + ) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + assert ( + False # help mypy figure out we can't get here lgtm[py/unreachable-statement] + ) + + +def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.SyncQuicConnection] = None, + verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, +) -> dns.message.Message: + """Return the response obtained after sending a query via DNS-over-QUIC. + + *q*, a ``dns.message.Message``, the query to send. + + *where*, a ``str``, the nameserver IP address. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. + + *port*, a ``int``, the port to send the query to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is + 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. + + *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the + connection to use to send the query. + + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + + Returns a ``dns.message.Message``. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.SyncQuicConnection + the_manager: dns.quic.SyncQuicManager + if connection: + manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) + the_connection = connection + else: + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=server_hostname + ) + the_manager = manager # for type checking happiness + + with manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + with the_connection.make_stream(timeout) as stream: + stream.send(wire, True) + wire = stream.receive(_remaining(expiration)) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + +def xfr( + where: str, + zone: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + timeout: Optional[float] = None, + port: int = 53, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, + keyname: Optional[Union[dns.name.Name, str]] = None, + relativize: bool = True, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + serial: int = 0, + use_udp: bool = False, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Any: + """Return a generator for the responses to a zone transfer. + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer. + + *rdtype*, an ``int`` or ``str``, the type of zone transfer. The + default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be + used to do an incremental transfer instead. + + *rdclass*, an ``int`` or ``str``, the class of the zone transfer. + The default is ``dns.rdataclass.IN``. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG + key to use. + + *relativize*, a ``bool``. If ``True``, all names in the zone will be + relativized to the zone origin. It is essential that the + relativize setting matches the one specified to + ``dns.zone.from_xfr()`` if using this generator to make a zone. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *serial*, an ``int``, the SOA serial number to use as the base for + an IXFR diff sequence (only meaningful if *rdtype* is + ``dns.rdatatype.IXFR``). + + *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR). + + *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. + + Raises on errors, and so does the generator. + + Returns a generator of ``dns.message.Message`` objects. + """ + + if isinstance(zone, str): + zone = dns.name.from_text(zone) + rdtype = dns.rdatatype.RdataType.make(rdtype) + q = dns.message.make_query(zone, rdtype, rdclass) + if rdtype == dns.rdatatype.IXFR: + rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) + q.authority.append(rrset) + if keyring is not None: + q.use_tsig(keyring, keyname, algorithm=keyalgorithm) + wire = q.to_wire() + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + if use_udp and rdtype != dns.rdatatype.IXFR: + raise ValueError("cannot do a UDP AXFR") + sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM + with _make_socket(af, sock_type, source) as s: + (_, expiration) = _compute_times(lifetime) + _connect(s, destination, expiration) + l = len(wire) + if use_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", l) + wire + _net_write(s, tcpmsg, expiration) + done = False + delete_mode = True + expecting_SOA = False + soa_rrset = None + if relativize: + origin = zone + oname = dns.name.empty + else: + origin = None + oname = zone + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if use_udp: + (wire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(s, l, mexpiration) + is_ixfr = rdtype == dns.rdatatype.IXFR + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=True, + one_rr_per_rrset=is_ixfr, + ) + rcode = r.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) + tsig_ctx = r.tsig_ctx + answer_index = 0 + if soa_rrset is None: + if not r.answer or r.answer[0].name != oname: + raise dns.exception.FormError("No answer or RRset not for qname") + rrset = r.answer[0] + if rrset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError("first RRset is not an SOA") + answer_index = 1 + soa_rrset = rrset.copy() + if rdtype == dns.rdatatype.IXFR: + if dns.serial.Serial(soa_rrset[0].serial) <= serial: + # + # We're already up-to-date. + # + done = True + else: + expecting_SOA = True + # + # Process SOAs in the answer section (other than the initial + # SOA in the first message). + # + for rrset in r.answer[answer_index:]: + if done: + raise dns.exception.FormError("answers after final SOA") + if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: + if expecting_SOA: + if rrset[0].serial != serial: + raise dns.exception.FormError("IXFR base serial mismatch") + expecting_SOA = False + elif rdtype == dns.rdatatype.IXFR: + delete_mode = not delete_mode + # + # If this SOA RRset is equal to the first we saw then we're + # finished. If this is an IXFR we also check that we're + # seeing the record in the expected part of the response. + # + if rrset == soa_rrset and ( + rdtype == dns.rdatatype.AXFR + or (rdtype == dns.rdatatype.IXFR and delete_mode) + ): + done = True + elif expecting_SOA: + # + # We made an IXFR request and are expecting another + # SOA RR, but saw something else, so this must be an + # AXFR response. + # + rdtype = dns.rdatatype.AXFR + expecting_SOA = False + if done and q.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") + yield r + + +class UDPMode(enum.IntEnum): + """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? + + NEVER means "never use UDP; always use TCP" + TRY_FIRST means "try to use UDP but fall back to TCP if needed" + ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" + """ + + NEVER = 0 + TRY_FIRST = 1 + ONLY = 2 + + +def inbound_xfr( + where: str, + txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message] = None, + port: int = 53, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + source: Optional[str] = None, + source_port: int = 0, + udp_mode: UDPMode = UDPMode.NEVER, +) -> None: + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager + for this transfer (typically a ``dns.zone.Zone``). + + *query*, the query to send. If not supplied, a default query is + constructed using information from the *txn_manager*. + + *port*, an ``int``, the port send the message to. The default is 53. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used + for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use + TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which + means "try UDP but fallback to TCP if needed", and + ``dns.UDPMode.ONLY``, which means "try UDP and raise + ``dns.xfr.UseTCP`` if it does not succeed. + + Raises on errors. + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + (af, destination, source) = _destination_and_source( + where, port, source, source_port + ) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + with _make_socket(af, sock_type, source) as s: + _connect(s, destination, expiration) + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or ( + expiration is not None and mexpiration > expiration + ): + mexpiration = expiration + if is_udp: + (rwire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire( + rwire, + keyring=query.keyring, + request_mac=query.mac, + xfr=True, + origin=origin, + tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr, + ) + try: + done = inbound.process_message(r) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") diff --git a/venv/Lib/site-packages/dns/quic/__init__.py b/venv/Lib/site-packages/dns/quic/__init__.py new file mode 100644 index 00000000..20aff345 --- /dev/null +++ b/venv/Lib/site-packages/dns/quic/__init__.py @@ -0,0 +1,75 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns._features +import dns.asyncbackend + +if dns._features.have("doq"): + import aioquic.quic.configuration # type: ignore + + from dns._asyncbackend import NullContext + from dns.quic._asyncio import ( + AsyncioQuicConnection, + AsyncioQuicManager, + AsyncioQuicStream, + ) + from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream + + have_quic = True + + def null_factory( + *args, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + return NullContext(None) + + def _asyncio_manager_factory( + context, *args, **kwargs # pylint: disable=unused-argument + ): + return AsyncioQuicManager(*args, **kwargs) + + # We have a context factory and a manager factory as for trio we need to have + # a nursery. + + _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} + + if dns._features.have("trio"): + import trio + + from dns.quic._trio import ( # pylint: disable=ungrouped-imports + TrioQuicConnection, + TrioQuicManager, + TrioQuicStream, + ) + + def _trio_context_factory(): + return trio.open_nursery() + + def _trio_manager_factory(context, *args, **kwargs): + return TrioQuicManager(context, *args, **kwargs) + + _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) + + def factories_for_backend(backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + return _async_factories[backend.name()] + +else: # pragma: no cover + have_quic = False + + from typing import Any + + class AsyncQuicStream: # type: ignore + pass + + class AsyncQuicConnection: # type: ignore + async def make_stream(self) -> Any: + raise NotImplementedError + + class SyncQuicStream: # type: ignore + pass + + class SyncQuicConnection: # type: ignore + def make_stream(self) -> Any: + raise NotImplementedError diff --git a/venv/Lib/site-packages/dns/quic/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/quic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..189878e1 Binary files /dev/null and b/venv/Lib/site-packages/dns/quic/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/quic/__pycache__/_asyncio.cpython-312.pyc b/venv/Lib/site-packages/dns/quic/__pycache__/_asyncio.cpython-312.pyc new file mode 100644 index 00000000..918200d7 Binary files /dev/null and b/venv/Lib/site-packages/dns/quic/__pycache__/_asyncio.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/quic/__pycache__/_common.cpython-312.pyc b/venv/Lib/site-packages/dns/quic/__pycache__/_common.cpython-312.pyc new file mode 100644 index 00000000..ca52141b Binary files /dev/null and b/venv/Lib/site-packages/dns/quic/__pycache__/_common.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/quic/__pycache__/_sync.cpython-312.pyc b/venv/Lib/site-packages/dns/quic/__pycache__/_sync.cpython-312.pyc new file mode 100644 index 00000000..9288aea0 Binary files /dev/null and b/venv/Lib/site-packages/dns/quic/__pycache__/_sync.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/quic/__pycache__/_trio.cpython-312.pyc b/venv/Lib/site-packages/dns/quic/__pycache__/_trio.cpython-312.pyc new file mode 100644 index 00000000..91631dee Binary files /dev/null and b/venv/Lib/site-packages/dns/quic/__pycache__/_trio.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/quic/_asyncio.py b/venv/Lib/site-packages/dns/quic/_asyncio.py new file mode 100644 index 00000000..0f44331f --- /dev/null +++ b/venv/Lib/site-packages/dns/quic/_asyncio.py @@ -0,0 +1,228 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore + +import dns.asyncbackend +import dns.exception +import dns.inet +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + AsyncQuicConnection, + AsyncQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + + +class AsyncioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = asyncio.Condition() + + async def _wait_for_wake_up(self): + async with self._wake_up: + await self._wake_up.wait() + + async def wait_for(self, amount, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + if self._buffer.have(amount): + return + self._expecting = amount + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + raise dns.exception.Timeout + self._expecting = 0 + + async def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class AsyncioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = None + self._handshake_complete = asyncio.Event() + self._socket_created = asyncio.Event() + self._wake_timer = asyncio.Condition() + self._receiver_task = None + self._sender_task = None + + async def _receiver(self): + try: + af = dns.inet.af_for_address(self._address) + backend = dns.asyncbackend.get_backend("asyncio") + # Note that peer is a low-level address tuple, but make_socket() wants + # a high-level address tuple, so we convert. + self._socket = await backend.make_socket( + af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1]) + ) + self._socket_created.set() + async with self._socket: + while not self._done: + (datagram, address) = await self._socket.recvfrom( + QUIC_MAX_DATAGRAM, None + ) + if address[0] != self._peer[0] or address[1] != self._peer[1]: + continue + self._connection.receive_datagram(datagram, address, time.time()) + # Wake up the timer in case the sender is sleeping, as there may be + # stuff to send now. + async with self._wake_timer: + self._wake_timer.notify_all() + except Exception: + pass + finally: + self._done = True + async with self._wake_timer: + self._wake_timer.notify_all() + self._handshake_complete.set() + + async def _wait_for_wake_timer(self): + async with self._wake_timer: + await self._wake_timer.wait() + + async def _sender(self): + await self._socket_created.wait() + while not self._done: + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, address in datagrams: + assert address == self._peer + await self._socket.sendto(datagram, self._peer, None) + (expiration, interval) = self._get_timer_values() + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass + self._handle_timer(expiration) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self._done = True + self._receiver_task.cancel() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + + count += 1 + if count > 10: + # yield + count = 0 + await asyncio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + async with self._wake_timer: + self._wake_timer.notify_all() + + def run(self): + if self._closed: + return + self._receiver_task = asyncio.Task(self._receiver()) + self._sender_task = asyncio.Task(self._sender()) + + async def make_stream(self, timeout=None): + try: + await asyncio.wait_for(self._handshake_complete.wait(), timeout) + except TimeoutError: + raise dns.exception.Timeout + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = AsyncioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + # sender might be blocked on this, so set it + self._socket_created.set() + async with self._wake_timer: + self._wake_timer.notify_all() + try: + await self._receiver_task + except asyncio.CancelledError: + pass + try: + await self._sender_task + except asyncio.CancelledError: + pass + await self._socket.close() + + +class AsyncioQuicManager(AsyncQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) + + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) + if start: + connection.run() + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/venv/Lib/site-packages/dns/quic/_common.py b/venv/Lib/site-packages/dns/quic/_common.py new file mode 100644 index 00000000..0eacc691 --- /dev/null +++ b/venv/Lib/site-packages/dns/quic/_common.py @@ -0,0 +1,224 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import copy +import functools +import socket +import struct +import time +from typing import Any, Optional + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore + +import dns.inet + +QUIC_MAX_DATAGRAM = 2048 +MAX_SESSION_TICKETS = 8 +# If we hit the max sessions limit we will delete this many of the oldest connections. +# The value must be a integer > 0 and <= MAX_SESSION_TICKETS. +SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4 + + +class UnexpectedEOF(Exception): + pass + + +class Buffer: + def __init__(self): + self._buffer = b"" + self._seen_end = False + + def put(self, data, is_end): + if self._seen_end: + return + self._buffer += data + if is_end: + self._seen_end = True + + def have(self, amount): + if len(self._buffer) >= amount: + return True + if self._seen_end: + raise UnexpectedEOF + return False + + def seen_end(self): + return self._seen_end + + def get(self, amount): + assert self.have(amount) + data = self._buffer[:amount] + self._buffer = self._buffer[amount:] + return data + + +class BaseQuicStream: + def __init__(self, connection, stream_id): + self._connection = connection + self._stream_id = stream_id + self._buffer = Buffer() + self._expecting = 0 + + def id(self): + return self._stream_id + + def _expiration_from_timeout(self, timeout): + if timeout is not None: + expiration = time.time() + timeout + else: + expiration = None + return expiration + + def _timeout_from_expiration(self, expiration): + if expiration is not None: + timeout = max(expiration - time.time(), 0.0) + else: + timeout = None + return timeout + + # Subclass must implement receive() as sync / async and which returns a message + # or raises UnexpectedEOF. + + def _encapsulate(self, datagram): + l = len(datagram) + return struct.pack("!H", l) + datagram + + def _common_add_input(self, data, is_end): + self._buffer.put(data, is_end) + try: + return self._expecting > 0 and self._buffer.have(self._expecting) + except UnexpectedEOF: + return True + + def _close(self): + self._connection.close_stream(self._stream_id) + self._buffer.put(b"", True) # send EOF in case we haven't seen it. + + +class BaseQuicConnection: + def __init__( + self, connection, address, port, source=None, source_port=0, manager=None + ): + self._done = False + self._connection = connection + self._address = address + self._port = port + self._closed = False + self._manager = manager + self._streams = {} + self._af = dns.inet.af_for_address(address) + self._peer = dns.inet.low_level_address_tuple((address, port)) + if source is None and source_port != 0: + if self._af == socket.AF_INET: + source = "0.0.0.0" + elif self._af == socket.AF_INET6: + source = "::" + else: + raise NotImplementedError + if source: + self._source = (source, source_port) + else: + self._source = None + + def close_stream(self, stream_id): + del self._streams[stream_id] + + def _get_timer_values(self, closed_is_special=True): + now = time.time() + expiration = self._connection.get_timer() + if expiration is None: + expiration = now + 3600 # arbitrary "big" value + interval = max(expiration - now, 0) + if self._closed and closed_is_special: + # lower sleep interval to avoid a race in the closing process + # which can lead to higher latency closing due to sleeping when + # we have events. + interval = min(interval, 0.05) + return (expiration, interval) + + def _handle_timer(self, expiration): + now = time.time() + if expiration <= now: + self._connection.handle_timer(now) + + +class AsyncQuicConnection(BaseQuicConnection): + async def make_stream(self, timeout: Optional[float] = None) -> Any: + pass + + +class BaseQuicManager: + def __init__(self, conf, verify_mode, connection_factory, server_name=None): + self._connections = {} + self._connection_factory = connection_factory + self._session_tickets = {} + if conf is None: + verify_path = None + if isinstance(verify_mode, str): + verify_path = verify_mode + verify_mode = True + conf = aioquic.quic.configuration.QuicConfiguration( + alpn_protocols=["doq", "doq-i03"], + verify_mode=verify_mode, + server_name=server_name, + ) + if verify_path is not None: + conf.load_verify_locations(verify_path) + self._conf = conf + + def _connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + connection = self._connections.get((address, port)) + if connection is not None: + return (connection, False) + conf = self._conf + if want_session_ticket: + try: + session_ticket = self._session_tickets.pop((address, port)) + # We found a session ticket, so make a configuration that uses it. + conf = copy.copy(conf) + conf.session_ticket = session_ticket + except KeyError: + # No session ticket. + pass + # Whether or not we found a session ticket, we want a handler to save + # one. + session_ticket_handler = functools.partial( + self.save_session_ticket, address, port + ) + else: + session_ticket_handler = None + qconn = aioquic.quic.connection.QuicConnection( + configuration=conf, + session_ticket_handler=session_ticket_handler, + ) + lladdress = dns.inet.low_level_address_tuple((address, port)) + qconn.connect(lladdress, time.time()) + connection = self._connection_factory( + qconn, address, port, source, source_port, self + ) + self._connections[(address, port)] = connection + return (connection, True) + + def closed(self, address, port): + try: + del self._connections[(address, port)] + except KeyError: + pass + + def save_session_ticket(self, address, port, ticket): + # We rely on dictionaries keys() being in insertion order here. We + # can't just popitem() as that would be LIFO which is the opposite of + # what we want. + l = len(self._session_tickets) + if l >= MAX_SESSION_TICKETS: + keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE] + for key in keys_to_delete: + del self._session_tickets[key] + self._session_tickets[(address, port)] = ticket + + +class AsyncQuicManager(BaseQuicManager): + def connect(self, address, port=853, source=None, source_port=0): + raise NotImplementedError diff --git a/venv/Lib/site-packages/dns/quic/_sync.py b/venv/Lib/site-packages/dns/quic/_sync.py new file mode 100644 index 00000000..120cb5f3 --- /dev/null +++ b/venv/Lib/site-packages/dns/quic/_sync.py @@ -0,0 +1,238 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import selectors +import socket +import ssl +import struct +import threading +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore + +import dns.exception +import dns.inet +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + BaseQuicConnection, + BaseQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + +# Avoid circularity with dns.query +if hasattr(selectors, "PollSelector"): + _selector_class = selectors.PollSelector # type: ignore +else: + _selector_class = selectors.SelectSelector # type: ignore + + +class SyncQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = threading.Condition() + self._lock = threading.Lock() + + def wait_for(self, amount, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + with self._lock: + if self._buffer.have(amount): + return + self._expecting = amount + with self._wake_up: + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout + self._expecting = 0 + + def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) + + def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + self._connection.write(self._stream_id, data, is_end) + + def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + with self._wake_up: + self._wake_up.notify() + + def close(self): + with self._lock: + self._close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + with self._wake_up: + self._wake_up.notify() + return False + + +class SyncQuicConnection(BaseQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) + if self._source is not None: + try: + self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + except Exception: + self._socket.close() + raise + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) + self._handshake_complete = threading.Event() + self._worker_thread = None + self._lock = threading.Lock() + + def _read(self): + count = 0 + while count < 10: + count += 1 + try: + datagram = self._socket.recv(QUIC_MAX_DATAGRAM) + except BlockingIOError: + return + with self._lock: + self._connection.receive_datagram(datagram, self._peer, time.time()) + + def _drain_wakeup(self): + while True: + try: + self._receive_wakeup.recv(32) + except BlockingIOError: + return + + def _worker(self): + try: + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for key, _ in items: + key.data() + with self._lock: + self._handle_timer(expiration) + self._handle_events() + with self._lock: + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + finally: + with self._lock: + self._done = True + # Ensure anyone waiting for this gets woken up. + self._handshake_complete.set() + + def _handle_events(self): + while True: + with self._lock: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + with self._lock: + self._done = True + elif isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(b"", True) + + def write(self, stream, data, is_end=False): + with self._lock: + self._connection.send_stream_data(stream, data, is_end) + self._send_wakeup.send(b"\x01") + + def run(self): + if self._closed: + return + self._worker_thread = threading.Thread(target=self._worker) + self._worker_thread.start() + + def make_stream(self, timeout=None): + if not self._handshake_complete.wait(timeout): + raise dns.exception.Timeout + with self._lock: + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = SyncQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + def close_stream(self, stream_id): + with self._lock: + super().close_stream(stream_id) + + def close(self): + with self._lock: + if self._closed: + return + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_wakeup.send(b"\x01") + self._worker_thread.join() + + +class SyncQuicManager(BaseQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name) + self._lock = threading.Lock() + + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + with self._lock: + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) + if start: + connection.run() + return connection + + def closed(self, address, port): + with self._lock: + super().closed(address, port) + + def save_session_ticket(self, address, port, ticket): + with self._lock: + super().save_session_ticket(address, port, ticket) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + connection.close() + return False diff --git a/venv/Lib/site-packages/dns/quic/_trio.py b/venv/Lib/site-packages/dns/quic/_trio.py new file mode 100644 index 00000000..35e36b98 --- /dev/null +++ b/venv/Lib/site-packages/dns/quic/_trio.py @@ -0,0 +1,210 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import trio + +import dns.exception +import dns.inet +from dns._asyncbackend import NullContext +from dns.quic._common import ( + QUIC_MAX_DATAGRAM, + AsyncQuicConnection, + AsyncQuicManager, + BaseQuicStream, + UnexpectedEOF, +) + + +class TrioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = trio.Condition() + + async def wait_for(self, amount): + while True: + if self._buffer.have(amount): + return + self._expecting = amount + async with self._wake_up: + await self._wake_up.wait() + self._expecting = 0 + + async def receive(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) + raise dns.exception.Timeout + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class TrioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) + self._handshake_complete = trio.Event() + self._run_done = trio.Event() + self._worker_scope = None + self._send_pending = False + + async def _worker(self): + try: + if self._source: + await self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + if self._send_pending: + # Do not block forever if sends are pending. Even though we + # have a wake-up mechanism if we've already started the blocking + # read, the possibility of context switching in send means that + # more writes can happen while we have no wake up context, so + # we need self._send_pending to avoid (effectively) a "lost wakeup" + # race. + interval = 0.0 + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram(datagram, self._peer, time.time()) + self._worker_scope = None + self._handle_timer(expiration) + await self._handle_events() + # We clear this now, before sending anything, as sending can cause + # context switches that do more sends. We want to know if that + # happens so we don't block a long time on the recv() above. + self._send_pending = False + datagrams = self._connection.datagrams_to_send(time.time()) + for datagram, _ in datagrams: + await self._socket.send(datagram) + finally: + self._done = True + self._handshake_complete.set() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self._done = True + self._socket.close() + elif isinstance(event, aioquic.quic.events.StreamReset): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(b"", True) + count += 1 + if count > 10: + # yield + count = 0 + await trio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + self._send_pending = True + if self._worker_scope is not None: + self._worker_scope.cancel() + + async def run(self): + if self._closed: + return + async with trio.open_nursery() as nursery: + nursery.start_soon(self._worker) + self._run_done.set() + + async def make_stream(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self._handshake_complete.wait() + if self._done: + raise UnexpectedEOF + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + raise dns.exception.Timeout + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_pending = True + if self._worker_scope is not None: + self._worker_scope.cancel() + await self._run_done.wait() + + +class TrioQuicManager(AsyncQuicManager): + def __init__( + self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + ): + super().__init__(conf, verify_mode, TrioQuicConnection, server_name) + self._nursery = nursery + + def connect( + self, address, port=853, source=None, source_port=0, want_session_ticket=True + ): + (connection, start) = self._connect( + address, port, source, source_port, want_session_ticket + ) + if start: + self._nursery.start_soon(connection.run) + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the iterator into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/venv/Lib/site-packages/dns/rcode.py b/venv/Lib/site-packages/dns/rcode.py new file mode 100644 index 00000000..8e6386f8 --- /dev/null +++ b/venv/Lib/site-packages/dns/rcode.py @@ -0,0 +1,168 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Result Codes.""" + +from typing import Tuple + +import dns.enum +import dns.exception + + +class Rcode(dns.enum.IntEnum): + #: No error + NOERROR = 0 + #: Format error + FORMERR = 1 + #: Server failure + SERVFAIL = 2 + #: Name does not exist ("Name Error" in RFC 1025 terminology). + NXDOMAIN = 3 + #: Not implemented + NOTIMP = 4 + #: Refused + REFUSED = 5 + #: Name exists. + YXDOMAIN = 6 + #: RRset exists. + YXRRSET = 7 + #: RRset does not exist. + NXRRSET = 8 + #: Not authoritative. + NOTAUTH = 9 + #: Name not in zone. + NOTZONE = 10 + #: DSO-TYPE Not Implemented + DSOTYPENI = 11 + #: Bad EDNS version. + BADVERS = 16 + #: TSIG Signature Failure + BADSIG = 16 + #: Key not recognized. + BADKEY = 17 + #: Signature out of time window. + BADTIME = 18 + #: Bad TKEY Mode. + BADMODE = 19 + #: Duplicate key name. + BADNAME = 20 + #: Algorithm not supported. + BADALG = 21 + #: Bad Truncation + BADTRUNC = 22 + #: Bad/missing Server Cookie + BADCOOKIE = 23 + + @classmethod + def _maximum(cls): + return 4095 + + @classmethod + def _unknown_exception_class(cls): + return UnknownRcode + + +class UnknownRcode(dns.exception.DNSException): + """A DNS rcode is unknown.""" + + +def from_text(text: str) -> Rcode: + """Convert text into an rcode. + + *text*, a ``str``, the textual rcode or an integer in textual form. + + Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. + + Returns a ``dns.rcode.Rcode``. + """ + + return Rcode.from_text(text) + + +def from_flags(flags: int, ednsflags: int) -> Rcode: + """Return the rcode value encoded by flags and ednsflags. + + *flags*, an ``int``, the DNS flags field. + + *ednsflags*, an ``int``, the EDNS flags field. + + Raises ``ValueError`` if rcode is < 0 or > 4095 + + Returns a ``dns.rcode.Rcode``. + """ + + value = (flags & 0x000F) | ((ednsflags >> 20) & 0xFF0) + return Rcode.make(value) + + +def to_flags(value: Rcode) -> Tuple[int, int]: + """Return a (flags, ednsflags) tuple which encodes the rcode. + + *value*, a ``dns.rcode.Rcode``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns an ``(int, int)`` tuple. + """ + + if value < 0 or value > 4095: + raise ValueError("rcode must be >= 0 and <= 4095") + v = value & 0xF + ev = (value & 0xFF0) << 20 + return (v, ev) + + +def to_text(value: Rcode, tsig: bool = False) -> str: + """Convert rcode into text. + + *value*, a ``dns.rcode.Rcode``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns a ``str``. + """ + + if tsig and value == Rcode.BADVERS: + return "BADSIG" + return Rcode.to_text(value) + + +### BEGIN generated Rcode constants + +NOERROR = Rcode.NOERROR +FORMERR = Rcode.FORMERR +SERVFAIL = Rcode.SERVFAIL +NXDOMAIN = Rcode.NXDOMAIN +NOTIMP = Rcode.NOTIMP +REFUSED = Rcode.REFUSED +YXDOMAIN = Rcode.YXDOMAIN +YXRRSET = Rcode.YXRRSET +NXRRSET = Rcode.NXRRSET +NOTAUTH = Rcode.NOTAUTH +NOTZONE = Rcode.NOTZONE +DSOTYPENI = Rcode.DSOTYPENI +BADVERS = Rcode.BADVERS +BADSIG = Rcode.BADSIG +BADKEY = Rcode.BADKEY +BADTIME = Rcode.BADTIME +BADMODE = Rcode.BADMODE +BADNAME = Rcode.BADNAME +BADALG = Rcode.BADALG +BADTRUNC = Rcode.BADTRUNC +BADCOOKIE = Rcode.BADCOOKIE + +### END generated Rcode constants diff --git a/venv/Lib/site-packages/dns/rdata.py b/venv/Lib/site-packages/dns/rdata.py new file mode 100644 index 00000000..024fd8f6 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdata.py @@ -0,0 +1,884 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS rdata.""" + +import base64 +import binascii +import inspect +import io +import itertools +import random +from importlib import import_module +from typing import Any, Dict, Optional, Tuple, Union + +import dns.exception +import dns.immutable +import dns.ipv4 +import dns.ipv6 +import dns.name +import dns.rdataclass +import dns.rdatatype +import dns.tokenizer +import dns.ttl +import dns.wire + +_chunksize = 32 + +# We currently allow comparisons for rdata with relative names for backwards +# compatibility, but in the future we will not, as these kinds of comparisons +# can lead to subtle bugs if code is not carefully written. +# +# This switch allows the future behavior to be turned on so code can be +# tested with it. +_allow_relative_comparisons = True + + +class NoRelativeRdataOrdering(dns.exception.DNSException): + """An attempt was made to do an ordered comparison of one or more + rdata with relative names. The only reliable way of sorting rdata + is to use non-relativized rdata. + + """ + + +def _wordbreak(data, chunksize=_chunksize, separator=b" "): + """Break a binary string into chunks of chunksize characters separated by + a space. + """ + + if not chunksize: + return data.decode() + return separator.join( + [data[i : i + chunksize] for i in range(0, len(data), chunksize)] + ).decode() + + +# pylint: disable=unused-argument + + +def _hexify(data, chunksize=_chunksize, separator=b" ", **kw): + """Convert a binary string into its hex encoding, broken up into chunks + of chunksize characters separated by a separator. + """ + + return _wordbreak(binascii.hexlify(data), chunksize, separator) + + +def _base64ify(data, chunksize=_chunksize, separator=b" ", **kw): + """Convert a binary string into its base64 encoding, broken up into chunks + of chunksize characters separated by a separator. + """ + + return _wordbreak(base64.b64encode(data), chunksize, separator) + + +# pylint: enable=unused-argument + +__escaped = b'"\\' + + +def _escapify(qstring): + """Escape the characters in a quoted string which need it.""" + + if isinstance(qstring, str): + qstring = qstring.encode() + if not isinstance(qstring, bytearray): + qstring = bytearray(qstring) + + text = "" + for c in qstring: + if c in __escaped: + text += "\\" + chr(c) + elif c >= 0x20 and c < 0x7F: + text += chr(c) + else: + text += "\\%03d" % c + return text + + +def _truncate_bitmap(what): + """Determine the index of greatest byte that isn't all zeros, and + return the bitmap that contains all the bytes less than that index. + """ + + for i in range(len(what) - 1, -1, -1): + if what[i] != 0: + return what[0 : i + 1] + return what[0:1] + + +# So we don't have to edit all the rdata classes... +_constify = dns.immutable.constify + + +@dns.immutable.immutable +class Rdata: + """Base class for all DNS rdata types.""" + + __slots__ = ["rdclass", "rdtype", "rdcomment"] + + def __init__(self, rdclass, rdtype): + """Initialize an rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + """ + + self.rdclass = self._as_rdataclass(rdclass) + self.rdtype = self._as_rdatatype(rdtype) + self.rdcomment = None + + def _get_all_slots(self): + return itertools.chain.from_iterable( + getattr(cls, "__slots__", []) for cls in self.__class__.__mro__ + ) + + def __getstate__(self): + # We used to try to do a tuple of all slots here, but it + # doesn't work as self._all_slots isn't available at + # __setstate__() time. Before that we tried to store a tuple + # of __slots__, but that didn't work as it didn't store the + # slots defined by ancestors. This older way didn't fail + # outright, but ended up with partially broken objects, e.g. + # if you unpickled an A RR it wouldn't have rdclass and rdtype + # attributes, and would compare badly. + state = {} + for slot in self._get_all_slots(): + state[slot] = getattr(self, slot) + return state + + def __setstate__(self, state): + for slot, val in state.items(): + object.__setattr__(self, slot, val) + if not hasattr(self, "rdcomment"): + # Pickled rdata from 2.0.x might not have a rdcomment, so add + # it if needed. + object.__setattr__(self, "rdcomment", None) + + def covers(self) -> dns.rdatatype.RdataType: + """Return the type a Rdata covers. + + DNS SIG/RRSIG rdatas apply to a specific type; this type is + returned by the covers() function. If the rdata type is not + SIG or RRSIG, dns.rdatatype.NONE is returned. This is useful when + creating rdatasets, allowing the rdataset to contain only RRSIGs + of a particular type, e.g. RRSIG(NS). + + Returns a ``dns.rdatatype.RdataType``. + """ + + return dns.rdatatype.NONE + + def extended_rdatatype(self) -> int: + """Return a 32-bit type value, the least significant 16 bits of + which are the ordinary DNS type, and the upper 16 bits of which are + the "covered" type, if any. + + Returns an ``int``. + """ + + return self.covers() << 16 | self.rdtype + + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: + """Convert an rdata to text format. + + Returns a ``str``. + """ + + raise NotImplementedError # pragma: no cover + + def _to_wire( + self, + file: Optional[Any], + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: + raise NotImplementedError # pragma: no cover + + def to_wire( + self, + file: Optional[Any] = None, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + canonicalize: bool = False, + ) -> bytes: + """Convert an rdata to wire format. + + Returns a ``bytes`` or ``None``. + """ + + if file: + return self._to_wire(file, compress, origin, canonicalize) + else: + f = io.BytesIO() + self._to_wire(f, compress, origin, canonicalize) + return f.getvalue() + + def to_generic( + self, origin: Optional[dns.name.Name] = None + ) -> "dns.rdata.GenericRdata": + """Creates a dns.rdata.GenericRdata equivalent of this rdata. + + Returns a ``dns.rdata.GenericRdata``. + """ + return dns.rdata.GenericRdata( + self.rdclass, self.rdtype, self.to_wire(origin=origin) + ) + + def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes: + """Convert rdata to a format suitable for digesting in hashes. This + is also the DNSSEC canonical form. + + Returns a ``bytes``. + """ + + return self.to_wire(origin=origin, canonicalize=True) + + def __repr__(self): + covers = self.covers() + if covers == dns.rdatatype.NONE: + ctext = "" + else: + ctext = "(" + dns.rdatatype.to_text(covers) + ")" + return ( + "" + ) + + def __str__(self): + return self.to_text() + + def _cmp(self, other): + """Compare an rdata with another rdata of the same rdtype and + rdclass. + + For rdata with only absolute names: + Return < 0 if self < other in the DNSSEC ordering, 0 if self + == other, and > 0 if self > other. + For rdata with at least one relative names: + The rdata sorts before any rdata with only absolute names. + When compared with another relative rdata, all names are + made absolute as if they were relative to the root, as the + proper origin is not available. While this creates a stable + ordering, it is NOT guaranteed to be the DNSSEC ordering. + In the future, all ordering comparisons for rdata with + relative names will be disallowed. + """ + try: + our = self.to_digestable() + our_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + their_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + their = other.to_digestable(dns.name.root) + their_relative = True + if _allow_relative_comparisons: + if our_relative != their_relative: + # For the purpose of comparison, all rdata with at least one + # relative name is less than an rdata with only absolute names. + if our_relative: + return -1 + else: + return 1 + elif our_relative or their_relative: + raise NoRelativeRdataOrdering + if our == their: + return 0 + elif our > their: + return 1 + else: + return -1 + + def __eq__(self, other): + if not isinstance(other, Rdata): + return False + if self.rdclass != other.rdclass or self.rdtype != other.rdtype: + return False + our_relative = False + their_relative = False + try: + our = self.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + their = other.to_digestable(dns.name.root) + their_relative = True + if our_relative != their_relative: + return False + return our == their + + def __ne__(self, other): + if not isinstance(other, Rdata): + return True + if self.rdclass != other.rdclass or self.rdtype != other.rdtype: + return True + return not self.__eq__(other) + + def __lt__(self, other): + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): + return NotImplemented + return self._cmp(other) < 0 + + def __le__(self, other): + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): + return NotImplemented + return self._cmp(other) <= 0 + + def __ge__(self, other): + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): + return NotImplemented + return self._cmp(other) >= 0 + + def __gt__(self, other): + if ( + not isinstance(other, Rdata) + or self.rdclass != other.rdclass + or self.rdtype != other.rdtype + ): + return NotImplemented + return self._cmp(other) > 0 + + def __hash__(self): + return hash(self.to_digestable(dns.name.root)) + + @classmethod + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> "Rdata": + raise NotImplementedError # pragma: no cover + + @classmethod + def from_wire_parser( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, + ) -> "Rdata": + raise NotImplementedError # pragma: no cover + + def replace(self, **kwargs: Any) -> "Rdata": + """ + Create a new Rdata instance based on the instance replace was + invoked on. It is possible to pass different parameters to + override the corresponding properties of the base Rdata. + + Any field specific to the Rdata type can be replaced, but the + *rdtype* and *rdclass* fields cannot. + + Returns an instance of the same Rdata subclass as *self*. + """ + + # Get the constructor parameters. + parameters = inspect.signature(self.__init__).parameters # type: ignore + + # Ensure that all of the arguments correspond to valid fields. + # Don't allow rdclass or rdtype to be changed, though. + for key in kwargs: + if key == "rdcomment": + continue + if key not in parameters: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, key + ) + ) + if key in ("rdclass", "rdtype"): + raise AttributeError( + "Cannot overwrite '{}' attribute '{}'".format( + self.__class__.__name__, key + ) + ) + + # Construct the parameter list. For each field, use the value in + # kwargs if present, and the current value otherwise. + args = (kwargs.get(key, getattr(self, key)) for key in parameters) + + # Create, validate, and return the new object. + rd = self.__class__(*args) + # The comment is not set in the constructor, so give it special + # handling. + rdcomment = kwargs.get("rdcomment", self.rdcomment) + if rdcomment is not None: + object.__setattr__(rd, "rdcomment", rdcomment) + return rd + + # Type checking and conversion helpers. These are class methods as + # they don't touch object state and may be useful to others. + + @classmethod + def _as_rdataclass(cls, value): + return dns.rdataclass.RdataClass.make(value) + + @classmethod + def _as_rdatatype(cls, value): + return dns.rdatatype.RdataType.make(value) + + @classmethod + def _as_bytes( + cls, + value: Any, + encode: bool = False, + max_length: Optional[int] = None, + empty_ok: bool = True, + ) -> bytes: + if encode and isinstance(value, str): + bvalue = value.encode() + elif isinstance(value, bytearray): + bvalue = bytes(value) + elif isinstance(value, bytes): + bvalue = value + else: + raise ValueError("not bytes") + if max_length is not None and len(bvalue) > max_length: + raise ValueError("too long") + if not empty_ok and len(bvalue) == 0: + raise ValueError("empty bytes not allowed") + return bvalue + + @classmethod + def _as_name(cls, value): + # Note that proper name conversion (e.g. with origin and IDNA + # awareness) is expected to be done via from_text. This is just + # a simple thing for people invoking the constructor directly. + if isinstance(value, str): + return dns.name.from_text(value) + elif not isinstance(value, dns.name.Name): + raise ValueError("not a name") + return value + + @classmethod + def _as_uint8(cls, value): + if not isinstance(value, int): + raise ValueError("not an integer") + if value < 0 or value > 255: + raise ValueError("not a uint8") + return value + + @classmethod + def _as_uint16(cls, value): + if not isinstance(value, int): + raise ValueError("not an integer") + if value < 0 or value > 65535: + raise ValueError("not a uint16") + return value + + @classmethod + def _as_uint32(cls, value): + if not isinstance(value, int): + raise ValueError("not an integer") + if value < 0 or value > 4294967295: + raise ValueError("not a uint32") + return value + + @classmethod + def _as_uint48(cls, value): + if not isinstance(value, int): + raise ValueError("not an integer") + if value < 0 or value > 281474976710655: + raise ValueError("not a uint48") + return value + + @classmethod + def _as_int(cls, value, low=None, high=None): + if not isinstance(value, int): + raise ValueError("not an integer") + if low is not None and value < low: + raise ValueError("value too small") + if high is not None and value > high: + raise ValueError("value too large") + return value + + @classmethod + def _as_ipv4_address(cls, value): + if isinstance(value, str): + return dns.ipv4.canonicalize(value) + elif isinstance(value, bytes): + return dns.ipv4.inet_ntoa(value) + else: + raise ValueError("not an IPv4 address") + + @classmethod + def _as_ipv6_address(cls, value): + if isinstance(value, str): + return dns.ipv6.canonicalize(value) + elif isinstance(value, bytes): + return dns.ipv6.inet_ntoa(value) + else: + raise ValueError("not an IPv6 address") + + @classmethod + def _as_bool(cls, value): + if isinstance(value, bool): + return value + else: + raise ValueError("not a boolean") + + @classmethod + def _as_ttl(cls, value): + if isinstance(value, int): + return cls._as_int(value, 0, dns.ttl.MAX_TTL) + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError("not a TTL") + + @classmethod + def _as_tuple(cls, value, as_value): + try: + # For user convenience, if value is a singleton of the list + # element type, wrap it in a tuple. + return (as_value(value),) + except Exception: + # Otherwise, check each element of the iterable *value* + # against *as_value*. + return tuple(as_value(v) for v in value) + + # Processing order + + @classmethod + def _processing_order(cls, iterable): + items = list(iterable) + random.shuffle(items) + return items + + +@dns.immutable.immutable +class GenericRdata(Rdata): + """Generic Rdata Class + + This class is used for rdata types for which we have no better + implementation. It implements the DNS "unknown RRs" scheme. + """ + + __slots__ = ["data"] + + def __init__(self, rdclass, rdtype, data): + super().__init__(rdclass, rdtype) + self.data = data + + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: + return r"\# %d " % len(self.data) + _hexify(self.data, **kw) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + token = tok.get() + if not token.is_identifier() or token.value != r"\#": + raise dns.exception.SyntaxError(r"generic rdata does not start with \#") + length = tok.get_int() + hex = tok.concatenate_remaining_identifiers(True).encode() + data = binascii.unhexlify(hex) + if len(data) != length: + raise dns.exception.SyntaxError("generic rdata hex data has wrong length") + return cls(rdclass, rdtype, data) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.data) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + return cls(rdclass, rdtype, parser.get_remaining()) + + +_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = ( + {} +) +_module_prefix = "dns.rdtypes" + + +def get_rdata_class(rdclass, rdtype): + cls = _rdata_classes.get((rdclass, rdtype)) + if not cls: + cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype)) + if not cls: + rdclass_text = dns.rdataclass.to_text(rdclass) + rdtype_text = dns.rdatatype.to_text(rdtype) + rdtype_text = rdtype_text.replace("-", "_") + try: + mod = import_module( + ".".join([_module_prefix, rdclass_text, rdtype_text]) + ) + cls = getattr(mod, rdtype_text) + _rdata_classes[(rdclass, rdtype)] = cls + except ImportError: + try: + mod = import_module(".".join([_module_prefix, "ANY", rdtype_text])) + cls = getattr(mod, rdtype_text) + _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls + _rdata_classes[(rdclass, rdtype)] = cls + except ImportError: + pass + if not cls: + cls = GenericRdata + _rdata_classes[(rdclass, rdtype)] = cls + return cls + + +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + tok: Union[dns.tokenizer.Tokenizer, str], + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> Rdata: + """Build an rdata object from text format. + + This function attempts to dynamically load a class which + implements the specified rdata class and type. If there is no + class-and-type-specific implementation, the GenericRdata class + is used. + + Once a class is chosen, its from_text() class method is called + with the parameters to this function. + + If *tok* is a ``str``, then a tokenizer is created and the string + is used as its input. + + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. + + *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use if a tokenizer needs to be created. If + ``None``, the default IDNA 2003 encoder/decoder is used. If a + tokenizer is not created, then the codec associated with the tokenizer + is the one that is used. + + Returns an instance of the chosen Rdata subclass. + + """ + if isinstance(tok, str): + tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + cls = get_rdata_class(rdclass, rdtype) + with dns.exception.ExceptionWrapper(dns.exception.SyntaxError): + rdata = None + if cls != GenericRdata: + # peek at first token + token = tok.get() + tok.unget(token) + if token.is_identifier() and token.value == r"\#": + # + # Known type using the generic syntax. Extract the + # wire form from the generic syntax, and then run + # from_wire on it. + # + grdata = GenericRdata.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) + rdata = from_wire( + rdclass, rdtype, grdata.data, 0, len(grdata.data), origin + ) + # + # If this comparison isn't equal, then there must have been + # compressed names in the wire format, which is an error, + # there being no reasonable context to decompress with. + # + rwire = rdata.to_wire() + if rwire != grdata.data: + raise dns.exception.SyntaxError( + "compressed data in " + "generic syntax form " + "of known rdatatype" + ) + if rdata is None: + rdata = cls.from_text( + rdclass, rdtype, tok, origin, relativize, relativize_to + ) + token = tok.get_eol_as_token() + if token.comment is not None: + object.__setattr__(rdata, "rdcomment", token.comment) + return rdata + + +def from_wire_parser( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + parser: dns.wire.Parser, + origin: Optional[dns.name.Name] = None, +) -> Rdata: + """Build an rdata object from wire format + + This function attempts to dynamically load a class which + implements the specified rdata class and type. If there is no + class-and-type-specific implementation, the GenericRdata class + is used. + + Once a class is chosen, its from_wire() class method is called + with the parameters to this function. + + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the rdata length. + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ + + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + cls = get_rdata_class(rdclass, rdtype) + with dns.exception.ExceptionWrapper(dns.exception.FormError): + return cls.from_wire_parser(rdclass, rdtype, parser, origin) + + +def from_wire( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + wire: bytes, + current: int, + rdlen: int, + origin: Optional[dns.name.Name] = None, +) -> Rdata: + """Build an rdata object from wire format + + This function attempts to dynamically load a class which + implements the specified rdata class and type. If there is no + class-and-type-specific implementation, the GenericRdata class + is used. + + Once a class is chosen, its from_wire() class method is called + with the parameters to this function. + + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *wire*, a ``bytes``, the wire-format message. + + *current*, an ``int``, the offset in wire of the beginning of + the rdata. + + *rdlen*, an ``int``, the length of the wire-format rdata + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(rdlen): + return from_wire_parser(rdclass, rdtype, parser, origin) + + +class RdatatypeExists(dns.exception.DNSException): + """DNS rdatatype already exists.""" + + supp_kwargs = {"rdclass", "rdtype"} + fmt = ( + "The rdata type with class {rdclass:d} and rdtype {rdtype:d} " + + "already exists." + ) + + +def register_type( + implementation: Any, + rdtype: int, + rdtype_text: str, + is_singleton: bool = False, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, +) -> None: + """Dynamically register a module to handle an rdatatype. + + *implementation*, a module implementing the type in the usual dnspython + way. + + *rdtype*, an ``int``, the rdatatype to register. + + *rdtype_text*, a ``str``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + + *rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if + it applies to all classes. + """ + + rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + _rdata_classes[(rdclass, rdtype)] = getattr( + implementation, rdtype_text.replace("-", "_") + ) + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/venv/Lib/site-packages/dns/rdataclass.py b/venv/Lib/site-packages/dns/rdataclass.py new file mode 100644 index 00000000..89b85a79 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdataclass.py @@ -0,0 +1,118 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Rdata Classes.""" + +import dns.enum +import dns.exception + + +class RdataClass(dns.enum.IntEnum): + """DNS Rdata Class""" + + RESERVED0 = 0 + IN = 1 + INTERNET = IN + CH = 3 + CHAOS = CH + HS = 4 + HESIOD = HS + NONE = 254 + ANY = 255 + + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "class" + + @classmethod + def _prefix(cls): + return "CLASS" + + @classmethod + def _unknown_exception_class(cls): + return UnknownRdataclass + + +_metaclasses = {RdataClass.NONE, RdataClass.ANY} + + +class UnknownRdataclass(dns.exception.DNSException): + """A DNS class is unknown.""" + + +def from_text(text: str) -> RdataClass: + """Convert text into a DNS rdata class value. + + The input text can be a defined DNS RR class mnemonic or + instance of the DNS generic class syntax. + + For example, "IN" and "CLASS1" will both result in a value of 1. + + Raises ``dns.rdatatype.UnknownRdataclass`` if the class is unknown. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns a ``dns.rdataclass.RdataClass``. + """ + + return RdataClass.from_text(text) + + +def to_text(value: RdataClass) -> str: + """Convert a DNS rdata class value to text. + + If the value has a known mnemonic, it will be used, otherwise the + DNS generic class syntax will be used. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns a ``str``. + """ + + return RdataClass.to_text(value) + + +def is_metaclass(rdclass: RdataClass) -> bool: + """True if the specified class is a metaclass. + + The currently defined metaclasses are ANY and NONE. + + *rdclass* is a ``dns.rdataclass.RdataClass``. + """ + + if rdclass in _metaclasses: + return True + return False + + +### BEGIN generated RdataClass constants + +RESERVED0 = RdataClass.RESERVED0 +IN = RdataClass.IN +INTERNET = RdataClass.INTERNET +CH = RdataClass.CH +CHAOS = RdataClass.CHAOS +HS = RdataClass.HS +HESIOD = RdataClass.HESIOD +NONE = RdataClass.NONE +ANY = RdataClass.ANY + +### END generated RdataClass constants diff --git a/venv/Lib/site-packages/dns/rdataset.py b/venv/Lib/site-packages/dns/rdataset.py new file mode 100644 index 00000000..8bff58d7 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdataset.py @@ -0,0 +1,516 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" + +import io +import random +import struct +from typing import Any, Collection, Dict, List, Optional, Union, cast + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdataclass +import dns.rdatatype +import dns.renderer +import dns.set +import dns.ttl + +# define SimpleSet here for backwards compatibility +SimpleSet = dns.set.Set + + +class DifferingCovers(dns.exception.DNSException): + """An attempt was made to add a DNS SIG/RRSIG whose covered type + is not the same as that of the other rdatas in the rdataset.""" + + +class IncompatibleTypes(dns.exception.DNSException): + """An attempt was made to add DNS RR data of an incompatible type.""" + + +class Rdataset(dns.set.Set): + """A DNS rdataset.""" + + __slots__ = ["rdclass", "rdtype", "covers", "ttl"] + + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ttl: int = 0, + ): + """Create a new rdataset of the specified class and type. + + *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass. + + *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype. + + *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype. + + *ttl*, an ``int``, the TTL. + """ + + super().__init__() + self.rdclass = rdclass + self.rdtype: dns.rdatatype.RdataType = rdtype + self.covers: dns.rdatatype.RdataType = covers + self.ttl = ttl + + def _clone(self): + obj = super()._clone() + obj.rdclass = self.rdclass + obj.rdtype = self.rdtype + obj.covers = self.covers + obj.ttl = self.ttl + return obj + + def update_ttl(self, ttl: int) -> None: + """Perform TTL minimization. + + Set the TTL of the rdataset to be the lesser of the set's current + TTL or the specified TTL. If the set contains no rdatas, set the TTL + to the specified TTL. + + *ttl*, an ``int`` or ``str``. + """ + ttl = dns.ttl.make(ttl) + if len(self) == 0: + self.ttl = ttl + elif ttl < self.ttl: + self.ttl = ttl + + def add( # pylint: disable=arguments-differ,arguments-renamed + self, rd: dns.rdata.Rdata, ttl: Optional[int] = None + ) -> None: + """Add the specified rdata to the rdataset. + + If the optional *ttl* parameter is supplied, then + ``self.update_ttl(ttl)`` will be called prior to adding the rdata. + + *rd*, a ``dns.rdata.Rdata``, the rdata + + *ttl*, an ``int``, the TTL. + + Raises ``dns.rdataset.IncompatibleTypes`` if the type and class + do not match the type and class of the rdataset. + + Raises ``dns.rdataset.DifferingCovers`` if the type is a signature + type and the covered type does not match that of the rdataset. + """ + + # + # If we're adding a signature, do some special handling to + # check that the signature covers the same type as the + # other rdatas in this rdataset. If this is the first rdata + # in the set, initialize the covers field. + # + if self.rdclass != rd.rdclass or self.rdtype != rd.rdtype: + raise IncompatibleTypes + if ttl is not None: + self.update_ttl(ttl) + if self.rdtype == dns.rdatatype.RRSIG or self.rdtype == dns.rdatatype.SIG: + covers = rd.covers() + if len(self) == 0 and self.covers == dns.rdatatype.NONE: + self.covers = covers + elif self.covers != covers: + raise DifferingCovers + if dns.rdatatype.is_singleton(rd.rdtype) and len(self) > 0: + self.clear() + super().add(rd) + + def union_update(self, other): + self.update_ttl(other.ttl) + super().union_update(other) + + def intersection_update(self, other): + self.update_ttl(other.ttl) + super().intersection_update(other) + + def update(self, other): + """Add all rdatas in other to self. + + *other*, a ``dns.rdataset.Rdataset``, the rdataset from which + to update. + """ + + self.update_ttl(other.ttl) + super().update(other) + + def _rdata_repr(self): + def maybe_truncate(s): + if len(s) > 100: + return s[:100] + "..." + return s + + return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self) + + def __repr__(self): + if self.covers == 0: + ctext = "" + else: + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" + return ( + "" + ) + + def __str__(self): + return self.to_text() + + def __eq__(self, other): + if not isinstance(other, Rdataset): + return False + if ( + self.rdclass != other.rdclass + or self.rdtype != other.rdtype + or self.covers != other.covers + ): + return False + return super().__eq__(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def to_text( + self, + name: Optional[dns.name.Name] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_comments: bool = False, + **kw: Dict[str, Any], + ) -> str: + """Convert the rdataset into DNS zone file format. + + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names + are emitted. + + Any additional keyword arguments are passed on to the rdata + ``to_text()`` method. + + *name*, a ``dns.name.Name``. If name is not ``None``, emit RRs with + *name* as the owner name. + + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. + + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + + *override_rdclass*, a ``dns.rdataclass.RdataClass`` or ``None``. + If not ``None``, use this class instead of the Rdataset's class. + + *want_comments*, a ``bool``. If ``True``, emit comments for rdata + which have them. The default is ``False``. + """ + + if name is not None: + name = name.choose_relativity(origin, relativize) + ntext = str(name) + pad = " " + else: + ntext = "" + pad = "" + s = io.StringIO() + if override_rdclass is not None: + rdclass = override_rdclass + else: + rdclass = self.rdclass + if len(self) == 0: + # + # Empty rdatasets are used for the question section, and in + # some dynamic updates, so we don't need to print out the TTL + # (which is meaningless anyway). + # + s.write( + "{}{}{} {}\n".format( + ntext, + pad, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + ) + ) + else: + for rd in self: + extra = "" + if want_comments: + if rd.rdcomment: + extra = f" ;{rd.rdcomment}" + s.write( + "%s%s%d %s %s %s%s\n" + % ( + ntext, + pad, + self.ttl, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype), + rd.to_text(origin=origin, relativize=relativize, **kw), + extra, + ) + ) + # + # We strip off the final \n for the caller's convenience in printing + # + return s.getvalue()[:-1] + + def to_wire( + self, + name: dns.name.Name, + file: Any, + compress: Optional[dns.name.CompressType] = None, + origin: Optional[dns.name.Name] = None, + override_rdclass: Optional[dns.rdataclass.RdataClass] = None, + want_shuffle: bool = True, + ) -> int: + """Convert the rdataset to wire format. + + *name*, a ``dns.name.Name`` is the owner name to use. + + *file* is the file where the name is emitted (typically a + BytesIO file). + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + *override_rdclass*, an ``int``, is used as the class instead of the + class of the rdataset. This is useful when rendering rdatasets + associated with dynamic updates. + + *want_shuffle*, a ``bool``. If ``True``, then the order of the + Rdatas within the Rdataset will be shuffled before rendering. + + Returns an ``int``, the number of records emitted. + """ + + if override_rdclass is not None: + rdclass = override_rdclass + want_shuffle = False + else: + rdclass = self.rdclass + if len(self) == 0: + name.to_wire(file, compress, origin) + file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)) + return 1 + else: + l: Union[Rdataset, List[dns.rdata.Rdata]] + if want_shuffle: + l = list(self) + random.shuffle(l) + else: + l = self + for rd in l: + name.to_wire(file, compress, origin) + file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl)) + with dns.renderer.prefixed_length(file, 2): + rd.to_wire(file, compress, origin) + return len(self) + + def match( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> bool: + """Returns ``True`` if this rdataset matches the specified class, + type, and covers. + """ + if self.rdclass == rdclass and self.rdtype == rdtype and self.covers == covers: + return True + return False + + def processing_order(self) -> List[dns.rdata.Rdata]: + """Return rdatas in a valid processing order according to the type's + specification. For example, MX records are in preference order from + lowest to highest preferences, with items of the same preference + shuffled. + + For types that do not define a processing order, the rdatas are + simply shuffled. + """ + if len(self) == 0: + return [] + else: + return self[0]._processing_order(iter(self)) + + +@dns.immutable.immutable +class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] + """An immutable DNS rdataset.""" + + _clone_class = Rdataset + + def __init__(self, rdataset: Rdataset): + """Create an immutable rdataset from the specified rdataset.""" + + super().__init__( + rdataset.rdclass, rdataset.rdtype, rdataset.covers, rdataset.ttl + ) + self.items = dns.immutable.Dict(rdataset.items) + + def update_ttl(self, ttl): + raise TypeError("immutable") + + def add(self, rd, ttl=None): + raise TypeError("immutable") + + def union_update(self, other): + raise TypeError("immutable") + + def intersection_update(self, other): + raise TypeError("immutable") + + def update(self, other): + raise TypeError("immutable") + + def __delitem__(self, i): + raise TypeError("immutable") + + # lgtm complains about these not raising ArithmeticError, but there is + # precedent for overrides of these methods in other classes to raise + # TypeError, and it seems like the better exception. + + def __ior__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") + + def __iand__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") + + def __iadd__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") + + def __isub__(self, other): # lgtm[py/unexpected-raise-in-special-method] + raise TypeError("immutable") + + def clear(self): + raise TypeError("immutable") + + def __copy__(self): + return ImmutableRdataset(super().copy()) + + def copy(self): + return ImmutableRdataset(super().copy()) + + def union(self, other): + return ImmutableRdataset(super().union(other)) + + def intersection(self, other): + return ImmutableRdataset(super().intersection(other)) + + def difference(self, other): + return ImmutableRdataset(super().difference(other)) + + def symmetric_difference(self, other): + return ImmutableRdataset(super().symmetric_difference(other)) + + +def from_text_list( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> Rdataset: + """Create an rdataset with the specified class, type, and TTL, and with + the specified list of rdatas in text format. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + Returns a ``dns.rdataset.Rdataset`` object. + """ + + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(rdclass, rdtype) + r.update_ttl(ttl) + for t in text_rdatas: + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) + r.add(rd) + return r + + +def from_text( + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, + *text_rdatas: Any, +) -> Rdataset: + """Create an rdataset with the specified class, type, and TTL, and with + the specified rdatas in text format. + + Returns a ``dns.rdataset.Rdataset`` object. + """ + + return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas)) + + +def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset: + """Create an rdataset with the specified TTL, and with + the specified list of rdata objects. + + Returns a ``dns.rdataset.Rdataset`` object. + """ + + if len(rdatas) == 0: + raise ValueError("rdata list must not be empty") + r = None + for rd in rdatas: + if r is None: + r = Rdataset(rd.rdclass, rd.rdtype) + r.update_ttl(ttl) + r.add(rd) + assert r is not None + return r + + +def from_rdata(ttl: int, *rdatas: Any) -> Rdataset: + """Create an rdataset with the specified TTL, and with + the specified rdata objects. + + Returns a ``dns.rdataset.Rdataset`` object. + """ + + return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/venv/Lib/site-packages/dns/rdatatype.py b/venv/Lib/site-packages/dns/rdatatype.py new file mode 100644 index 00000000..e6c58186 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdatatype.py @@ -0,0 +1,332 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Rdata Types.""" + +from typing import Dict + +import dns.enum +import dns.exception + + +class RdataType(dns.enum.IntEnum): + """DNS Rdata Type""" + + TYPE0 = 0 + NONE = 0 + A = 1 + NS = 2 + MD = 3 + MF = 4 + CNAME = 5 + SOA = 6 + MB = 7 + MG = 8 + MR = 9 + NULL = 10 + WKS = 11 + PTR = 12 + HINFO = 13 + MINFO = 14 + MX = 15 + TXT = 16 + RP = 17 + AFSDB = 18 + X25 = 19 + ISDN = 20 + RT = 21 + NSAP = 22 + NSAP_PTR = 23 + SIG = 24 + KEY = 25 + PX = 26 + GPOS = 27 + AAAA = 28 + LOC = 29 + NXT = 30 + SRV = 33 + NAPTR = 35 + KX = 36 + CERT = 37 + A6 = 38 + DNAME = 39 + OPT = 41 + APL = 42 + DS = 43 + SSHFP = 44 + IPSECKEY = 45 + RRSIG = 46 + NSEC = 47 + DNSKEY = 48 + DHCID = 49 + NSEC3 = 50 + NSEC3PARAM = 51 + TLSA = 52 + SMIMEA = 53 + HIP = 55 + NINFO = 56 + CDS = 59 + CDNSKEY = 60 + OPENPGPKEY = 61 + CSYNC = 62 + ZONEMD = 63 + SVCB = 64 + HTTPS = 65 + SPF = 99 + UNSPEC = 103 + NID = 104 + L32 = 105 + L64 = 106 + LP = 107 + EUI48 = 108 + EUI64 = 109 + TKEY = 249 + TSIG = 250 + IXFR = 251 + AXFR = 252 + MAILB = 253 + MAILA = 254 + ANY = 255 + URI = 256 + CAA = 257 + AVC = 258 + AMTRELAY = 260 + TA = 32768 + DLV = 32769 + + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "type" + + @classmethod + def _prefix(cls): + return "TYPE" + + @classmethod + def _extra_from_text(cls, text): + if text.find("-") >= 0: + try: + return cls[text.replace("-", "_")] + except KeyError: + pass + return _registered_by_text.get(text) + + @classmethod + def _extra_to_text(cls, value, current_text): + if current_text is None: + return _registered_by_value.get(value) + if current_text.find("_") >= 0: + return current_text.replace("_", "-") + return current_text + + @classmethod + def _unknown_exception_class(cls): + return UnknownRdatatype + + +_registered_by_text: Dict[str, RdataType] = {} +_registered_by_value: Dict[RdataType, str] = {} + +_metatypes = {RdataType.OPT} + +_singletons = { + RdataType.SOA, + RdataType.NXT, + RdataType.DNAME, + RdataType.NSEC, + RdataType.CNAME, +} + + +class UnknownRdatatype(dns.exception.DNSException): + """DNS resource record type is unknown.""" + + +def from_text(text: str) -> RdataType: + """Convert text into a DNS rdata type value. + + The input text can be a defined DNS RR type mnemonic or + instance of the DNS generic type syntax. + + For example, "NS" and "TYPE2" will both result in a value of 2. + + Raises ``dns.rdatatype.UnknownRdatatype`` if the type is unknown. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns a ``dns.rdatatype.RdataType``. + """ + + return RdataType.from_text(text) + + +def to_text(value: RdataType) -> str: + """Convert a DNS rdata type value to text. + + If the value has a known mnemonic, it will be used, otherwise the + DNS generic type syntax will be used. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns a ``str``. + """ + + return RdataType.to_text(value) + + +def is_metatype(rdtype: RdataType) -> bool: + """True if the specified type is a metatype. + + *rdtype* is a ``dns.rdatatype.RdataType``. + + The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, + MAILB, ANY, and OPT. + + Returns a ``bool``. + """ + + return (256 > rdtype >= 128) or rdtype in _metatypes + + +def is_singleton(rdtype: RdataType) -> bool: + """Is the specified type a singleton type? + + Singleton types can only have a single rdata in an rdataset, or a single + RR in an RRset. + + The currently defined singleton types are CNAME, DNAME, NSEC, NXT, and + SOA. + + *rdtype* is an ``int``. + + Returns a ``bool``. + """ + + if rdtype in _singletons: + return True + return False + + +# pylint: disable=redefined-outer-name +def register_type( + rdtype: RdataType, rdtype_text: str, is_singleton: bool = False +) -> None: + """Dynamically register an rdatatype. + + *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register. + + *rdtype_text*, a ``str``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + """ + + _registered_by_text[rdtype_text] = rdtype + _registered_by_value[rdtype] = rdtype_text + if is_singleton: + _singletons.add(rdtype) + + +### BEGIN generated RdataType constants + +TYPE0 = RdataType.TYPE0 +NONE = RdataType.NONE +A = RdataType.A +NS = RdataType.NS +MD = RdataType.MD +MF = RdataType.MF +CNAME = RdataType.CNAME +SOA = RdataType.SOA +MB = RdataType.MB +MG = RdataType.MG +MR = RdataType.MR +NULL = RdataType.NULL +WKS = RdataType.WKS +PTR = RdataType.PTR +HINFO = RdataType.HINFO +MINFO = RdataType.MINFO +MX = RdataType.MX +TXT = RdataType.TXT +RP = RdataType.RP +AFSDB = RdataType.AFSDB +X25 = RdataType.X25 +ISDN = RdataType.ISDN +RT = RdataType.RT +NSAP = RdataType.NSAP +NSAP_PTR = RdataType.NSAP_PTR +SIG = RdataType.SIG +KEY = RdataType.KEY +PX = RdataType.PX +GPOS = RdataType.GPOS +AAAA = RdataType.AAAA +LOC = RdataType.LOC +NXT = RdataType.NXT +SRV = RdataType.SRV +NAPTR = RdataType.NAPTR +KX = RdataType.KX +CERT = RdataType.CERT +A6 = RdataType.A6 +DNAME = RdataType.DNAME +OPT = RdataType.OPT +APL = RdataType.APL +DS = RdataType.DS +SSHFP = RdataType.SSHFP +IPSECKEY = RdataType.IPSECKEY +RRSIG = RdataType.RRSIG +NSEC = RdataType.NSEC +DNSKEY = RdataType.DNSKEY +DHCID = RdataType.DHCID +NSEC3 = RdataType.NSEC3 +NSEC3PARAM = RdataType.NSEC3PARAM +TLSA = RdataType.TLSA +SMIMEA = RdataType.SMIMEA +HIP = RdataType.HIP +NINFO = RdataType.NINFO +CDS = RdataType.CDS +CDNSKEY = RdataType.CDNSKEY +OPENPGPKEY = RdataType.OPENPGPKEY +CSYNC = RdataType.CSYNC +ZONEMD = RdataType.ZONEMD +SVCB = RdataType.SVCB +HTTPS = RdataType.HTTPS +SPF = RdataType.SPF +UNSPEC = RdataType.UNSPEC +NID = RdataType.NID +L32 = RdataType.L32 +L64 = RdataType.L64 +LP = RdataType.LP +EUI48 = RdataType.EUI48 +EUI64 = RdataType.EUI64 +TKEY = RdataType.TKEY +TSIG = RdataType.TSIG +IXFR = RdataType.IXFR +AXFR = RdataType.AXFR +MAILB = RdataType.MAILB +MAILA = RdataType.MAILA +ANY = RdataType.ANY +URI = RdataType.URI +CAA = RdataType.CAA +AVC = RdataType.AVC +AMTRELAY = RdataType.AMTRELAY +TA = RdataType.TA +DLV = RdataType.DLV + +### END generated RdataType constants diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/AFSDB.py b/venv/Lib/site-packages/dns/rdtypes/ANY/AFSDB.py new file mode 100644 index 00000000..06a3b970 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/AFSDB.py @@ -0,0 +1,45 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.mxbase + + +@dns.immutable.immutable +class AFSDB(dns.rdtypes.mxbase.UncompressedDowncasingMX): + """AFSDB record""" + + # Use the property mechanism to make "subtype" an alias for the + # "preference" attribute, and "hostname" an alias for the "exchange" + # attribute. + # + # This lets us inherit the UncompressedMX implementation but lets + # the caller use appropriate attribute names for the rdata type. + # + # We probably lose some performance vs. a cut-and-paste + # implementation, but this way we don't copy code, and that's + # good. + + @property + def subtype(self): + "the AFSDB subtype" + return self.preference + + @property + def hostname(self): + "the AFSDB hostname" + return self.exchange diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/AMTRELAY.py b/venv/Lib/site-packages/dns/rdtypes/ANY/AMTRELAY.py new file mode 100644 index 00000000..ed2b072b --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/AMTRELAY.py @@ -0,0 +1,91 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdtypes.util + + +class Relay(dns.rdtypes.util.Gateway): + name = "AMTRELAY relay" + + @property + def relay(self): + return self.gateway + + +@dns.immutable.immutable +class AMTRELAY(dns.rdata.Rdata): + """AMTRELAY record""" + + # see: RFC 8777 + + __slots__ = ["precedence", "discovery_optional", "relay_type", "relay"] + + def __init__( + self, rdclass, rdtype, precedence, discovery_optional, relay_type, relay + ): + super().__init__(rdclass, rdtype) + relay = Relay(relay_type, relay) + self.precedence = self._as_uint8(precedence) + self.discovery_optional = self._as_bool(discovery_optional) + self.relay_type = relay.type + self.relay = relay.relay + + def to_text(self, origin=None, relativize=True, **kw): + relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) + return "%d %d %d %s" % ( + self.precedence, + self.discovery_optional, + self.relay_type, + relay, + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + precedence = tok.get_uint8() + discovery_optional = tok.get_uint8() + if discovery_optional > 1: + raise dns.exception.SyntaxError("expecting 0 or 1") + discovery_optional = bool(discovery_optional) + relay_type = tok.get_uint8() + if relay_type > 0x7F: + raise dns.exception.SyntaxError("expecting an integer <= 127") + relay = Relay.from_text(relay_type, tok, origin, relativize, relativize_to) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + relay_type = self.relay_type | (self.discovery_optional << 7) + header = struct.pack("!BB", self.precedence, relay_type) + file.write(header) + Relay(self.relay_type, self.relay).to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (precedence, relay_type) = parser.get_struct("!BB") + discovery_optional = bool(relay_type >> 7) + relay_type &= 0x7F + relay = Relay.from_wire_parser(relay_type, parser, origin) + return cls( + rdclass, rdtype, precedence, discovery_optional, relay_type, relay.relay + ) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/AVC.py b/venv/Lib/site-packages/dns/rdtypes/ANY/AVC.py new file mode 100644 index 00000000..a27ae2d6 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/AVC.py @@ -0,0 +1,26 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2016 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class AVC(dns.rdtypes.txtbase.TXTBase): + """AVC record""" + + # See: IANA dns parameters for AVC diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CAA.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CAA.py new file mode 100644 index 00000000..2e6a7e7e --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CAA.py @@ -0,0 +1,71 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class CAA(dns.rdata.Rdata): + """CAA (Certification Authority Authorization) record""" + + # see: RFC 6844 + + __slots__ = ["flags", "tag", "value"] + + def __init__(self, rdclass, rdtype, flags, tag, value): + super().__init__(rdclass, rdtype) + self.flags = self._as_uint8(flags) + self.tag = self._as_bytes(tag, True, 255) + if not tag.isalnum(): + raise ValueError("tag is not alphanumeric") + self.value = self._as_bytes(value) + + def to_text(self, origin=None, relativize=True, **kw): + return '%u %s "%s"' % ( + self.flags, + dns.rdata._escapify(self.tag), + dns.rdata._escapify(self.value), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + flags = tok.get_uint8() + tag = tok.get_string().encode() + value = tok.get_string().encode() + return cls(rdclass, rdtype, flags, tag, value) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!B", self.flags)) + l = len(self.tag) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.tag) + file.write(self.value) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + flags = parser.get_uint8() + tag = parser.get_counted_bytes() + value = parser.get_remaining() + return cls(rdclass, rdtype, flags, tag, value) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CDNSKEY.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CDNSKEY.py new file mode 100644 index 00000000..b613409f --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CDNSKEY.py @@ -0,0 +1,33 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] + +# pylint: disable=unused-import +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] + REVOKE, + SEP, + ZONE, +) + +# pylint: enable=unused-import + + +@dns.immutable.immutable +class CDNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): + """CDNSKEY record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CDS.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CDS.py new file mode 100644 index 00000000..8312b972 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CDS.py @@ -0,0 +1,29 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.dsbase + + +@dns.immutable.immutable +class CDS(dns.rdtypes.dsbase.DSBase): + """CDS record""" + + _digest_length_by_type = { + **dns.rdtypes.dsbase.DSBase._digest_length_by_type, + 0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049) + } diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CERT.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CERT.py new file mode 100644 index 00000000..f369cc85 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CERT.py @@ -0,0 +1,116 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import struct + +import dns.dnssectypes +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + +_ctype_by_value = { + 1: "PKIX", + 2: "SPKI", + 3: "PGP", + 4: "IPKIX", + 5: "ISPKI", + 6: "IPGP", + 7: "ACPKIX", + 8: "IACPKIX", + 253: "URI", + 254: "OID", +} + +_ctype_by_name = { + "PKIX": 1, + "SPKI": 2, + "PGP": 3, + "IPKIX": 4, + "ISPKI": 5, + "IPGP": 6, + "ACPKIX": 7, + "IACPKIX": 8, + "URI": 253, + "OID": 254, +} + + +def _ctype_from_text(what): + v = _ctype_by_name.get(what) + if v is not None: + return v + return int(what) + + +def _ctype_to_text(what): + v = _ctype_by_value.get(what) + if v is not None: + return v + return str(what) + + +@dns.immutable.immutable +class CERT(dns.rdata.Rdata): + """CERT record""" + + # see RFC 4398 + + __slots__ = ["certificate_type", "key_tag", "algorithm", "certificate"] + + def __init__( + self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate + ): + super().__init__(rdclass, rdtype) + self.certificate_type = self._as_uint16(certificate_type) + self.key_tag = self._as_uint16(key_tag) + self.algorithm = self._as_uint8(algorithm) + self.certificate = self._as_bytes(certificate) + + def to_text(self, origin=None, relativize=True, **kw): + certificate_type = _ctype_to_text(self.certificate_type) + return "%s %d %s %s" % ( + certificate_type, + self.key_tag, + dns.dnssectypes.Algorithm.to_text(self.algorithm), + dns.rdata._base64ify(self.certificate, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + certificate_type = _ctype_from_text(tok.get_string()) + key_tag = tok.get_uint16() + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) + b64 = tok.concatenate_remaining_identifiers().encode() + certificate = base64.b64decode(b64) + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + prefix = struct.pack( + "!HHB", self.certificate_type, self.key_tag, self.algorithm + ) + file.write(prefix) + file.write(self.certificate) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") + certificate = parser.get_remaining() + return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CNAME.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CNAME.py new file mode 100644 index 00000000..665e407c --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CNAME.py @@ -0,0 +1,28 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.nsbase + + +@dns.immutable.immutable +class CNAME(dns.rdtypes.nsbase.NSBase): + """CNAME record + + Note: although CNAME is officially a singleton type, dnspython allows + non-singleton CNAME rdatasets because such sets have been commonly + used by BIND and other nameservers for load balancing.""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/CSYNC.py b/venv/Lib/site-packages/dns/rdtypes/ANY/CSYNC.py new file mode 100644 index 00000000..2f972f6e --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/CSYNC.py @@ -0,0 +1,68 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011, 2016 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdatatype +import dns.rdtypes.util + + +@dns.immutable.immutable +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = "CSYNC" + + +@dns.immutable.immutable +class CSYNC(dns.rdata.Rdata): + """CSYNC record""" + + __slots__ = ["serial", "flags", "windows"] + + def __init__(self, rdclass, rdtype, serial, flags, windows): + super().__init__(rdclass, rdtype) + self.serial = self._as_uint32(serial) + self.flags = self._as_uint16(flags) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) + + def to_text(self, origin=None, relativize=True, **kw): + text = Bitmap(self.windows).to_text() + return "%d %d%s" % (self.serial, self.flags, text) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + serial = tok.get_uint32() + flags = tok.get_uint16() + bitmap = Bitmap.from_text(tok) + return cls(rdclass, rdtype, serial, flags, bitmap) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!IH", self.serial, self.flags)) + Bitmap(self.windows).to_wire(file) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (serial, flags) = parser.get_struct("!IH") + bitmap = Bitmap.from_wire_parser(parser) + return cls(rdclass, rdtype, serial, flags, bitmap) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/DLV.py b/venv/Lib/site-packages/dns/rdtypes/ANY/DLV.py new file mode 100644 index 00000000..6c134f18 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/DLV.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.dsbase + + +@dns.immutable.immutable +class DLV(dns.rdtypes.dsbase.DSBase): + """DLV record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/DNAME.py b/venv/Lib/site-packages/dns/rdtypes/ANY/DNAME.py new file mode 100644 index 00000000..bbf9186c --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/DNAME.py @@ -0,0 +1,27 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.nsbase + + +@dns.immutable.immutable +class DNAME(dns.rdtypes.nsbase.UncompressedNS): + """DNAME record""" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, None, origin, canonicalize) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/DNSKEY.py b/venv/Lib/site-packages/dns/rdtypes/ANY/DNSKEY.py new file mode 100644 index 00000000..6d961a9f --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/DNSKEY.py @@ -0,0 +1,33 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.dnskeybase # lgtm[py/import-and-import-from] + +# pylint: disable=unused-import +from dns.rdtypes.dnskeybase import ( # noqa: F401 lgtm[py/unused-import] + REVOKE, + SEP, + ZONE, +) + +# pylint: enable=unused-import + + +@dns.immutable.immutable +class DNSKEY(dns.rdtypes.dnskeybase.DNSKEYBase): + """DNSKEY record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/DS.py b/venv/Lib/site-packages/dns/rdtypes/ANY/DS.py new file mode 100644 index 00000000..58b3108d --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/DS.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.dsbase + + +@dns.immutable.immutable +class DS(dns.rdtypes.dsbase.DSBase): + """DS record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/EUI48.py b/venv/Lib/site-packages/dns/rdtypes/ANY/EUI48.py new file mode 100644 index 00000000..c843be50 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/EUI48.py @@ -0,0 +1,30 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2015 Red Hat, Inc. +# Author: Petr Spacek +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.euibase + + +@dns.immutable.immutable +class EUI48(dns.rdtypes.euibase.EUIBase): + """EUI48 record""" + + # see: rfc7043.txt + + byte_len = 6 # 0123456789ab (in hex) + text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/EUI64.py b/venv/Lib/site-packages/dns/rdtypes/ANY/EUI64.py new file mode 100644 index 00000000..f6d7e257 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/EUI64.py @@ -0,0 +1,30 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2015 Red Hat, Inc. +# Author: Petr Spacek +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.euibase + + +@dns.immutable.immutable +class EUI64(dns.rdtypes.euibase.EUIBase): + """EUI64 record""" + + # see: rfc7043.txt + + byte_len = 8 # 0123456789abcdef (in hex) + text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab-cd-ef diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/GPOS.py b/venv/Lib/site-packages/dns/rdtypes/ANY/GPOS.py new file mode 100644 index 00000000..312338f9 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/GPOS.py @@ -0,0 +1,125 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +def _validate_float_string(what): + if len(what) == 0: + raise dns.exception.FormError + if what[0] == b"-"[0] or what[0] == b"+"[0]: + what = what[1:] + if what.isdigit(): + return + try: + (left, right) = what.split(b".") + except ValueError: + raise dns.exception.FormError + if left == b"" and right == b"": + raise dns.exception.FormError + if not left == b"" and not left.decode().isdigit(): + raise dns.exception.FormError + if not right == b"" and not right.decode().isdigit(): + raise dns.exception.FormError + + +@dns.immutable.immutable +class GPOS(dns.rdata.Rdata): + """GPOS record""" + + # see: RFC 1712 + + __slots__ = ["latitude", "longitude", "altitude"] + + def __init__(self, rdclass, rdtype, latitude, longitude, altitude): + super().__init__(rdclass, rdtype) + if isinstance(latitude, float) or isinstance(latitude, int): + latitude = str(latitude) + if isinstance(longitude, float) or isinstance(longitude, int): + longitude = str(longitude) + if isinstance(altitude, float) or isinstance(altitude, int): + altitude = str(altitude) + latitude = self._as_bytes(latitude, True, 255) + longitude = self._as_bytes(longitude, True, 255) + altitude = self._as_bytes(altitude, True, 255) + _validate_float_string(latitude) + _validate_float_string(longitude) + _validate_float_string(altitude) + self.latitude = latitude + self.longitude = longitude + self.altitude = altitude + flat = self.float_latitude + if flat < -90.0 or flat > 90.0: + raise dns.exception.FormError("bad latitude") + flong = self.float_longitude + if flong < -180.0 or flong > 180.0: + raise dns.exception.FormError("bad longitude") + + def to_text(self, origin=None, relativize=True, **kw): + return "{} {} {}".format( + self.latitude.decode(), self.longitude.decode(), self.altitude.decode() + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + latitude = tok.get_string() + longitude = tok.get_string() + altitude = tok.get_string() + return cls(rdclass, rdtype, latitude, longitude, altitude) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.latitude) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.latitude) + l = len(self.longitude) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.longitude) + l = len(self.altitude) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.altitude) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + latitude = parser.get_counted_bytes() + longitude = parser.get_counted_bytes() + altitude = parser.get_counted_bytes() + return cls(rdclass, rdtype, latitude, longitude, altitude) + + @property + def float_latitude(self): + "latitude as a floating point value" + return float(self.latitude) + + @property + def float_longitude(self): + "longitude as a floating point value" + return float(self.longitude) + + @property + def float_altitude(self): + "altitude as a floating point value" + return float(self.altitude) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/HINFO.py b/venv/Lib/site-packages/dns/rdtypes/ANY/HINFO.py new file mode 100644 index 00000000..c2c45de0 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/HINFO.py @@ -0,0 +1,66 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class HINFO(dns.rdata.Rdata): + """HINFO record""" + + # see: RFC 1035 + + __slots__ = ["cpu", "os"] + + def __init__(self, rdclass, rdtype, cpu, os): + super().__init__(rdclass, rdtype) + self.cpu = self._as_bytes(cpu, True, 255) + self.os = self._as_bytes(os, True, 255) + + def to_text(self, origin=None, relativize=True, **kw): + return '"{}" "{}"'.format( + dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os) + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + cpu = tok.get_string(max_length=255) + os = tok.get_string(max_length=255) + return cls(rdclass, rdtype, cpu, os) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.cpu) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.cpu) + l = len(self.os) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.os) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + cpu = parser.get_counted_bytes() + os = parser.get_counted_bytes() + return cls(rdclass, rdtype, cpu, os) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/HIP.py b/venv/Lib/site-packages/dns/rdtypes/ANY/HIP.py new file mode 100644 index 00000000..91669139 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/HIP.py @@ -0,0 +1,85 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2010, 2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import binascii +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.rdatatype + + +@dns.immutable.immutable +class HIP(dns.rdata.Rdata): + """HIP record""" + + # see: RFC 5205 + + __slots__ = ["hit", "algorithm", "key", "servers"] + + def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): + super().__init__(rdclass, rdtype) + self.hit = self._as_bytes(hit, True, 255) + self.algorithm = self._as_uint8(algorithm) + self.key = self._as_bytes(key, True) + self.servers = self._as_tuple(servers, self._as_name) + + def to_text(self, origin=None, relativize=True, **kw): + hit = binascii.hexlify(self.hit).decode() + key = base64.b64encode(self.key).replace(b"\n", b"").decode() + text = "" + servers = [] + for server in self.servers: + servers.append(server.choose_relativity(origin, relativize)) + if len(servers) > 0: + text += " " + " ".join((x.to_unicode() for x in servers)) + return "%u %s %s%s" % (self.algorithm, hit, key, text) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_uint8() + hit = binascii.unhexlify(tok.get_string().encode()) + key = base64.b64decode(tok.get_string().encode()) + servers = [] + for token in tok.get_remaining(): + server = tok.as_name(token, origin, relativize, relativize_to) + servers.append(server) + return cls(rdclass, rdtype, hit, algorithm, key, servers) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + lh = len(self.hit) + lk = len(self.key) + file.write(struct.pack("!BBH", lh, self.algorithm, lk)) + file.write(self.hit) + file.write(self.key) + for server in self.servers: + server.to_wire(file, None, origin, False) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (lh, algorithm, lk) = parser.get_struct("!BBH") + hit = parser.get_bytes(lh) + key = parser.get_bytes(lk) + servers = [] + while parser.remaining() > 0: + server = parser.get_name(origin) + servers.append(server) + return cls(rdclass, rdtype, hit, algorithm, key, servers) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/ISDN.py b/venv/Lib/site-packages/dns/rdtypes/ANY/ISDN.py new file mode 100644 index 00000000..fb01eab3 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/ISDN.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class ISDN(dns.rdata.Rdata): + """ISDN record""" + + # see: RFC 1183 + + __slots__ = ["address", "subaddress"] + + def __init__(self, rdclass, rdtype, address, subaddress): + super().__init__(rdclass, rdtype) + self.address = self._as_bytes(address, True, 255) + self.subaddress = self._as_bytes(subaddress, True, 255) + + def to_text(self, origin=None, relativize=True, **kw): + if self.subaddress: + return '"{}" "{}"'.format( + dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress) + ) + else: + return '"%s"' % dns.rdata._escapify(self.address) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_string() + tokens = tok.get_remaining(max_tokens=1) + if len(tokens) >= 1: + subaddress = tokens[0].unescape().value + else: + subaddress = "" + return cls(rdclass, rdtype, address, subaddress) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.address) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.address) + l = len(self.subaddress) + if l > 0: + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.subaddress) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() + if parser.remaining() > 0: + subaddress = parser.get_counted_bytes() + else: + subaddress = b"" + return cls(rdclass, rdtype, address, subaddress) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/L32.py b/venv/Lib/site-packages/dns/rdtypes/ANY/L32.py new file mode 100644 index 00000000..09804c2d --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/L32.py @@ -0,0 +1,41 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class L32(dns.rdata.Rdata): + """L32 record""" + + # see: rfc6742.txt + + __slots__ = ["preference", "locator32"] + + def __init__(self, rdclass, rdtype, preference, locator32): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.locator32 = self._as_ipv4_address(locator32) + + def to_text(self, origin=None, relativize=True, **kw): + return f"{self.preference} {self.locator32}" + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + nodeid = tok.get_identifier() + return cls(rdclass, rdtype, preference, nodeid) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.preference)) + file.write(dns.ipv4.inet_aton(self.locator32)) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + locator32 = parser.get_remaining() + return cls(rdclass, rdtype, preference, locator32) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/L64.py b/venv/Lib/site-packages/dns/rdtypes/ANY/L64.py new file mode 100644 index 00000000..fb76808e --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/L64.py @@ -0,0 +1,47 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdtypes.util + + +@dns.immutable.immutable +class L64(dns.rdata.Rdata): + """L64 record""" + + # see: rfc6742.txt + + __slots__ = ["preference", "locator64"] + + def __init__(self, rdclass, rdtype, preference, locator64): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + if isinstance(locator64, bytes): + if len(locator64) != 8: + raise ValueError("invalid locator64") + self.locator64 = dns.rdata._hexify(locator64, 4, b":") + else: + dns.rdtypes.util.parse_formatted_hex(locator64, 4, 4, ":") + self.locator64 = locator64 + + def to_text(self, origin=None, relativize=True, **kw): + return f"{self.preference} {self.locator64}" + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + locator64 = tok.get_identifier() + return cls(rdclass, rdtype, preference, locator64) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.locator64, 4, 4, ":")) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + locator64 = parser.get_remaining() + return cls(rdclass, rdtype, preference, locator64) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/LOC.py b/venv/Lib/site-packages/dns/rdtypes/ANY/LOC.py new file mode 100644 index 00000000..a36a2c10 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/LOC.py @@ -0,0 +1,354 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata + +_pows = tuple(10**i for i in range(0, 11)) + +# default values are in centimeters +_default_size = 100.0 +_default_hprec = 1000000.0 +_default_vprec = 1000.0 + +# for use by from_wire() +_MAX_LATITUDE = 0x80000000 + 90 * 3600000 +_MIN_LATITUDE = 0x80000000 - 90 * 3600000 +_MAX_LONGITUDE = 0x80000000 + 180 * 3600000 +_MIN_LONGITUDE = 0x80000000 - 180 * 3600000 + + +def _exponent_of(what, desc): + if what == 0: + return 0 + exp = None + for i, pow in enumerate(_pows): + if what < pow: + exp = i - 1 + break + if exp is None or exp < 0: + raise dns.exception.SyntaxError("%s value out of bounds" % desc) + return exp + + +def _float_to_tuple(what): + if what < 0: + sign = -1 + what *= -1 + else: + sign = 1 + what = round(what * 3600000) + degrees = int(what // 3600000) + what -= degrees * 3600000 + minutes = int(what // 60000) + what -= minutes * 60000 + seconds = int(what // 1000) + what -= int(seconds * 1000) + what = int(what) + return (degrees, minutes, seconds, what, sign) + + +def _tuple_to_float(what): + value = float(what[0]) + value += float(what[1]) / 60.0 + value += float(what[2]) / 3600.0 + value += float(what[3]) / 3600000.0 + return float(what[4]) * value + + +def _encode_size(what, desc): + what = int(what) + exponent = _exponent_of(what, desc) & 0xF + base = what // pow(10, exponent) & 0xF + return base * 16 + exponent + + +def _decode_size(what, desc): + exponent = what & 0x0F + if exponent > 9: + raise dns.exception.FormError("bad %s exponent" % desc) + base = (what & 0xF0) >> 4 + if base > 9: + raise dns.exception.FormError("bad %s base" % desc) + return base * pow(10, exponent) + + +def _check_coordinate_list(value, low, high): + if value[0] < low or value[0] > high: + raise ValueError(f"not in range [{low}, {high}]") + if value[1] < 0 or value[1] > 59: + raise ValueError("bad minutes value") + if value[2] < 0 or value[2] > 59: + raise ValueError("bad seconds value") + if value[3] < 0 or value[3] > 999: + raise ValueError("bad milliseconds value") + if value[4] != 1 and value[4] != -1: + raise ValueError("bad hemisphere value") + + +@dns.immutable.immutable +class LOC(dns.rdata.Rdata): + """LOC record""" + + # see: RFC 1876 + + __slots__ = [ + "latitude", + "longitude", + "altitude", + "size", + "horizontal_precision", + "vertical_precision", + ] + + def __init__( + self, + rdclass, + rdtype, + latitude, + longitude, + altitude, + size=_default_size, + hprec=_default_hprec, + vprec=_default_vprec, + ): + """Initialize a LOC record instance. + + The parameters I{latitude} and I{longitude} may be either a 4-tuple + of integers specifying (degrees, minutes, seconds, milliseconds), + or they may be floating point values specifying the number of + degrees. The other parameters are floats. Size, horizontal precision, + and vertical precision are specified in centimeters.""" + + super().__init__(rdclass, rdtype) + if isinstance(latitude, int): + latitude = float(latitude) + if isinstance(latitude, float): + latitude = _float_to_tuple(latitude) + _check_coordinate_list(latitude, -90, 90) + self.latitude = tuple(latitude) + if isinstance(longitude, int): + longitude = float(longitude) + if isinstance(longitude, float): + longitude = _float_to_tuple(longitude) + _check_coordinate_list(longitude, -180, 180) + self.longitude = tuple(longitude) + self.altitude = float(altitude) + self.size = float(size) + self.horizontal_precision = float(hprec) + self.vertical_precision = float(vprec) + + def to_text(self, origin=None, relativize=True, **kw): + if self.latitude[4] > 0: + lat_hemisphere = "N" + else: + lat_hemisphere = "S" + if self.longitude[4] > 0: + long_hemisphere = "E" + else: + long_hemisphere = "W" + text = "%d %d %d.%03d %s %d %d %d.%03d %s %0.2fm" % ( + self.latitude[0], + self.latitude[1], + self.latitude[2], + self.latitude[3], + lat_hemisphere, + self.longitude[0], + self.longitude[1], + self.longitude[2], + self.longitude[3], + long_hemisphere, + self.altitude / 100.0, + ) + + # do not print default values + if ( + self.size != _default_size + or self.horizontal_precision != _default_hprec + or self.vertical_precision != _default_vprec + ): + text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( + self.size / 100.0, + self.horizontal_precision / 100.0, + self.vertical_precision / 100.0, + ) + return text + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + latitude = [0, 0, 0, 0, 1] + longitude = [0, 0, 0, 0, 1] + size = _default_size + hprec = _default_hprec + vprec = _default_vprec + + latitude[0] = tok.get_int() + t = tok.get_string() + if t.isdigit(): + latitude[1] = int(t) + t = tok.get_string() + if "." in t: + (seconds, milliseconds) = t.split(".") + if not seconds.isdigit(): + raise dns.exception.SyntaxError("bad latitude seconds value") + latitude[2] = int(seconds) + l = len(milliseconds) + if l == 0 or l > 3 or not milliseconds.isdigit(): + raise dns.exception.SyntaxError("bad latitude milliseconds value") + if l == 1: + m = 100 + elif l == 2: + m = 10 + else: + m = 1 + latitude[3] = m * int(milliseconds) + t = tok.get_string() + elif t.isdigit(): + latitude[2] = int(t) + t = tok.get_string() + if t == "S": + latitude[4] = -1 + elif t != "N": + raise dns.exception.SyntaxError("bad latitude hemisphere value") + + longitude[0] = tok.get_int() + t = tok.get_string() + if t.isdigit(): + longitude[1] = int(t) + t = tok.get_string() + if "." in t: + (seconds, milliseconds) = t.split(".") + if not seconds.isdigit(): + raise dns.exception.SyntaxError("bad longitude seconds value") + longitude[2] = int(seconds) + l = len(milliseconds) + if l == 0 or l > 3 or not milliseconds.isdigit(): + raise dns.exception.SyntaxError("bad longitude milliseconds value") + if l == 1: + m = 100 + elif l == 2: + m = 10 + else: + m = 1 + longitude[3] = m * int(milliseconds) + t = tok.get_string() + elif t.isdigit(): + longitude[2] = int(t) + t = tok.get_string() + if t == "W": + longitude[4] = -1 + elif t != "E": + raise dns.exception.SyntaxError("bad longitude hemisphere value") + + t = tok.get_string() + if t[-1] == "m": + t = t[0:-1] + altitude = float(t) * 100.0 # m -> cm + + tokens = tok.get_remaining(max_tokens=3) + if len(tokens) >= 1: + value = tokens[0].unescape().value + if value[-1] == "m": + value = value[0:-1] + size = float(value) * 100.0 # m -> cm + if len(tokens) >= 2: + value = tokens[1].unescape().value + if value[-1] == "m": + value = value[0:-1] + hprec = float(value) * 100.0 # m -> cm + if len(tokens) >= 3: + value = tokens[2].unescape().value + if value[-1] == "m": + value = value[0:-1] + vprec = float(value) * 100.0 # m -> cm + + # Try encoding these now so we raise if they are bad + _encode_size(size, "size") + _encode_size(hprec, "horizontal precision") + _encode_size(vprec, "vertical precision") + + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + milliseconds = ( + self.latitude[0] * 3600000 + + self.latitude[1] * 60000 + + self.latitude[2] * 1000 + + self.latitude[3] + ) * self.latitude[4] + latitude = 0x80000000 + milliseconds + milliseconds = ( + self.longitude[0] * 3600000 + + self.longitude[1] * 60000 + + self.longitude[2] * 1000 + + self.longitude[3] + ) * self.longitude[4] + longitude = 0x80000000 + milliseconds + altitude = int(self.altitude) + 10000000 + size = _encode_size(self.size, "size") + hprec = _encode_size(self.horizontal_precision, "horizontal precision") + vprec = _encode_size(self.vertical_precision, "vertical precision") + wire = struct.pack( + "!BBBBIII", 0, size, hprec, vprec, latitude, longitude, altitude + ) + file.write(wire) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + ( + version, + size, + hprec, + vprec, + latitude, + longitude, + altitude, + ) = parser.get_struct("!BBBBIII") + if version != 0: + raise dns.exception.FormError("LOC version not zero") + if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: + raise dns.exception.FormError("bad latitude") + if latitude > 0x80000000: + latitude = (latitude - 0x80000000) / 3600000 + else: + latitude = -1 * (0x80000000 - latitude) / 3600000 + if longitude < _MIN_LONGITUDE or longitude > _MAX_LONGITUDE: + raise dns.exception.FormError("bad longitude") + if longitude > 0x80000000: + longitude = (longitude - 0x80000000) / 3600000 + else: + longitude = -1 * (0x80000000 - longitude) / 3600000 + altitude = float(altitude) - 10000000.0 + size = _decode_size(size, "size") + hprec = _decode_size(hprec, "horizontal precision") + vprec = _decode_size(vprec, "vertical precision") + return cls(rdclass, rdtype, latitude, longitude, altitude, size, hprec, vprec) + + @property + def float_latitude(self): + "latitude as a floating point value" + return _tuple_to_float(self.latitude) + + @property + def float_longitude(self): + "longitude as a floating point value" + return _tuple_to_float(self.longitude) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/LP.py b/venv/Lib/site-packages/dns/rdtypes/ANY/LP.py new file mode 100644 index 00000000..312663f1 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/LP.py @@ -0,0 +1,42 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class LP(dns.rdata.Rdata): + """LP record""" + + # see: rfc6742.txt + + __slots__ = ["preference", "fqdn"] + + def __init__(self, rdclass, rdtype, preference, fqdn): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.fqdn = self._as_name(fqdn) + + def to_text(self, origin=None, relativize=True, **kw): + fqdn = self.fqdn.choose_relativity(origin, relativize) + return "%d %s" % (self.preference, fqdn) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + fqdn = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, preference, fqdn) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.preference)) + self.fqdn.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + fqdn = parser.get_name(origin) + return cls(rdclass, rdtype, preference, fqdn) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/MX.py b/venv/Lib/site-packages/dns/rdtypes/ANY/MX.py new file mode 100644 index 00000000..0c300c5a --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/MX.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.mxbase + + +@dns.immutable.immutable +class MX(dns.rdtypes.mxbase.MXBase): + """MX record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NID.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NID.py new file mode 100644 index 00000000..2f649178 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NID.py @@ -0,0 +1,47 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import struct + +import dns.immutable +import dns.rdtypes.util + + +@dns.immutable.immutable +class NID(dns.rdata.Rdata): + """NID record""" + + # see: rfc6742.txt + + __slots__ = ["preference", "nodeid"] + + def __init__(self, rdclass, rdtype, preference, nodeid): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + if isinstance(nodeid, bytes): + if len(nodeid) != 8: + raise ValueError("invalid nodeid") + self.nodeid = dns.rdata._hexify(nodeid, 4, b":") + else: + dns.rdtypes.util.parse_formatted_hex(nodeid, 4, 4, ":") + self.nodeid = nodeid + + def to_text(self, origin=None, relativize=True, **kw): + return f"{self.preference} {self.nodeid}" + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + nodeid = tok.get_identifier() + return cls(rdclass, rdtype, preference, nodeid) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.preference)) + file.write(dns.rdtypes.util.parse_formatted_hex(self.nodeid, 4, 4, ":")) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + nodeid = parser.get_remaining() + return cls(rdclass, rdtype, preference, nodeid) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NINFO.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NINFO.py new file mode 100644 index 00000000..b177bddb --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NINFO.py @@ -0,0 +1,26 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class NINFO(dns.rdtypes.txtbase.TXTBase): + """NINFO record""" + + # see: draft-reid-dnsext-zs-01 diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NS.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NS.py new file mode 100644 index 00000000..c3f34ce9 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NS.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.nsbase + + +@dns.immutable.immutable +class NS(dns.rdtypes.nsbase.NSBase): + """NS record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC.py new file mode 100644 index 00000000..340525a6 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC.py @@ -0,0 +1,67 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdatatype +import dns.rdtypes.util + + +@dns.immutable.immutable +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = "NSEC" + + +@dns.immutable.immutable +class NSEC(dns.rdata.Rdata): + """NSEC record""" + + __slots__ = ["next", "windows"] + + def __init__(self, rdclass, rdtype, next, windows): + super().__init__(rdclass, rdtype) + self.next = self._as_name(next) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) + + def to_text(self, origin=None, relativize=True, **kw): + next = self.next.choose_relativity(origin, relativize) + text = Bitmap(self.windows).to_text() + return "{}{}".format(next, text) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + next = tok.get_name(origin, relativize, relativize_to) + windows = Bitmap.from_text(tok) + return cls(rdclass, rdtype, next, windows) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + # Note that NSEC downcasing, originally mandated by RFC 4034 + # section 6.2 was removed by RFC 6840 section 5.1. + self.next.to_wire(file, None, origin, False) + Bitmap(self.windows).to_wire(file) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + next = parser.get_name(origin) + bitmap = Bitmap.from_wire_parser(parser) + return cls(rdclass, rdtype, next, bitmap) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3.py new file mode 100644 index 00000000..d71302b7 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3.py @@ -0,0 +1,126 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import binascii +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.rdatatype +import dns.rdtypes.util + +b32_hex_to_normal = bytes.maketrans( + b"0123456789ABCDEFGHIJKLMNOPQRSTUV", b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" +) +b32_normal_to_hex = bytes.maketrans( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", b"0123456789ABCDEFGHIJKLMNOPQRSTUV" +) + +# hash algorithm constants +SHA1 = 1 + +# flag constants +OPTOUT = 1 + + +@dns.immutable.immutable +class Bitmap(dns.rdtypes.util.Bitmap): + type_name = "NSEC3" + + +@dns.immutable.immutable +class NSEC3(dns.rdata.Rdata): + """NSEC3 record""" + + __slots__ = ["algorithm", "flags", "iterations", "salt", "next", "windows"] + + def __init__( + self, rdclass, rdtype, algorithm, flags, iterations, salt, next, windows + ): + super().__init__(rdclass, rdtype) + self.algorithm = self._as_uint8(algorithm) + self.flags = self._as_uint8(flags) + self.iterations = self._as_uint16(iterations) + self.salt = self._as_bytes(salt, True, 255) + self.next = self._as_bytes(next, True, 255) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = tuple(windows.windows) + + def _next_text(self): + next = base64.b32encode(self.next).translate(b32_normal_to_hex).lower().decode() + next = next.rstrip("=") + return next + + def to_text(self, origin=None, relativize=True, **kw): + next = self._next_text() + if self.salt == b"": + salt = "-" + else: + salt = binascii.hexlify(self.salt).decode() + text = Bitmap(self.windows).to_text() + return "%u %u %u %s %s%s" % ( + self.algorithm, + self.flags, + self.iterations, + salt, + next, + text, + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_uint8() + flags = tok.get_uint8() + iterations = tok.get_uint16() + salt = tok.get_string() + if salt == "-": + salt = b"" + else: + salt = binascii.unhexlify(salt.encode("ascii")) + next = tok.get_string().encode("ascii").upper().translate(b32_hex_to_normal) + if next.endswith(b"="): + raise binascii.Error("Incorrect padding") + if len(next) % 8 != 0: + next += b"=" * (8 - len(next) % 8) + next = base64.b32decode(next) + bitmap = Bitmap.from_text(tok) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.salt) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) + file.write(self.salt) + l = len(self.next) + file.write(struct.pack("!B", l)) + file.write(self.next) + Bitmap(self.windows).to_wire(file) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct("!BBH") + salt = parser.get_counted_bytes() + next = parser.get_counted_bytes() + bitmap = Bitmap.from_wire_parser(parser) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, bitmap) + + def next_name(self, origin=None): + return dns.name.from_text(self._next_text(), origin) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3PARAM.py b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3PARAM.py new file mode 100644 index 00000000..d1e62ebc --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/NSEC3PARAM.py @@ -0,0 +1,69 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii +import struct + +import dns.exception +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class NSEC3PARAM(dns.rdata.Rdata): + """NSEC3PARAM record""" + + __slots__ = ["algorithm", "flags", "iterations", "salt"] + + def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): + super().__init__(rdclass, rdtype) + self.algorithm = self._as_uint8(algorithm) + self.flags = self._as_uint8(flags) + self.iterations = self._as_uint16(iterations) + self.salt = self._as_bytes(salt, True, 255) + + def to_text(self, origin=None, relativize=True, **kw): + if self.salt == b"": + salt = "-" + else: + salt = binascii.hexlify(self.salt).decode() + return "%u %u %u %s" % (self.algorithm, self.flags, self.iterations, salt) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_uint8() + flags = tok.get_uint8() + iterations = tok.get_uint16() + salt = tok.get_string() + if salt == "-": + salt = "" + else: + salt = binascii.unhexlify(salt.encode()) + return cls(rdclass, rdtype, algorithm, flags, iterations, salt) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.salt) + file.write(struct.pack("!BBHB", self.algorithm, self.flags, self.iterations, l)) + file.write(self.salt) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct("!BBH") + salt = parser.get_counted_bytes() + return cls(rdclass, rdtype, algorithm, flags, iterations, salt) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/OPENPGPKEY.py b/venv/Lib/site-packages/dns/rdtypes/ANY/OPENPGPKEY.py new file mode 100644 index 00000000..4d7a4b6c --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/OPENPGPKEY.py @@ -0,0 +1,53 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2016 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class OPENPGPKEY(dns.rdata.Rdata): + """OPENPGPKEY record""" + + # see: RFC 7929 + + def __init__(self, rdclass, rdtype, key): + super().__init__(rdclass, rdtype) + self.key = self._as_bytes(key) + + def to_text(self, origin=None, relativize=True, **kw): + return dns.rdata._base64ify(self.key, chunksize=None, **kw) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + b64 = tok.concatenate_remaining_identifiers().encode() + key = base64.b64decode(b64) + return cls(rdclass, rdtype, key) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.key) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + key = parser.get_remaining() + return cls(rdclass, rdtype, key) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/OPT.py b/venv/Lib/site-packages/dns/rdtypes/ANY/OPT.py new file mode 100644 index 00000000..d343dfa5 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/OPT.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.edns +import dns.exception +import dns.immutable +import dns.rdata + +# We don't implement from_text, and that's ok. +# pylint: disable=abstract-method + + +@dns.immutable.immutable +class OPT(dns.rdata.Rdata): + """OPT record""" + + __slots__ = ["options"] + + def __init__(self, rdclass, rdtype, options): + """Initialize an OPT rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata, + which is also the payload size. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *options*, a tuple of ``bytes`` + """ + + super().__init__(rdclass, rdtype) + + def as_option(option): + if not isinstance(option, dns.edns.Option): + raise ValueError("option is not a dns.edns.option") + return option + + self.options = self._as_tuple(options, as_option) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + for opt in self.options: + owire = opt.to_wire() + file.write(struct.pack("!HH", opt.otype, len(owire))) + file.write(owire) + + def to_text(self, origin=None, relativize=True, **kw): + return " ".join(opt.to_text() for opt in self.options) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + options = [] + while parser.remaining() > 0: + (otype, olen) = parser.get_struct("!HH") + with parser.restrict_to(olen): + opt = dns.edns.option_from_wire_parser(otype, parser) + options.append(opt) + return cls(rdclass, rdtype, options) + + @property + def payload(self): + "payload size" + return self.rdclass diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/PTR.py b/venv/Lib/site-packages/dns/rdtypes/ANY/PTR.py new file mode 100644 index 00000000..98c36167 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/PTR.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.nsbase + + +@dns.immutable.immutable +class PTR(dns.rdtypes.nsbase.NSBase): + """PTR record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/RP.py b/venv/Lib/site-packages/dns/rdtypes/ANY/RP.py new file mode 100644 index 00000000..9b74549d --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/RP.py @@ -0,0 +1,58 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata + + +@dns.immutable.immutable +class RP(dns.rdata.Rdata): + """RP record""" + + # see: RFC 1183 + + __slots__ = ["mbox", "txt"] + + def __init__(self, rdclass, rdtype, mbox, txt): + super().__init__(rdclass, rdtype) + self.mbox = self._as_name(mbox) + self.txt = self._as_name(txt) + + def to_text(self, origin=None, relativize=True, **kw): + mbox = self.mbox.choose_relativity(origin, relativize) + txt = self.txt.choose_relativity(origin, relativize) + return "{} {}".format(str(mbox), str(txt)) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + mbox = tok.get_name(origin, relativize, relativize_to) + txt = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, mbox, txt) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.mbox.to_wire(file, None, origin, canonicalize) + self.txt.to_wire(file, None, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mbox = parser.get_name(origin) + txt = parser.get_name(origin) + return cls(rdclass, rdtype, mbox, txt) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/RRSIG.py b/venv/Lib/site-packages/dns/rdtypes/ANY/RRSIG.py new file mode 100644 index 00000000..8beb4237 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/RRSIG.py @@ -0,0 +1,157 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import calendar +import struct +import time + +import dns.dnssectypes +import dns.exception +import dns.immutable +import dns.rdata +import dns.rdatatype + + +class BadSigTime(dns.exception.DNSException): + """Time in DNS SIG or RRSIG resource record cannot be parsed.""" + + +def sigtime_to_posixtime(what): + if len(what) <= 10 and what.isdigit(): + return int(what) + if len(what) != 14: + raise BadSigTime + year = int(what[0:4]) + month = int(what[4:6]) + day = int(what[6:8]) + hour = int(what[8:10]) + minute = int(what[10:12]) + second = int(what[12:14]) + return calendar.timegm((year, month, day, hour, minute, second, 0, 0, 0)) + + +def posixtime_to_sigtime(what): + return time.strftime("%Y%m%d%H%M%S", time.gmtime(what)) + + +@dns.immutable.immutable +class RRSIG(dns.rdata.Rdata): + """RRSIG record""" + + __slots__ = [ + "type_covered", + "algorithm", + "labels", + "original_ttl", + "expiration", + "inception", + "key_tag", + "signer", + "signature", + ] + + def __init__( + self, + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ): + super().__init__(rdclass, rdtype) + self.type_covered = self._as_rdatatype(type_covered) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) + self.labels = self._as_uint8(labels) + self.original_ttl = self._as_ttl(original_ttl) + self.expiration = self._as_uint32(expiration) + self.inception = self._as_uint32(inception) + self.key_tag = self._as_uint16(key_tag) + self.signer = self._as_name(signer) + self.signature = self._as_bytes(signature) + + def covers(self): + return self.type_covered + + def to_text(self, origin=None, relativize=True, **kw): + return "%s %d %d %d %s %s %d %s %s" % ( + dns.rdatatype.to_text(self.type_covered), + self.algorithm, + self.labels, + self.original_ttl, + posixtime_to_sigtime(self.expiration), + posixtime_to_sigtime(self.inception), + self.key_tag, + self.signer.choose_relativity(origin, relativize), + dns.rdata._base64ify(self.signature, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + type_covered = dns.rdatatype.from_text(tok.get_string()) + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) + labels = tok.get_int() + original_ttl = tok.get_ttl() + expiration = sigtime_to_posixtime(tok.get_string()) + inception = sigtime_to_posixtime(tok.get_string()) + key_tag = tok.get_int() + signer = tok.get_name(origin, relativize, relativize_to) + b64 = tok.concatenate_remaining_identifiers().encode() + signature = base64.b64decode(b64) + return cls( + rdclass, + rdtype, + type_covered, + algorithm, + labels, + original_ttl, + expiration, + inception, + key_tag, + signer, + signature, + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack( + "!HBBIIIH", + self.type_covered, + self.algorithm, + self.labels, + self.original_ttl, + self.expiration, + self.inception, + self.key_tag, + ) + file.write(header) + self.signer.to_wire(file, None, origin, canonicalize) + file.write(self.signature) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!HBBIIIH") + signer = parser.get_name(origin) + signature = parser.get_remaining() + return cls(rdclass, rdtype, *header, signer, signature) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/RT.py b/venv/Lib/site-packages/dns/rdtypes/ANY/RT.py new file mode 100644 index 00000000..5a4d45cf --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/RT.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.mxbase + + +@dns.immutable.immutable +class RT(dns.rdtypes.mxbase.UncompressedDowncasingMX): + """RT record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/SMIMEA.py b/venv/Lib/site-packages/dns/rdtypes/ANY/SMIMEA.py new file mode 100644 index 00000000..55d87bf8 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/SMIMEA.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.tlsabase + + +@dns.immutable.immutable +class SMIMEA(dns.rdtypes.tlsabase.TLSABase): + """SMIMEA record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/SOA.py b/venv/Lib/site-packages/dns/rdtypes/ANY/SOA.py new file mode 100644 index 00000000..09aa8321 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/SOA.py @@ -0,0 +1,86 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata + + +@dns.immutable.immutable +class SOA(dns.rdata.Rdata): + """SOA record""" + + # see: RFC 1035 + + __slots__ = ["mname", "rname", "serial", "refresh", "retry", "expire", "minimum"] + + def __init__( + self, rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ): + super().__init__(rdclass, rdtype) + self.mname = self._as_name(mname) + self.rname = self._as_name(rname) + self.serial = self._as_uint32(serial) + self.refresh = self._as_ttl(refresh) + self.retry = self._as_ttl(retry) + self.expire = self._as_ttl(expire) + self.minimum = self._as_ttl(minimum) + + def to_text(self, origin=None, relativize=True, **kw): + mname = self.mname.choose_relativity(origin, relativize) + rname = self.rname.choose_relativity(origin, relativize) + return "%s %s %d %d %d %d %d" % ( + mname, + rname, + self.serial, + self.refresh, + self.retry, + self.expire, + self.minimum, + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + mname = tok.get_name(origin, relativize, relativize_to) + rname = tok.get_name(origin, relativize, relativize_to) + serial = tok.get_uint32() + refresh = tok.get_ttl() + retry = tok.get_ttl() + expire = tok.get_ttl() + minimum = tok.get_ttl() + return cls( + rdclass, rdtype, mname, rname, serial, refresh, retry, expire, minimum + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.mname.to_wire(file, compress, origin, canonicalize) + self.rname.to_wire(file, compress, origin, canonicalize) + five_ints = struct.pack( + "!IIIII", self.serial, self.refresh, self.retry, self.expire, self.minimum + ) + file.write(five_ints) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mname = parser.get_name(origin) + rname = parser.get_name(origin) + return cls(rdclass, rdtype, mname, rname, *parser.get_struct("!IIIII")) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/SPF.py b/venv/Lib/site-packages/dns/rdtypes/ANY/SPF.py new file mode 100644 index 00000000..1df3b705 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/SPF.py @@ -0,0 +1,26 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class SPF(dns.rdtypes.txtbase.TXTBase): + """SPF record""" + + # see: RFC 4408 diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/SSHFP.py b/venv/Lib/site-packages/dns/rdtypes/ANY/SSHFP.py new file mode 100644 index 00000000..d2c4b073 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/SSHFP.py @@ -0,0 +1,68 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii +import struct + +import dns.immutable +import dns.rdata +import dns.rdatatype + + +@dns.immutable.immutable +class SSHFP(dns.rdata.Rdata): + """SSHFP record""" + + # See RFC 4255 + + __slots__ = ["algorithm", "fp_type", "fingerprint"] + + def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): + super().__init__(rdclass, rdtype) + self.algorithm = self._as_uint8(algorithm) + self.fp_type = self._as_uint8(fp_type) + self.fingerprint = self._as_bytes(fingerprint, True) + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop("chunksize", 128) + return "%d %d %s" % ( + self.algorithm, + self.fp_type, + dns.rdata._hexify(self.fingerprint, chunksize=chunksize, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_uint8() + fp_type = tok.get_uint8() + fingerprint = tok.concatenate_remaining_identifiers().encode() + fingerprint = binascii.unhexlify(fingerprint) + return cls(rdclass, rdtype, algorithm, fp_type, fingerprint) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!BB", self.algorithm, self.fp_type) + file.write(header) + file.write(self.fingerprint) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BB") + fingerprint = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], fingerprint) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/TKEY.py b/venv/Lib/site-packages/dns/rdtypes/ANY/TKEY.py new file mode 100644 index 00000000..5b490b82 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/TKEY.py @@ -0,0 +1,142 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import struct + +import dns.exception +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class TKEY(dns.rdata.Rdata): + """TKEY Record""" + + __slots__ = [ + "algorithm", + "inception", + "expiration", + "mode", + "error", + "key", + "other", + ] + + def __init__( + self, + rdclass, + rdtype, + algorithm, + inception, + expiration, + mode, + error, + key, + other=b"", + ): + super().__init__(rdclass, rdtype) + self.algorithm = self._as_name(algorithm) + self.inception = self._as_uint32(inception) + self.expiration = self._as_uint32(expiration) + self.mode = self._as_uint16(mode) + self.error = self._as_uint16(error) + self.key = self._as_bytes(key) + self.other = self._as_bytes(other) + + def to_text(self, origin=None, relativize=True, **kw): + _algorithm = self.algorithm.choose_relativity(origin, relativize) + text = "%s %u %u %u %u %s" % ( + str(_algorithm), + self.inception, + self.expiration, + self.mode, + self.error, + dns.rdata._base64ify(self.key, 0), + ) + if len(self.other) > 0: + text += " %s" % (dns.rdata._base64ify(self.other, 0)) + + return text + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_name(relativize=False) + inception = tok.get_uint32() + expiration = tok.get_uint32() + mode = tok.get_uint16() + error = tok.get_uint16() + key_b64 = tok.get_string().encode() + key = base64.b64decode(key_b64) + other_b64 = tok.concatenate_remaining_identifiers(True).encode() + other = base64.b64decode(other_b64) + + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.algorithm.to_wire(file, compress, origin) + file.write( + struct.pack("!IIHH", self.inception, self.expiration, self.mode, self.error) + ) + file.write(struct.pack("!H", len(self.key))) + file.write(self.key) + file.write(struct.pack("!H", len(self.other))) + if len(self.other) > 0: + file.write(self.other) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + algorithm = parser.get_name(origin) + inception, expiration, mode, error = parser.get_struct("!IIHH") + key = parser.get_counted_bytes(2) + other = parser.get_counted_bytes(2) + + return cls( + rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other + ) + + # Constants for the mode field - from RFC 2930: + # 2.5 The Mode Field + # + # The mode field specifies the general scheme for key agreement or + # the purpose of the TKEY DNS message. Servers and resolvers + # supporting this specification MUST implement the Diffie-Hellman key + # agreement mode and the key deletion mode for queries. All other + # modes are OPTIONAL. A server supporting TKEY that receives a TKEY + # request with a mode it does not support returns the BADMODE error. + # The following values of the Mode octet are defined, available, or + # reserved: + # + # Value Description + # ----- ----------- + # 0 - reserved, see section 7 + # 1 server assignment + # 2 Diffie-Hellman exchange + # 3 GSS-API negotiation + # 4 resolver assignment + # 5 key deletion + # 6-65534 - available, see section 7 + # 65535 - reserved, see section 7 + SERVER_ASSIGNMENT = 1 + DIFFIE_HELLMAN_EXCHANGE = 2 + GSSAPI_NEGOTIATION = 3 + RESOLVER_ASSIGNMENT = 4 + KEY_DELETION = 5 diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/TLSA.py b/venv/Lib/site-packages/dns/rdtypes/ANY/TLSA.py new file mode 100644 index 00000000..4dffc553 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/TLSA.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.tlsabase + + +@dns.immutable.immutable +class TLSA(dns.rdtypes.tlsabase.TLSABase): + """TLSA record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/TSIG.py b/venv/Lib/site-packages/dns/rdtypes/ANY/TSIG.py new file mode 100644 index 00000000..79423826 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/TSIG.py @@ -0,0 +1,160 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import struct + +import dns.exception +import dns.immutable +import dns.rcode +import dns.rdata + + +@dns.immutable.immutable +class TSIG(dns.rdata.Rdata): + """TSIG record""" + + __slots__ = [ + "algorithm", + "time_signed", + "fudge", + "mac", + "original_id", + "error", + "other", + ] + + def __init__( + self, + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ): + """Initialize a TSIG rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *algorithm*, a ``dns.name.Name``. + + *time_signed*, an ``int``. + + *fudge*, an ``int`. + + *mac*, a ``bytes`` + + *original_id*, an ``int`` + + *error*, an ``int`` + + *other*, a ``bytes`` + """ + + super().__init__(rdclass, rdtype) + self.algorithm = self._as_name(algorithm) + self.time_signed = self._as_uint48(time_signed) + self.fudge = self._as_uint16(fudge) + self.mac = self._as_bytes(mac) + self.original_id = self._as_uint16(original_id) + self.error = dns.rcode.Rcode.make(error) + self.other = self._as_bytes(other) + + def to_text(self, origin=None, relativize=True, **kw): + algorithm = self.algorithm.choose_relativity(origin, relativize) + error = dns.rcode.to_text(self.error, True) + text = ( + f"{algorithm} {self.time_signed} {self.fudge} " + + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 0)} " + + f"{self.original_id} {error} {len(self.other)}" + ) + if self.other: + text += f" {dns.rdata._base64ify(self.other, 0)}" + return text + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + algorithm = tok.get_name(relativize=False) + time_signed = tok.get_uint48() + fudge = tok.get_uint16() + mac_len = tok.get_uint16() + mac = base64.b64decode(tok.get_string()) + if len(mac) != mac_len: + raise SyntaxError("invalid MAC") + original_id = tok.get_uint16() + error = dns.rcode.from_text(tok.get_string()) + other_len = tok.get_uint16() + if other_len > 0: + other = base64.b64decode(tok.get_string()) + if len(other) != other_len: + raise SyntaxError("invalid other data") + else: + other = b"" + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.algorithm.to_wire(file, None, origin, False) + file.write( + struct.pack( + "!HIHH", + (self.time_signed >> 32) & 0xFFFF, + self.time_signed & 0xFFFFFFFF, + self.fudge, + len(self.mac), + ) + ) + file.write(self.mac) + file.write(struct.pack("!HHH", self.original_id, self.error, len(self.other))) + file.write(self.other) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + algorithm = parser.get_name() + time_signed = parser.get_uint48() + fudge = parser.get_uint16() + mac = parser.get_counted_bytes(2) + (original_id, error) = parser.get_struct("!HH") + other = parser.get_counted_bytes(2) + return cls( + rdclass, + rdtype, + algorithm, + time_signed, + fudge, + mac, + original_id, + error, + other, + ) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/TXT.py b/venv/Lib/site-packages/dns/rdtypes/ANY/TXT.py new file mode 100644 index 00000000..6d4dae27 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/TXT.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.txtbase + + +@dns.immutable.immutable +class TXT(dns.rdtypes.txtbase.TXTBase): + """TXT record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/URI.py b/venv/Lib/site-packages/dns/rdtypes/ANY/URI.py new file mode 100644 index 00000000..2efbb305 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/URI.py @@ -0,0 +1,79 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) 2015 Red Hat, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdtypes.util + + +@dns.immutable.immutable +class URI(dns.rdata.Rdata): + """URI record""" + + # see RFC 7553 + + __slots__ = ["priority", "weight", "target"] + + def __init__(self, rdclass, rdtype, priority, weight, target): + super().__init__(rdclass, rdtype) + self.priority = self._as_uint16(priority) + self.weight = self._as_uint16(weight) + self.target = self._as_bytes(target, True) + if len(self.target) == 0: + raise dns.exception.SyntaxError("URI target cannot be empty") + + def to_text(self, origin=None, relativize=True, **kw): + return '%d %d "%s"' % (self.priority, self.weight, self.target.decode()) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + priority = tok.get_uint16() + weight = tok.get_uint16() + target = tok.get().unescape() + if not (target.is_quoted_string() or target.is_identifier()): + raise dns.exception.SyntaxError("URI target must be a string") + return cls(rdclass, rdtype, priority, weight, target.value) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + two_ints = struct.pack("!HH", self.priority, self.weight) + file.write(two_ints) + file.write(self.target) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight) = parser.get_struct("!HH") + target = parser.get_remaining() + if len(target) == 0: + raise dns.exception.FormError("URI target may not be empty") + return cls(rdclass, rdtype, priority, weight, target) + + def _processing_priority(self): + return self.priority + + def _processing_weight(self): + return self.weight + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.weighted_processing_order(iterable) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/X25.py b/venv/Lib/site-packages/dns/rdtypes/ANY/X25.py new file mode 100644 index 00000000..8375611d --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/X25.py @@ -0,0 +1,57 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class X25(dns.rdata.Rdata): + """X25 record""" + + # see RFC 1183 + + __slots__ = ["address"] + + def __init__(self, rdclass, rdtype, address): + super().__init__(rdclass, rdtype) + self.address = self._as_bytes(address, True, 255) + + def to_text(self, origin=None, relativize=True, **kw): + return '"%s"' % dns.rdata._escapify(self.address) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_string() + return cls(rdclass, rdtype, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + l = len(self.address) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(self.address) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() + return cls(rdclass, rdtype, address) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/ZONEMD.py b/venv/Lib/site-packages/dns/rdtypes/ANY/ZONEMD.py new file mode 100644 index 00000000..c90e3ee1 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/ZONEMD.py @@ -0,0 +1,66 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import binascii +import struct + +import dns.immutable +import dns.rdata +import dns.rdatatype +import dns.zonetypes + + +@dns.immutable.immutable +class ZONEMD(dns.rdata.Rdata): + """ZONEMD record""" + + # See RFC 8976 + + __slots__ = ["serial", "scheme", "hash_algorithm", "digest"] + + def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): + super().__init__(rdclass, rdtype) + self.serial = self._as_uint32(serial) + self.scheme = dns.zonetypes.DigestScheme.make(scheme) + self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm) + self.digest = self._as_bytes(digest) + + if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 + raise ValueError("scheme 0 is reserved") + if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 + raise ValueError("hash_algorithm 0 is reserved") + + hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm) + if hasher and hasher().digest_size != len(self.digest): + raise ValueError("digest length inconsistent with hash algorithm") + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.serial, + self.scheme, + self.hash_algorithm, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + serial = tok.get_uint32() + scheme = tok.get_uint8() + hash_algorithm = tok.get_uint8() + digest = tok.concatenate_remaining_identifiers().encode() + digest = binascii.unhexlify(digest) + return cls(rdclass, rdtype, serial, scheme, hash_algorithm, digest) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!IBB", self.serial, self.scheme, self.hash_algorithm) + file.write(header) + file.write(self.digest) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!IBB") + digest = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], digest) diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__init__.py b/venv/Lib/site-packages/dns/rdtypes/ANY/__init__.py new file mode 100644 index 00000000..3824a0a0 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/ANY/__init__.py @@ -0,0 +1,68 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Class ANY (generic) rdata type classes.""" + +__all__ = [ + "AFSDB", + "AMTRELAY", + "AVC", + "CAA", + "CDNSKEY", + "CDS", + "CERT", + "CNAME", + "CSYNC", + "DLV", + "DNAME", + "DNSKEY", + "DS", + "EUI48", + "EUI64", + "GPOS", + "HINFO", + "HIP", + "ISDN", + "L32", + "L64", + "LOC", + "LP", + "MX", + "NID", + "NINFO", + "NS", + "NSEC", + "NSEC3", + "NSEC3PARAM", + "OPENPGPKEY", + "OPT", + "PTR", + "RP", + "RRSIG", + "RT", + "SMIMEA", + "SOA", + "SPF", + "SSHFP", + "TKEY", + "TLSA", + "TSIG", + "TXT", + "URI", + "X25", + "ZONEMD", +] diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AFSDB.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AFSDB.cpython-312.pyc new file mode 100644 index 00000000..f4d74204 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AFSDB.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AMTRELAY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AMTRELAY.cpython-312.pyc new file mode 100644 index 00000000..48020511 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AMTRELAY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AVC.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AVC.cpython-312.pyc new file mode 100644 index 00000000..b65da7db Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/AVC.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CAA.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CAA.cpython-312.pyc new file mode 100644 index 00000000..faa7bac6 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CAA.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDNSKEY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDNSKEY.cpython-312.pyc new file mode 100644 index 00000000..d779fe93 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDNSKEY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDS.cpython-312.pyc new file mode 100644 index 00000000..f4895994 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CDS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CERT.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CERT.cpython-312.pyc new file mode 100644 index 00000000..4a6db3a7 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CERT.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CNAME.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CNAME.cpython-312.pyc new file mode 100644 index 00000000..d774f87e Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CNAME.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CSYNC.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CSYNC.cpython-312.pyc new file mode 100644 index 00000000..0fe33b42 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/CSYNC.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DLV.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DLV.cpython-312.pyc new file mode 100644 index 00000000..3ca19935 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DLV.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNAME.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNAME.cpython-312.pyc new file mode 100644 index 00000000..ff71f9e7 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNAME.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNSKEY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNSKEY.cpython-312.pyc new file mode 100644 index 00000000..72fdcb5e Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DNSKEY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DS.cpython-312.pyc new file mode 100644 index 00000000..53cf563d Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/DS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI48.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI48.cpython-312.pyc new file mode 100644 index 00000000..be5ea43f Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI48.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI64.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI64.cpython-312.pyc new file mode 100644 index 00000000..bd26fea4 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/EUI64.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/GPOS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/GPOS.cpython-312.pyc new file mode 100644 index 00000000..5072e455 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/GPOS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HINFO.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HINFO.cpython-312.pyc new file mode 100644 index 00000000..c01cb119 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HINFO.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HIP.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HIP.cpython-312.pyc new file mode 100644 index 00000000..66de5156 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/HIP.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ISDN.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ISDN.cpython-312.pyc new file mode 100644 index 00000000..900b791b Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ISDN.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L32.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L32.cpython-312.pyc new file mode 100644 index 00000000..fd10c74e Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L32.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L64.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L64.cpython-312.pyc new file mode 100644 index 00000000..681d2f1e Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/L64.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LOC.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LOC.cpython-312.pyc new file mode 100644 index 00000000..e3eb6253 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LOC.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LP.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LP.cpython-312.pyc new file mode 100644 index 00000000..a4abd7df Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/LP.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/MX.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/MX.cpython-312.pyc new file mode 100644 index 00000000..1265d0e5 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/MX.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NID.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NID.cpython-312.pyc new file mode 100644 index 00000000..91908aa6 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NID.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NINFO.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NINFO.cpython-312.pyc new file mode 100644 index 00000000..ba26183d Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NINFO.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NS.cpython-312.pyc new file mode 100644 index 00000000..344a740d Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC.cpython-312.pyc new file mode 100644 index 00000000..df7aec70 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3.cpython-312.pyc new file mode 100644 index 00000000..310e9bdb Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3PARAM.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3PARAM.cpython-312.pyc new file mode 100644 index 00000000..ac64cb88 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/NSEC3PARAM.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPENPGPKEY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPENPGPKEY.cpython-312.pyc new file mode 100644 index 00000000..9b31fbf2 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPENPGPKEY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPT.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPT.cpython-312.pyc new file mode 100644 index 00000000..c7fc32cd Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/OPT.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/PTR.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/PTR.cpython-312.pyc new file mode 100644 index 00000000..8ae47a03 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/PTR.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RP.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RP.cpython-312.pyc new file mode 100644 index 00000000..7879b6a9 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RP.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RRSIG.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RRSIG.cpython-312.pyc new file mode 100644 index 00000000..9ee1010a Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RRSIG.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RT.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RT.cpython-312.pyc new file mode 100644 index 00000000..7b758e71 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/RT.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SMIMEA.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SMIMEA.cpython-312.pyc new file mode 100644 index 00000000..6a86672f Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SMIMEA.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SOA.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SOA.cpython-312.pyc new file mode 100644 index 00000000..93bad4e4 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SOA.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SPF.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SPF.cpython-312.pyc new file mode 100644 index 00000000..1bc4a79a Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SPF.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SSHFP.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SSHFP.cpython-312.pyc new file mode 100644 index 00000000..f43c3490 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/SSHFP.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TKEY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TKEY.cpython-312.pyc new file mode 100644 index 00000000..8e9250e9 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TKEY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TLSA.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TLSA.cpython-312.pyc new file mode 100644 index 00000000..f47f534e Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TLSA.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TSIG.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TSIG.cpython-312.pyc new file mode 100644 index 00000000..eb65cbef Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TSIG.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TXT.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TXT.cpython-312.pyc new file mode 100644 index 00000000..644dbaca Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/TXT.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/URI.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/URI.cpython-312.pyc new file mode 100644 index 00000000..f69a7a67 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/URI.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/X25.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/X25.cpython-312.pyc new file mode 100644 index 00000000..e4827fa1 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/X25.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ZONEMD.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ZONEMD.cpython-312.pyc new file mode 100644 index 00000000..51d922a0 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/ZONEMD.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..06a091fe Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/ANY/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/CH/A.py b/venv/Lib/site-packages/dns/rdtypes/CH/A.py new file mode 100644 index 00000000..583a88ac --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/CH/A.py @@ -0,0 +1,59 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.immutable +import dns.rdtypes.mxbase + + +@dns.immutable.immutable +class A(dns.rdata.Rdata): + """A record for Chaosnet""" + + # domain: the domain of the address + # address: the 16-bit address + + __slots__ = ["domain", "address"] + + def __init__(self, rdclass, rdtype, domain, address): + super().__init__(rdclass, rdtype) + self.domain = self._as_name(domain) + self.address = self._as_uint16(address) + + def to_text(self, origin=None, relativize=True, **kw): + domain = self.domain.choose_relativity(origin, relativize) + return "%s %o" % (domain, self.address) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + domain = tok.get_name(origin, relativize, relativize_to) + address = tok.get_uint16(base=8) + return cls(rdclass, rdtype, domain, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.domain.to_wire(file, compress, origin, canonicalize) + pref = struct.pack("!H", self.address) + file.write(pref) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + domain = parser.get_name(origin) + address = parser.get_uint16() + return cls(rdclass, rdtype, domain, address) diff --git a/venv/Lib/site-packages/dns/rdtypes/CH/__init__.py b/venv/Lib/site-packages/dns/rdtypes/CH/__init__.py new file mode 100644 index 00000000..0760c26c --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/CH/__init__.py @@ -0,0 +1,22 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Class CH rdata type classes.""" + +__all__ = [ + "A", +] diff --git a/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/A.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/A.cpython-312.pyc new file mode 100644 index 00000000..c542f846 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/A.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..c11a1d1a Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/CH/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/A.py b/venv/Lib/site-packages/dns/rdtypes/IN/A.py new file mode 100644 index 00000000..e09d6110 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/A.py @@ -0,0 +1,51 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.immutable +import dns.ipv4 +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class A(dns.rdata.Rdata): + """A record.""" + + __slots__ = ["address"] + + def __init__(self, rdclass, rdtype, address): + super().__init__(rdclass, rdtype) + self.address = self._as_ipv4_address(address) + + def to_text(self, origin=None, relativize=True, **kw): + return self.address + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_identifier() + return cls(rdclass, rdtype, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(dns.ipv4.inet_aton(self.address)) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_remaining() + return cls(rdclass, rdtype, address) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/AAAA.py b/venv/Lib/site-packages/dns/rdtypes/IN/AAAA.py new file mode 100644 index 00000000..0cd139e7 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/AAAA.py @@ -0,0 +1,51 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.immutable +import dns.ipv6 +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class AAAA(dns.rdata.Rdata): + """AAAA record.""" + + __slots__ = ["address"] + + def __init__(self, rdclass, rdtype, address): + super().__init__(rdclass, rdtype) + self.address = self._as_ipv6_address(address) + + def to_text(self, origin=None, relativize=True, **kw): + return self.address + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_identifier() + return cls(rdclass, rdtype, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(dns.ipv6.inet_aton(self.address)) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_remaining() + return cls(rdclass, rdtype, address) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/APL.py b/venv/Lib/site-packages/dns/rdtypes/IN/APL.py new file mode 100644 index 00000000..44cb3fef --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/APL.py @@ -0,0 +1,150 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii +import codecs +import struct + +import dns.exception +import dns.immutable +import dns.ipv4 +import dns.ipv6 +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class APLItem: + """An APL list item.""" + + __slots__ = ["family", "negation", "address", "prefix"] + + def __init__(self, family, negation, address, prefix): + self.family = dns.rdata.Rdata._as_uint16(family) + self.negation = dns.rdata.Rdata._as_bool(negation) + if self.family == 1: + self.address = dns.rdata.Rdata._as_ipv4_address(address) + self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 32) + elif self.family == 2: + self.address = dns.rdata.Rdata._as_ipv6_address(address) + self.prefix = dns.rdata.Rdata._as_int(prefix, 0, 128) + else: + self.address = dns.rdata.Rdata._as_bytes(address, max_length=127) + self.prefix = dns.rdata.Rdata._as_uint8(prefix) + + def __str__(self): + if self.negation: + return "!%d:%s/%s" % (self.family, self.address, self.prefix) + else: + return "%d:%s/%s" % (self.family, self.address, self.prefix) + + def to_wire(self, file): + if self.family == 1: + address = dns.ipv4.inet_aton(self.address) + elif self.family == 2: + address = dns.ipv6.inet_aton(self.address) + else: + address = binascii.unhexlify(self.address) + # + # Truncate least significant zero bytes. + # + last = 0 + for i in range(len(address) - 1, -1, -1): + if address[i] != 0: + last = i + 1 + break + address = address[0:last] + l = len(address) + assert l < 128 + if self.negation: + l |= 0x80 + header = struct.pack("!HBB", self.family, self.prefix, l) + file.write(header) + file.write(address) + + +@dns.immutable.immutable +class APL(dns.rdata.Rdata): + """APL record.""" + + # see: RFC 3123 + + __slots__ = ["items"] + + def __init__(self, rdclass, rdtype, items): + super().__init__(rdclass, rdtype) + for item in items: + if not isinstance(item, APLItem): + raise ValueError("item not an APLItem") + self.items = tuple(items) + + def to_text(self, origin=None, relativize=True, **kw): + return " ".join(map(str, self.items)) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + items = [] + for token in tok.get_remaining(): + item = token.unescape().value + if item[0] == "!": + negation = True + item = item[1:] + else: + negation = False + (family, rest) = item.split(":", 1) + family = int(family) + (address, prefix) = rest.split("/", 1) + prefix = int(prefix) + item = APLItem(family, negation, address, prefix) + items.append(item) + + return cls(rdclass, rdtype, items) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + for item in self.items: + item.to_wire(file) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + items = [] + while parser.remaining() > 0: + header = parser.get_struct("!HBB") + afdlen = header[2] + if afdlen > 127: + negation = True + afdlen -= 128 + else: + negation = False + address = parser.get_bytes(afdlen) + l = len(address) + if header[0] == 1: + if l < 4: + address += b"\x00" * (4 - l) + elif header[0] == 2: + if l < 16: + address += b"\x00" * (16 - l) + else: + # + # This isn't really right according to the RFC, but it + # seems better than throwing an exception + # + address = codecs.encode(address, "hex_codec") + item = APLItem(header[0], negation, address, header[1]) + items.append(item) + return cls(rdclass, rdtype, items) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/DHCID.py b/venv/Lib/site-packages/dns/rdtypes/IN/DHCID.py new file mode 100644 index 00000000..723492fa --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/DHCID.py @@ -0,0 +1,54 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 + +import dns.exception +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class DHCID(dns.rdata.Rdata): + """DHCID record""" + + # see: RFC 4701 + + __slots__ = ["data"] + + def __init__(self, rdclass, rdtype, data): + super().__init__(rdclass, rdtype) + self.data = self._as_bytes(data) + + def to_text(self, origin=None, relativize=True, **kw): + return dns.rdata._base64ify(self.data, **kw) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + b64 = tok.concatenate_remaining_identifiers().encode() + data = base64.b64decode(b64) + return cls(rdclass, rdtype, data) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.data) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + data = parser.get_remaining() + return cls(rdclass, rdtype, data) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/HTTPS.py b/venv/Lib/site-packages/dns/rdtypes/IN/HTTPS.py new file mode 100644 index 00000000..15464cbd --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/HTTPS.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.svcbbase + + +@dns.immutable.immutable +class HTTPS(dns.rdtypes.svcbbase.SVCBBase): + """HTTPS record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/IPSECKEY.py b/venv/Lib/site-packages/dns/rdtypes/IN/IPSECKEY.py new file mode 100644 index 00000000..e3a66157 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/IPSECKEY.py @@ -0,0 +1,91 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import struct + +import dns.exception +import dns.immutable +import dns.rdtypes.util + + +class Gateway(dns.rdtypes.util.Gateway): + name = "IPSECKEY gateway" + + +@dns.immutable.immutable +class IPSECKEY(dns.rdata.Rdata): + """IPSECKEY record""" + + # see: RFC 4025 + + __slots__ = ["precedence", "gateway_type", "algorithm", "gateway", "key"] + + def __init__( + self, rdclass, rdtype, precedence, gateway_type, algorithm, gateway, key + ): + super().__init__(rdclass, rdtype) + gateway = Gateway(gateway_type, gateway) + self.precedence = self._as_uint8(precedence) + self.gateway_type = gateway.type + self.algorithm = self._as_uint8(algorithm) + self.gateway = gateway.gateway + self.key = self._as_bytes(key) + + def to_text(self, origin=None, relativize=True, **kw): + gateway = Gateway(self.gateway_type, self.gateway).to_text(origin, relativize) + return "%d %d %d %s %s" % ( + self.precedence, + self.gateway_type, + self.algorithm, + gateway, + dns.rdata._base64ify(self.key, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + precedence = tok.get_uint8() + gateway_type = tok.get_uint8() + algorithm = tok.get_uint8() + gateway = Gateway.from_text( + gateway_type, tok, origin, relativize, relativize_to + ) + b64 = tok.concatenate_remaining_identifiers().encode() + key = base64.b64decode(b64) + return cls( + rdclass, rdtype, precedence, gateway_type, algorithm, gateway.gateway, key + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!BBB", self.precedence, self.gateway_type, self.algorithm) + file.write(header) + Gateway(self.gateway_type, self.gateway).to_wire( + file, compress, origin, canonicalize + ) + file.write(self.key) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!BBB") + gateway_type = header[1] + gateway = Gateway.from_wire_parser(gateway_type, parser, origin) + key = parser.get_remaining() + return cls( + rdclass, rdtype, header[0], gateway_type, header[2], gateway.gateway, key + ) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/KX.py b/venv/Lib/site-packages/dns/rdtypes/IN/KX.py new file mode 100644 index 00000000..6073df47 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/KX.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.mxbase + + +@dns.immutable.immutable +class KX(dns.rdtypes.mxbase.UncompressedDowncasingMX): + """KX record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/NAPTR.py b/venv/Lib/site-packages/dns/rdtypes/IN/NAPTR.py new file mode 100644 index 00000000..195d1cba --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/NAPTR.py @@ -0,0 +1,110 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdtypes.util + + +def _write_string(file, s): + l = len(s) + assert l < 256 + file.write(struct.pack("!B", l)) + file.write(s) + + +@dns.immutable.immutable +class NAPTR(dns.rdata.Rdata): + """NAPTR record""" + + # see: RFC 3403 + + __slots__ = ["order", "preference", "flags", "service", "regexp", "replacement"] + + def __init__( + self, rdclass, rdtype, order, preference, flags, service, regexp, replacement + ): + super().__init__(rdclass, rdtype) + self.flags = self._as_bytes(flags, True, 255) + self.service = self._as_bytes(service, True, 255) + self.regexp = self._as_bytes(regexp, True, 255) + self.order = self._as_uint16(order) + self.preference = self._as_uint16(preference) + self.replacement = self._as_name(replacement) + + def to_text(self, origin=None, relativize=True, **kw): + replacement = self.replacement.choose_relativity(origin, relativize) + return '%d %d "%s" "%s" "%s" %s' % ( + self.order, + self.preference, + dns.rdata._escapify(self.flags), + dns.rdata._escapify(self.service), + dns.rdata._escapify(self.regexp), + replacement, + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + order = tok.get_uint16() + preference = tok.get_uint16() + flags = tok.get_string() + service = tok.get_string() + regexp = tok.get_string() + replacement = tok.get_name(origin, relativize, relativize_to) + return cls( + rdclass, rdtype, order, preference, flags, service, regexp, replacement + ) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + two_ints = struct.pack("!HH", self.order, self.preference) + file.write(two_ints) + _write_string(file, self.flags) + _write_string(file, self.service) + _write_string(file, self.regexp) + self.replacement.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (order, preference) = parser.get_struct("!HH") + strings = [] + for _ in range(3): + s = parser.get_counted_bytes() + strings.append(s) + replacement = parser.get_name(origin) + return cls( + rdclass, + rdtype, + order, + preference, + strings[0], + strings[1], + strings[2], + replacement, + ) + + def _processing_priority(self): + return (self.order, self.preference) + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/NSAP.py b/venv/Lib/site-packages/dns/rdtypes/IN/NSAP.py new file mode 100644 index 00000000..a4854b3f --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/NSAP.py @@ -0,0 +1,60 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii + +import dns.exception +import dns.immutable +import dns.rdata +import dns.tokenizer + + +@dns.immutable.immutable +class NSAP(dns.rdata.Rdata): + """NSAP record.""" + + # see: RFC 1706 + + __slots__ = ["address"] + + def __init__(self, rdclass, rdtype, address): + super().__init__(rdclass, rdtype) + self.address = self._as_bytes(address) + + def to_text(self, origin=None, relativize=True, **kw): + return "0x%s" % binascii.hexlify(self.address).decode() + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_string() + if address[0:2] != "0x": + raise dns.exception.SyntaxError("string does not start with 0x") + address = address[2:].replace(".", "") + if len(address) % 2 != 0: + raise dns.exception.SyntaxError("hexstring has odd length") + address = binascii.unhexlify(address.encode()) + return cls(rdclass, rdtype, address) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.address) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_remaining() + return cls(rdclass, rdtype, address) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/NSAP_PTR.py b/venv/Lib/site-packages/dns/rdtypes/IN/NSAP_PTR.py new file mode 100644 index 00000000..ce1c6632 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/NSAP_PTR.py @@ -0,0 +1,24 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.immutable +import dns.rdtypes.nsbase + + +@dns.immutable.immutable +class NSAP_PTR(dns.rdtypes.nsbase.UncompressedNS): + """NSAP-PTR record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/PX.py b/venv/Lib/site-packages/dns/rdtypes/IN/PX.py new file mode 100644 index 00000000..cdca1532 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/PX.py @@ -0,0 +1,73 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdtypes.util + + +@dns.immutable.immutable +class PX(dns.rdata.Rdata): + """PX record.""" + + # see: RFC 2163 + + __slots__ = ["preference", "map822", "mapx400"] + + def __init__(self, rdclass, rdtype, preference, map822, mapx400): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.map822 = self._as_name(map822) + self.mapx400 = self._as_name(mapx400) + + def to_text(self, origin=None, relativize=True, **kw): + map822 = self.map822.choose_relativity(origin, relativize) + mapx400 = self.mapx400.choose_relativity(origin, relativize) + return "%d %s %s" % (self.preference, map822, mapx400) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + map822 = tok.get_name(origin, relativize, relativize_to) + mapx400 = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, preference, map822, mapx400) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + pref = struct.pack("!H", self.preference) + file.write(pref) + self.map822.to_wire(file, None, origin, canonicalize) + self.mapx400.to_wire(file, None, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + map822 = parser.get_name(origin) + mapx400 = parser.get_name(origin) + return cls(rdclass, rdtype, preference, map822, mapx400) + + def _processing_priority(self): + return self.preference + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/SRV.py b/venv/Lib/site-packages/dns/rdtypes/IN/SRV.py new file mode 100644 index 00000000..5adef98f --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/SRV.py @@ -0,0 +1,75 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdtypes.util + + +@dns.immutable.immutable +class SRV(dns.rdata.Rdata): + """SRV record""" + + # see: RFC 2782 + + __slots__ = ["priority", "weight", "port", "target"] + + def __init__(self, rdclass, rdtype, priority, weight, port, target): + super().__init__(rdclass, rdtype) + self.priority = self._as_uint16(priority) + self.weight = self._as_uint16(weight) + self.port = self._as_uint16(port) + self.target = self._as_name(target) + + def to_text(self, origin=None, relativize=True, **kw): + target = self.target.choose_relativity(origin, relativize) + return "%d %d %d %s" % (self.priority, self.weight, self.port, target) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + priority = tok.get_uint16() + weight = tok.get_uint16() + port = tok.get_uint16() + target = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, priority, weight, port, target) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + three_ints = struct.pack("!HHH", self.priority, self.weight, self.port) + file.write(three_ints) + self.target.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight, port) = parser.get_struct("!HHH") + target = parser.get_name(origin) + return cls(rdclass, rdtype, priority, weight, port, target) + + def _processing_priority(self): + return self.priority + + def _processing_weight(self): + return self.weight + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.weighted_processing_order(iterable) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/SVCB.py b/venv/Lib/site-packages/dns/rdtypes/IN/SVCB.py new file mode 100644 index 00000000..ff3e9327 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/SVCB.py @@ -0,0 +1,9 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import dns.immutable +import dns.rdtypes.svcbbase + + +@dns.immutable.immutable +class SVCB(dns.rdtypes.svcbbase.SVCBBase): + """SVCB record""" diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/WKS.py b/venv/Lib/site-packages/dns/rdtypes/IN/WKS.py new file mode 100644 index 00000000..881a7849 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/WKS.py @@ -0,0 +1,100 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import socket +import struct + +import dns.immutable +import dns.ipv4 +import dns.rdata + +try: + _proto_tcp = socket.getprotobyname("tcp") + _proto_udp = socket.getprotobyname("udp") +except OSError: + # Fall back to defaults in case /etc/protocols is unavailable. + _proto_tcp = 6 + _proto_udp = 17 + + +@dns.immutable.immutable +class WKS(dns.rdata.Rdata): + """WKS record""" + + # see: RFC 1035 + + __slots__ = ["address", "protocol", "bitmap"] + + def __init__(self, rdclass, rdtype, address, protocol, bitmap): + super().__init__(rdclass, rdtype) + self.address = self._as_ipv4_address(address) + self.protocol = self._as_uint8(protocol) + self.bitmap = self._as_bytes(bitmap) + + def to_text(self, origin=None, relativize=True, **kw): + bits = [] + for i, byte in enumerate(self.bitmap): + for j in range(0, 8): + if byte & (0x80 >> j): + bits.append(str(i * 8 + j)) + text = " ".join(bits) + return "%s %d %s" % (self.address, self.protocol, text) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + address = tok.get_string() + protocol = tok.get_string() + if protocol.isdigit(): + protocol = int(protocol) + else: + protocol = socket.getprotobyname(protocol) + bitmap = bytearray() + for token in tok.get_remaining(): + value = token.unescape().value + if value.isdigit(): + serv = int(value) + else: + if protocol != _proto_udp and protocol != _proto_tcp: + raise NotImplementedError("protocol must be TCP or UDP") + if protocol == _proto_udp: + protocol_text = "udp" + else: + protocol_text = "tcp" + serv = socket.getservbyname(value, protocol_text) + i = serv // 8 + l = len(bitmap) + if l < i + 1: + for _ in range(l, i + 1): + bitmap.append(0) + bitmap[i] = bitmap[i] | (0x80 >> (serv % 8)) + bitmap = dns.rdata._truncate_bitmap(bitmap) + return cls(rdclass, rdtype, address, protocol, bitmap) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(dns.ipv4.inet_aton(self.address)) + protocol = struct.pack("!B", self.protocol) + file.write(protocol) + file.write(self.bitmap) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_bytes(4) + protocol = parser.get_uint8() + bitmap = parser.get_remaining() + return cls(rdclass, rdtype, address, protocol, bitmap) diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__init__.py b/venv/Lib/site-packages/dns/rdtypes/IN/__init__.py new file mode 100644 index 00000000..dcec4dd2 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/IN/__init__.py @@ -0,0 +1,35 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Class IN rdata type classes.""" + +__all__ = [ + "A", + "AAAA", + "APL", + "DHCID", + "HTTPS", + "IPSECKEY", + "KX", + "NAPTR", + "NSAP", + "NSAP_PTR", + "PX", + "SRV", + "SVCB", + "WKS", +] diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/A.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/A.cpython-312.pyc new file mode 100644 index 00000000..9c7889ac Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/A.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/AAAA.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/AAAA.cpython-312.pyc new file mode 100644 index 00000000..a1d38345 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/AAAA.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/APL.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/APL.cpython-312.pyc new file mode 100644 index 00000000..78cf9dbb Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/APL.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/DHCID.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/DHCID.cpython-312.pyc new file mode 100644 index 00000000..4319c465 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/DHCID.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/HTTPS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/HTTPS.cpython-312.pyc new file mode 100644 index 00000000..2e6f1730 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/HTTPS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/IPSECKEY.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/IPSECKEY.cpython-312.pyc new file mode 100644 index 00000000..cb1af504 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/IPSECKEY.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/KX.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/KX.cpython-312.pyc new file mode 100644 index 00000000..79a966a3 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/KX.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NAPTR.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NAPTR.cpython-312.pyc new file mode 100644 index 00000000..e2c0b30a Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NAPTR.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP.cpython-312.pyc new file mode 100644 index 00000000..29385757 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP_PTR.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP_PTR.cpython-312.pyc new file mode 100644 index 00000000..d2368cdd Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/NSAP_PTR.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/PX.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/PX.cpython-312.pyc new file mode 100644 index 00000000..d9c12109 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/PX.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SRV.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SRV.cpython-312.pyc new file mode 100644 index 00000000..d43d0e7d Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SRV.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SVCB.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SVCB.cpython-312.pyc new file mode 100644 index 00000000..82bacae5 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/SVCB.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/WKS.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/WKS.cpython-312.pyc new file mode 100644 index 00000000..7c04dc58 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/WKS.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..cb2d9ab3 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/IN/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__init__.py b/venv/Lib/site-packages/dns/rdtypes/__init__.py new file mode 100644 index 00000000..3997f84c --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/__init__.py @@ -0,0 +1,33 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS rdata type classes""" + +__all__ = [ + "ANY", + "IN", + "CH", + "dnskeybase", + "dsbase", + "euibase", + "mxbase", + "nsbase", + "svcbbase", + "tlsabase", + "txtbase", + "util", +] diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..f73a7bff Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/dnskeybase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/dnskeybase.cpython-312.pyc new file mode 100644 index 00000000..171fdcbf Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/dnskeybase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/dsbase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/dsbase.cpython-312.pyc new file mode 100644 index 00000000..1542e209 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/dsbase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/euibase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/euibase.cpython-312.pyc new file mode 100644 index 00000000..eb6e39ae Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/euibase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/mxbase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/mxbase.cpython-312.pyc new file mode 100644 index 00000000..da0f89a7 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/mxbase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/nsbase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/nsbase.cpython-312.pyc new file mode 100644 index 00000000..9add3356 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/nsbase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/svcbbase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/svcbbase.cpython-312.pyc new file mode 100644 index 00000000..9c9187d5 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/svcbbase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/tlsabase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/tlsabase.cpython-312.pyc new file mode 100644 index 00000000..831c8333 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/tlsabase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/txtbase.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/txtbase.cpython-312.pyc new file mode 100644 index 00000000..875557f3 Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/txtbase.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/__pycache__/util.cpython-312.pyc b/venv/Lib/site-packages/dns/rdtypes/__pycache__/util.cpython-312.pyc new file mode 100644 index 00000000..d91827aa Binary files /dev/null and b/venv/Lib/site-packages/dns/rdtypes/__pycache__/util.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/dns/rdtypes/dnskeybase.py b/venv/Lib/site-packages/dns/rdtypes/dnskeybase.py new file mode 100644 index 00000000..db300f8b --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/dnskeybase.py @@ -0,0 +1,87 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 +import enum +import struct + +import dns.dnssectypes +import dns.exception +import dns.immutable +import dns.rdata + +# wildcard import +__all__ = ["SEP", "REVOKE", "ZONE"] # noqa: F822 + + +class Flag(enum.IntFlag): + SEP = 0x0001 + REVOKE = 0x0080 + ZONE = 0x0100 + + +@dns.immutable.immutable +class DNSKEYBase(dns.rdata.Rdata): + """Base class for rdata that is like a DNSKEY record""" + + __slots__ = ["flags", "protocol", "algorithm", "key"] + + def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): + super().__init__(rdclass, rdtype) + self.flags = Flag(self._as_uint16(flags)) + self.protocol = self._as_uint8(protocol) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) + self.key = self._as_bytes(key) + + def to_text(self, origin=None, relativize=True, **kw): + return "%d %d %d %s" % ( + self.flags, + self.protocol, + self.algorithm, + dns.rdata._base64ify(self.key, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + flags = tok.get_uint16() + protocol = tok.get_uint8() + algorithm = tok.get_string() + b64 = tok.concatenate_remaining_identifiers().encode() + key = base64.b64decode(b64) + return cls(rdclass, rdtype, flags, protocol, algorithm, key) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!HBB", self.flags, self.protocol, self.algorithm) + file.write(header) + file.write(self.key) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!HBB") + key = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], key) + + +### BEGIN generated Flag constants + +SEP = Flag.SEP +REVOKE = Flag.REVOKE +ZONE = Flag.ZONE + +### END generated Flag constants diff --git a/venv/Lib/site-packages/dns/rdtypes/dsbase.py b/venv/Lib/site-packages/dns/rdtypes/dsbase.py new file mode 100644 index 00000000..cd21f026 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/dsbase.py @@ -0,0 +1,85 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2010, 2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii +import struct + +import dns.dnssectypes +import dns.immutable +import dns.rdata +import dns.rdatatype + + +@dns.immutable.immutable +class DSBase(dns.rdata.Rdata): + """Base class for rdata that is like a DS record""" + + __slots__ = ["key_tag", "algorithm", "digest_type", "digest"] + + # Digest types registry: + # https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml + _digest_length_by_type = { + 1: 20, # SHA-1, RFC 3658 Sec. 2.4 + 2: 32, # SHA-256, RFC 4509 Sec. 2.2 + 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1 + 4: 48, # SHA-384, RFC 6605 Sec. 2 + } + + def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): + super().__init__(rdclass, rdtype) + self.key_tag = self._as_uint16(key_tag) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) + self.digest_type = dns.dnssectypes.DSDigest.make(self._as_uint8(digest_type)) + self.digest = self._as_bytes(digest) + try: + if len(self.digest) != self._digest_length_by_type[self.digest_type]: + raise ValueError("digest length inconsistent with digest type") + except KeyError: + if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4 + raise ValueError("digest type 0 is reserved") + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.key_tag, + self.algorithm, + self.digest_type, + dns.rdata._hexify(self.digest, chunksize=chunksize, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + key_tag = tok.get_uint16() + algorithm = tok.get_string() + digest_type = tok.get_uint8() + digest = tok.concatenate_remaining_identifiers().encode() + digest = binascii.unhexlify(digest) + return cls(rdclass, rdtype, key_tag, algorithm, digest_type, digest) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!HBB", self.key_tag, self.algorithm, self.digest_type) + file.write(header) + file.write(self.digest) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!HBB") + digest = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], digest) diff --git a/venv/Lib/site-packages/dns/rdtypes/euibase.py b/venv/Lib/site-packages/dns/rdtypes/euibase.py new file mode 100644 index 00000000..751087b4 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/euibase.py @@ -0,0 +1,70 @@ +# Copyright (C) 2015 Red Hat, Inc. +# Author: Petr Spacek +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED 'AS IS' AND RED HAT DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii + +import dns.immutable +import dns.rdata + + +@dns.immutable.immutable +class EUIBase(dns.rdata.Rdata): + """EUIxx record""" + + # see: rfc7043.txt + + __slots__ = ["eui"] + # define these in subclasses + # byte_len = 6 # 0123456789ab (in hex) + # text_len = byte_len * 3 - 1 # 01-23-45-67-89-ab + + def __init__(self, rdclass, rdtype, eui): + super().__init__(rdclass, rdtype) + self.eui = self._as_bytes(eui) + if len(self.eui) != self.byte_len: + raise dns.exception.FormError( + "EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len) + ) + + def to_text(self, origin=None, relativize=True, **kw): + return dns.rdata._hexify(self.eui, chunksize=2, separator=b"-", **kw) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + text = tok.get_string() + if len(text) != cls.text_len: + raise dns.exception.SyntaxError( + "Input text must have %s characters" % cls.text_len + ) + for i in range(2, cls.byte_len * 3 - 1, 3): + if text[i] != "-": + raise dns.exception.SyntaxError("Dash expected at position %s" % i) + text = text.replace("-", "") + try: + data = binascii.unhexlify(text.encode()) + except (ValueError, TypeError) as ex: + raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex)) + return cls(rdclass, rdtype, data) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(self.eui) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + eui = parser.get_bytes(cls.byte_len) + return cls(rdclass, rdtype, eui) diff --git a/venv/Lib/site-packages/dns/rdtypes/mxbase.py b/venv/Lib/site-packages/dns/rdtypes/mxbase.py new file mode 100644 index 00000000..6d5e3d87 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/mxbase.py @@ -0,0 +1,87 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""MX-like base classes.""" + +import struct + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata +import dns.rdtypes.util + + +@dns.immutable.immutable +class MXBase(dns.rdata.Rdata): + """Base class for rdata that is like an MX record.""" + + __slots__ = ["preference", "exchange"] + + def __init__(self, rdclass, rdtype, preference, exchange): + super().__init__(rdclass, rdtype) + self.preference = self._as_uint16(preference) + self.exchange = self._as_name(exchange) + + def to_text(self, origin=None, relativize=True, **kw): + exchange = self.exchange.choose_relativity(origin, relativize) + return "%d %s" % (self.preference, exchange) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + preference = tok.get_uint16() + exchange = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, preference, exchange) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + pref = struct.pack("!H", self.preference) + file.write(pref) + self.exchange.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + exchange = parser.get_name(origin) + return cls(rdclass, rdtype, preference, exchange) + + def _processing_priority(self): + return self.preference + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) + + +@dns.immutable.immutable +class UncompressedMX(MXBase): + """Base class for rdata that is like an MX record, but whose name + is not compressed when converted to DNS wire format, and whose + digestable form is not downcased.""" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + super()._to_wire(file, None, origin, False) + + +@dns.immutable.immutable +class UncompressedDowncasingMX(MXBase): + """Base class for rdata that is like an MX record, but whose name + is not compressed when convert to DNS wire format.""" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + super()._to_wire(file, None, origin, canonicalize) diff --git a/venv/Lib/site-packages/dns/rdtypes/nsbase.py b/venv/Lib/site-packages/dns/rdtypes/nsbase.py new file mode 100644 index 00000000..904224f0 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/nsbase.py @@ -0,0 +1,63 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""NS-like base classes.""" + +import dns.exception +import dns.immutable +import dns.name +import dns.rdata + + +@dns.immutable.immutable +class NSBase(dns.rdata.Rdata): + """Base class for rdata that is like an NS record.""" + + __slots__ = ["target"] + + def __init__(self, rdclass, rdtype, target): + super().__init__(rdclass, rdtype) + self.target = self._as_name(target) + + def to_text(self, origin=None, relativize=True, **kw): + target = self.target.choose_relativity(origin, relativize) + return str(target) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + target = tok.get_name(origin, relativize, relativize_to) + return cls(rdclass, rdtype, target) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, compress, origin, canonicalize) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + target = parser.get_name(origin) + return cls(rdclass, rdtype, target) + + +@dns.immutable.immutable +class UncompressedNS(NSBase): + """Base class for rdata that is like an NS record, but whose name + is not compressed when convert to DNS wire format, and whose + digestable form is not downcased.""" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.target.to_wire(file, None, origin, False) diff --git a/venv/Lib/site-packages/dns/rdtypes/svcbbase.py b/venv/Lib/site-packages/dns/rdtypes/svcbbase.py new file mode 100644 index 00000000..05652413 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/svcbbase.py @@ -0,0 +1,553 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import base64 +import enum +import struct + +import dns.enum +import dns.exception +import dns.immutable +import dns.ipv4 +import dns.ipv6 +import dns.name +import dns.rdata +import dns.rdtypes.util +import dns.renderer +import dns.tokenizer +import dns.wire + +# Until there is an RFC, this module is experimental and may be changed in +# incompatible ways. + + +class UnknownParamKey(dns.exception.DNSException): + """Unknown SVCB ParamKey""" + + +class ParamKey(dns.enum.IntEnum): + """SVCB ParamKey""" + + MANDATORY = 0 + ALPN = 1 + NO_DEFAULT_ALPN = 2 + PORT = 3 + IPV4HINT = 4 + ECH = 5 + IPV6HINT = 6 + DOHPATH = 7 + + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "SVCBParamKey" + + @classmethod + def _prefix(cls): + return "KEY" + + @classmethod + def _unknown_exception_class(cls): + return UnknownParamKey + + +class Emptiness(enum.IntEnum): + NEVER = 0 + ALWAYS = 1 + ALLOWED = 2 + + +def _validate_key(key): + force_generic = False + if isinstance(key, bytes): + # We decode to latin-1 so we get 0-255 as valid and do NOT interpret + # UTF-8 sequences + key = key.decode("latin-1") + if isinstance(key, str): + if key.lower().startswith("key"): + force_generic = True + if key[3:].startswith("0") and len(key) != 4: + # key has leading zeros + raise ValueError("leading zeros in key") + key = key.replace("-", "_") + return (ParamKey.make(key), force_generic) + + +def key_to_text(key): + return ParamKey.to_text(key).replace("_", "-").lower() + + +# Like rdata escapify, but escapes ',' too. + +_escaped = b'",\\' + + +def _escapify(qstring): + text = "" + for c in qstring: + if c in _escaped: + text += "\\" + chr(c) + elif c >= 0x20 and c < 0x7F: + text += chr(c) + else: + text += "\\%03d" % c + return text + + +def _unescape(value): + if value == "": + return value + unescaped = b"" + l = len(value) + i = 0 + while i < l: + c = value[i] + i += 1 + if c == "\\": + if i >= l: # pragma: no cover (can't happen via tokenizer get()) + raise dns.exception.UnexpectedEnd + c = value[i] + i += 1 + if c.isdigit(): + if i >= l: + raise dns.exception.UnexpectedEnd + c2 = value[i] + i += 1 + if i >= l: + raise dns.exception.UnexpectedEnd + c3 = value[i] + i += 1 + if not (c2.isdigit() and c3.isdigit()): + raise dns.exception.SyntaxError + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + unescaped += b"%c" % (codepoint) + continue + unescaped += c.encode() + return unescaped + + +def _split(value): + l = len(value) + i = 0 + items = [] + unescaped = b"" + while i < l: + c = value[i] + i += 1 + if c == ord("\\"): + if i >= l: # pragma: no cover (can't happen via tokenizer get()) + raise dns.exception.UnexpectedEnd + c = value[i] + i += 1 + unescaped += b"%c" % (c) + elif c == ord(","): + items.append(unescaped) + unescaped = b"" + else: + unescaped += b"%c" % (c) + items.append(unescaped) + return items + + +@dns.immutable.immutable +class Param: + """Abstract base class for SVCB parameters""" + + @classmethod + def emptiness(cls): + return Emptiness.NEVER + + +@dns.immutable.immutable +class GenericParam(Param): + """Generic SVCB parameter""" + + def __init__(self, value): + self.value = dns.rdata.Rdata._as_bytes(value, True) + + @classmethod + def emptiness(cls): + return Emptiness.ALLOWED + + @classmethod + def from_value(cls, value): + if value is None or len(value) == 0: + return None + else: + return cls(_unescape(value)) + + def to_text(self): + return '"' + dns.rdata._escapify(self.value) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + value = parser.get_bytes(parser.remaining()) + if len(value) == 0: + return None + else: + return cls(value) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(self.value) + + +@dns.immutable.immutable +class MandatoryParam(Param): + def __init__(self, keys): + # check for duplicates + keys = sorted([_validate_key(key)[0] for key in keys]) + prior_k = None + for k in keys: + if k == prior_k: + raise ValueError(f"duplicate key {k:d}") + prior_k = k + if k == ParamKey.MANDATORY: + raise ValueError("listed the mandatory key as mandatory") + self.keys = tuple(keys) + + @classmethod + def from_value(cls, value): + keys = [k.encode() for k in value.split(",")] + return cls(keys) + + def to_text(self): + return '"' + ",".join([key_to_text(key) for key in self.keys]) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + keys = [] + last_key = -1 + while parser.remaining() > 0: + key = parser.get_uint16() + if key < last_key: + raise dns.exception.FormError("manadatory keys not ascending") + last_key = key + keys.append(key) + return cls(keys) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for key in self.keys: + file.write(struct.pack("!H", key)) + + +@dns.immutable.immutable +class ALPNParam(Param): + def __init__(self, ids): + self.ids = dns.rdata.Rdata._as_tuple( + ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False) + ) + + @classmethod + def from_value(cls, value): + return cls(_split(_unescape(value))) + + def to_text(self): + value = ",".join([_escapify(id) for id in self.ids]) + return '"' + dns.rdata._escapify(value.encode()) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + ids = [] + while parser.remaining() > 0: + id = parser.get_counted_bytes() + ids.append(id) + return cls(ids) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for id in self.ids: + file.write(struct.pack("!B", len(id))) + file.write(id) + + +@dns.immutable.immutable +class NoDefaultALPNParam(Param): + # We don't ever expect to instantiate this class, but we need + # a from_value() and a from_wire_parser(), so we just return None + # from the class methods when things are OK. + + @classmethod + def emptiness(cls): + return Emptiness.ALWAYS + + @classmethod + def from_value(cls, value): + if value is None or value == "": + return None + else: + raise ValueError("no-default-alpn with non-empty value") + + def to_text(self): + raise NotImplementedError # pragma: no cover + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + if parser.remaining() != 0: + raise dns.exception.FormError + return None + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + raise NotImplementedError # pragma: no cover + + +@dns.immutable.immutable +class PortParam(Param): + def __init__(self, port): + self.port = dns.rdata.Rdata._as_uint16(port) + + @classmethod + def from_value(cls, value): + value = int(value) + return cls(value) + + def to_text(self): + return f'"{self.port}"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + port = parser.get_uint16() + return cls(port) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(struct.pack("!H", self.port)) + + +@dns.immutable.immutable +class IPv4HintParam(Param): + def __init__(self, addresses): + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv4_address + ) + + @classmethod + def from_value(cls, value): + addresses = value.split(",") + return cls(addresses) + + def to_text(self): + return '"' + ",".join(self.addresses) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + addresses = [] + while parser.remaining() > 0: + ip = parser.get_bytes(4) + addresses.append(dns.ipv4.inet_ntoa(ip)) + return cls(addresses) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for address in self.addresses: + file.write(dns.ipv4.inet_aton(address)) + + +@dns.immutable.immutable +class IPv6HintParam(Param): + def __init__(self, addresses): + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv6_address + ) + + @classmethod + def from_value(cls, value): + addresses = value.split(",") + return cls(addresses) + + def to_text(self): + return '"' + ",".join(self.addresses) + '"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + addresses = [] + while parser.remaining() > 0: + ip = parser.get_bytes(16) + addresses.append(dns.ipv6.inet_ntoa(ip)) + return cls(addresses) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + for address in self.addresses: + file.write(dns.ipv6.inet_aton(address)) + + +@dns.immutable.immutable +class ECHParam(Param): + def __init__(self, ech): + self.ech = dns.rdata.Rdata._as_bytes(ech, True) + + @classmethod + def from_value(cls, value): + if "\\" in value: + raise ValueError("escape in ECH value") + value = base64.b64decode(value.encode()) + return cls(value) + + def to_text(self): + b64 = base64.b64encode(self.ech).decode("ascii") + return f'"{b64}"' + + @classmethod + def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613 + value = parser.get_bytes(parser.remaining()) + return cls(value) + + def to_wire(self, file, origin=None): # pylint: disable=W0613 + file.write(self.ech) + + +_class_for_key = { + ParamKey.MANDATORY: MandatoryParam, + ParamKey.ALPN: ALPNParam, + ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam, + ParamKey.PORT: PortParam, + ParamKey.IPV4HINT: IPv4HintParam, + ParamKey.ECH: ECHParam, + ParamKey.IPV6HINT: IPv6HintParam, +} + + +def _validate_and_define(params, key, value): + (key, force_generic) = _validate_key(_unescape(key)) + if key in params: + raise SyntaxError(f'duplicate key "{key:d}"') + cls = _class_for_key.get(key, GenericParam) + emptiness = cls.emptiness() + if value is None: + if emptiness == Emptiness.NEVER: + raise SyntaxError("value cannot be empty") + value = cls.from_value(value) + else: + if force_generic: + value = cls.from_wire_parser(dns.wire.Parser(_unescape(value))) + else: + value = cls.from_value(value) + params[key] = value + + +@dns.immutable.immutable +class SVCBBase(dns.rdata.Rdata): + """Base class for SVCB-like records""" + + # see: draft-ietf-dnsop-svcb-https-11 + + __slots__ = ["priority", "target", "params"] + + def __init__(self, rdclass, rdtype, priority, target, params): + super().__init__(rdclass, rdtype) + self.priority = self._as_uint16(priority) + self.target = self._as_name(target) + for k, v in params.items(): + k = ParamKey.make(k) + if not isinstance(v, Param) and v is not None: + raise ValueError(f"{k:d} not a Param") + self.params = dns.immutable.Dict(params) + # Make sure any parameter listed as mandatory is present in the + # record. + mandatory = params.get(ParamKey.MANDATORY) + if mandatory: + for key in mandatory.keys: + # Note we have to say "not in" as we have None as a value + # so a get() and a not None test would be wrong. + if key not in params: + raise ValueError(f"key {key:d} declared mandatory but not present") + # The no-default-alpn parameter requires the alpn parameter. + if ParamKey.NO_DEFAULT_ALPN in params: + if ParamKey.ALPN not in params: + raise ValueError("no-default-alpn present, but alpn missing") + + def to_text(self, origin=None, relativize=True, **kw): + target = self.target.choose_relativity(origin, relativize) + params = [] + for key in sorted(self.params.keys()): + value = self.params[key] + if value is None: + params.append(key_to_text(key)) + else: + kv = key_to_text(key) + "=" + value.to_text() + params.append(kv) + if len(params) > 0: + space = " " + else: + space = "" + return "%d %s%s%s" % (self.priority, target, space, " ".join(params)) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + priority = tok.get_uint16() + target = tok.get_name(origin, relativize, relativize_to) + if priority == 0: + token = tok.get() + if not token.is_eol_or_eof(): + raise SyntaxError("parameters in AliasMode") + tok.unget(token) + params = {} + while True: + token = tok.get() + if token.is_eol_or_eof(): + tok.unget(token) + break + if token.ttype != dns.tokenizer.IDENTIFIER: + raise SyntaxError("parameter is not an identifier") + equals = token.value.find("=") + if equals == len(token.value) - 1: + # 'key=', so next token should be a quoted string without + # any intervening whitespace. + key = token.value[:-1] + token = tok.get(want_leading=True) + if token.ttype != dns.tokenizer.QUOTED_STRING: + raise SyntaxError("whitespace after =") + value = token.value + elif equals > 0: + # key=value + key = token.value[:equals] + value = token.value[equals + 1 :] + elif equals == 0: + # =key + raise SyntaxError('parameter cannot start with "="') + else: + # key + key = token.value + value = None + _validate_and_define(params, key, value) + return cls(rdclass, rdtype, priority, target, params) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + file.write(struct.pack("!H", self.priority)) + self.target.to_wire(file, None, origin, False) + for key in sorted(self.params): + file.write(struct.pack("!H", key)) + value = self.params[key] + with dns.renderer.prefixed_length(file, 2): + # Note that we're still writing a length of zero if the value is None + if value is not None: + value.to_wire(file, origin) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + priority = parser.get_uint16() + target = parser.get_name(origin) + if priority == 0 and parser.remaining() != 0: + raise dns.exception.FormError("parameters in AliasMode") + params = {} + prior_key = -1 + while parser.remaining() > 0: + key = parser.get_uint16() + if key < prior_key: + raise dns.exception.FormError("keys not in order") + prior_key = key + vlen = parser.get_uint16() + pcls = _class_for_key.get(key, GenericParam) + with parser.restrict_to(vlen): + value = pcls.from_wire_parser(parser, origin) + params[key] = value + return cls(rdclass, rdtype, priority, target, params) + + def _processing_priority(self): + return self.priority + + @classmethod + def _processing_order(cls, iterable): + return dns.rdtypes.util.priority_processing_order(iterable) diff --git a/venv/Lib/site-packages/dns/rdtypes/tlsabase.py b/venv/Lib/site-packages/dns/rdtypes/tlsabase.py new file mode 100644 index 00000000..a059d2c4 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/tlsabase.py @@ -0,0 +1,71 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import binascii +import struct + +import dns.immutable +import dns.rdata +import dns.rdatatype + + +@dns.immutable.immutable +class TLSABase(dns.rdata.Rdata): + """Base class for TLSA and SMIMEA records""" + + # see: RFC 6698 + + __slots__ = ["usage", "selector", "mtype", "cert"] + + def __init__(self, rdclass, rdtype, usage, selector, mtype, cert): + super().__init__(rdclass, rdtype) + self.usage = self._as_uint8(usage) + self.selector = self._as_uint8(selector) + self.mtype = self._as_uint8(mtype) + self.cert = self._as_bytes(cert) + + def to_text(self, origin=None, relativize=True, **kw): + kw = kw.copy() + chunksize = kw.pop("chunksize", 128) + return "%d %d %d %s" % ( + self.usage, + self.selector, + self.mtype, + dns.rdata._hexify(self.cert, chunksize=chunksize, **kw), + ) + + @classmethod + def from_text( + cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None + ): + usage = tok.get_uint8() + selector = tok.get_uint8() + mtype = tok.get_uint8() + cert = tok.concatenate_remaining_identifiers().encode() + cert = binascii.unhexlify(cert) + return cls(rdclass, rdtype, usage, selector, mtype, cert) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + header = struct.pack("!BBB", self.usage, self.selector, self.mtype) + file.write(header) + file.write(self.cert) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BBB") + cert = parser.get_remaining() + return cls(rdclass, rdtype, header[0], header[1], header[2], cert) diff --git a/venv/Lib/site-packages/dns/rdtypes/txtbase.py b/venv/Lib/site-packages/dns/rdtypes/txtbase.py new file mode 100644 index 00000000..44d6df57 --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/txtbase.py @@ -0,0 +1,104 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""TXT-like base class.""" + +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import dns.exception +import dns.immutable +import dns.rdata +import dns.renderer +import dns.tokenizer + + +@dns.immutable.immutable +class TXTBase(dns.rdata.Rdata): + """Base class for rdata that is like a TXT record (see RFC 1035).""" + + __slots__ = ["strings"] + + def __init__( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + strings: Iterable[Union[bytes, str]], + ): + """Initialize a TXT-like rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *strings*, a tuple of ``bytes`` + """ + super().__init__(rdclass, rdtype) + self.strings: Tuple[bytes] = self._as_tuple( + strings, lambda x: self._as_bytes(x, True, 255) + ) + + def to_text( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: + txt = "" + prefix = "" + for s in self.strings: + txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) + prefix = " " + return txt + + @classmethod + def from_text( + cls, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.rdata.Rdata: + strings = [] + for token in tok.get_remaining(): + token = token.unescape_to_bytes() + # The 'if' below is always true in the current code, but we + # are leaving this check in in case things change some day. + if not ( + token.is_quoted_string() or token.is_identifier() + ): # pragma: no cover + raise dns.exception.SyntaxError("expected a string") + if len(token.value) > 255: + raise dns.exception.SyntaxError("string too long") + strings.append(token.value) + if len(strings) == 0: + raise dns.exception.UnexpectedEnd + return cls(rdclass, rdtype, strings) + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + for s in self.strings: + with dns.renderer.prefixed_length(file, 1): + file.write(s) + + @classmethod + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + strings = [] + while parser.remaining() > 0: + s = parser.get_counted_bytes() + strings.append(s) + return cls(rdclass, rdtype, strings) diff --git a/venv/Lib/site-packages/dns/rdtypes/util.py b/venv/Lib/site-packages/dns/rdtypes/util.py new file mode 100644 index 00000000..54908fdc --- /dev/null +++ b/venv/Lib/site-packages/dns/rdtypes/util.py @@ -0,0 +1,257 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import collections +import random +import struct +from typing import Any, List + +import dns.exception +import dns.ipv4 +import dns.ipv6 +import dns.name +import dns.rdata + + +class Gateway: + """A helper class for the IPSECKEY gateway and AMTRELAY relay fields""" + + name = "" + + def __init__(self, type, gateway=None): + self.type = dns.rdata.Rdata._as_uint8(type) + self.gateway = gateway + self._check() + + @classmethod + def _invalid_type(cls, gateway_type): + return f"invalid {cls.name} type: {gateway_type}" + + def _check(self): + if self.type == 0: + if self.gateway not in (".", None): + raise SyntaxError(f"invalid {self.name} for type 0") + self.gateway = None + elif self.type == 1: + # check that it's OK + dns.ipv4.inet_aton(self.gateway) + elif self.type == 2: + # check that it's OK + dns.ipv6.inet_aton(self.gateway) + elif self.type == 3: + if not isinstance(self.gateway, dns.name.Name): + raise SyntaxError(f"invalid {self.name}; not a name") + else: + raise SyntaxError(self._invalid_type(self.type)) + + def to_text(self, origin=None, relativize=True): + if self.type == 0: + return "." + elif self.type in (1, 2): + return self.gateway + elif self.type == 3: + return str(self.gateway.choose_relativity(origin, relativize)) + else: + raise ValueError(self._invalid_type(self.type)) # pragma: no cover + + @classmethod + def from_text( + cls, gateway_type, tok, origin=None, relativize=True, relativize_to=None + ): + if gateway_type in (0, 1, 2): + gateway = tok.get_string() + elif gateway_type == 3: + gateway = tok.get_name(origin, relativize, relativize_to) + else: + raise dns.exception.SyntaxError( + cls._invalid_type(gateway_type) + ) # pragma: no cover + return cls(gateway_type, gateway) + + # pylint: disable=unused-argument + def to_wire(self, file, compress=None, origin=None, canonicalize=False): + if self.type == 0: + pass + elif self.type == 1: + file.write(dns.ipv4.inet_aton(self.gateway)) + elif self.type == 2: + file.write(dns.ipv6.inet_aton(self.gateway)) + elif self.type == 3: + self.gateway.to_wire(file, None, origin, False) + else: + raise ValueError(self._invalid_type(self.type)) # pragma: no cover + + # pylint: enable=unused-argument + + @classmethod + def from_wire_parser(cls, gateway_type, parser, origin=None): + if gateway_type == 0: + gateway = None + elif gateway_type == 1: + gateway = dns.ipv4.inet_ntoa(parser.get_bytes(4)) + elif gateway_type == 2: + gateway = dns.ipv6.inet_ntoa(parser.get_bytes(16)) + elif gateway_type == 3: + gateway = parser.get_name(origin) + else: + raise dns.exception.FormError(cls._invalid_type(gateway_type)) + return cls(gateway_type, gateway) + + +class Bitmap: + """A helper class for the NSEC/NSEC3/CSYNC type bitmaps""" + + type_name = "" + + def __init__(self, windows=None): + last_window = -1 + self.windows = windows + for window, bitmap in self.windows: + if not isinstance(window, int): + raise ValueError(f"bad {self.type_name} window type") + if window <= last_window: + raise ValueError(f"bad {self.type_name} window order") + if window > 256: + raise ValueError(f"bad {self.type_name} window number") + last_window = window + if not isinstance(bitmap, bytes): + raise ValueError(f"bad {self.type_name} octets type") + if len(bitmap) == 0 or len(bitmap) > 32: + raise ValueError(f"bad {self.type_name} octets") + + def to_text(self) -> str: + text = "" + for window, bitmap in self.windows: + bits = [] + for i, byte in enumerate(bitmap): + for j in range(0, 8): + if byte & (0x80 >> j): + rdtype = window * 256 + i * 8 + j + bits.append(dns.rdatatype.to_text(rdtype)) + text += " " + " ".join(bits) + return text + + @classmethod + def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap": + rdtypes = [] + for token in tok.get_remaining(): + rdtype = dns.rdatatype.from_text(token.unescape().value) + if rdtype == 0: + raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0") + rdtypes.append(rdtype) + return cls.from_rdtypes(rdtypes) + + @classmethod + def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap": + rdtypes = sorted(rdtypes) + window = 0 + octets = 0 + prior_rdtype = 0 + bitmap = bytearray(b"\0" * 32) + windows = [] + for rdtype in rdtypes: + if rdtype == prior_rdtype: + continue + prior_rdtype = rdtype + new_window = rdtype // 256 + if new_window != window: + if octets != 0: + windows.append((window, bytes(bitmap[0:octets]))) + bitmap = bytearray(b"\0" * 32) + window = new_window + offset = rdtype % 256 + byte = offset // 8 + bit = offset % 8 + octets = byte + 1 + bitmap[byte] = bitmap[byte] | (0x80 >> bit) + if octets != 0: + windows.append((window, bytes(bitmap[0:octets]))) + return cls(windows) + + def to_wire(self, file: Any) -> None: + for window, bitmap in self.windows: + file.write(struct.pack("!BB", window, len(bitmap))) + file.write(bitmap) + + @classmethod + def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap": + windows = [] + while parser.remaining() > 0: + window = parser.get_uint8() + bitmap = parser.get_counted_bytes() + windows.append((window, bitmap)) + return cls(windows) + + +def _priority_table(items): + by_priority = collections.defaultdict(list) + for rdata in items: + by_priority[rdata._processing_priority()].append(rdata) + return by_priority + + +def priority_processing_order(iterable): + items = list(iterable) + if len(items) == 1: + return items + by_priority = _priority_table(items) + ordered = [] + for k in sorted(by_priority.keys()): + rdatas = by_priority[k] + random.shuffle(rdatas) + ordered.extend(rdatas) + return ordered + + +_no_weight = 0.1 + + +def weighted_processing_order(iterable): + items = list(iterable) + if len(items) == 1: + return items + by_priority = _priority_table(items) + ordered = [] + for k in sorted(by_priority.keys()): + rdatas = by_priority[k] + total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas) + while len(rdatas) > 1: + r = random.uniform(0, total) + for n, rdata in enumerate(rdatas): + weight = rdata._processing_weight() or _no_weight + if weight > r: + break + r -= weight + total -= weight + ordered.append(rdata) # pylint: disable=undefined-loop-variable + del rdatas[n] # pylint: disable=undefined-loop-variable + ordered.append(rdatas[0]) + return ordered + + +def parse_formatted_hex(formatted, num_chunks, chunk_size, separator): + if len(formatted) != num_chunks * (chunk_size + 1) - 1: + raise ValueError("invalid formatted hex string") + value = b"" + for _ in range(num_chunks): + chunk = formatted[0:chunk_size] + value += int(chunk, 16).to_bytes(chunk_size // 2, "big") + formatted = formatted[chunk_size:] + if len(formatted) > 0 and formatted[0] != separator: + raise ValueError("invalid formatted hex string") + formatted = formatted[1:] + return value diff --git a/venv/Lib/site-packages/dns/renderer.py b/venv/Lib/site-packages/dns/renderer.py new file mode 100644 index 00000000..a77481f6 --- /dev/null +++ b/venv/Lib/site-packages/dns/renderer.py @@ -0,0 +1,346 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Help for building DNS wire format messages""" + +import contextlib +import io +import random +import struct +import time + +import dns.exception +import dns.tsig + +QUESTION = 0 +ANSWER = 1 +AUTHORITY = 2 +ADDITIONAL = 3 + + +@contextlib.contextmanager +def prefixed_length(output, length_length): + output.write(b"\00" * length_length) + start = output.tell() + yield + end = output.tell() + length = end - start + if length > 0: + try: + output.seek(start - length_length) + try: + output.write(length.to_bytes(length_length, "big")) + except OverflowError: + raise dns.exception.FormError + finally: + output.seek(end) + + +class Renderer: + """Helper class for building DNS wire-format messages. + + Most applications can use the higher-level L{dns.message.Message} + class and its to_wire() method to generate wire-format messages. + This class is for those applications which need finer control + over the generation of messages. + + Typical use:: + + r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512) + r.add_question(qname, qtype, qclass) + r.add_rrset(dns.renderer.ANSWER, rrset_1) + r.add_rrset(dns.renderer.ANSWER, rrset_2) + r.add_rrset(dns.renderer.AUTHORITY, ns_rrset) + r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_1) + r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_2) + r.add_edns(0, 0, 4096) + r.write_header() + r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac) + wire = r.get_wire() + + If padding is going to be used, then the OPT record MUST be + written after everything else in the additional section except for + the TSIG (if any). + + output, an io.BytesIO, where rendering is written + + id: the message id + + flags: the message flags + + max_size: the maximum size of the message + + origin: the origin to use when rendering relative names + + compress: the compression table + + section: an int, the section currently being rendered + + counts: list of the number of RRs in each section + + mac: the MAC of the rendered message (if TSIG was used) + """ + + def __init__(self, id=None, flags=0, max_size=65535, origin=None): + """Initialize a new renderer.""" + + self.output = io.BytesIO() + if id is None: + self.id = random.randint(0, 65535) + else: + self.id = id + self.flags = flags + self.max_size = max_size + self.origin = origin + self.compress = {} + self.section = QUESTION + self.counts = [0, 0, 0, 0] + self.output.write(b"\x00" * 12) + self.mac = "" + self.reserved = 0 + self.was_padded = False + + def _rollback(self, where): + """Truncate the output buffer at offset *where*, and remove any + compression table entries that pointed beyond the truncation + point. + """ + + self.output.seek(where) + self.output.truncate() + keys_to_delete = [] + for k, v in self.compress.items(): + if v >= where: + keys_to_delete.append(k) + for k in keys_to_delete: + del self.compress[k] + + def _set_section(self, section): + """Set the renderer's current section. + + Sections must be rendered order: QUESTION, ANSWER, AUTHORITY, + ADDITIONAL. Sections may be empty. + + Raises dns.exception.FormError if an attempt was made to set + a section value less than the current section. + """ + + if self.section != section: + if self.section > section: + raise dns.exception.FormError + self.section = section + + @contextlib.contextmanager + def _track_size(self): + start = self.output.tell() + yield start + if self.output.tell() > self.max_size: + self._rollback(start) + raise dns.exception.TooBig + + @contextlib.contextmanager + def _temporarily_seek_to(self, where): + current = self.output.tell() + try: + self.output.seek(where) + yield + finally: + self.output.seek(current) + + def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): + """Add a question to the message.""" + + self._set_section(QUESTION) + with self._track_size(): + qname.to_wire(self.output, self.compress, self.origin) + self.output.write(struct.pack("!HH", rdtype, rdclass)) + self.counts[QUESTION] += 1 + + def add_rrset(self, section, rrset, **kw): + """Add the rrset to the specified section. + + Any keyword arguments are passed on to the rdataset's to_wire() + routine. + """ + + self._set_section(section) + with self._track_size(): + n = rrset.to_wire(self.output, self.compress, self.origin, **kw) + self.counts[section] += n + + def add_rdataset(self, section, name, rdataset, **kw): + """Add the rdataset to the specified section, using the specified + name as the owner name. + + Any keyword arguments are passed on to the rdataset's to_wire() + routine. + """ + + self._set_section(section) + with self._track_size(): + n = rdataset.to_wire(name, self.output, self.compress, self.origin, **kw) + self.counts[section] += n + + def add_opt(self, opt, pad=0, opt_size=0, tsig_size=0): + """Add *opt* to the additional section, applying padding if desired. The + padding will take the specified precomputed OPT size and TSIG size into + account. + + Note that we don't have reliable way of knowing how big a GSS-TSIG digest + might be, so we we might not get an even multiple of the pad in that case.""" + if pad: + ttl = opt.ttl + assert opt_size >= 11 + opt_rdata = opt[0] + size_without_padding = self.output.tell() + opt_size + tsig_size + remainder = size_without_padding % pad + if remainder: + pad = b"\x00" * (pad - remainder) + else: + pad = b"" + options = list(opt_rdata.options) + options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad)) + opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options) + self.was_padded = True + self.add_rrset(ADDITIONAL, opt) + + def add_edns(self, edns, ednsflags, payload, options=None): + """Add an EDNS OPT record to the message.""" + + # make sure the EDNS version in ednsflags agrees with edns + ednsflags &= 0xFF00FFFF + ednsflags |= edns << 16 + opt = dns.message.Message._make_opt(ednsflags, payload, options) + self.add_opt(opt) + + def add_tsig( + self, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): + """Add a TSIG signature to the message.""" + + s = self.output.getvalue() + + if isinstance(secret, dns.tsig.Key): + key = secret + else: + key = dns.tsig.Key(keyname, secret, algorithm) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac) + self._write_tsig(tsig, keyname) + + def add_multi_tsig( + self, + ctx, + keyname, + secret, + fudge, + id, + tsig_error, + other_data, + request_mac, + algorithm=dns.tsig.default_algorithm, + ): + """Add a TSIG signature to the message. Unlike add_tsig(), this can be + used for a series of consecutive DNS envelopes, e.g. for a zone + transfer over TCP [RFC2845, 4.4]. + + For the first message in the sequence, give ctx=None. For each + subsequent message, give the ctx that was returned from the + add_multi_tsig() call for the previous message.""" + + s = self.output.getvalue() + + if isinstance(secret, dns.tsig.Key): + key = secret + else: + key = dns.tsig.Key(keyname, secret, algorithm) + tsig = dns.message.Message._make_tsig( + keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data + ) + (tsig, ctx) = dns.tsig.sign( + s, key, tsig[0], int(time.time()), request_mac, ctx, True + ) + self._write_tsig(tsig, keyname) + return ctx + + def _write_tsig(self, tsig, keyname): + if self.was_padded: + compress = None + else: + compress = self.compress + self._set_section(ADDITIONAL) + with self._track_size(): + keyname.to_wire(self.output, compress, self.origin) + self.output.write( + struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0) + ) + with prefixed_length(self.output, 2): + tsig.to_wire(self.output) + + self.counts[ADDITIONAL] += 1 + with self._temporarily_seek_to(10): + self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) + + def write_header(self): + """Write the DNS message header. + + Writing the DNS message header is done after all sections + have been rendered, but before the optional TSIG signature + is added. + """ + + with self._temporarily_seek_to(0): + self.output.write( + struct.pack( + "!HHHHHH", + self.id, + self.flags, + self.counts[0], + self.counts[1], + self.counts[2], + self.counts[3], + ) + ) + + def get_wire(self): + """Return the wire format message.""" + + return self.output.getvalue() + + def reserve(self, size: int) -> None: + """Reserve *size* bytes.""" + if size < 0: + raise ValueError("reserved amount must be non-negative") + if size > self.max_size: + raise ValueError("cannot reserve more than the maximum size") + self.reserved += size + self.max_size -= size + + def release_reserved(self) -> None: + """Release the reserved bytes.""" + self.max_size += self.reserved + self.reserved = 0 diff --git a/venv/Lib/site-packages/dns/resolver.py b/venv/Lib/site-packages/dns/resolver.py new file mode 100644 index 00000000..f08f824d --- /dev/null +++ b/venv/Lib/site-packages/dns/resolver.py @@ -0,0 +1,2054 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS stub resolver.""" + +import contextlib +import random +import socket +import sys +import threading +import time +import warnings +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlparse + +import dns._ddr +import dns.edns +import dns.exception +import dns.flags +import dns.inet +import dns.ipv4 +import dns.ipv6 +import dns.message +import dns.name +import dns.nameserver +import dns.query +import dns.rcode +import dns.rdataclass +import dns.rdatatype +import dns.rdtypes.svcbbase +import dns.reversename +import dns.tsig + +if sys.platform == "win32": + import dns.win32util + + +class NXDOMAIN(dns.exception.DNSException): + """The DNS query name does not exist.""" + + supp_kwargs = {"qnames", "responses"} + fmt = None # we have our own __str__ implementation + + # pylint: disable=arguments-differ + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _check_kwargs(self, qnames, responses=None): + if not isinstance(qnames, (list, tuple, set)): + raise AttributeError("qnames must be a list, tuple or set") + if len(qnames) == 0: + raise AttributeError("qnames must contain at least one element") + if responses is None: + responses = {} + elif not isinstance(responses, dict): + raise AttributeError("responses must be a dict(qname=response)") + kwargs = dict(qnames=qnames, responses=responses) + return kwargs + + def __str__(self) -> str: + if "qnames" not in self.kwargs: + return super().__str__() + qnames = self.kwargs["qnames"] + if len(qnames) > 1: + msg = "None of DNS query names exist" + else: + msg = "The DNS query name does not exist" + qnames = ", ".join(map(str, qnames)) + return "{}: {}".format(msg, qnames) + + @property + def canonical_name(self): + """Return the unresolved canonical name.""" + if "qnames" not in self.kwargs: + raise TypeError("parametrized exception required") + for qname in self.kwargs["qnames"]: + response = self.kwargs["responses"][qname] + try: + cname = response.canonical_name() + if cname != qname: + return cname + except Exception: + # We can just eat this exception as it means there was + # something wrong with the response. + pass + return self.kwargs["qnames"][0] + + def __add__(self, e_nx): + """Augment by results from another NXDOMAIN exception.""" + qnames0 = list(self.kwargs.get("qnames", [])) + responses0 = dict(self.kwargs.get("responses", {})) + responses1 = e_nx.kwargs.get("responses", {}) + for qname1 in e_nx.kwargs.get("qnames", []): + if qname1 not in qnames0: + qnames0.append(qname1) + if qname1 in responses1: + responses0[qname1] = responses1[qname1] + return NXDOMAIN(qnames=qnames0, responses=responses0) + + def qnames(self): + """All of the names that were tried. + + Returns a list of ``dns.name.Name``. + """ + return self.kwargs["qnames"] + + def responses(self): + """A map from queried names to their NXDOMAIN responses. + + Returns a dict mapping a ``dns.name.Name`` to a + ``dns.message.Message``. + """ + return self.kwargs["responses"] + + def response(self, qname): + """The response for query *qname*. + + Returns a ``dns.message.Message``. + """ + return self.kwargs["responses"][qname] + + +class YXDOMAIN(dns.exception.DNSException): + """The DNS query name is too long after DNAME substitution.""" + + +ErrorTuple = Tuple[ + Optional[str], + bool, + int, + Union[Exception, str], + Optional[dns.message.Message], +] + + +def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: + """Turn a resolution errors trace into a list of text.""" + texts = [] + for err in errors: + texts.append("Server {} answered {}".format(err[0], err[3])) + return texts + + +class LifetimeTimeout(dns.exception.Timeout): + """The resolution lifetime expired.""" + + msg = "The resolution lifetime expired." + fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] + supp_kwargs = {"timeout", "errors"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _fmt_kwargs(self, **kwargs): + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + timeout=kwargs["timeout"], errors="; ".join(srv_msgs) + ) + + +# We added more detail to resolution timeouts, but they are still +# subclasses of dns.exception.Timeout for backwards compatibility. We also +# keep dns.resolver.Timeout defined for backwards compatibility. +Timeout = LifetimeTimeout + + +class NoAnswer(dns.exception.DNSException): + """The DNS response does not contain an answer to the question.""" + + fmt = "The DNS response does not contain an answer to the question: {query}" + supp_kwargs = {"response"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _fmt_kwargs(self, **kwargs): + return super()._fmt_kwargs(query=kwargs["response"].question) + + def response(self): + return self.kwargs["response"] + + +class NoNameservers(dns.exception.DNSException): + """All nameservers failed to answer the query. + + errors: list of servers and respective errors + The type of errors is + [(server IP address, any object convertible to string)]. + Non-empty errors list will add explanatory message () + """ + + msg = "All nameservers failed to answer the query." + fmt = "%s {query}: {errors}" % msg[:-1] + supp_kwargs = {"request", "errors"} + + # We do this as otherwise mypy complains about unexpected keyword argument + # idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _fmt_kwargs(self, **kwargs): + srv_msgs = _errors_to_text(kwargs["errors"]) + return super()._fmt_kwargs( + query=kwargs["request"].question, errors="; ".join(srv_msgs) + ) + + +class NotAbsolute(dns.exception.DNSException): + """An absolute domain name is required but a relative name was provided.""" + + +class NoRootSOA(dns.exception.DNSException): + """There is no SOA RR at the DNS root name. This should never happen!""" + + +class NoMetaqueries(dns.exception.DNSException): + """DNS metaqueries are not allowed.""" + + +class NoResolverConfiguration(dns.exception.DNSException): + """Resolver configuration could not be read or specified no nameservers.""" + + +class Answer: + """DNS stub resolver answer. + + Instances of this class bundle up the result of a successful DNS + resolution. + + For convenience, the answer object implements much of the sequence + protocol, forwarding to its ``rrset`` attribute. E.g. + ``for a in answer`` is equivalent to ``for a in answer.rrset``. + ``answer[i]`` is equivalent to ``answer.rrset[i]``, and + ``answer[i:j]`` is equivalent to ``answer.rrset[i:j]``. + + Note that CNAMEs or DNAMEs in the response may mean that answer + RRset's name might not be the query name. + """ + + def __init__( + self, + qname: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + rdclass: dns.rdataclass.RdataClass, + response: dns.message.QueryMessage, + nameserver: Optional[str] = None, + port: Optional[int] = None, + ) -> None: + self.qname = qname + self.rdtype = rdtype + self.rdclass = rdclass + self.response = response + self.nameserver = nameserver + self.port = port + self.chaining_result = response.resolve_chaining() + # Copy some attributes out of chaining_result for backwards + # compatibility and convenience. + self.canonical_name = self.chaining_result.canonical_name + self.rrset = self.chaining_result.answer + self.expiration = time.time() + self.chaining_result.minimum_ttl + + def __getattr__(self, attr): # pragma: no cover + if attr == "name": + return self.rrset.name + elif attr == "ttl": + return self.rrset.ttl + elif attr == "covers": + return self.rrset.covers + elif attr == "rdclass": + return self.rrset.rdclass + elif attr == "rdtype": + return self.rrset.rdtype + else: + raise AttributeError(attr) + + def __len__(self) -> int: + return self.rrset and len(self.rrset) or 0 + + def __iter__(self): + return self.rrset and iter(self.rrset) or iter(tuple()) + + def __getitem__(self, i): + if self.rrset is None: + raise IndexError + return self.rrset[i] + + def __delitem__(self, i): + if self.rrset is None: + raise IndexError + del self.rrset[i] + + +class Answers(dict): + """A dict of DNS stub resolver answers, indexed by type.""" + + +class HostAnswers(Answers): + """A dict of DNS stub resolver answers to a host name lookup, indexed by + type. + """ + + @classmethod + def make( + cls, + v6: Optional[Answer] = None, + v4: Optional[Answer] = None, + add_empty: bool = True, + ) -> "HostAnswers": + answers = HostAnswers() + if v6 is not None and (add_empty or v6.rrset): + answers[dns.rdatatype.AAAA] = v6 + if v4 is not None and (add_empty or v4.rrset): + answers[dns.rdatatype.A] = v4 + return answers + + # Returns pairs of (address, family) from this result, potentiallys + # filtering by address family. + def addresses_and_families( + self, family: int = socket.AF_UNSPEC + ) -> Iterator[Tuple[str, int]]: + if family == socket.AF_UNSPEC: + yield from self.addresses_and_families(socket.AF_INET6) + yield from self.addresses_and_families(socket.AF_INET) + return + elif family == socket.AF_INET6: + answer = self.get(dns.rdatatype.AAAA) + elif family == socket.AF_INET: + answer = self.get(dns.rdatatype.A) + else: + raise NotImplementedError(f"unknown address family {family}") + if answer: + for rdata in answer: + yield (rdata.address, family) + + # Returns addresses from this result, potentially filtering by + # address family. + def addresses(self, family: int = socket.AF_UNSPEC) -> Iterator[str]: + return (pair[0] for pair in self.addresses_and_families(family)) + + # Returns the canonical name from this result. + def canonical_name(self) -> dns.name.Name: + answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A)) + return answer.canonical_name + + +class CacheStatistics: + """Cache Statistics""" + + def __init__(self, hits: int = 0, misses: int = 0) -> None: + self.hits = hits + self.misses = misses + + def reset(self) -> None: + self.hits = 0 + self.misses = 0 + + def clone(self) -> "CacheStatistics": + return CacheStatistics(self.hits, self.misses) + + +class CacheBase: + def __init__(self) -> None: + self.lock = threading.Lock() + self.statistics = CacheStatistics() + + def reset_statistics(self) -> None: + """Reset all statistics to zero.""" + with self.lock: + self.statistics.reset() + + def hits(self) -> int: + """How many hits has the cache had?""" + with self.lock: + return self.statistics.hits + + def misses(self) -> int: + """How many misses has the cache had?""" + with self.lock: + return self.statistics.misses + + def get_statistics_snapshot(self) -> CacheStatistics: + """Return a consistent snapshot of all the statistics. + + If running with multiple threads, it's better to take a + snapshot than to call statistics methods such as hits() and + misses() individually. + """ + with self.lock: + return self.statistics.clone() + + +CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass] + + +class Cache(CacheBase): + """Simple thread-safe DNS answer cache.""" + + def __init__(self, cleaning_interval: float = 300.0) -> None: + """*cleaning_interval*, a ``float`` is the number of seconds between + periodic cleanings. + """ + + super().__init__() + self.data: Dict[CacheKey, Answer] = {} + self.cleaning_interval = cleaning_interval + self.next_cleaning: float = time.time() + self.cleaning_interval + + def _maybe_clean(self) -> None: + """Clean the cache if it's time to do so.""" + + now = time.time() + if self.next_cleaning <= now: + keys_to_delete = [] + for k, v in self.data.items(): + if v.expiration <= now: + keys_to_delete.append(k) + for k in keys_to_delete: + del self.data[k] + now = time.time() + self.next_cleaning = now + self.cleaning_interval + + def get(self, key: CacheKey) -> Optional[Answer]: + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. + """ + + with self.lock: + self._maybe_clean() + v = self.data.get(key) + if v is None or v.expiration <= time.time(): + self.statistics.misses += 1 + return None + self.statistics.hits += 1 + return v + + def put(self, key: CacheKey, value: Answer) -> None: + """Associate key and value in the cache. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. + """ + + with self.lock: + self._maybe_clean() + self.data[key] = value + + def flush(self, key: Optional[CacheKey] = None) -> None: + """Flush the cache. + + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + """ + + with self.lock: + if key is not None: + if key in self.data: + del self.data[key] + else: + self.data = {} + self.next_cleaning = time.time() + self.cleaning_interval + + +class LRUCacheNode: + """LRUCache node.""" + + def __init__(self, key, value): + self.key = key + self.value = value + self.hits = 0 + self.prev = self + self.next = self + + def link_after(self, node: "LRUCacheNode") -> None: + self.prev = node + self.next = node.next + node.next.prev = self + node.next = self + + def unlink(self) -> None: + self.next.prev = self.prev + self.prev.next = self.next + + +class LRUCache(CacheBase): + """Thread-safe, bounded, least-recently-used DNS answer cache. + + This cache is better than the simple cache (above) if you're + running a web crawler or other process that does a lot of + resolutions. The LRUCache has a maximum number of nodes, and when + it is full, the least-recently used node is removed to make space + for a new one. + """ + + def __init__(self, max_size: int = 100000) -> None: + """*max_size*, an ``int``, is the maximum number of nodes to cache; + it must be greater than 0. + """ + + super().__init__() + self.data: Dict[CacheKey, LRUCacheNode] = {} + self.set_max_size(max_size) + self.sentinel: LRUCacheNode = LRUCacheNode(None, None) + self.sentinel.prev = self.sentinel + self.sentinel.next = self.sentinel + + def set_max_size(self, max_size: int) -> None: + if max_size < 1: + max_size = 1 + self.max_size = max_size + + def get(self, key: CacheKey) -> Optional[Answer]: + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. + """ + + with self.lock: + node = self.data.get(key) + if node is None: + self.statistics.misses += 1 + return None + # Unlink because we're either going to move the node to the front + # of the LRU list or we're going to free it. + node.unlink() + if node.value.expiration <= time.time(): + del self.data[node.key] + self.statistics.misses += 1 + return None + node.link_after(self.sentinel) + self.statistics.hits += 1 + node.hits += 1 + return node.value + + def get_hits_for_key(self, key: CacheKey) -> int: + """Return the number of cache hits associated with the specified key.""" + with self.lock: + node = self.data.get(key) + if node is None or node.value.expiration <= time.time(): + return 0 + else: + return node.hits + + def put(self, key: CacheKey, value: Answer) -> None: + """Associate key and value in the cache. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. + """ + + with self.lock: + node = self.data.get(key) + if node is not None: + node.unlink() + del self.data[node.key] + while len(self.data) >= self.max_size: + gnode = self.sentinel.prev + gnode.unlink() + del self.data[gnode.key] + node = LRUCacheNode(key, value) + node.link_after(self.sentinel) + self.data[key] = node + + def flush(self, key: Optional[CacheKey] = None) -> None: + """Flush the cache. + + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache + is flushed. + + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` + tuple whose values are the query name, rdtype, and rdclass respectively. + """ + + with self.lock: + if key is not None: + node = self.data.get(key) + if node is not None: + node.unlink() + del self.data[node.key] + else: + gnode = self.sentinel.next + while gnode != self.sentinel: + next = gnode.next + gnode.unlink() + gnode = next + self.data = {} + + +class _Resolution: + """Helper class for dns.resolver.Resolver.resolve(). + + All of the "business logic" of resolution is encapsulated in this + class, allowing us to have multiple resolve() implementations + using different I/O schemes without copying all of the + complicated logic. + + This class is a "friend" to dns.resolver.Resolver and manipulates + resolver data structures directly. + """ + + def __init__( + self, + resolver: "BaseResolver", + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str], + tcp: bool, + raise_on_no_answer: bool, + search: Optional[bool], + ) -> None: + if isinstance(qname, str): + qname = dns.name.from_text(qname, None) + rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(rdtype): + raise NoMetaqueries + rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(rdclass): + raise NoMetaqueries + self.resolver = resolver + self.qnames_to_try = resolver._get_qnames_to_try(qname, search) + self.qnames = self.qnames_to_try[:] + self.rdtype = rdtype + self.rdclass = rdclass + self.tcp = tcp + self.raise_on_no_answer = raise_on_no_answer + self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} + # Initialize other things to help analysis tools + self.qname = dns.name.empty + self.nameservers: List[dns.nameserver.Nameserver] = [] + self.current_nameservers: List[dns.nameserver.Nameserver] = [] + self.errors: List[ErrorTuple] = [] + self.nameserver: Optional[dns.nameserver.Nameserver] = None + self.tcp_attempt = False + self.retry_with_tcp = False + self.request: Optional[dns.message.QueryMessage] = None + self.backoff = 0.0 + + def next_request( + self, + ) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]: + """Get the next request to send, and check the cache. + + Returns a (request, answer) tuple. At most one of request or + answer will not be None. + """ + + # We return a tuple instead of Union[Message,Answer] as it lets + # the caller avoid isinstance(). + + while len(self.qnames) > 0: + self.qname = self.qnames.pop(0) + + # Do we know the answer? + if self.resolver.cache: + answer = self.resolver.cache.get( + (self.qname, self.rdtype, self.rdclass) + ) + if answer is not None: + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + else: + return (None, answer) + answer = self.resolver.cache.get( + (self.qname, dns.rdatatype.ANY, self.rdclass) + ) + if answer is not None and answer.response.rcode() == dns.rcode.NXDOMAIN: + # cached NXDOMAIN; record it and continue to next + # name. + self.nxdomain_responses[self.qname] = answer.response + continue + + # Build the request + request = dns.message.make_query(self.qname, self.rdtype, self.rdclass) + if self.resolver.keyname is not None: + request.use_tsig( + self.resolver.keyring, + self.resolver.keyname, + algorithm=self.resolver.keyalgorithm, + ) + request.use_edns( + self.resolver.edns, + self.resolver.ednsflags, + self.resolver.payload, + options=self.resolver.ednsoptions, + ) + if self.resolver.flags is not None: + request.flags = self.resolver.flags + + self.nameservers = self.resolver._enrich_nameservers( + self.resolver._nameservers, + self.resolver.nameserver_ports, + self.resolver.port, + ) + if self.resolver.rotate: + random.shuffle(self.nameservers) + self.current_nameservers = self.nameservers[:] + self.errors = [] + self.nameserver = None + self.tcp_attempt = False + self.retry_with_tcp = False + self.request = request + self.backoff = 0.10 + + return (request, None) + + # + # We've tried everything and only gotten NXDOMAINs. (We know + # it's only NXDOMAINs as anything else would have returned + # before now.) + # + raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) + + def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]: + if self.retry_with_tcp: + assert self.nameserver is not None + assert not self.nameserver.is_always_max_size() + self.tcp_attempt = True + self.retry_with_tcp = False + return (self.nameserver, True, 0) + + backoff = 0.0 + if not self.current_nameservers: + if len(self.nameservers) == 0: + # Out of things to try! + raise NoNameservers(request=self.request, errors=self.errors) + self.current_nameservers = self.nameservers[:] + backoff = self.backoff + self.backoff = min(self.backoff * 2, 2) + + self.nameserver = self.current_nameservers.pop(0) + self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size() + return (self.nameserver, self.tcp_attempt, backoff) + + def query_result( + self, response: Optional[dns.message.Message], ex: Optional[Exception] + ) -> Tuple[Optional[Answer], bool]: + # + # returns an (answer: Answer, end_loop: bool) tuple. + # + assert self.nameserver is not None + if ex: + # Exception during I/O or from_wire() + assert response is None + self.errors.append( + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + ex, + response, + ) + ) + if ( + isinstance(ex, dns.exception.FormError) + or isinstance(ex, EOFError) + or isinstance(ex, OSError) + or isinstance(ex, NotImplementedError) + ): + # This nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + elif isinstance(ex, dns.message.Truncated): + if self.tcp_attempt: + # Truncation with TCP is no good! + self.nameservers.remove(self.nameserver) + else: + self.retry_with_tcp = True + return (None, False) + # We got an answer! + assert response is not None + assert isinstance(response, dns.message.QueryMessage) + rcode = response.rcode() + if rcode == dns.rcode.NOERROR: + try: + answer = Answer( + self.qname, + self.rdtype, + self.rdclass, + response, + self.nameserver.answer_nameserver(), + self.nameserver.answer_port(), + ) + except Exception as e: + self.errors.append( + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) + ) + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) + if self.resolver.cache: + self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer) + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + return (answer, True) + elif rcode == dns.rcode.NXDOMAIN: + # Further validate the response by making an Answer, even + # if we aren't going to cache it. + try: + answer = Answer( + self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response + ) + except Exception as e: + self.errors.append( + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + e, + response, + ) + ) + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) + self.nxdomain_responses[self.qname] = response + if self.resolver.cache: + self.resolver.cache.put( + (self.qname, dns.rdatatype.ANY, self.rdclass), answer + ) + # Make next_nameserver() return None, so caller breaks its + # inner loop and calls next_request(). + return (None, True) + elif rcode == dns.rcode.YXDOMAIN: + yex = YXDOMAIN() + self.errors.append( + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + yex, + response, + ) + ) + raise yex + else: + # + # We got a response, but we're not happy with the + # rcode in it. + # + if rcode != dns.rcode.SERVFAIL or not self.resolver.retry_servfail: + self.nameservers.remove(self.nameserver) + self.errors.append( + ( + str(self.nameserver), + self.tcp_attempt, + self.nameserver.answer_port(), + dns.rcode.to_text(rcode), + response, + ) + ) + return (None, False) + + +class BaseResolver: + """DNS stub resolver.""" + + # We initialize in reset() + # + # pylint: disable=attribute-defined-outside-init + + domain: dns.name.Name + nameserver_ports: Dict[str, int] + port: int + search: List[dns.name.Name] + use_search_by_default: bool + timeout: float + lifetime: float + keyring: Optional[Any] + keyname: Optional[Union[dns.name.Name, str]] + keyalgorithm: Union[dns.name.Name, str] + edns: int + ednsflags: int + ednsoptions: Optional[List[dns.edns.Option]] + payload: int + cache: Any + flags: Optional[int] + retry_servfail: bool + rotate: bool + ndots: Optional[int] + _nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] + + def __init__( + self, filename: str = "/etc/resolv.conf", configure: bool = True + ) -> None: + """*filename*, a ``str`` or file object, specifying a file + in standard /etc/resolv.conf format. This parameter is meaningful + only when *configure* is true and the platform is POSIX. + + *configure*, a ``bool``. If True (the default), the resolver + instance is configured in the normal fashion for the operating + system the resolver is running on. (I.e. by reading a + /etc/resolv.conf file on POSIX systems and from the registry + on Windows systems.) + """ + + self.reset() + if configure: + if sys.platform == "win32": + self.read_registry() + elif filename: + self.read_resolv_conf(filename) + + def reset(self) -> None: + """Reset all resolver configuration to the defaults.""" + + self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) + if len(self.domain) == 0: + self.domain = dns.name.root + self._nameservers = [] + self.nameserver_ports = {} + self.port = 53 + self.search = [] + self.use_search_by_default = False + self.timeout = 2.0 + self.lifetime = 5.0 + self.keyring = None + self.keyname = None + self.keyalgorithm = dns.tsig.default_algorithm + self.edns = -1 + self.ednsflags = 0 + self.ednsoptions = None + self.payload = 0 + self.cache = None + self.flags = None + self.retry_servfail = False + self.rotate = False + self.ndots = None + + def read_resolv_conf(self, f: Any) -> None: + """Process *f* as a file in the /etc/resolv.conf format. If f is + a ``str``, it is used as the name of the file to open; otherwise it + is treated as the file itself. + + Interprets the following items: + + - nameserver - name server IP address + + - domain - local domain name + + - search - search list for host-name lookup + + - options - supported options are rotate, timeout, edns0, and ndots + + """ + + nameservers = [] + if isinstance(f, str): + try: + cm: contextlib.AbstractContextManager = open(f) + except OSError: + # /etc/resolv.conf doesn't exist, can't be read, etc. + raise NoResolverConfiguration(f"cannot open {f}") + else: + cm = contextlib.nullcontext(f) + with cm as f: + for l in f: + if len(l) == 0 or l[0] == "#" or l[0] == ";": + continue + tokens = l.split() + + # Any line containing less than 2 tokens is malformed + if len(tokens) < 2: + continue + + if tokens[0] == "nameserver": + nameservers.append(tokens[1]) + elif tokens[0] == "domain": + self.domain = dns.name.from_text(tokens[1]) + # domain and search are exclusive + self.search = [] + elif tokens[0] == "search": + # the last search wins + self.search = [] + for suffix in tokens[1:]: + self.search.append(dns.name.from_text(suffix)) + # We don't set domain as it is not used if + # len(self.search) > 0 + elif tokens[0] == "options": + for opt in tokens[1:]: + if opt == "rotate": + self.rotate = True + elif opt == "edns0": + self.use_edns() + elif "timeout" in opt: + try: + self.timeout = int(opt.split(":")[1]) + except (ValueError, IndexError): + pass + elif "ndots" in opt: + try: + self.ndots = int(opt.split(":")[1]) + except (ValueError, IndexError): + pass + if len(nameservers) == 0: + raise NoResolverConfiguration("no nameservers") + # Assigning directly instead of appending means we invoke the + # setter logic, with additonal checking and enrichment. + self.nameservers = nameservers + + def read_registry(self) -> None: + """Extract resolver configuration from the Windows registry.""" + try: + info = dns.win32util.get_dns_info() # type: ignore + if info.domain is not None: + self.domain = info.domain + self.nameservers = info.nameservers + self.search = info.search + except AttributeError: + raise NotImplementedError + + def _compute_timeout( + self, + start: float, + lifetime: Optional[float] = None, + errors: Optional[List[ErrorTuple]] = None, + ) -> float: + lifetime = self.lifetime if lifetime is None else lifetime + now = time.time() + duration = now - start + if errors is None: + errors = [] + if duration < 0: + if duration < -1: + # Time going backwards is bad. Just give up. + raise LifetimeTimeout(timeout=duration, errors=errors) + else: + # Time went backwards, but only a little. This can + # happen, e.g. under vmware with older linux kernels. + # Pretend it didn't happen. + duration = 0 + if duration >= lifetime: + raise LifetimeTimeout(timeout=duration, errors=errors) + return min(lifetime - duration, self.timeout) + + def _get_qnames_to_try( + self, qname: dns.name.Name, search: Optional[bool] + ) -> List[dns.name.Name]: + # This is a separate method so we can unit test the search + # rules without requiring the Internet. + if search is None: + search = self.use_search_by_default + qnames_to_try = [] + if qname.is_absolute(): + qnames_to_try.append(qname) + else: + abs_qname = qname.concatenate(dns.name.root) + if search: + if len(self.search) > 0: + # There is a search list, so use it exclusively + search_list = self.search[:] + elif self.domain != dns.name.root and self.domain is not None: + # We have some notion of a domain that isn't the root, so + # use it as the search list. + search_list = [self.domain] + else: + search_list = [] + # Figure out the effective ndots (default is 1) + if self.ndots is None: + ndots = 1 + else: + ndots = self.ndots + for suffix in search_list: + qnames_to_try.append(qname + suffix) + if len(qname) > ndots: + # The name has at least ndots dots, so we should try an + # absolute query first. + qnames_to_try.insert(0, abs_qname) + else: + # The name has less than ndots dots, so we should search + # first, then try the absolute name. + qnames_to_try.append(abs_qname) + else: + qnames_to_try.append(abs_qname) + return qnames_to_try + + def use_tsig( + self, + keyring: Any, + keyname: Optional[Union[dns.name.Name, str]] = None, + algorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + ) -> None: + """Add a TSIG signature to each query. + + The parameters are passed to ``dns.message.Message.use_tsig()``; + see its documentation for details. + """ + + self.keyring = keyring + self.keyname = keyname + self.keyalgorithm = algorithm + + def use_edns( + self, + edns: Optional[Union[int, bool]] = 0, + ednsflags: int = 0, + payload: int = dns.message.DEFAULT_EDNS_PAYLOAD, + options: Optional[List[dns.edns.Option]] = None, + ) -> None: + """Configure EDNS behavior. + + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. + """ + + if edns is None or edns is False: + edns = -1 + elif edns is True: + edns = 0 + self.edns = edns + self.ednsflags = ednsflags + self.payload = payload + self.ednsoptions = options + + def set_flags(self, flags: int) -> None: + """Overrides the default flags with your own. + + *flags*, an ``int``, the message flags to use. + """ + + self.flags = flags + + @classmethod + def _enrich_nameservers( + cls, + nameservers: Sequence[Union[str, dns.nameserver.Nameserver]], + nameserver_ports: Dict[str, int], + default_port: int, + ) -> List[dns.nameserver.Nameserver]: + enriched_nameservers = [] + if isinstance(nameservers, list): + for nameserver in nameservers: + enriched_nameserver: dns.nameserver.Nameserver + if isinstance(nameserver, dns.nameserver.Nameserver): + enriched_nameserver = nameserver + elif dns.inet.is_address(nameserver): + port = nameserver_ports.get(nameserver, default_port) + enriched_nameserver = dns.nameserver.Do53Nameserver( + nameserver, port + ) + else: + try: + if urlparse(nameserver).scheme != "https": + raise NotImplementedError + except Exception: + raise ValueError( + f"nameserver {nameserver} is not a " + "dns.nameserver.Nameserver instance or text form, " + "IP address, nor a valid https URL" + ) + enriched_nameserver = dns.nameserver.DoHNameserver(nameserver) + enriched_nameservers.append(enriched_nameserver) + else: + raise ValueError( + "nameservers must be a list or tuple (not a {})".format( + type(nameservers) + ) + ) + return enriched_nameservers + + @property + def nameservers( + self, + ) -> Sequence[Union[str, dns.nameserver.Nameserver]]: + return self._nameservers + + @nameservers.setter + def nameservers( + self, nameservers: Sequence[Union[str, dns.nameserver.Nameserver]] + ) -> None: + """ + *nameservers*, a ``list`` of nameservers, where a nameserver is either + a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver`` + instance. + + Raises ``ValueError`` if *nameservers* is not a list of nameservers. + """ + # We just call _enrich_nameservers() for checking + self._enrich_nameservers(nameservers, self.nameserver_ports, self.port) + self._nameservers = nameservers + + +class Resolver(BaseResolver): + """DNS stub resolver.""" + + def resolve( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + ) -> Answer: # pylint: disable=arguments-differ + """Query nameservers to find the answer to the question. + + The *qname*, *rdtype*, and *rdclass* parameters may be objects + of the appropriate type, or strings that can be converted into objects + of the appropriate type. + + *qname*, a ``dns.name.Name`` or ``str``, the query name. + + *rdtype*, an ``int`` or ``str``, the query type. + + *rdclass*, an ``int`` or ``str``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *source*, a ``str`` or ``None``. If not ``None``, bind to this IP + address when making queries. + + *raise_on_no_answer*, a ``bool``. If ``True``, raise + ``dns.resolver.NoAnswer`` if there's no answer to the question. + + *source_port*, an ``int``, the port from which to send the message. + + *lifetime*, a ``float``, how many seconds a query should run + before timing out. + + *search*, a ``bool`` or ``None``, determines whether the + search list configured in the system's resolver configuration + are used for relative names, and whether the resolver's domain + may be added to relative names. The default is ``None``, + which causes the value of the resolver's + ``use_search_by_default`` attribute to be used. + + Raises ``dns.resolver.LifetimeTimeout`` if no answers could be found + in the specified lifetime. + + Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. + + Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after + DNAME substitution. + + Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is + ``True`` and the query name exists but has no RRset of the + desired type and class. + + Raises ``dns.resolver.NoNameservers`` if no non-broken + nameservers are available to answer the question. + + Returns a ``dns.resolver.Answer`` instance. + + """ + + resolution = _Resolution( + self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search + ) + start = time.time() + while True: + (request, answer) = resolution.next_request() + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + # cache hit! + return answer + assert request is not None # needed for type checking + done = False + while not done: + (nameserver, tcp, backoff) = resolution.next_nameserver() + if backoff: + time.sleep(backoff) + timeout = self._compute_timeout(start, lifetime, resolution.errors) + try: + response = nameserver.query( + request, + timeout=timeout, + source=source, + source_port=source_port, + max_size=tcp, + ) + except Exception as ex: + (_, done) = resolution.query_result(None, ex) + continue + (answer, done) = resolution.query_result(response, None) + # Note we need to say "if answer is not None" and not just + # "if answer" because answer implements __len__, and python + # will call that. We want to return if we have an answer + # object, including in cases where its length is 0. + if answer is not None: + return answer + + def query( + self, + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + ) -> Answer: # pragma: no cover + """Query nameservers to find the answer to the question. + + This method calls resolve() with ``search=True``, and is + provided for backwards compatibility with prior versions of + dnspython. See the documentation for the resolve() method for + further details. + """ + warnings.warn( + "please use dns.resolver.Resolver.resolve() instead", + DeprecationWarning, + stacklevel=2, + ) + return self.resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) + + def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Any) -> Answer: + """Use a resolver to run a reverse query for PTR records. + + This utilizes the resolve() method to perform a PTR lookup on the + specified IP address. + + *ipaddr*, a ``str``, the IPv4 or IPv6 address you want to get + the PTR record for. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs["rdtype"] = dns.rdatatype.PTR + modified_kwargs["rdclass"] = dns.rdataclass.IN + return self.resolve( + dns.reversename.from_address(ipaddr), *args, **modified_kwargs + ) + + def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any, + ) -> HostAnswers: + """Use a resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True) + lifetime = modified_kwargs.pop("lifetime", None) + start = time.time() + v6 = self.resolve( + name, + dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = self.resolve( + name, + dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs, + ) + answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer) + if not answers: + raise NoAnswer(response=v6.response) + return answers + + # pylint: disable=redefined-outer-name + + def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: + """Determine the canonical name of *name*. + + The canonical name is the name the resolver uses for queries + after all CNAME and DNAME renamings have been applied. + + *name*, a ``dns.name.Name`` or ``str``, the query name. + + This method can raise any exception that ``resolve()`` can + raise, other than ``dns.resolver.NoAnswer`` and + ``dns.resolver.NXDOMAIN``. + + Returns a ``dns.name.Name``. + """ + try: + answer = self.resolve(name, raise_on_no_answer=False) + canonical_name = answer.canonical_name + except dns.resolver.NXDOMAIN as e: + canonical_name = e.canonical_name + return canonical_name + + # pylint: enable=redefined-outer-name + + def try_ddr(self, lifetime: float = 5.0) -> None: + """Try to update the resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + *lifetime*, a float, is the maximum time to spend attempting DDR. The default + is 5 seconds. + + If the SVCB query is successful and results in a non-empty list of nameservers, + then the resolver's nameservers are set to the returned servers in priority + order. + + The current implementation does not use any address hints from the SVCB record, + nor does it resolve addresses for the SCVB target name, rather it assumes that + the bootstrap nameserver will always be one of the addresses and uses it. + A future revision to the code may offer fuller support. The code verifies that + the bootstrap nameserver is in the Subject Alternative Name field of the + TLS certficate. + """ + try: + expiration = time.time() + lifetime + answer = self.resolve( + dns._ddr._local_resolver_name, "SVCB", lifetime=lifetime + ) + timeout = dns.query._remaining(expiration) + nameservers = dns._ddr._get_nameservers_sync(answer, timeout) + if len(nameservers) > 0: + self.nameservers = nameservers + except Exception: + pass + + +#: The default resolver. +default_resolver: Optional[Resolver] = None + + +def get_default_resolver() -> Resolver: + """Get the default resolver, initializing it if necessary.""" + if default_resolver is None: + reset_default_resolver() + assert default_resolver is not None + return default_resolver + + +def reset_default_resolver() -> None: + """Re-initialize default resolver. + + Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX + systems) will be re-read immediately. + """ + + global default_resolver + default_resolver = Resolver() + + +def resolve( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, +) -> Answer: # pragma: no cover + """Query nameservers to find the answer to the question. + + This is a convenience function that uses the default resolver + object to make the query. + + See ``dns.resolver.Resolver.resolve`` for more information on the + parameters. + """ + + return get_default_resolver().resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) + + +def query( + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, +) -> Answer: # pragma: no cover + """Query nameservers to find the answer to the question. + + This method calls resolve() with ``search=True``, and is + provided for backwards compatibility with prior versions of + dnspython. See the documentation for the resolve() method for + further details. + """ + warnings.warn( + "please use dns.resolver.resolve() instead", DeprecationWarning, stacklevel=2 + ) + return resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + True, + ) + + +def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer: + """Use a resolver to run a reverse query for PTR records. + + See ``dns.resolver.Resolver.resolve_address`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) + + +def resolve_name( + name: Union[dns.name.Name, str], family: int = socket.AF_UNSPEC, **kwargs: Any +) -> HostAnswers: + """Use a resolver to query for address records. + + See ``dns.resolver.Resolver.resolve_name`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_name(name, family, **kwargs) + + +def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: + """Determine the canonical name of *name*. + + See ``dns.resolver.Resolver.canonical_name`` for more information on the + parameters and possible exceptions. + """ + + return get_default_resolver().canonical_name(name) + + +def try_ddr(lifetime: float = 5.0) -> None: + """Try to update the default resolver's nameservers using Discovery of Designated + Resolvers (DDR). If successful, the resolver will subsequently use + DNS-over-HTTPS or DNS-over-TLS for future queries. + + See :py:func:`dns.resolver.Resolver.try_ddr` for more information. + """ + return get_default_resolver().try_ddr(lifetime) + + +def zone_for_name( + name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + tcp: bool = False, + resolver: Optional[Resolver] = None, + lifetime: Optional[float] = None, +) -> dns.name.Name: + """Find the name of the zone which contains the specified name. + + *name*, an absolute ``dns.name.Name`` or ``str``, the query name. + + *rdclass*, an ``int``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. + If ``None``, the default, then the default resolver is used. + + *lifetime*, a ``float``, the total time to allow for the queries needed + to determine the zone. If ``None``, the default, then only the individual + query limits of the resolver apply. + + Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS + root. (This is only likely to happen if you're using non-default + root servers in your network and they are misconfigured.) + + Raises ``dns.resolver.LifetimeTimeout`` if the answer could not be + found in the allotted lifetime. + + Returns a ``dns.name.Name``. + """ + + if isinstance(name, str): + name = dns.name.from_text(name, dns.name.root) + if resolver is None: + resolver = get_default_resolver() + if not name.is_absolute(): + raise NotAbsolute(name) + start = time.time() + expiration: Optional[float] + if lifetime is not None: + expiration = start + lifetime + else: + expiration = None + while 1: + try: + rlifetime: Optional[float] + if expiration is not None: + rlifetime = expiration - time.time() + if rlifetime <= 0: + rlifetime = 0 + else: + rlifetime = None + answer = resolver.resolve( + name, dns.rdatatype.SOA, rdclass, tcp, lifetime=rlifetime + ) + assert answer.rrset is not None + if answer.rrset.name == name: + return name + # otherwise we were CNAMEd or DNAMEd and need to look higher + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e: + if isinstance(e, dns.resolver.NXDOMAIN): + response = e.responses().get(name) + else: + response = e.response() # pylint: disable=no-value-for-parameter + if response: + for rrs in response.authority: + if rrs.rdtype == dns.rdatatype.SOA and rrs.rdclass == rdclass: + (nr, _, _) = rrs.name.fullcompare(name) + if nr == dns.name.NAMERELN_SUPERDOMAIN: + # We're doing a proper superdomain check as + # if the name were equal we ought to have gotten + # it in the answer section! We are ignoring the + # possibility that the authority is insane and + # is including multiple SOA RRs for different + # authorities. + return rrs.name + # we couldn't extract anything useful from the response (e.g. it's + # a type 3 NXDOMAIN) + try: + name = name.parent() + except dns.name.NoParent: + raise NoRootSOA + + +def make_resolver_at( + where: Union[dns.name.Name, str], + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Resolver: + """Make a stub resolver using the specified destination as the full resolver. + + *where*, a ``dns.name.Name`` or ``str`` the domain name or IP address of the + full resolver. + + *port*, an ``int``, the port to use. If not specified, the default is 53. + + *family*, an ``int``, the address family to use. This parameter is used if + *where* is not an address. The default is ``socket.AF_UNSPEC`` in which case + the first address returned by ``resolve_name()`` will be used, otherwise the + first address of the specified family will be used. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames. If not specified, the default resolver will be used. + + Returns a ``dns.resolver.Resolver`` or raises an exception. + """ + if resolver is None: + resolver = get_default_resolver() + nameservers: List[Union[str, dns.nameserver.Nameserver]] = [] + if isinstance(where, str) and dns.inet.is_address(where): + nameservers.append(dns.nameserver.Do53Nameserver(where, port)) + else: + for address in resolver.resolve_name(where, family).addresses(): + nameservers.append(dns.nameserver.Do53Nameserver(address, port)) + res = dns.resolver.Resolver(configure=False) + res.nameservers = nameservers + return res + + +def resolve_at( + where: Union[dns.name.Name, str], + qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + tcp: bool = False, + source: Optional[str] = None, + raise_on_no_answer: bool = True, + source_port: int = 0, + lifetime: Optional[float] = None, + search: Optional[bool] = None, + port: int = 53, + family: int = socket.AF_UNSPEC, + resolver: Optional[Resolver] = None, +) -> Answer: + """Query nameservers to find the answer to the question. + + This is a convenience function that calls ``dns.resolver.make_resolver_at()`` to + make a resolver, and then uses it to resolve the query. + + See ``dns.resolver.Resolver.resolve`` for more information on the resolution + parameters, and ``dns.resolver.make_resolver_at`` for information about the resolver + parameters *where*, *port*, *family*, and *resolver*. + + If making more than one query, it is more efficient to call + ``dns.resolver.make_resolver_at()`` and then use that resolver for the queries + instead of calling ``resolve_at()`` multiple times. + """ + return make_resolver_at(where, port, family, resolver).resolve( + qname, + rdtype, + rdclass, + tcp, + source, + raise_on_no_answer, + source_port, + lifetime, + search, + ) + + +# +# Support for overriding the system resolver for all python code in the +# running process. +# + +_protocols_for_socktype = { + socket.SOCK_DGRAM: [socket.SOL_UDP], + socket.SOCK_STREAM: [socket.SOL_TCP], +} + +_resolver = None +_original_getaddrinfo = socket.getaddrinfo +_original_getnameinfo = socket.getnameinfo +_original_getfqdn = socket.getfqdn +_original_gethostbyname = socket.gethostbyname +_original_gethostbyname_ex = socket.gethostbyname_ex +_original_gethostbyaddr = socket.gethostbyaddr + + +def _getaddrinfo( + host=None, service=None, family=socket.AF_UNSPEC, socktype=0, proto=0, flags=0 +): + if flags & socket.AI_NUMERICHOST != 0: + # Short circuit directly into the system's getaddrinfo(). We're + # not adding any value in this case, and this avoids infinite loops + # because dns.query.* needs to call getaddrinfo() for IPv6 scoping + # reasons. We will also do this short circuit below if we + # discover that the host is an address literal. + return _original_getaddrinfo(host, service, family, socktype, proto, flags) + if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0: + # Not implemented. We raise a gaierror as opposed to a + # NotImplementedError as it helps callers handle errors more + # appropriately. [Issue #316] + # + # We raise EAI_FAIL as opposed to EAI_SYSTEM because there is + # no EAI_SYSTEM on Windows [Issue #416]. We didn't go for + # EAI_BADFLAGS as the flags aren't bad, we just don't + # implement them. + raise socket.gaierror( + socket.EAI_FAIL, "Non-recoverable failure in name resolution" + ) + if host is None and service is None: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + addrs = [] + canonical_name = None # pylint: disable=redefined-outer-name + # Is host None or an address literal? If so, use the system's + # getaddrinfo(). + if host is None: + return _original_getaddrinfo(host, service, family, socktype, proto, flags) + try: + # We don't care about the result of af_for_address(), we're just + # calling it so it raises an exception if host is not an IPv4 or + # IPv6 address. + dns.inet.af_for_address(host) + return _original_getaddrinfo(host, service, family, socktype, proto, flags) + except Exception: + pass + # Something needs resolution! + try: + answers = _resolver.resolve_name(host, family) + addrs = answers.addresses_and_families() + canonical_name = answers.canonical_name().to_text(True) + except dns.resolver.NXDOMAIN: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + except Exception: + # We raise EAI_AGAIN here as the failure may be temporary + # (e.g. a timeout) and EAI_SYSTEM isn't defined on Windows. + # [Issue #416] + raise socket.gaierror(socket.EAI_AGAIN, "Temporary failure in name resolution") + port = None + try: + # Is it a port literal? + if service is None: + port = 0 + else: + port = int(service) + except Exception: + if flags & socket.AI_NUMERICSERV == 0: + try: + port = socket.getservbyname(service) + except Exception: + pass + if port is None: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + tuples = [] + if socktype == 0: + socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM] + else: + socktypes = [socktype] + if flags & socket.AI_CANONNAME != 0: + cname = canonical_name + else: + cname = "" + for addr, af in addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + addr_tuple = dns.inet.low_level_address_tuple((addr, port), af) + tuples.append((af, socktype, proto, cname, addr_tuple)) + if len(tuples) == 0: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + return tuples + + +def _getnameinfo(sockaddr, flags=0): + host = sockaddr[0] + port = sockaddr[1] + if len(sockaddr) == 4: + scope = sockaddr[3] + family = socket.AF_INET6 + else: + scope = None + family = socket.AF_INET + tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0) + if len(tuples) > 1: + raise socket.error("sockaddr resolved to multiple addresses") + addr = tuples[0][4][0] + if flags & socket.NI_DGRAM: + pname = "udp" + else: + pname = "tcp" + qname = dns.reversename.from_address(addr) + if flags & socket.NI_NUMERICHOST == 0: + try: + answer = _resolver.resolve(qname, "PTR") + hostname = answer.rrset[0].target.to_text(True) + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + if flags & socket.NI_NAMEREQD: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + hostname = addr + if scope is not None: + hostname += "%" + str(scope) + else: + hostname = addr + if scope is not None: + hostname += "%" + str(scope) + if flags & socket.NI_NUMERICSERV: + service = str(port) + else: + service = socket.getservbyport(port, pname) + return (hostname, service) + + +def _getfqdn(name=None): + if name is None: + name = socket.gethostname() + try: + (name, _, _) = _gethostbyaddr(name) + # Python's version checks aliases too, but our gethostbyname + # ignores them, so we do so here as well. + except Exception: + pass + return name + + +def _gethostbyname(name): + return _gethostbyname_ex(name)[2][0] + + +def _gethostbyname_ex(name): + aliases = [] + addresses = [] + tuples = _getaddrinfo( + name, 0, socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) + canonical = tuples[0][3] + for item in tuples: + addresses.append(item[4][0]) + # XXX we just ignore aliases + return (canonical, aliases, addresses) + + +def _gethostbyaddr(ip): + try: + dns.ipv6.inet_aton(ip) + sockaddr = (ip, 80, 0, 0) + family = socket.AF_INET6 + except Exception: + try: + dns.ipv4.inet_aton(ip) + except Exception: + raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") + sockaddr = (ip, 80) + family = socket.AF_INET + (name, _) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) + aliases = [] + addresses = [] + tuples = _getaddrinfo( + name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, socket.AI_CANONNAME + ) + canonical = tuples[0][3] + # We only want to include an address from the tuples if it's the + # same as the one we asked about. We do this comparison in binary + # to avoid any differences in text representations. + bin_ip = dns.inet.inet_pton(family, ip) + for item in tuples: + addr = item[4][0] + bin_addr = dns.inet.inet_pton(family, addr) + if bin_ip == bin_addr: + addresses.append(addr) + # XXX we just ignore aliases + return (canonical, aliases, addresses) + + +def override_system_resolver(resolver: Optional[Resolver] = None) -> None: + """Override the system resolver routines in the socket module with + versions which use dnspython's resolver. + + This can be useful in testing situations where you want to control + the resolution behavior of python code without having to change + the system's resolver settings (e.g. /etc/resolv.conf). + + The resolver to use may be specified; if it's not, the default + resolver will be used. + + resolver, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. + """ + + if resolver is None: + resolver = get_default_resolver() + global _resolver + _resolver = resolver + socket.getaddrinfo = _getaddrinfo + socket.getnameinfo = _getnameinfo + socket.getfqdn = _getfqdn + socket.gethostbyname = _gethostbyname + socket.gethostbyname_ex = _gethostbyname_ex + socket.gethostbyaddr = _gethostbyaddr + + +def restore_system_resolver() -> None: + """Undo the effects of prior override_system_resolver().""" + + global _resolver + _resolver = None + socket.getaddrinfo = _original_getaddrinfo + socket.getnameinfo = _original_getnameinfo + socket.getfqdn = _original_getfqdn + socket.gethostbyname = _original_gethostbyname + socket.gethostbyname_ex = _original_gethostbyname_ex + socket.gethostbyaddr = _original_gethostbyaddr diff --git a/venv/Lib/site-packages/dns/reversename.py b/venv/Lib/site-packages/dns/reversename.py new file mode 100644 index 00000000..8236c711 --- /dev/null +++ b/venv/Lib/site-packages/dns/reversename.py @@ -0,0 +1,105 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Reverse Map Names.""" + +import binascii + +import dns.ipv4 +import dns.ipv6 +import dns.name + +ipv4_reverse_domain = dns.name.from_text("in-addr.arpa.") +ipv6_reverse_domain = dns.name.from_text("ip6.arpa.") + + +def from_address( + text: str, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> dns.name.Name: + """Convert an IPv4 or IPv6 address in textual form into a Name object whose + value is the reverse-map domain name of the address. + + *text*, a ``str``, is an IPv4 or IPv6 address in textual form + (e.g. '127.0.0.1', '::1') + + *v4_origin*, a ``dns.name.Name`` to append to the labels corresponding to + the address if the address is an IPv4 address, instead of the default + (in-addr.arpa.) + + *v6_origin*, a ``dns.name.Name`` to append to the labels corresponding to + the address if the address is an IPv6 address, instead of the default + (ip6.arpa.) + + Raises ``dns.exception.SyntaxError`` if the address is badly formed. + + Returns a ``dns.name.Name``. + """ + + try: + v6 = dns.ipv6.inet_aton(text) + if dns.ipv6.is_mapped(v6): + parts = ["%d" % byte for byte in v6[12:]] + origin = v4_origin + else: + parts = [x for x in str(binascii.hexlify(v6).decode())] + origin = v6_origin + except Exception: + parts = ["%d" % byte for byte in dns.ipv4.inet_aton(text)] + origin = v4_origin + return dns.name.from_text(".".join(reversed(parts)), origin=origin) + + +def to_address( + name: dns.name.Name, + v4_origin: dns.name.Name = ipv4_reverse_domain, + v6_origin: dns.name.Name = ipv6_reverse_domain, +) -> str: + """Convert a reverse map domain name into textual address form. + + *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name + form. + + *v4_origin*, a ``dns.name.Name`` representing the top-level domain for + IPv4 addresses, instead of the default (in-addr.arpa.) + + *v6_origin*, a ``dns.name.Name`` representing the top-level domain for + IPv4 addresses, instead of the default (ip6.arpa.) + + Raises ``dns.exception.SyntaxError`` if the name does not have a + reverse-map form. + + Returns a ``str``. + """ + + if name.is_subdomain(v4_origin): + name = name.relativize(v4_origin) + text = b".".join(reversed(name.labels)) + # run through inet_ntoa() to check syntax and make pretty. + return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text)) + elif name.is_subdomain(v6_origin): + name = name.relativize(v6_origin) + labels = list(reversed(name.labels)) + parts = [] + for i in range(0, len(labels), 4): + parts.append(b"".join(labels[i : i + 4])) + text = b":".join(parts) + # run through inet_ntoa() to check syntax and make pretty. + return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text)) + else: + raise dns.exception.SyntaxError("unknown reverse-map address family") diff --git a/venv/Lib/site-packages/dns/rrset.py b/venv/Lib/site-packages/dns/rrset.py new file mode 100644 index 00000000..6f39b108 --- /dev/null +++ b/venv/Lib/site-packages/dns/rrset.py @@ -0,0 +1,285 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS RRsets (an RRset is a named rdataset)""" + +from typing import Any, Collection, Dict, Optional, Union, cast + +import dns.name +import dns.rdataclass +import dns.rdataset +import dns.renderer + + +class RRset(dns.rdataset.Rdataset): + """A DNS RRset (named rdataset). + + RRset inherits from Rdataset, and RRsets can be treated as + Rdatasets in most cases. There are, however, a few notable + exceptions. RRsets have different to_wire() and to_text() method + arguments, reflecting the fact that RRsets always have an owner + name. + """ + + __slots__ = ["name", "deleting"] + + def __init__( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ): + """Create a new RRset.""" + + super().__init__(rdclass, rdtype, covers) + self.name = name + self.deleting = deleting + + def _clone(self): + obj = super()._clone() + obj.name = self.name + obj.deleting = self.deleting + return obj + + def __repr__(self): + if self.covers == 0: + ctext = "" + else: + ctext = "(" + dns.rdatatype.to_text(self.covers) + ")" + if self.deleting is not None: + dtext = " delete=" + dns.rdataclass.to_text(self.deleting) + else: + dtext = "" + return ( + "" + ) + + def __str__(self): + return self.to_text() + + def __eq__(self, other): + if isinstance(other, RRset): + if self.name != other.name: + return False + elif not isinstance(other, dns.rdataset.Rdataset): + return False + return super().__eq__(other) + + def match(self, *args: Any, **kwargs: Any) -> bool: # type: ignore[override] + """Does this rrset match the specified attributes? + + Behaves as :py:func:`full_match()` if the first argument is a + ``dns.name.Name``, and as :py:func:`dns.rdataset.Rdataset.match()` + otherwise. + + (This behavior fixes a design mistake where the signature of this + method became incompatible with that of its superclass. The fix + makes RRsets matchable as Rdatasets while preserving backwards + compatibility.) + """ + if isinstance(args[0], dns.name.Name): + return self.full_match(*args, **kwargs) # type: ignore[arg-type] + else: + return super().match(*args, **kwargs) # type: ignore[arg-type] + + def full_match( + self, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + deleting: Optional[dns.rdataclass.RdataClass] = None, + ) -> bool: + """Returns ``True`` if this rrset matches the specified name, class, + type, covers, and deletion state. + """ + if not super().match(rdclass, rdtype, covers): + return False + if self.name != name or self.deleting != deleting: + return False + return True + + # pylint: disable=arguments-differ + + def to_text( # type: ignore[override] + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + **kw: Dict[str, Any], + ) -> str: + """Convert the RRset into DNS zone file format. + + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names + are emitted. + + Any additional keyword arguments are passed on to the rdata + ``to_text()`` method. + + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. + + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + """ + + return super().to_text( + self.name, origin, relativize, self.deleting, **kw # type: ignore + ) + + def to_wire( # type: ignore[override] + self, + file: Any, + compress: Optional[dns.name.CompressType] = None, # type: ignore + origin: Optional[dns.name.Name] = None, + **kw: Dict[str, Any], + ) -> int: + """Convert the RRset to wire format. + + All keyword arguments are passed to ``dns.rdataset.to_wire()``; see + that function for details. + + Returns an ``int``, the number of records emitted. + """ + + return super().to_wire( + self.name, file, compress, origin, self.deleting, **kw # type:ignore + ) + + # pylint: enable=arguments-differ + + def to_rdataset(self) -> dns.rdataset.Rdataset: + """Convert an RRset into an Rdataset. + + Returns a ``dns.rdataset.Rdataset``. + """ + return dns.rdataset.from_rdata_list(self.ttl, list(self)) + + +def from_text_list( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[dns.name.Name] = None, + relativize: bool = True, + relativize_to: Optional[dns.name.Name] = None, +) -> RRset: + """Create an RRset with the specified name, TTL, class, and type, and with + the specified list of rdatas in text format. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized. + + *relativize_to*, a ``dns.name.Name`` (or ``None``), the origin to use + when relativizing names. If not set, the *origin* value will be used. + + Returns a ``dns.rrset.RRset`` object. + """ + + if isinstance(name, str): + name = dns.name.from_text(name, None, idna_codec=idna_codec) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, rdclass, rdtype) + r.update_ttl(ttl) + for t in text_rdatas: + rd = dns.rdata.from_text( + r.rdclass, r.rdtype, t, origin, relativize, relativize_to, idna_codec + ) + r.add(rd) + return r + + +def from_text( + name: Union[dns.name.Name, str], + ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + *text_rdatas: Any, +) -> RRset: + """Create an RRset with the specified name, TTL, class, and type and with + the specified rdatas in text format. + + Returns a ``dns.rrset.RRset`` object. + """ + + return from_text_list( + name, ttl, rdclass, rdtype, cast(Collection[str], text_rdatas) + ) + + +def from_rdata_list( + name: Union[dns.name.Name, str], + ttl: int, + rdatas: Collection[dns.rdata.Rdata], + idna_codec: Optional[dns.name.IDNACodec] = None, +) -> RRset: + """Create an RRset with the specified name and TTL, and with + the specified list of rdata objects. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder to use; if ``None``, the default IDNA 2003 + encoder/decoder is used. + + Returns a ``dns.rrset.RRset`` object. + + """ + + if isinstance(name, str): + name = dns.name.from_text(name, None, idna_codec=idna_codec) + + if len(rdatas) == 0: + raise ValueError("rdata list must not be empty") + r = None + for rd in rdatas: + if r is None: + r = RRset(name, rd.rdclass, rd.rdtype) + r.update_ttl(ttl) + r.add(rd) + assert r is not None + return r + + +def from_rdata(name: Union[dns.name.Name, str], ttl: int, *rdatas: Any) -> RRset: + """Create an RRset with the specified name and TTL, and with + the specified rdata objects. + + Returns a ``dns.rrset.RRset`` object. + """ + + return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/venv/Lib/site-packages/dns/serial.py b/venv/Lib/site-packages/dns/serial.py new file mode 100644 index 00000000..3417299b --- /dev/null +++ b/venv/Lib/site-packages/dns/serial.py @@ -0,0 +1,118 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Serial Number Arthimetic from RFC 1982""" + + +class Serial: + def __init__(self, value: int, bits: int = 32): + self.value = value % 2**bits + self.bits = bits + + def __repr__(self): + return f"dns.serial.Serial({self.value}, {self.bits})" + + def __eq__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value == other.value + + def __ne__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + return self.value != other.value + + def __lt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and other.value - self.value < 2 ** (self.bits - 1): + return True + elif self.value > other.value and self.value - other.value > 2 ** ( + self.bits - 1 + ): + return True + else: + return False + + def __le__(self, other): + return self == other or self < other + + def __gt__(self, other): + if isinstance(other, int): + other = Serial(other, self.bits) + elif not isinstance(other, Serial) or other.bits != self.bits: + return NotImplemented + if self.value < other.value and other.value - self.value > 2 ** (self.bits - 1): + return True + elif self.value > other.value and self.value - other.value < 2 ** ( + self.bits - 1 + ): + return True + else: + return False + + def __ge__(self, other): + return self == other or self > other + + def __add__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2**self.bits + return Serial(v, self.bits) + + def __iadd__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v += delta + v = v % 2**self.bits + self.value = v + return self + + def __sub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2**self.bits + return Serial(v, self.bits) + + def __isub__(self, other): + v = self.value + if isinstance(other, Serial): + delta = other.value + elif isinstance(other, int): + delta = other + else: + raise ValueError + if abs(delta) > (2 ** (self.bits - 1) - 1): + raise ValueError + v -= delta + v = v % 2**self.bits + self.value = v + return self diff --git a/venv/Lib/site-packages/dns/set.py b/venv/Lib/site-packages/dns/set.py new file mode 100644 index 00000000..f0fb0d50 --- /dev/null +++ b/venv/Lib/site-packages/dns/set.py @@ -0,0 +1,307 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import itertools + + +class Set: + """A simple set class. + + This class was originally used to deal with sets being missing in + ancient versions of python, but dnspython will continue to use it + as these sets are based on lists and are thus indexable, and this + ability is widely used in dnspython applications. + """ + + __slots__ = ["items"] + + def __init__(self, items=None): + """Initialize the set. + + *items*, an iterable or ``None``, the initial set of items. + """ + + self.items = dict() + if items is not None: + for item in items: + # This is safe for how we use set, but if other code + # subclasses it could be a legitimate issue. + self.add(item) # lgtm[py/init-calls-subclass] + + def __repr__(self): + return "dns.set.Set(%s)" % repr(list(self.items.keys())) + + def add(self, item): + """Add an item to the set.""" + + if item not in self.items: + self.items[item] = None + + def remove(self, item): + """Remove an item from the set.""" + + try: + del self.items[item] + except KeyError: + raise ValueError + + def discard(self, item): + """Remove an item from the set if present.""" + + self.items.pop(item, None) + + def pop(self): + """Remove an arbitrary item from the set.""" + (k, _) = self.items.popitem() + return k + + def _clone(self) -> "Set": + """Make a (shallow) copy of the set. + + There is a 'clone protocol' that subclasses of this class + should use. To make a copy, first call your super's _clone() + method, and use the object returned as the new instance. Then + make shallow copies of the attributes defined in the subclass. + + This protocol allows us to write the set algorithms that + return new instances (e.g. union) once, and keep using them in + subclasses. + """ + + if hasattr(self, "_clone_class"): + cls = self._clone_class # type: ignore + else: + cls = self.__class__ + obj = cls.__new__(cls) + obj.items = dict() + obj.items.update(self.items) + return obj + + def __copy__(self): + """Make a (shallow) copy of the set.""" + + return self._clone() + + def copy(self): + """Make a (shallow) copy of the set.""" + + return self._clone() + + def union_update(self, other): + """Update the set, adding any elements from other which are not + already in the set. + """ + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] + return + for item in other.items: + self.add(item) + + def intersection_update(self, other): + """Update the set, removing any elements from other which are not + in both sets. + """ + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] + return + # we make a copy of the list so that we can remove items from + # the list without breaking the iterator. + for item in list(self.items): + if item not in other.items: + del self.items[item] + + def difference_update(self, other): + """Update the set, removing any elements from other which are in + the set. + """ + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] + self.items.clear() + else: + for item in other.items: + self.discard(item) + + def symmetric_difference_update(self, other): + """Update the set, retaining only elements unique to both sets.""" + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + if self is other: # lgtm[py/comparison-using-is] + self.items.clear() + else: + overlap = self.intersection(other) + self.union_update(other) + self.difference_update(overlap) + + def union(self, other): + """Return a new set which is the union of ``self`` and ``other``. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.union_update(other) + return obj + + def intersection(self, other): + """Return a new set which is the intersection of ``self`` and + ``other``. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.intersection_update(other) + return obj + + def difference(self, other): + """Return a new set which ``self`` - ``other``, i.e. the items + in ``self`` which are not also in ``other``. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.difference_update(other) + return obj + + def symmetric_difference(self, other): + """Return a new set which (``self`` - ``other``) | (``other`` + - ``self), ie: the items in either ``self`` or ``other`` which + are not contained in their intersection. + + Returns the same Set type as this set. + """ + + obj = self._clone() + obj.symmetric_difference_update(other) + return obj + + def __or__(self, other): + return self.union(other) + + def __and__(self, other): + return self.intersection(other) + + def __add__(self, other): + return self.union(other) + + def __sub__(self, other): + return self.difference(other) + + def __xor__(self, other): + return self.symmetric_difference(other) + + def __ior__(self, other): + self.union_update(other) + return self + + def __iand__(self, other): + self.intersection_update(other) + return self + + def __iadd__(self, other): + self.union_update(other) + return self + + def __isub__(self, other): + self.difference_update(other) + return self + + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + + def update(self, other): + """Update the set, adding any elements from other which are not + already in the set. + + *other*, the collection of items with which to update the set, which + may be any iterable type. + """ + + for item in other: + self.add(item) + + def clear(self): + """Make the set empty.""" + self.items.clear() + + def __eq__(self, other): + return self.items == other.items + + def __ne__(self, other): + return not self.__eq__(other) + + def __len__(self): + return len(self.items) + + def __iter__(self): + return iter(self.items) + + def __getitem__(self, i): + if isinstance(i, slice): + return list(itertools.islice(self.items, i.start, i.stop, i.step)) + else: + return next(itertools.islice(self.items, i, i + 1)) + + def __delitem__(self, i): + if isinstance(i, slice): + for elt in list(self[i]): + del self.items[elt] + else: + del self.items[self[i]] + + def issubset(self, other): + """Is this set a subset of *other*? + + Returns a ``bool``. + """ + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + for item in self.items: + if item not in other.items: + return False + return True + + def issuperset(self, other): + """Is this set a superset of *other*? + + Returns a ``bool``. + """ + + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + for item in other.items: + if item not in self.items: + return False + return True + + def isdisjoint(self, other): + if not isinstance(other, Set): + raise ValueError("other must be a Set instance") + for item in other.items: + if item in self.items: + return False + return True diff --git a/venv/Lib/site-packages/dns/tokenizer.py b/venv/Lib/site-packages/dns/tokenizer.py new file mode 100644 index 00000000..454cac4a --- /dev/null +++ b/venv/Lib/site-packages/dns/tokenizer.py @@ -0,0 +1,708 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Tokenize DNS zone file format""" + +import io +import sys +from typing import Any, List, Optional, Tuple + +import dns.exception +import dns.name +import dns.ttl + +_DELIMITERS = {" ", "\t", "\n", ";", "(", ")", '"'} +_QUOTING_DELIMITERS = {'"'} + +EOF = 0 +EOL = 1 +WHITESPACE = 2 +IDENTIFIER = 3 +QUOTED_STRING = 4 +COMMENT = 5 +DELIMITER = 6 + + +class UngetBufferFull(dns.exception.DNSException): + """An attempt was made to unget a token when the unget buffer was full.""" + + +class Token: + """A DNS zone file format token. + + ttype: The token type + value: The token value + has_escape: Does the token value contain escapes? + """ + + def __init__( + self, + ttype: int, + value: Any = "", + has_escape: bool = False, + comment: Optional[str] = None, + ): + """Initialize a token instance.""" + + self.ttype = ttype + self.value = value + self.has_escape = has_escape + self.comment = comment + + def is_eof(self) -> bool: + return self.ttype == EOF + + def is_eol(self) -> bool: + return self.ttype == EOL + + def is_whitespace(self) -> bool: + return self.ttype == WHITESPACE + + def is_identifier(self) -> bool: + return self.ttype == IDENTIFIER + + def is_quoted_string(self) -> bool: + return self.ttype == QUOTED_STRING + + def is_comment(self) -> bool: + return self.ttype == COMMENT + + def is_delimiter(self) -> bool: # pragma: no cover (we don't return delimiters yet) + return self.ttype == DELIMITER + + def is_eol_or_eof(self) -> bool: + return self.ttype == EOL or self.ttype == EOF + + def __eq__(self, other): + if not isinstance(other, Token): + return False + return self.ttype == other.ttype and self.value == other.value + + def __ne__(self, other): + if not isinstance(other, Token): + return True + return self.ttype != other.ttype or self.value != other.value + + def __str__(self): + return '%d "%s"' % (self.ttype, self.value) + + def unescape(self) -> "Token": + if not self.has_escape: + return self + unescaped = "" + l = len(self.value) + i = 0 + while i < l: + c = self.value[i] + i += 1 + if c == "\\": + if i >= l: # pragma: no cover (can't happen via get()) + raise dns.exception.UnexpectedEnd + c = self.value[i] + i += 1 + if c.isdigit(): + if i >= l: + raise dns.exception.UnexpectedEnd + c2 = self.value[i] + i += 1 + if i >= l: + raise dns.exception.UnexpectedEnd + c3 = self.value[i] + i += 1 + if not (c2.isdigit() and c3.isdigit()): + raise dns.exception.SyntaxError + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + c = chr(codepoint) + unescaped += c + return Token(self.ttype, unescaped) + + def unescape_to_bytes(self) -> "Token": + # We used to use unescape() for TXT-like records, but this + # caused problems as we'd process DNS escapes into Unicode code + # points instead of byte values, and then a to_text() of the + # processed data would not equal the original input. For + # example, \226 in the TXT record would have a to_text() of + # \195\162 because we applied UTF-8 encoding to Unicode code + # point 226. + # + # We now apply escapes while converting directly to bytes, + # avoiding this double encoding. + # + # This code also handles cases where the unicode input has + # non-ASCII code-points in it by converting it to UTF-8. TXT + # records aren't defined for Unicode, but this is the best we + # can do to preserve meaning. For example, + # + # foo\u200bbar + # + # (where \u200b is Unicode code point 0x200b) will be treated + # as if the input had been the UTF-8 encoding of that string, + # namely: + # + # foo\226\128\139bar + # + unescaped = b"" + l = len(self.value) + i = 0 + while i < l: + c = self.value[i] + i += 1 + if c == "\\": + if i >= l: # pragma: no cover (can't happen via get()) + raise dns.exception.UnexpectedEnd + c = self.value[i] + i += 1 + if c.isdigit(): + if i >= l: + raise dns.exception.UnexpectedEnd + c2 = self.value[i] + i += 1 + if i >= l: + raise dns.exception.UnexpectedEnd + c3 = self.value[i] + i += 1 + if not (c2.isdigit() and c3.isdigit()): + raise dns.exception.SyntaxError + codepoint = int(c) * 100 + int(c2) * 10 + int(c3) + if codepoint > 255: + raise dns.exception.SyntaxError + unescaped += b"%c" % (codepoint) + else: + # Note that as mentioned above, if c is a Unicode + # code point outside of the ASCII range, then this + # += is converting that code point to its UTF-8 + # encoding and appending multiple bytes to + # unescaped. + unescaped += c.encode() + else: + unescaped += c.encode() + return Token(self.ttype, bytes(unescaped)) + + +class Tokenizer: + """A DNS zone file format tokenizer. + + A token object is basically a (type, value) tuple. The valid + types are EOF, EOL, WHITESPACE, IDENTIFIER, QUOTED_STRING, + COMMENT, and DELIMITER. + + file: The file to tokenize + + ungotten_char: The most recently ungotten character, or None. + + ungotten_token: The most recently ungotten token, or None. + + multiline: The current multiline level. This value is increased + by one every time a '(' delimiter is read, and decreased by one every time + a ')' delimiter is read. + + quoting: This variable is true if the tokenizer is currently + reading a quoted string. + + eof: This variable is true if the tokenizer has encountered EOF. + + delimiters: The current delimiter dictionary. + + line_number: The current line number + + filename: A filename that will be returned by the where() method. + + idna_codec: A dns.name.IDNACodec, specifies the IDNA + encoder/decoder. If None, the default IDNA 2003 + encoder/decoder is used. + """ + + def __init__( + self, + f: Any = sys.stdin, + filename: Optional[str] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + ): + """Initialize a tokenizer instance. + + f: The file to tokenize. The default is sys.stdin. + This parameter may also be a string, in which case the tokenizer + will take its input from the contents of the string. + + filename: the name of the filename that the where() method + will return. + + idna_codec: A dns.name.IDNACodec, specifies the IDNA + encoder/decoder. If None, the default IDNA 2003 + encoder/decoder is used. + """ + + if isinstance(f, str): + f = io.StringIO(f) + if filename is None: + filename = "" + elif isinstance(f, bytes): + f = io.StringIO(f.decode()) + if filename is None: + filename = "" + else: + if filename is None: + if f is sys.stdin: + filename = "" + else: + filename = "" + self.file = f + self.ungotten_char: Optional[str] = None + self.ungotten_token: Optional[Token] = None + self.multiline = 0 + self.quoting = False + self.eof = False + self.delimiters = _DELIMITERS + self.line_number = 1 + assert filename is not None + self.filename = filename + if idna_codec is None: + self.idna_codec: dns.name.IDNACodec = dns.name.IDNA_2003 + else: + self.idna_codec = idna_codec + + def _get_char(self) -> str: + """Read a character from input.""" + + if self.ungotten_char is None: + if self.eof: + c = "" + else: + c = self.file.read(1) + if c == "": + self.eof = True + elif c == "\n": + self.line_number += 1 + else: + c = self.ungotten_char + self.ungotten_char = None + return c + + def where(self) -> Tuple[str, int]: + """Return the current location in the input. + + Returns a (string, int) tuple. The first item is the filename of + the input, the second is the current line number. + """ + + return (self.filename, self.line_number) + + def _unget_char(self, c: str) -> None: + """Unget a character. + + The unget buffer for characters is only one character large; it is + an error to try to unget a character when the unget buffer is not + empty. + + c: the character to unget + raises UngetBufferFull: there is already an ungotten char + """ + + if self.ungotten_char is not None: + # this should never happen! + raise UngetBufferFull # pragma: no cover + self.ungotten_char = c + + def skip_whitespace(self) -> int: + """Consume input until a non-whitespace character is encountered. + + The non-whitespace character is then ungotten, and the number of + whitespace characters consumed is returned. + + If the tokenizer is in multiline mode, then newlines are whitespace. + + Returns the number of characters skipped. + """ + + skipped = 0 + while True: + c = self._get_char() + if c != " " and c != "\t": + if (c != "\n") or not self.multiline: + self._unget_char(c) + return skipped + skipped += 1 + + def get(self, want_leading: bool = False, want_comment: bool = False) -> Token: + """Get the next token. + + want_leading: If True, return a WHITESPACE token if the + first character read is whitespace. The default is False. + + want_comment: If True, return a COMMENT token if the + first token read is a comment. The default is False. + + Raises dns.exception.UnexpectedEnd: input ended prematurely + + Raises dns.exception.SyntaxError: input was badly formed + + Returns a Token. + """ + + if self.ungotten_token is not None: + utoken = self.ungotten_token + self.ungotten_token = None + if utoken.is_whitespace(): + if want_leading: + return utoken + elif utoken.is_comment(): + if want_comment: + return utoken + else: + return utoken + skipped = self.skip_whitespace() + if want_leading and skipped > 0: + return Token(WHITESPACE, " ") + token = "" + ttype = IDENTIFIER + has_escape = False + while True: + c = self._get_char() + if c == "" or c in self.delimiters: + if c == "" and self.quoting: + raise dns.exception.UnexpectedEnd + if token == "" and ttype != QUOTED_STRING: + if c == "(": + self.multiline += 1 + self.skip_whitespace() + continue + elif c == ")": + if self.multiline <= 0: + raise dns.exception.SyntaxError + self.multiline -= 1 + self.skip_whitespace() + continue + elif c == '"': + if not self.quoting: + self.quoting = True + self.delimiters = _QUOTING_DELIMITERS + ttype = QUOTED_STRING + continue + else: + self.quoting = False + self.delimiters = _DELIMITERS + self.skip_whitespace() + continue + elif c == "\n": + return Token(EOL, "\n") + elif c == ";": + while 1: + c = self._get_char() + if c == "\n" or c == "": + break + token += c + if want_comment: + self._unget_char(c) + return Token(COMMENT, token) + elif c == "": + if self.multiline: + raise dns.exception.SyntaxError( + "unbalanced parentheses" + ) + return Token(EOF, comment=token) + elif self.multiline: + self.skip_whitespace() + token = "" + continue + else: + return Token(EOL, "\n", comment=token) + else: + # This code exists in case we ever want a + # delimiter to be returned. It never produces + # a token currently. + token = c + ttype = DELIMITER + else: + self._unget_char(c) + break + elif self.quoting and c == "\n": + raise dns.exception.SyntaxError("newline in quoted string") + elif c == "\\": + # + # It's an escape. Put it and the next character into + # the token; it will be checked later for goodness. + # + token += c + has_escape = True + c = self._get_char() + if c == "" or (c == "\n" and not self.quoting): + raise dns.exception.UnexpectedEnd + token += c + if token == "" and ttype != QUOTED_STRING: + if self.multiline: + raise dns.exception.SyntaxError("unbalanced parentheses") + ttype = EOF + return Token(ttype, token, has_escape) + + def unget(self, token: Token) -> None: + """Unget a token. + + The unget buffer for tokens is only one token large; it is + an error to try to unget a token when the unget buffer is not + empty. + + token: the token to unget + + Raises UngetBufferFull: there is already an ungotten token + """ + + if self.ungotten_token is not None: + raise UngetBufferFull + self.ungotten_token = token + + def next(self): + """Return the next item in an iteration. + + Returns a Token. + """ + + token = self.get() + if token.is_eof(): + raise StopIteration + return token + + __next__ = next + + def __iter__(self): + return self + + # Helpers + + def get_int(self, base: int = 10) -> int: + """Read the next token and interpret it as an unsigned integer. + + Raises dns.exception.SyntaxError if not an unsigned integer. + + Returns an int. + """ + + token = self.get().unescape() + if not token.is_identifier(): + raise dns.exception.SyntaxError("expecting an identifier") + if not token.value.isdigit(): + raise dns.exception.SyntaxError("expecting an integer") + return int(token.value, base) + + def get_uint8(self) -> int: + """Read the next token and interpret it as an 8-bit unsigned + integer. + + Raises dns.exception.SyntaxError if not an 8-bit unsigned integer. + + Returns an int. + """ + + value = self.get_int() + if value < 0 or value > 255: + raise dns.exception.SyntaxError( + "%d is not an unsigned 8-bit integer" % value + ) + return value + + def get_uint16(self, base: int = 10) -> int: + """Read the next token and interpret it as a 16-bit unsigned + integer. + + Raises dns.exception.SyntaxError if not a 16-bit unsigned integer. + + Returns an int. + """ + + value = self.get_int(base=base) + if value < 0 or value > 65535: + if base == 8: + raise dns.exception.SyntaxError( + "%o is not an octal unsigned 16-bit integer" % value + ) + else: + raise dns.exception.SyntaxError( + "%d is not an unsigned 16-bit integer" % value + ) + return value + + def get_uint32(self, base: int = 10) -> int: + """Read the next token and interpret it as a 32-bit unsigned + integer. + + Raises dns.exception.SyntaxError if not a 32-bit unsigned integer. + + Returns an int. + """ + + value = self.get_int(base=base) + if value < 0 or value > 4294967295: + raise dns.exception.SyntaxError( + "%d is not an unsigned 32-bit integer" % value + ) + return value + + def get_uint48(self, base: int = 10) -> int: + """Read the next token and interpret it as a 48-bit unsigned + integer. + + Raises dns.exception.SyntaxError if not a 48-bit unsigned integer. + + Returns an int. + """ + + value = self.get_int(base=base) + if value < 0 or value > 281474976710655: + raise dns.exception.SyntaxError( + "%d is not an unsigned 48-bit integer" % value + ) + return value + + def get_string(self, max_length: Optional[int] = None) -> str: + """Read the next token and interpret it as a string. + + Raises dns.exception.SyntaxError if not a string. + Raises dns.exception.SyntaxError if token value length + exceeds max_length (if specified). + + Returns a string. + """ + + token = self.get().unescape() + if not (token.is_identifier() or token.is_quoted_string()): + raise dns.exception.SyntaxError("expecting a string") + if max_length and len(token.value) > max_length: + raise dns.exception.SyntaxError("string too long") + return token.value + + def get_identifier(self) -> str: + """Read the next token, which should be an identifier. + + Raises dns.exception.SyntaxError if not an identifier. + + Returns a string. + """ + + token = self.get().unescape() + if not token.is_identifier(): + raise dns.exception.SyntaxError("expecting an identifier") + return token.value + + def get_remaining(self, max_tokens: Optional[int] = None) -> List[Token]: + """Return the remaining tokens on the line, until an EOL or EOF is seen. + + max_tokens: If not None, stop after this number of tokens. + + Returns a list of tokens. + """ + + tokens = [] + while True: + token = self.get() + if token.is_eol_or_eof(): + self.unget(token) + break + tokens.append(token) + if len(tokens) == max_tokens: + break + return tokens + + def concatenate_remaining_identifiers(self, allow_empty: bool = False) -> str: + """Read the remaining tokens on the line, which should be identifiers. + + Raises dns.exception.SyntaxError if there are no remaining tokens, + unless `allow_empty=True` is given. + + Raises dns.exception.SyntaxError if a token is seen that is not an + identifier. + + Returns a string containing a concatenation of the remaining + identifiers. + """ + s = "" + while True: + token = self.get().unescape() + if token.is_eol_or_eof(): + self.unget(token) + break + if not token.is_identifier(): + raise dns.exception.SyntaxError + s += token.value + if not (allow_empty or s): + raise dns.exception.SyntaxError("expecting another identifier") + return s + + def as_name( + self, + token: Token, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: + """Try to interpret the token as a DNS name. + + Raises dns.exception.SyntaxError if not a name. + + Returns a dns.name.Name. + """ + if not token.is_identifier(): + raise dns.exception.SyntaxError("expecting an identifier") + name = dns.name.from_text(token.value, origin, self.idna_codec) + return name.choose_relativity(relativize_to or origin, relativize) + + def get_name( + self, + origin: Optional[dns.name.Name] = None, + relativize: bool = False, + relativize_to: Optional[dns.name.Name] = None, + ) -> dns.name.Name: + """Read the next token and interpret it as a DNS name. + + Raises dns.exception.SyntaxError if not a name. + + Returns a dns.name.Name. + """ + + token = self.get() + return self.as_name(token, origin, relativize, relativize_to) + + def get_eol_as_token(self) -> Token: + """Read the next token and raise an exception if it isn't EOL or + EOF. + + Returns a string. + """ + + token = self.get() + if not token.is_eol_or_eof(): + raise dns.exception.SyntaxError( + 'expected EOL or EOF, got %d "%s"' % (token.ttype, token.value) + ) + return token + + def get_eol(self) -> str: + return self.get_eol_as_token().value + + def get_ttl(self) -> int: + """Read the next token and interpret it as a DNS TTL. + + Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an + identifier or badly formed. + + Returns an int. + """ + + token = self.get().unescape() + if not token.is_identifier(): + raise dns.exception.SyntaxError("expecting an identifier") + return dns.ttl.from_text(token.value) diff --git a/venv/Lib/site-packages/dns/transaction.py b/venv/Lib/site-packages/dns/transaction.py new file mode 100644 index 00000000..84e54f7d --- /dev/null +++ b/venv/Lib/site-packages/dns/transaction.py @@ -0,0 +1,651 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union + +import dns.exception +import dns.name +import dns.node +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rrset +import dns.serial +import dns.ttl + + +class TransactionManager: + def reader(self) -> "Transaction": + """Begin a read-only transaction.""" + raise NotImplementedError # pragma: no cover + + def writer(self, replacement: bool = False) -> "Transaction": + """Begin a writable transaction. + + *replacement*, a ``bool``. If `True`, the content of the + transaction completely replaces any prior content. If False, + the default, then the content of the transaction updates the + existing content. + """ + raise NotImplementedError # pragma: no cover + + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + """Returns a tuple + + (absolute_origin, relativize, effective_origin) + + giving the absolute name of the default origin for any + relative domain names, the "effective origin", and whether + names should be relativized. The "effective origin" is the + absolute origin if relativize is False, and the empty name if + relativize is true. (The effective origin is provided even + though it can be computed from the absolute_origin and + relativize setting because it avoids a lot of code + duplication.) + + If the returned names are `None`, then no origin information is + available. + + This information is used by code working with transactions to + allow it to coordinate relativization. The transaction code + itself takes what it gets (i.e. does not change name + relativity). + + """ + raise NotImplementedError # pragma: no cover + + def get_class(self) -> dns.rdataclass.RdataClass: + """The class of the transaction manager.""" + raise NotImplementedError # pragma: no cover + + def from_wire_origin(self) -> Optional[dns.name.Name]: + """Origin to use in from_wire() calls.""" + (absolute_origin, relativize, _) = self.origin_information() + if relativize: + return absolute_origin + else: + return None + + +class DeleteNotExact(dns.exception.DNSException): + """Existing data did not match data specified by an exact delete.""" + + +class ReadOnly(dns.exception.DNSException): + """Tried to write to a read-only transaction.""" + + +class AlreadyEnded(dns.exception.DNSException): + """Tried to use an already-ended transaction.""" + + +def _ensure_immutable_rdataset(rdataset): + if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset): + return rdataset + return dns.rdataset.ImmutableRdataset(rdataset) + + +def _ensure_immutable_node(node): + if node is None or node.is_immutable(): + return node + return dns.node.ImmutableNode(node) + + +CheckPutRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None +] +CheckDeleteRdatasetType = Callable[ + ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType], + None, +] +CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None] + + +class Transaction: + def __init__( + self, + manager: TransactionManager, + replacement: bool = False, + read_only: bool = False, + ): + self.manager = manager + self.replacement = replacement + self.read_only = read_only + self._ended = False + self._check_put_rdataset: List[CheckPutRdatasetType] = [] + self._check_delete_rdataset: List[CheckDeleteRdatasetType] = [] + self._check_delete_name: List[CheckDeleteNameType] = [] + + # + # This is the high level API + # + # Note that we currently use non-immutable types in the return type signature to + # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be + # unhappy if we return an ImmutableRdataset. + + def get( + self, + name: Optional[Union[dns.name.Name, str]], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rdataset.Rdataset: + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. + + Note that the returned rdataset is immutable. + """ + self._check_ended() + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + rdataset = self._get_rdataset(name, rdtype, covers) + return _ensure_immutable_rdataset(rdataset) + + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: + """Return the node at *name*, if any. + + Returns an immutable node or ``None``. + """ + return _ensure_immutable_node(self._get_node(name)) + + def _check_read_only(self) -> None: + if self.read_only: + raise ReadOnly + + def add(self, *args: Any) -> None: + """Add records. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + """ + self._check_ended() + self._check_read_only() + self._add(False, args) + + def replace(self, *args: Any) -> None: + """Replace the existing rdataset at the name with the specified + rdataset, or add the specified rdataset if there was no existing + rdataset. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + + Note that if you want to replace the entire node, you should do + a delete of the name followed by one or more calls to add() or + replace(). + """ + self._check_ended() + self._check_read_only() + self._add(True, args) + + def delete(self, *args: Any) -> None: + """Delete records. + + It is not an error if some of the records are not in the existing + set. + + The arguments may be: + + - rrset + + - name + + - name, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + """ + self._check_ended() + self._check_read_only() + self._delete(False, args) + + def delete_exact(self, *args: Any) -> None: + """Delete records. + + The arguments may be: + + - rrset + + - name + + - name, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + + Raises dns.transaction.DeleteNotExact if some of the records + are not in the existing set. + + """ + self._check_ended() + self._check_read_only() + self._delete(True, args) + + def name_exists(self, name: Union[dns.name.Name, str]) -> bool: + """Does the specified name exist?""" + self._check_ended() + if isinstance(name, str): + name = dns.name.from_text(name, None) + return self._name_exists(name) + + def update_serial( + self, + value: int = 1, + relative: bool = True, + name: dns.name.Name = dns.name.empty, + ) -> None: + """Update the serial number. + + *value*, an `int`, is an increment if *relative* is `True`, or the + actual value to set if *relative* is `False`. + + Raises `KeyError` if there is no SOA rdataset at *name*. + + Raises `ValueError` if *value* is negative or if the increment is + so large that it would cause the new serial to be less than the + prior value. + """ + self._check_ended() + if value < 0: + raise ValueError("negative update_serial() value") + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE) + if rdataset is None or len(rdataset) == 0: + raise KeyError + if relative: + serial = dns.serial.Serial(rdataset[0].serial) + value + else: + serial = dns.serial.Serial(value) + serial = serial.value # convert back to int + if serial == 0: + serial = 1 + rdata = rdataset[0].replace(serial=serial) + new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata) + self.replace(name, new_rdataset) + + def __iter__(self): + self._check_ended() + return self._iterate_rdatasets() + + def changed(self) -> bool: + """Has this transaction changed anything? + + For read-only transactions, the result is always `False`. + + For writable transactions, the result is `True` if at some time + during the life of the transaction, the content was changed. + """ + self._check_ended() + return self._changed() + + def commit(self) -> None: + """Commit the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.Ended`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Raises an exception if the commit fails (in which case the transaction + is also rolled back. + """ + self._end(True) + + def rollback(self) -> None: + """Rollback the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.AlreadyEnded`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Rollback cannot otherwise fail. + """ + self._end(False) + + def check_put_rdataset(self, check: CheckPutRdatasetType) -> None: + """Call *check* before putting (storing) an rdataset. + + The function is called with the transaction, the name, and the rdataset. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_put_rdataset.append(check) + + def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None: + """Call *check* before deleting an rdataset. + + The function is called with the transaction, the name, the rdatatype, + and the covered rdatatype. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_delete_rdataset.append(check) + + def check_delete_name(self, check: CheckDeleteNameType) -> None: + """Call *check* before putting (storing) an rdataset. + + The function is called with the transaction and the name. + + The check function may safely make non-mutating transaction method + calls, but behavior is undefined if mutating transaction methods are + called. The check function should raise an exception if it objects to + the put, and otherwise should return ``None``. + """ + self._check_delete_name.append(check) + + def iterate_rdatasets( + self, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + """Iterate all the rdatasets in the transaction, returning + (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples. + + Note that as is usual with python iterators, adding or removing items + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_rdatasets() + + def iterate_names(self) -> Iterator[dns.name.Name]: + """Iterate all the names in the transaction. + + Note that as is usual with python iterators, adding or removing names + while iterating will invalidate the iterator and may raise `RuntimeError` + or fail to iterate over all entries.""" + self._check_ended() + return self._iterate_names() + + # + # Helper methods + # + + def _raise_if_not_empty(self, method, args): + if len(args) != 0: + raise TypeError(f"extra parameters to {method}") + + def _rdataset_from_args(self, method, deleting, args): + try: + arg = args.popleft() + if isinstance(arg, dns.rrset.RRset): + rdataset = arg.to_rdataset() + elif isinstance(arg, dns.rdataset.Rdataset): + rdataset = arg + else: + if deleting: + ttl = 0 + else: + if isinstance(arg, int): + ttl = arg + if ttl > dns.ttl.MAX_TTL: + raise ValueError(f"{method}: TTL value too big") + else: + raise TypeError(f"{method}: expected a TTL") + arg = args.popleft() + if isinstance(arg, dns.rdata.Rdata): + rdataset = dns.rdataset.from_rdata(ttl, arg) + else: + raise TypeError(f"{method}: expected an Rdata") + return rdataset + except IndexError: + if deleting: + return None + else: + # reraise + raise TypeError(f"{method}: expected more arguments") + + def _add(self, replace, args): + try: + args = collections.deque(args) + if replace: + method = "replace()" + else: + method = "add()" + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + rdataset = self._rdataset_from_args(method, False, args) + elif isinstance(arg, dns.rrset.RRset): + rrset = arg + name = rrset.name + # rrsets are also rdatasets, but they don't print the + # same and can't be stored in nodes, so convert. + rdataset = rrset.to_rdataset() + else: + raise TypeError( + f"{method} requires a name or RRset as the first argument" + ) + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f"{method} has objects of wrong RdataClass") + if rdataset.rdtype == dns.rdatatype.SOA: + (_, _, origin) = self._origin_information() + if name != origin: + raise ValueError(f"{method} has non-origin SOA") + self._raise_if_not_empty(method, args) + if not replace: + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) + if existing is not None: + if isinstance(existing, dns.rdataset.ImmutableRdataset): + trds = dns.rdataset.Rdataset( + existing.rdclass, existing.rdtype, existing.covers + ) + trds.update(existing) + existing = trds + rdataset = existing.union(rdataset) + self._checked_put_rdataset(name, rdataset) + except IndexError: + raise TypeError(f"not enough parameters to {method}") + + def _delete(self, exact, args): + try: + args = collections.deque(args) + if exact: + method = "delete_exact()" + else: + method = "delete()" + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + if len(args) > 0 and ( + isinstance(args[0], int) or isinstance(args[0], str) + ): + # deleting by type and (optionally) covers + rdtype = dns.rdatatype.RdataType.make(args.popleft()) + if len(args) > 0: + covers = dns.rdatatype.RdataType.make(args.popleft()) + else: + covers = dns.rdatatype.NONE + self._raise_if_not_empty(method, args) + existing = self._get_rdataset(name, rdtype, covers) + if existing is None: + if exact: + raise DeleteNotExact(f"{method}: missing rdataset") + else: + self._delete_rdataset(name, rdtype, covers) + return + else: + rdataset = self._rdataset_from_args(method, True, args) + elif isinstance(arg, dns.rrset.RRset): + rdataset = arg # rrsets are also rdatasets + name = rdataset.name + else: + raise TypeError( + f"{method} requires a name or RRset as the first argument" + ) + self._raise_if_not_empty(method, args) + if rdataset: + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f"{method} has objects of wrong RdataClass") + existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) + if existing is not None: + if exact: + intersection = existing.intersection(rdataset) + if intersection != rdataset: + raise DeleteNotExact(f"{method}: missing rdatas") + rdataset = existing.difference(rdataset) + if len(rdataset) == 0: + self._checked_delete_rdataset( + name, rdataset.rdtype, rdataset.covers + ) + else: + self._checked_put_rdataset(name, rdataset) + elif exact: + raise DeleteNotExact(f"{method}: missing rdataset") + else: + if exact and not self._name_exists(name): + raise DeleteNotExact(f"{method}: name not known") + self._checked_delete_name(name) + except IndexError: + raise TypeError(f"not enough parameters to {method}") + + def _check_ended(self): + if self._ended: + raise AlreadyEnded + + def _end(self, commit): + self._check_ended() + if self._ended: + raise AlreadyEnded + try: + self._end_transaction(commit) + finally: + self._ended = True + + def _checked_put_rdataset(self, name, rdataset): + for check in self._check_put_rdataset: + check(self, name, rdataset) + self._put_rdataset(name, rdataset) + + def _checked_delete_rdataset(self, name, rdtype, covers): + for check in self._check_delete_rdataset: + check(self, name, rdtype, covers) + self._delete_rdataset(name, rdtype, covers) + + def _checked_delete_name(self, name): + for check in self._check_delete_name: + check(self, name) + self._delete_name(name) + + # + # Transactions are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._ended: + if exc_type is None: + self.commit() + else: + self.rollback() + return False + + # + # This is the low level API, which must be implemented by subclasses + # of Transaction. + # + + def _get_rdataset(self, name, rdtype, covers): + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. + """ + raise NotImplementedError # pragma: no cover + + def _put_rdataset(self, name, rdataset): + """Store the rdataset.""" + raise NotImplementedError # pragma: no cover + + def _delete_name(self, name): + """Delete all data associated with *name*. + + It is not an error if the name does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _delete_rdataset(self, name, rdtype, covers): + """Delete all data associated with *name*, *rdtype*, and *covers*. + + It is not an error if the rdataset does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _name_exists(self, name): + """Does name exist? + + Returns a bool. + """ + raise NotImplementedError # pragma: no cover + + def _changed(self): + """Has this transaction changed anything?""" + raise NotImplementedError # pragma: no cover + + def _end_transaction(self, commit): + """End the transaction. + + *commit*, a bool. If ``True``, commit the transaction, otherwise + roll it back. + + If committing and the commit fails, then roll back and raise an + exception. + """ + raise NotImplementedError # pragma: no cover + + def _set_origin(self, origin): + """Set the origin. + + This method is called when reading a possibly relativized + source, and an origin setting operation occurs (e.g. $ORIGIN + in a zone file). + """ + raise NotImplementedError # pragma: no cover + + def _iterate_rdatasets(self): + """Return an iterator that yields (name, rdataset) tuples.""" + raise NotImplementedError # pragma: no cover + + def _iterate_names(self): + """Return an iterator that yields a name.""" + raise NotImplementedError # pragma: no cover + + def _get_node(self, name): + """Return the node at *name*, if any. + + Returns a node or ``None``. + """ + raise NotImplementedError # pragma: no cover + + # + # Low-level API with a default implementation, in case a subclass needs + # to override. + # + + def _origin_information(self): + # This is only used by _add() + return self.manager.origin_information() diff --git a/venv/Lib/site-packages/dns/tsig.py b/venv/Lib/site-packages/dns/tsig.py new file mode 100644 index 00000000..780852e8 --- /dev/null +++ b/venv/Lib/site-packages/dns/tsig.py @@ -0,0 +1,352 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS TSIG support.""" + +import base64 +import hashlib +import hmac +import struct + +import dns.exception +import dns.name +import dns.rcode +import dns.rdataclass + + +class BadTime(dns.exception.DNSException): + """The current time is not within the TSIG's validity time.""" + + +class BadSignature(dns.exception.DNSException): + """The TSIG signature fails to verify.""" + + +class BadKey(dns.exception.DNSException): + """The TSIG record owner name does not match the key.""" + + +class BadAlgorithm(dns.exception.DNSException): + """The TSIG algorithm does not match the key.""" + + +class PeerError(dns.exception.DNSException): + """Base class for all TSIG errors generated by the remote peer""" + + +class PeerBadKey(PeerError): + """The peer didn't know the key we used""" + + +class PeerBadSignature(PeerError): + """The peer didn't like the signature we sent""" + + +class PeerBadTime(PeerError): + """The peer didn't like the time we sent""" + + +class PeerBadTruncation(PeerError): + """The peer didn't like amount of truncation in the TSIG we sent""" + + +# TSIG Algorithms + +HMAC_MD5 = dns.name.from_text("HMAC-MD5.SIG-ALG.REG.INT") +HMAC_SHA1 = dns.name.from_text("hmac-sha1") +HMAC_SHA224 = dns.name.from_text("hmac-sha224") +HMAC_SHA256 = dns.name.from_text("hmac-sha256") +HMAC_SHA256_128 = dns.name.from_text("hmac-sha256-128") +HMAC_SHA384 = dns.name.from_text("hmac-sha384") +HMAC_SHA384_192 = dns.name.from_text("hmac-sha384-192") +HMAC_SHA512 = dns.name.from_text("hmac-sha512") +HMAC_SHA512_256 = dns.name.from_text("hmac-sha512-256") +GSS_TSIG = dns.name.from_text("gss-tsig") + +default_algorithm = HMAC_SHA256 + +mac_sizes = { + HMAC_SHA1: 20, + HMAC_SHA224: 28, + HMAC_SHA256: 32, + HMAC_SHA256_128: 16, + HMAC_SHA384: 48, + HMAC_SHA384_192: 24, + HMAC_SHA512: 64, + HMAC_SHA512_256: 32, + HMAC_MD5: 16, + GSS_TSIG: 128, # This is what we assume to be the worst case! +} + + +class GSSTSig: + """ + GSS-TSIG TSIG implementation. This uses the GSS-API context established + in the TKEY message handshake to sign messages using GSS-API message + integrity codes, per the RFC. + + In order to avoid a direct GSSAPI dependency, the keyring holds a ref + to the GSSAPI object required, rather than the key itself. + """ + + def __init__(self, gssapi_context): + self.gssapi_context = gssapi_context + self.data = b"" + self.name = "gss-tsig" + + def update(self, data): + self.data += data + + def sign(self): + # defer to the GSSAPI function to sign + return self.gssapi_context.get_signature(self.data) + + def verify(self, expected): + try: + # defer to the GSSAPI function to verify + return self.gssapi_context.verify_signature(self.data, expected) + except Exception: + # note the usage of a bare exception + raise BadSignature + + +class GSSTSigAdapter: + def __init__(self, keyring): + self.keyring = keyring + + def __call__(self, message, keyname): + if keyname in self.keyring: + key = self.keyring[keyname] + if isinstance(key, Key) and key.algorithm == GSS_TSIG: + if message: + GSSTSigAdapter.parse_tkey_and_step(key, message, keyname) + return key + else: + return None + + @classmethod + def parse_tkey_and_step(cls, key, message, keyname): + # if the message is a TKEY type, absorb the key material + # into the context using step(); this is used to allow the + # client to complete the GSSAPI negotiation before attempting + # to verify the signed response to a TKEY message exchange + try: + rrset = message.find_rrset( + message.answer, keyname, dns.rdataclass.ANY, dns.rdatatype.TKEY + ) + if rrset: + token = rrset[0].key + gssapi_context = key.secret + return gssapi_context.step(token) + except KeyError: + pass + + +class HMACTSig: + """ + HMAC TSIG implementation. This uses the HMAC python module to handle the + sign/verify operations. + """ + + _hashes = { + HMAC_SHA1: hashlib.sha1, + HMAC_SHA224: hashlib.sha224, + HMAC_SHA256: hashlib.sha256, + HMAC_SHA256_128: (hashlib.sha256, 128), + HMAC_SHA384: hashlib.sha384, + HMAC_SHA384_192: (hashlib.sha384, 192), + HMAC_SHA512: hashlib.sha512, + HMAC_SHA512_256: (hashlib.sha512, 256), + HMAC_MD5: hashlib.md5, + } + + def __init__(self, key, algorithm): + try: + hashinfo = self._hashes[algorithm] + except KeyError: + raise NotImplementedError(f"TSIG algorithm {algorithm} is not supported") + + # create the HMAC context + if isinstance(hashinfo, tuple): + self.hmac_context = hmac.new(key, digestmod=hashinfo[0]) + self.size = hashinfo[1] + else: + self.hmac_context = hmac.new(key, digestmod=hashinfo) + self.size = None + self.name = self.hmac_context.name + if self.size: + self.name += f"-{self.size}" + + def update(self, data): + return self.hmac_context.update(data) + + def sign(self): + # defer to the HMAC digest() function for that digestmod + digest = self.hmac_context.digest() + if self.size: + digest = digest[: (self.size // 8)] + return digest + + def verify(self, expected): + # re-digest and compare the results + mac = self.sign() + if not hmac.compare_digest(mac, expected): + raise BadSignature + + +def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None): + """Return a context containing the TSIG rdata for the input parameters + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object + @raises ValueError: I{other_data} is too long + @raises NotImplementedError: I{algorithm} is not supported + """ + + first = not (ctx and multi) + if first: + ctx = get_context(key) + if request_mac: + ctx.update(struct.pack("!H", len(request_mac))) + ctx.update(request_mac) + ctx.update(struct.pack("!H", rdata.original_id)) + ctx.update(wire[2:]) + if first: + ctx.update(key.name.to_digestable()) + ctx.update(struct.pack("!H", dns.rdataclass.ANY)) + ctx.update(struct.pack("!I", 0)) + if time is None: + time = rdata.time_signed + upper_time = (time >> 32) & 0xFFFF + lower_time = time & 0xFFFFFFFF + time_encoded = struct.pack("!HIH", upper_time, lower_time, rdata.fudge) + other_len = len(rdata.other) + if other_len > 65535: + raise ValueError("TSIG Other Data is > 65535 bytes") + if first: + ctx.update(key.algorithm.to_digestable() + time_encoded) + ctx.update(struct.pack("!HH", rdata.error, other_len) + rdata.other) + else: + ctx.update(time_encoded) + return ctx + + +def _maybe_start_digest(key, mac, multi): + """If this is the first message in a multi-message sequence, + start a new context. + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object + """ + if multi: + ctx = get_context(key) + ctx.update(struct.pack("!H", len(mac))) + ctx.update(mac) + return ctx + else: + return None + + +def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): + """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata + for the input parameters, the HMAC MAC calculated by applying the + TSIG signature algorithm, and the TSIG digest context. + @rtype: (string, dns.tsig.HMACTSig or dns.tsig.GSSTSig object) + @raises ValueError: I{other_data} is too long + @raises NotImplementedError: I{algorithm} is not supported + """ + + ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi) + mac = ctx.sign() + tsig = rdata.replace(time_signed=time, mac=mac) + + return (tsig, _maybe_start_digest(key, mac, multi)) + + +def validate( + wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, multi=False +): + """Validate the specified TSIG rdata against the other input parameters. + + @raises FormError: The TSIG is badly formed. + @raises BadTime: There is too much time skew between the client and the + server. + @raises BadSignature: The TSIG signature did not validate + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object""" + + (adcount,) = struct.unpack("!H", wire[10:12]) + if adcount == 0: + raise dns.exception.FormError + adcount -= 1 + new_wire = wire[0:10] + struct.pack("!H", adcount) + wire[12:tsig_start] + if rdata.error != 0: + if rdata.error == dns.rcode.BADSIG: + raise PeerBadSignature + elif rdata.error == dns.rcode.BADKEY: + raise PeerBadKey + elif rdata.error == dns.rcode.BADTIME: + raise PeerBadTime + elif rdata.error == dns.rcode.BADTRUNC: + raise PeerBadTruncation + else: + raise PeerError("unknown TSIG error code %d" % rdata.error) + if abs(rdata.time_signed - now) > rdata.fudge: + raise BadTime + if key.name != owner: + raise BadKey + if key.algorithm != rdata.algorithm: + raise BadAlgorithm + ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi) + ctx.verify(rdata.mac) + return _maybe_start_digest(key, rdata.mac, multi) + + +def get_context(key): + """Returns an HMAC context for the specified key. + + @rtype: HMAC context + @raises NotImplementedError: I{algorithm} is not supported + """ + + if key.algorithm == GSS_TSIG: + return GSSTSig(key.secret) + else: + return HMACTSig(key.secret, key.algorithm) + + +class Key: + def __init__(self, name, secret, algorithm=default_algorithm): + if isinstance(name, str): + name = dns.name.from_text(name) + self.name = name + if isinstance(secret, str): + secret = base64.decodebytes(secret.encode()) + self.secret = secret + if isinstance(algorithm, str): + algorithm = dns.name.from_text(algorithm) + self.algorithm = algorithm + + def __eq__(self, other): + return ( + isinstance(other, Key) + and self.name == other.name + and self.secret == other.secret + and self.algorithm == other.algorithm + ) + + def __repr__(self): + r = f" Dict[dns.name.Name, dns.tsig.Key]: + """Convert a dictionary containing (textual DNS name, base64 secret) + pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or + a dictionary containing (textual DNS name, (algorithm, base64 secret)) + pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs. + @rtype: dict""" + + keyring = {} + for name, value in textring.items(): + kname = dns.name.from_text(name) + if isinstance(value, str): + keyring[kname] = dns.tsig.Key(kname, value).secret + else: + (algorithm, secret) = value + keyring[kname] = dns.tsig.Key(kname, secret, algorithm) + return keyring + + +def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]: + """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs + into a text keyring which has (textual DNS name, (textual algorithm, + base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes) + pairs into a text keyring which has (textual DNS name, base64 secret) pairs. + @rtype: dict""" + + textring = {} + + def b64encode(secret): + return base64.encodebytes(secret).decode().rstrip() + + for name, key in keyring.items(): + tname = name.to_text() + if isinstance(key, bytes): + textring[tname] = b64encode(key) + else: + if isinstance(key.secret, bytes): + text_secret = b64encode(key.secret) + else: + text_secret = str(key.secret) + + textring[tname] = (key.algorithm.to_text(), text_secret) + return textring diff --git a/venv/Lib/site-packages/dns/ttl.py b/venv/Lib/site-packages/dns/ttl.py new file mode 100644 index 00000000..264b0338 --- /dev/null +++ b/venv/Lib/site-packages/dns/ttl.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS TTL conversion.""" + +from typing import Union + +import dns.exception + +# Technically TTLs are supposed to be between 0 and 2**31 - 1, with values +# greater than that interpreted as 0, but we do not impose this policy here +# as values > 2**31 - 1 occur in real world data. +# +# We leave it to applications to impose tighter bounds if desired. +MAX_TTL = 2**32 - 1 + + +class BadTTL(dns.exception.SyntaxError): + """DNS TTL value is not well-formed.""" + + +def from_text(text: str) -> int: + """Convert the text form of a TTL to an integer. + + The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported. + + *text*, a ``str``, the textual TTL. + + Raises ``dns.ttl.BadTTL`` if the TTL is not well-formed. + + Returns an ``int``. + """ + + if text.isdigit(): + total = int(text) + elif len(text) == 0: + raise BadTTL + else: + total = 0 + current = 0 + need_digit = True + for c in text: + if c.isdigit(): + current *= 10 + current += int(c) + need_digit = False + else: + if need_digit: + raise BadTTL + c = c.lower() + if c == "w": + total += current * 604800 + elif c == "d": + total += current * 86400 + elif c == "h": + total += current * 3600 + elif c == "m": + total += current * 60 + elif c == "s": + total += current + else: + raise BadTTL("unknown unit '%s'" % c) + current = 0 + need_digit = True + if not current == 0: + raise BadTTL("trailing integer") + if total < 0 or total > MAX_TTL: + raise BadTTL("TTL should be between 0 and 2**32 - 1 (inclusive)") + return total + + +def make(value: Union[int, str]) -> int: + if isinstance(value, int): + return value + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError("cannot convert value to TTL") diff --git a/venv/Lib/site-packages/dns/update.py b/venv/Lib/site-packages/dns/update.py new file mode 100644 index 00000000..bf1157ac --- /dev/null +++ b/venv/Lib/site-packages/dns/update.py @@ -0,0 +1,386 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Dynamic Update Support""" + +from typing import Any, List, Optional, Union + +import dns.message +import dns.name +import dns.opcode +import dns.rdata +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.tsig + + +class UpdateSection(dns.enum.IntEnum): + """Update sections""" + + ZONE = 0 + PREREQ = 1 + UPDATE = 2 + ADDITIONAL = 3 + + @classmethod + def _maximum(cls): + return 3 + + +class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] + # ignore the mypy error here as we mean to use a different enum + _section_enum = UpdateSection # type: ignore + + def __init__( + self, + zone: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + keyring: Optional[Any] = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, + id: Optional[int] = None, + ): + """Initialize a new DNS Update object. + + See the documentation of the Message class for a complete + description of the keyring dictionary. + + *zone*, a ``dns.name.Name``, ``str``, or ``None``, the zone + which is being updated. ``None`` should only be used by dnspython's + message constructors, as a zone is required for the convenience + methods like ``add()``, ``replace()``, etc. + + *rdclass*, an ``int`` or ``str``, the class of the zone. + + The *keyring*, *keyname*, and *keyalgorithm* parameters are passed to + ``use_tsig()``; see its documentation for details. + """ + super().__init__(id=id) + self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE) + if isinstance(zone, str): + zone = dns.name.from_text(zone) + self.origin = zone + rdclass = dns.rdataclass.RdataClass.make(rdclass) + self.zone_rdclass = rdclass + if self.origin: + self.find_rrset( + self.zone, + self.origin, + rdclass, + dns.rdatatype.SOA, + create=True, + force_unique=True, + ) + if keyring is not None: + self.use_tsig(keyring, keyname, algorithm=keyalgorithm) + + @property + def zone(self) -> List[dns.rrset.RRset]: + """The zone section.""" + return self.sections[0] + + @zone.setter + def zone(self, v): + self.sections[0] = v + + @property + def prerequisite(self) -> List[dns.rrset.RRset]: + """The prerequisite section.""" + return self.sections[1] + + @prerequisite.setter + def prerequisite(self, v): + self.sections[1] = v + + @property + def update(self) -> List[dns.rrset.RRset]: + """The update section.""" + return self.sections[2] + + @update.setter + def update(self, v): + self.sections[2] = v + + def _add_rr(self, name, ttl, rd, deleting=None, section=None): + """Add a single RR to the update section.""" + + if section is None: + section = self.update + covers = rd.covers() + rrset = self.find_rrset( + section, name, self.zone_rdclass, rd.rdtype, covers, deleting, True, True + ) + rrset.add(rd, ttl) + + def _add(self, replace, section, name, *args): + """Add records. + + *replace* is the replacement mode. If ``False``, + RRs are added to an existing RRset; if ``True``, the RRset + is replaced with the specified contents. The second + argument is the section to add to. The third argument + is always a name. The other arguments can be: + + - rdataset... + + - ttl, rdata... + + - ttl, rdtype, string... + """ + + if isinstance(name, str): + name = dns.name.from_text(name, None) + if isinstance(args[0], dns.rdataset.Rdataset): + for rds in args: + if replace: + self.delete(name, rds.rdtype) + for rd in rds: + self._add_rr(name, rds.ttl, rd, section=section) + else: + args = list(args) + ttl = int(args.pop(0)) + if isinstance(args[0], dns.rdata.Rdata): + if replace: + self.delete(name, args[0].rdtype) + for rd in args: + self._add_rr(name, ttl, rd, section=section) + else: + rdtype = dns.rdatatype.RdataType.make(args.pop(0)) + if replace: + self.delete(name, rdtype) + for s in args: + rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, self.origin) + self._add_rr(name, ttl, rd, section=section) + + def add(self, name: Union[dns.name.Name, str], *args: Any) -> None: + """Add records. + + The first argument is always a name. The other + arguments can be: + + - rdataset... + + - ttl, rdata... + + - ttl, rdtype, string... + """ + + self._add(False, self.update, name, *args) + + def delete(self, name: Union[dns.name.Name, str], *args: Any) -> None: + """Delete records. + + The first argument is always a name. The other + arguments can be: + + - *empty* + + - rdataset... + + - rdata... + + - rdtype, [string...] + """ + + if isinstance(name, str): + name = dns.name.from_text(name, None) + if len(args) == 0: + self.find_rrset( + self.update, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) + elif isinstance(args[0], dns.rdataset.Rdataset): + for rds in args: + for rd in rds: + self._add_rr(name, 0, rd, dns.rdataclass.NONE) + else: + largs = list(args) + if isinstance(largs[0], dns.rdata.Rdata): + for rd in largs: + self._add_rr(name, 0, rd, dns.rdataclass.NONE) + else: + rdtype = dns.rdatatype.RdataType.make(largs.pop(0)) + if len(largs) == 0: + self.find_rrset( + self.update, + name, + self.zone_rdclass, + rdtype, + dns.rdatatype.NONE, + dns.rdataclass.ANY, + True, + True, + ) + else: + for s in largs: + rd = dns.rdata.from_text( + self.zone_rdclass, + rdtype, + s, # type: ignore[arg-type] + self.origin, + ) + self._add_rr(name, 0, rd, dns.rdataclass.NONE) + + def replace(self, name: Union[dns.name.Name, str], *args: Any) -> None: + """Replace records. + + The first argument is always a name. The other + arguments can be: + + - rdataset... + + - ttl, rdata... + + - ttl, rdtype, string... + + Note that if you want to replace the entire node, you should do + a delete of the name followed by one or more calls to add. + """ + + self._add(True, self.update, name, *args) + + def present(self, name: Union[dns.name.Name, str], *args: Any) -> None: + """Require that an owner name (and optionally an rdata type, + or specific rdataset) exists as a prerequisite to the + execution of the update. + + The first argument is always a name. + The other arguments can be: + + - rdataset... + + - rdata... + + - rdtype, string... + """ + + if isinstance(name, str): + name = dns.name.from_text(name, None) + if len(args) == 0: + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) + elif ( + isinstance(args[0], dns.rdataset.Rdataset) + or isinstance(args[0], dns.rdata.Rdata) + or len(args) > 1 + ): + if not isinstance(args[0], dns.rdataset.Rdataset): + # Add a 0 TTL + largs = list(args) + largs.insert(0, 0) # type: ignore[arg-type] + self._add(False, self.prerequisite, name, *largs) + else: + self._add(False, self.prerequisite, name, *args) + else: + rdtype = dns.rdatatype.RdataType.make(args[0]) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.ANY, + rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) + + def absent( + self, + name: Union[dns.name.Name, str], + rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None, + ) -> None: + """Require that an owner name (and optionally an rdata type) does + not exist as a prerequisite to the execution of the update.""" + + if isinstance(name, str): + name = dns.name.from_text(name, None) + if rdtype is None: + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + dns.rdatatype.ANY, + dns.rdatatype.NONE, + None, + True, + True, + ) + else: + rdtype = dns.rdatatype.RdataType.make(rdtype) + self.find_rrset( + self.prerequisite, + name, + dns.rdataclass.NONE, + rdtype, + dns.rdatatype.NONE, + None, + True, + True, + ) + + def _get_one_rr_per_rrset(self, value): + # Updates are always one_rr_per_rrset + return True + + def _parse_rr_header(self, section, name, rdclass, rdtype): + deleting = None + empty = False + if section == UpdateSection.ZONE: + if ( + dns.rdataclass.is_metaclass(rdclass) + or rdtype != dns.rdatatype.SOA + or self.zone + ): + raise dns.exception.FormError + else: + if not self.zone: + raise dns.exception.FormError + if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): + deleting = rdclass + rdclass = self.zone[0].rdclass + empty = ( + deleting == dns.rdataclass.ANY or section == UpdateSection.PREREQ + ) + return (rdclass, rdtype, deleting, empty) + + +# backwards compatibility +Update = UpdateMessage + +### BEGIN generated UpdateSection constants + +ZONE = UpdateSection.ZONE +PREREQ = UpdateSection.PREREQ +UPDATE = UpdateSection.UPDATE +ADDITIONAL = UpdateSection.ADDITIONAL + +### END generated UpdateSection constants diff --git a/venv/Lib/site-packages/dns/version.py b/venv/Lib/site-packages/dns/version.py new file mode 100644 index 00000000..251f2583 --- /dev/null +++ b/venv/Lib/site-packages/dns/version.py @@ -0,0 +1,58 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""dnspython release version information.""" + +#: MAJOR +MAJOR = 2 +#: MINOR +MINOR = 6 +#: MICRO +MICRO = 1 +#: RELEASELEVEL +RELEASELEVEL = 0x0F +#: SERIAL +SERIAL = 0 + +if RELEASELEVEL == 0x0F: # pragma: no cover lgtm[py/unreachable-statement] + #: version + version = "%d.%d.%d" % (MAJOR, MINOR, MICRO) # lgtm[py/unreachable-statement] +elif RELEASELEVEL == 0x00: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%ddev%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] +elif RELEASELEVEL == 0x0C: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%drc%d" % ( + MAJOR, + MINOR, + MICRO, + SERIAL, + ) # lgtm[py/unreachable-statement] +else: # pragma: no cover lgtm[py/unreachable-statement] + version = "%d.%d.%d%x%d" % ( + MAJOR, + MINOR, + MICRO, + RELEASELEVEL, + SERIAL, + ) # lgtm[py/unreachable-statement] + +#: hexversion +hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | SERIAL diff --git a/venv/Lib/site-packages/dns/versioned.py b/venv/Lib/site-packages/dns/versioned.py new file mode 100644 index 00000000..fd78e674 --- /dev/null +++ b/venv/Lib/site-packages/dns/versioned.py @@ -0,0 +1,318 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""DNS Versioned Zones.""" + +import collections +import threading +from typing import Callable, Deque, Optional, Set, Union + +import dns.exception +import dns.immutable +import dns.name +import dns.node +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rdtypes.ANY.SOA +import dns.zone + + +class UseTransaction(dns.exception.DNSException): + """To alter a versioned zone, use a transaction.""" + + +# Backwards compatibility +Node = dns.zone.VersionedNode +ImmutableNode = dns.zone.ImmutableVersionedNode +Version = dns.zone.Version +WritableVersion = dns.zone.WritableVersion +ImmutableVersion = dns.zone.ImmutableVersion +Transaction = dns.zone.Transaction + + +class Zone(dns.zone.Zone): # lgtm[py/missing-equals] + __slots__ = [ + "_versions", + "_versions_lock", + "_write_txn", + "_write_waiters", + "_write_event", + "_pruning_policy", + "_readers", + ] + + node_factory = Node + + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None, + ): + """Initialize a versioned zone object. + + *origin* is the origin of the zone. It may be a ``dns.name.Name``, + a ``str``, or ``None``. If ``None``, then the zone's origin will + be set by the first ``$ORIGIN`` line in a zone file. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning + a ``bool``, or ``None``. Should the version be pruned? If ``None``, + the default policy, which retains one version is used. + """ + super().__init__(origin, rdclass, relativize) + self._versions: Deque[Version] = collections.deque() + self._version_lock = threading.Lock() + if pruning_policy is None: + self._pruning_policy = self._default_pruning_policy + else: + self._pruning_policy = pruning_policy + self._write_txn: Optional[Transaction] = None + self._write_event: Optional[threading.Event] = None + self._write_waiters: Deque[threading.Event] = collections.deque() + self._readers: Set[Transaction] = set() + self._commit_version_unlocked( + None, WritableVersion(self, replacement=True), origin + ) + + def reader( + self, id: Optional[int] = None, serial: Optional[int] = None + ) -> Transaction: # pylint: disable=arguments-differ + if id is not None and serial is not None: + raise ValueError("cannot specify both id and serial") + with self._version_lock: + if id is not None: + version = None + for v in reversed(self._versions): + if v.id == id: + version = v + break + if version is None: + raise KeyError("version not found") + elif serial is not None: + if self.relativize: + oname = dns.name.empty + else: + assert self.origin is not None + oname = self.origin + version = None + for v in reversed(self._versions): + n = v.nodes.get(oname) + if n: + rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) + if rds and rds[0].serial == serial: + version = v + break + if version is None: + raise KeyError("serial not found") + else: + version = self._versions[-1] + txn = Transaction(self, False, version) + self._readers.add(txn) + return txn + + def writer(self, replacement: bool = False) -> Transaction: + event = None + while True: + with self._version_lock: + # Checking event == self._write_event ensures that either + # no one was waiting before we got lucky and found no write + # txn, or we were the one who was waiting and got woken up. + # This prevents "taking cuts" when creating a write txn. + if self._write_txn is None and event == self._write_event: + # Creating the transaction defers version setup + # (i.e. copying the nodes dictionary) until we + # give up the lock, so that we hold the lock as + # short a time as possible. This is why we call + # _setup_version() below. + self._write_txn = Transaction( + self, replacement, make_immutable=True + ) + # give up our exclusive right to make a Transaction + self._write_event = None + break + # Someone else is writing already, so we will have to + # wait, but we want to do the actual wait outside the + # lock. + event = threading.Event() + self._write_waiters.append(event) + # wait (note we gave up the lock!) + # + # We only wake one sleeper at a time, so it's important + # that no event waiter can exit this method (e.g. via + # cancellation) without returning a transaction or waking + # someone else up. + # + # This is not a problem with Threading module threads as + # they cannot be canceled, but could be an issue with trio + # tasks when we do the async version of writer(). + # I.e. we'd need to do something like: + # + # try: + # event.wait() + # except trio.Cancelled: + # with self._version_lock: + # self._maybe_wakeup_one_waiter_unlocked() + # raise + # + event.wait() + # Do the deferred version setup. + self._write_txn._setup_version() + return self._write_txn + + def _maybe_wakeup_one_waiter_unlocked(self): + if len(self._write_waiters) > 0: + self._write_event = self._write_waiters.popleft() + self._write_event.set() + + # pylint: disable=unused-argument + def _default_pruning_policy(self, zone, version): + return True + + # pylint: enable=unused-argument + + def _prune_versions_unlocked(self): + assert len(self._versions) > 0 + # Don't ever prune a version greater than or equal to one that + # a reader has open. This pins versions in memory while the + # reader is open, and importantly lets the reader open a txn on + # a successor version (e.g. if generating an IXFR). + # + # Note our definition of least_kept also ensures we do not try to + # delete the greatest version. + if len(self._readers) > 0: + least_kept = min(txn.version.id for txn in self._readers) + else: + least_kept = self._versions[-1].id + while self._versions[0].id < least_kept and self._pruning_policy( + self, self._versions[0] + ): + self._versions.popleft() + + def set_max_versions(self, max_versions: Optional[int]) -> None: + """Set a pruning policy that retains up to the specified number + of versions + """ + if max_versions is not None and max_versions < 1: + raise ValueError("max versions must be at least 1") + if max_versions is None: + + def policy(zone, _): # pylint: disable=unused-argument + return False + + else: + + def policy(zone, _): + return len(zone._versions) > max_versions + + self.set_pruning_policy(policy) + + def set_pruning_policy( + self, policy: Optional[Callable[["Zone", Version], Optional[bool]]] + ) -> None: + """Set the pruning policy for the zone. + + The *policy* function takes a `Version` and returns `True` if + the version should be pruned, and `False` otherwise. `None` + may also be specified for policy, in which case the default policy + is used. + + Pruning checking proceeds from the least version and the first + time the function returns `False`, the checking stops. I.e. the + retained versions are always a consecutive sequence. + """ + if policy is None: + policy = self._default_pruning_policy + with self._version_lock: + self._pruning_policy = policy + self._prune_versions_unlocked() + + def _end_read(self, txn): + with self._version_lock: + self._readers.remove(txn) + self._prune_versions_unlocked() + + def _end_write_unlocked(self, txn): + assert self._write_txn == txn + self._write_txn = None + self._maybe_wakeup_one_waiter_unlocked() + + def _end_write(self, txn): + with self._version_lock: + self._end_write_unlocked(txn) + + def _commit_version_unlocked(self, txn, version, origin): + self._versions.append(version) + self._prune_versions_unlocked() + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + # txn can be None in __init__ when we make the empty version. + if txn is not None: + self._end_write_unlocked(txn) + + def _commit_version(self, txn, version, origin): + with self._version_lock: + self._commit_version_unlocked(txn, version, origin) + + def _get_next_version_id(self): + if len(self._versions) > 0: + id = self._versions[-1].id + 1 + else: + id = 1 + return id + + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: + if create: + raise UseTransaction + return super().find_node(name) + + def delete_node(self, name: Union[dns.name.Name, str]) -> None: + raise UseTransaction + + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + if create: + raise UseTransaction + rdataset = super().find_rdataset(name, rdtype, covers) + return dns.rdataset.ImmutableRdataset(rdataset) + + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + if create: + raise UseTransaction + rdataset = super().get_rdataset(name, rdtype, covers) + if rdataset is not None: + return dns.rdataset.ImmutableRdataset(rdataset) + else: + return None + + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: + raise UseTransaction + + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: + raise UseTransaction diff --git a/venv/Lib/site-packages/dns/win32util.py b/venv/Lib/site-packages/dns/win32util.py new file mode 100644 index 00000000..aaa7e93e --- /dev/null +++ b/venv/Lib/site-packages/dns/win32util.py @@ -0,0 +1,252 @@ +import sys + +import dns._features + +if sys.platform == "win32": + from typing import Any + + import dns.name + + _prefer_wmi = True + + import winreg # pylint: disable=import-error + + # Keep pylint quiet on non-windows. + try: + WindowsError is None # pylint: disable=used-before-assignment + except KeyError: + WindowsError = Exception + + if dns._features.have("wmi"): + import threading + + import pythoncom # pylint: disable=import-error + import wmi # pylint: disable=import-error + + _have_wmi = True + else: + _have_wmi = False + + def _config_domain(domain): + # Sometimes DHCP servers add a '.' prefix to the default domain, and + # Windows just stores such values in the registry (see #687). + # Check for this and fix it. + if domain.startswith("."): + domain = domain[1:] + return dns.name.from_text(domain) + + class DnsInfo: + def __init__(self): + self.domain = None + self.nameservers = [] + self.search = [] + + if _have_wmi: + + class _WMIGetter(threading.Thread): + def __init__(self): + super().__init__() + self.info = DnsInfo() + + def run(self): + pythoncom.CoInitialize() + try: + system = wmi.WMI() + for interface in system.Win32_NetworkAdapterConfiguration(): + if interface.IPEnabled and interface.DNSServerSearchOrder: + self.info.nameservers = list(interface.DNSServerSearchOrder) + if interface.DNSDomain: + self.info.domain = _config_domain(interface.DNSDomain) + if interface.DNSDomainSuffixSearchOrder: + self.info.search = [ + _config_domain(x) + for x in interface.DNSDomainSuffixSearchOrder + ] + break + finally: + pythoncom.CoUninitialize() + + def get(self): + # We always run in a separate thread to avoid any issues with + # the COM threading model. + self.start() + self.join() + return self.info + + else: + + class _WMIGetter: # type: ignore + pass + + class _RegistryGetter: + def __init__(self): + self.info = DnsInfo() + + def _determine_split_char(self, entry): + # + # The windows registry irritatingly changes the list element + # delimiter in between ' ' and ',' (and vice-versa) in various + # versions of windows. + # + if entry.find(" ") >= 0: + split_char = " " + elif entry.find(",") >= 0: + split_char = "," + else: + # probably a singleton; treat as a space-separated list. + split_char = " " + return split_char + + def _config_nameservers(self, nameservers): + split_char = self._determine_split_char(nameservers) + ns_list = nameservers.split(split_char) + for ns in ns_list: + if ns not in self.info.nameservers: + self.info.nameservers.append(ns) + + def _config_search(self, search): + split_char = self._determine_split_char(search) + search_list = search.split(split_char) + for s in search_list: + s = _config_domain(s) + if s not in self.info.search: + self.info.search.append(s) + + def _config_fromkey(self, key, always_try_domain): + try: + servers, _ = winreg.QueryValueEx(key, "NameServer") + except WindowsError: + servers = None + if servers: + self._config_nameservers(servers) + if servers or always_try_domain: + try: + dom, _ = winreg.QueryValueEx(key, "Domain") + if dom: + self.info.domain = _config_domain(dom) + except WindowsError: + pass + else: + try: + servers, _ = winreg.QueryValueEx(key, "DhcpNameServer") + except WindowsError: + servers = None + if servers: + self._config_nameservers(servers) + try: + dom, _ = winreg.QueryValueEx(key, "DhcpDomain") + if dom: + self.info.domain = _config_domain(dom) + except WindowsError: + pass + try: + search, _ = winreg.QueryValueEx(key, "SearchList") + except WindowsError: + search = None + if search is None: + try: + search, _ = winreg.QueryValueEx(key, "DhcpSearchList") + except WindowsError: + search = None + if search: + self._config_search(search) + + def _is_nic_enabled(self, lm, guid): + # Look in the Windows Registry to determine whether the network + # interface corresponding to the given guid is enabled. + # + # (Code contributed by Paul Marks, thanks!) + # + try: + # This hard-coded location seems to be consistent, at least + # from Windows 2000 through Vista. + connection_key = winreg.OpenKey( + lm, + r"SYSTEM\CurrentControlSet\Control\Network" + r"\{4D36E972-E325-11CE-BFC1-08002BE10318}" + r"\%s\Connection" % guid, + ) + + try: + # The PnpInstanceID points to a key inside Enum + (pnp_id, ttype) = winreg.QueryValueEx( + connection_key, "PnpInstanceID" + ) + + if ttype != winreg.REG_SZ: + raise ValueError # pragma: no cover + + device_key = winreg.OpenKey( + lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id + ) + + try: + # Get ConfigFlags for this device + (flags, ttype) = winreg.QueryValueEx(device_key, "ConfigFlags") + + if ttype != winreg.REG_DWORD: + raise ValueError # pragma: no cover + + # Based on experimentation, bit 0x1 indicates that the + # device is disabled. + # + # XXXRTH I suspect we really want to & with 0x03 so + # that CONFIGFLAGS_REMOVED devices are also ignored, + # but we're shifting to WMI as ConfigFlags is not + # supposed to be used. + return not flags & 0x1 + + finally: + device_key.Close() + finally: + connection_key.Close() + except Exception: # pragma: no cover + return False + + def get(self): + """Extract resolver configuration from the Windows registry.""" + + lm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) + try: + tcp_params = winreg.OpenKey( + lm, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" + ) + try: + self._config_fromkey(tcp_params, True) + finally: + tcp_params.Close() + interfaces = winreg.OpenKey( + lm, + r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces", + ) + try: + i = 0 + while True: + try: + guid = winreg.EnumKey(interfaces, i) + i += 1 + key = winreg.OpenKey(interfaces, guid) + try: + if not self._is_nic_enabled(lm, guid): + continue + self._config_fromkey(key, False) + finally: + key.Close() + except EnvironmentError: + break + finally: + interfaces.Close() + finally: + lm.Close() + return self.info + + _getter_class: Any + if _have_wmi and _prefer_wmi: + _getter_class = _WMIGetter + else: + _getter_class = _RegistryGetter + + def get_dns_info(): + """Extract resolver configuration.""" + getter = _getter_class() + return getter.get() diff --git a/venv/Lib/site-packages/dns/wire.py b/venv/Lib/site-packages/dns/wire.py new file mode 100644 index 00000000..9f9b1573 --- /dev/null +++ b/venv/Lib/site-packages/dns/wire.py @@ -0,0 +1,89 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import contextlib +import struct +from typing import Iterator, Optional, Tuple + +import dns.exception +import dns.name + + +class Parser: + def __init__(self, wire: bytes, current: int = 0): + self.wire = wire + self.current = 0 + self.end = len(self.wire) + if current: + self.seek(current) + self.furthest = current + + def remaining(self) -> int: + return self.end - self.current + + def get_bytes(self, size: int) -> bytes: + assert size >= 0 + if size > self.remaining(): + raise dns.exception.FormError + output = self.wire[self.current : self.current + size] + self.current += size + self.furthest = max(self.furthest, self.current) + return output + + def get_counted_bytes(self, length_size: int = 1) -> bytes: + length = int.from_bytes(self.get_bytes(length_size), "big") + return self.get_bytes(length) + + def get_remaining(self) -> bytes: + return self.get_bytes(self.remaining()) + + def get_uint8(self) -> int: + return struct.unpack("!B", self.get_bytes(1))[0] + + def get_uint16(self) -> int: + return struct.unpack("!H", self.get_bytes(2))[0] + + def get_uint32(self) -> int: + return struct.unpack("!I", self.get_bytes(4))[0] + + def get_uint48(self) -> int: + return int.from_bytes(self.get_bytes(6), "big") + + def get_struct(self, format: str) -> Tuple: + return struct.unpack(format, self.get_bytes(struct.calcsize(format))) + + def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name": + name = dns.name.from_wire_parser(self) + if origin: + name = name.relativize(origin) + return name + + def seek(self, where: int) -> None: + # Note that seeking to the end is OK! (If you try to read + # after such a seek, you'll get an exception as expected.) + if where < 0 or where > self.end: + raise dns.exception.FormError + self.current = where + + @contextlib.contextmanager + def restrict_to(self, size: int) -> Iterator: + assert size >= 0 + if size > self.remaining(): + raise dns.exception.FormError + saved_end = self.end + try: + self.end = self.current + size + yield + # We make this check here and not in the finally as we + # don't want to raise if we're already raising for some + # other reason. + if self.current != self.end: + raise dns.exception.FormError + finally: + self.end = saved_end + + @contextlib.contextmanager + def restore_furthest(self) -> Iterator: + try: + yield None + finally: + self.current = self.furthest diff --git a/venv/Lib/site-packages/dns/xfr.py b/venv/Lib/site-packages/dns/xfr.py new file mode 100644 index 00000000..dd247d33 --- /dev/null +++ b/venv/Lib/site-packages/dns/xfr.py @@ -0,0 +1,343 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +from typing import Any, List, Optional, Tuple, Union + +import dns.exception +import dns.message +import dns.name +import dns.rcode +import dns.rdataset +import dns.rdatatype +import dns.serial +import dns.transaction +import dns.tsig +import dns.zone + + +class TransferError(dns.exception.DNSException): + """A zone transfer response got a non-zero rcode.""" + + def __init__(self, rcode): + message = "Zone transfer error: %s" % dns.rcode.to_text(rcode) + super().__init__(message) + self.rcode = rcode + + +class SerialWentBackwards(dns.exception.FormError): + """The current serial number is less than the serial we know.""" + + +class UseTCP(dns.exception.DNSException): + """This IXFR cannot be completed with UDP.""" + + +class Inbound: + """ + State machine for zone transfers. + """ + + def __init__( + self, + txn_manager: dns.transaction.TransactionManager, + rdtype: dns.rdatatype.RdataType = dns.rdatatype.AXFR, + serial: Optional[int] = None, + is_udp: bool = False, + ): + """Initialize an inbound zone transfer. + + *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. + + *rdtype* can be `dns.rdatatype.AXFR` or `dns.rdatatype.IXFR` + + *serial* is the base serial number for IXFRs, and is required in + that case. + + *is_udp*, a ``bool`` indidicates if UDP is being used for this + XFR. + """ + self.txn_manager = txn_manager + self.txn: Optional[dns.transaction.Transaction] = None + self.rdtype = rdtype + if rdtype == dns.rdatatype.IXFR: + if serial is None: + raise ValueError("a starting serial must be supplied for IXFRs") + elif is_udp: + raise ValueError("is_udp specified for AXFR") + self.serial = serial + self.is_udp = is_udp + (_, _, self.origin) = txn_manager.origin_information() + self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None + self.done = False + self.expecting_SOA = False + self.delete_mode = False + + def process_message(self, message: dns.message.Message) -> bool: + """Process one message in the transfer. + + The message should have the same relativization as was specified when + the `dns.xfr.Inbound` was created. The message should also have been + created with `one_rr_per_rrset=True` because order matters. + + Returns `True` if the transfer is complete, and `False` otherwise. + """ + if self.txn is None: + replacement = self.rdtype == dns.rdatatype.AXFR + self.txn = self.txn_manager.writer(replacement) + rcode = message.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) + # + # We don't require a question section, but if it is present is + # should be correct. + # + if len(message.question) > 0: + if message.question[0].name != self.origin: + raise dns.exception.FormError("wrong question name") + if message.question[0].rdtype != self.rdtype: + raise dns.exception.FormError("wrong question rdatatype") + answer_index = 0 + if self.soa_rdataset is None: + # + # This is the first message. We're expecting an SOA at + # the origin. + # + if not message.answer or message.answer[0].name != self.origin: + raise dns.exception.FormError("No answer or RRset not for zone origin") + rrset = message.answer[0] + rdataset = rrset + if rdataset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError("first RRset is not an SOA") + answer_index = 1 + self.soa_rdataset = rdataset.copy() + if self.rdtype == dns.rdatatype.IXFR: + if self.soa_rdataset[0].serial == self.serial: + # + # We're already up-to-date. + # + self.done = True + elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial: + # It went backwards! + raise SerialWentBackwards + else: + if self.is_udp and len(message.answer[answer_index:]) == 0: + # + # There are no more records, so this is the + # "truncated" response. Say to use TCP + # + raise UseTCP + # + # Note we're expecting another SOA so we can detect + # if this IXFR response is an AXFR-style response. + # + self.expecting_SOA = True + # + # Process the answer section (other than the initial SOA in + # the first message). + # + for rrset in message.answer[answer_index:]: + name = rrset.name + rdataset = rrset + if self.done: + raise dns.exception.FormError("answers after final SOA") + assert self.txn is not None # for mypy + if rdataset.rdtype == dns.rdatatype.SOA and name == self.origin: + # + # Every time we see an origin SOA delete_mode inverts + # + if self.rdtype == dns.rdatatype.IXFR: + self.delete_mode = not self.delete_mode + # + # If this SOA Rdataset is equal to the first we saw + # then we're finished. If this is an IXFR we also + # check that we're seeing the record in the expected + # part of the response. + # + if rdataset == self.soa_rdataset and ( + self.rdtype == dns.rdatatype.AXFR + or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode) + ): + # + # This is the final SOA + # + if self.expecting_SOA: + # We got an empty IXFR sequence! + raise dns.exception.FormError("empty IXFR sequence") + if ( + self.rdtype == dns.rdatatype.IXFR + and self.serial != rdataset[0].serial + ): + raise dns.exception.FormError("unexpected end of IXFR sequence") + self.txn.replace(name, rdataset) + self.txn.commit() + self.txn = None + self.done = True + else: + # + # This is not the final SOA + # + self.expecting_SOA = False + if self.rdtype == dns.rdatatype.IXFR: + if self.delete_mode: + # This is the start of an IXFR deletion set + if rdataset[0].serial != self.serial: + raise dns.exception.FormError( + "IXFR base serial mismatch" + ) + else: + # This is the start of an IXFR addition set + self.serial = rdataset[0].serial + self.txn.replace(name, rdataset) + else: + # We saw a non-final SOA for the origin in an AXFR. + raise dns.exception.FormError("unexpected origin SOA in AXFR") + continue + if self.expecting_SOA: + # + # We made an IXFR request and are expecting another + # SOA RR, but saw something else, so this must be an + # AXFR response. + # + self.rdtype = dns.rdatatype.AXFR + self.expecting_SOA = False + self.delete_mode = False + self.txn.rollback() + self.txn = self.txn_manager.writer(True) + # + # Note we are falling through into the code below + # so whatever rdataset this was gets written. + # + # Add or remove the data + if self.delete_mode: + self.txn.delete_exact(name, rdataset) + else: + self.txn.add(name, rdataset) + if self.is_udp and not self.done: + # + # This is a UDP IXFR and we didn't get to done, and we didn't + # get the proper "truncated" response + # + raise dns.exception.FormError("unexpected end of UDP IXFR") + return self.done + + # + # Inbounds are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.txn: + self.txn.rollback() + return False + + +def make_query( + txn_manager: dns.transaction.TransactionManager, + serial: Optional[int] = 0, + use_edns: Optional[Union[int, bool]] = None, + ednsflags: Optional[int] = None, + payload: Optional[int] = None, + request_payload: Optional[int] = None, + options: Optional[List[dns.edns.Option]] = None, + keyring: Any = None, + keyname: Optional[dns.name.Name] = None, + keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, +) -> Tuple[dns.message.QueryMessage, Optional[int]]: + """Make an AXFR or IXFR query. + + *txn_manager* is a ``dns.transaction.TransactionManager``, typically a + ``dns.zone.Zone``. + + *serial* is an ``int`` or ``None``. If 0, then IXFR will be + attempted using the most recent serial number from the + *txn_manager*; it is the caller's responsibility to ensure there + are no write transactions active that could invalidate the + retrieved serial. If a serial cannot be determined, AXFR will be + forced. Other integer values are the starting serial to use. + ``None`` forces an AXFR. + + Please see the documentation for :py:func:`dns.message.make_query` and + :py:func:`dns.message.Message.use_tsig` for details on the other parameters + to this function. + + Returns a `(query, serial)` tuple. + """ + (zone_origin, _, origin) = txn_manager.origin_information() + if zone_origin is None: + raise ValueError("no zone origin") + if serial is None: + rdtype = dns.rdatatype.AXFR + elif not isinstance(serial, int): + raise ValueError("serial is not an integer") + elif serial == 0: + with txn_manager.reader() as txn: + rdataset = txn.get(origin, "SOA") + if rdataset: + serial = rdataset[0].serial + rdtype = dns.rdatatype.IXFR + else: + serial = None + rdtype = dns.rdatatype.AXFR + elif serial > 0 and serial < 4294967296: + rdtype = dns.rdatatype.IXFR + else: + raise ValueError("serial out-of-range") + rdclass = txn_manager.get_class() + q = dns.message.make_query( + zone_origin, + rdtype, + rdclass, + use_edns, + False, + ednsflags, + payload, + request_payload, + options, + ) + if serial is not None: + rdata = dns.rdata.from_text(rdclass, "SOA", f". . {serial} 0 0 0 0") + rrset = q.find_rrset( + q.authority, zone_origin, rdclass, dns.rdatatype.SOA, create=True + ) + rrset.add(rdata, 0) + if keyring is not None: + q.use_tsig(keyring, keyname, algorithm=keyalgorithm) + return (q, serial) + + +def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: + """Extract the SOA serial number from query if it is an IXFR and return + it, otherwise return None. + + *query* is a dns.message.QueryMessage that is an IXFR or AXFR request. + + Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have + an appropriate SOA RRset in the authority section. + """ + if not isinstance(query, dns.message.QueryMessage): + raise ValueError("query not a QueryMessage") + question = query.question[0] + if question.rdtype == dns.rdatatype.AXFR: + return None + elif question.rdtype != dns.rdatatype.IXFR: + raise ValueError("query is not an AXFR or IXFR") + soa = query.find_rrset( + query.authority, question.name, question.rdclass, dns.rdatatype.SOA + ) + return soa[0].serial diff --git a/venv/Lib/site-packages/dns/zone.py b/venv/Lib/site-packages/dns/zone.py new file mode 100644 index 00000000..844919e4 --- /dev/null +++ b/venv/Lib/site-packages/dns/zone.py @@ -0,0 +1,1434 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Zones.""" + +import contextlib +import io +import os +import struct +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) + +import dns.exception +import dns.grange +import dns.immutable +import dns.name +import dns.node +import dns.rdata +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rdtypes.ANY.SOA +import dns.rdtypes.ANY.ZONEMD +import dns.rrset +import dns.tokenizer +import dns.transaction +import dns.ttl +import dns.zonefile +from dns.zonetypes import DigestHashAlgorithm, DigestScheme, _digest_hashers + + +class BadZone(dns.exception.DNSException): + """The DNS zone is malformed.""" + + +class NoSOA(BadZone): + """The DNS zone has no SOA RR at its origin.""" + + +class NoNS(BadZone): + """The DNS zone has no NS RRset at its origin.""" + + +class UnknownOrigin(BadZone): + """The DNS zone's origin is unknown.""" + + +class UnsupportedDigestScheme(dns.exception.DNSException): + """The zone digest's scheme is unsupported.""" + + +class UnsupportedDigestHashAlgorithm(dns.exception.DNSException): + """The zone digest's origin is unsupported.""" + + +class NoDigest(dns.exception.DNSException): + """The DNS zone has no ZONEMD RRset at its origin.""" + + +class DigestVerificationFailure(dns.exception.DNSException): + """The ZONEMD digest failed to verify.""" + + +def _validate_name( + name: dns.name.Name, + origin: Optional[dns.name.Name], + relativize: bool, +) -> dns.name.Name: + # This name validation code is shared by Zone and Version + if origin is None: + # This should probably never happen as other code (e.g. + # _rr_line) will notice the lack of an origin before us, but + # we check just in case! + raise KeyError("no zone origin is defined") + if name.is_absolute(): + if not name.is_subdomain(origin): + raise KeyError("name parameter must be a subdomain of the zone origin") + if relativize: + name = name.relativize(origin) + else: + # We have a relative name. Make sure that the derelativized name is + # not too long. + try: + abs_name = name.derelativize(origin) + except dns.name.NameTooLong: + # We map dns.name.NameTooLong to KeyError to be consistent with + # the other exceptions above. + raise KeyError("relative name too long for zone") + if not relativize: + # We have a relative name in a non-relative zone, so use the + # derelativized name. + name = abs_name + return name + + +class Zone(dns.transaction.TransactionManager): + """A DNS zone. + + A ``Zone`` is a mapping from names to nodes. The zone object may be + treated like a Python dictionary, e.g. ``zone[name]`` will retrieve + the node associated with that name. The *name* may be a + ``dns.name.Name object``, or it may be a string. In either case, + if the name is relative it is treated as relative to the origin of + the zone. + """ + + node_factory: Callable[[], dns.node.Node] = dns.node.Node + map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict + writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None + immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None + + __slots__ = ["rdclass", "origin", "nodes", "relativize"] + + def __init__( + self, + origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + ): + """Initialize a zone object. + + *origin* is the origin of the zone. It may be a ``dns.name.Name``, + a ``str``, or ``None``. If ``None``, then the zone's origin will + be set by the first ``$ORIGIN`` line in a zone file. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + """ + + if origin is not None: + if isinstance(origin, str): + origin = dns.name.from_text(origin) + elif not isinstance(origin, dns.name.Name): + raise ValueError("origin parameter must be convertible to a DNS name") + if not origin.is_absolute(): + raise ValueError("origin parameter must be an absolute name") + self.origin = origin + self.rdclass = rdclass + self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory() + self.relativize = relativize + + def __eq__(self, other): + """Two zones are equal if they have the same origin, class, and + nodes. + + Returns a ``bool``. + """ + + if not isinstance(other, Zone): + return False + if ( + self.rdclass != other.rdclass + or self.origin != other.origin + or self.nodes != other.nodes + ): + return False + return True + + def __ne__(self, other): + """Are two zones not equal? + + Returns a ``bool``. + """ + + return not self.__eq__(other) + + def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: + # Note that any changes in this method should have corresponding changes + # made in the Version _validate_name() method. + if isinstance(name, str): + name = dns.name.from_text(name, None) + elif not isinstance(name, dns.name.Name): + raise KeyError("name parameter must be convertible to a DNS name") + return _validate_name(name, self.origin, self.relativize) + + def __getitem__(self, key): + key = self._validate_name(key) + return self.nodes[key] + + def __setitem__(self, key, value): + key = self._validate_name(key) + self.nodes[key] = value + + def __delitem__(self, key): + key = self._validate_name(key) + del self.nodes[key] + + def __iter__(self): + return self.nodes.__iter__() + + def keys(self): + return self.nodes.keys() + + def values(self): + return self.nodes.values() + + def items(self): + return self.nodes.items() + + def get(self, key): + key = self._validate_name(key) + return self.nodes.get(key) + + def __contains__(self, key): + key = self._validate_name(key) + return key in self.nodes + + def find_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> dns.node.Node: + """Find a node in the zone, possibly creating it. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.node.Node``. + """ + + name = self._validate_name(name) + node = self.nodes.get(name) + if node is None: + if not create: + raise KeyError + node = self.node_factory() + self.nodes[name] = node + return node + + def get_node( + self, name: Union[dns.name.Name, str], create: bool = False + ) -> Optional[dns.node.Node]: + """Get a node in the zone, possibly creating it. + + This method is like ``find_node()``, except it returns None instead + of raising an exception if the node does not exist and creation + has not been requested. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Returns a ``dns.node.Node`` or ``None``. + """ + + try: + node = self.find_node(name, create) + except KeyError: + node = None + return node + + def delete_node(self, name: Union[dns.name.Name, str]) -> None: + """Delete the specified node if it exists. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + It is not an error if the node does not exist. + """ + + name = self._validate_name(name) + if name in self.nodes: + del self.nodes[name] + + def find_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + """Look for an rdataset with the specified name and type in the zone, + and return an rdataset encapsulating it. + + The rdataset returned is not a copy; changes to it will change + the zone. + + KeyError is raised if the name or type are not found. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rdataset.Rdataset``. + """ + + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(name, create) + return node.find_rdataset(self.rdclass, rdtype, covers, create) + + def get_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + """Look for an rdataset with the specified name and type in the zone. + + This method is like ``find_rdataset()``, except it returns None instead + of raising an exception if the rdataset does not exist and creation + has not been requested. + + The rdataset returned is not a copy; changes to it will change + the zone. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rdataset.Rdataset`` or ``None``. + """ + + try: + rdataset = self.find_rdataset(name, rdtype, covers, create) + except KeyError: + rdataset = None + return rdataset + + def delete_rdataset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> None: + """Delete the rdataset matching *rdtype* and *covers*, if it + exists at the node specified by *name*. + + It is not an error if the node does not exist, or if there is no matching + rdataset at the node. + + If the node has no rdatasets after the deletion, it will itself be deleted. + + *name*: the name of the node to find. The value may be a ``dns.name.Name`` or a + ``str``. If absolute, the name must be a subdomain of the zone's origin. If + ``zone.relativize`` is ``True``, then the name will be relativized. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered + type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is + ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be + the rdata type the SIG/RRSIG covers. The library treats the SIG and RRSIG types + as if they were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This + makes RRSIGs much easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. + """ + + name = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(name) + if node is not None: + node.delete_rdataset(self.rdclass, rdtype, covers) + if len(node) == 0: + self.delete_node(name) + + def replace_rdataset( + self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset + ) -> None: + """Replace an rdataset at name. + + It is not an error if there is no rdataset matching I{replacement}. + + Ownership of the *replacement* object is transferred to the zone; + in other words, this method does not store a copy of *replacement* + at the node, it stores *replacement* itself. + + If the node does not exist, it is created. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *replacement*, a ``dns.rdataset.Rdataset``, the replacement rdataset. + """ + + if replacement.rdclass != self.rdclass: + raise ValueError("replacement.rdclass != zone.rdclass") + node = self.find_node(name, True) + node.replace_rdataset(replacement) + + def find_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> dns.rrset.RRset: + """Look for an rdataset with the specified name and type in the zone, + and return an RRset encapsulating it. + + This method is less efficient than the similar + ``find_rdataset()`` because it creates an RRset instead of + returning the matching rdataset. It may be more convenient + for some uses since it returns an object which binds the owner + name to the rdataset. + + This method may not be used to create new nodes or rdatasets; + use ``find_rdataset`` instead. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Raises ``KeyError`` if the name is not known and create was + not specified, or if the name was not a subdomain of the origin. + + Returns a ``dns.rrset.RRset`` or ``None``. + """ + + vname = self._validate_name(name) + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers) + rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers) + rrset.update(rdataset) + return rrset + + def get_rrset( + self, + name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Optional[dns.rrset.RRset]: + """Look for an rdataset with the specified name and type in the zone, + and return an RRset encapsulating it. + + This method is less efficient than the similar ``get_rdataset()`` + because it creates an RRset instead of returning the matching + rdataset. It may be more convenient for some uses since it + returns an object which binds the owner name to the rdataset. + + This method may not be used to create new nodes or rdatasets; + use ``get_rdataset()`` instead. + + *name*: the name of the node to find. + The value may be a ``dns.name.Name`` or a ``str``. If absolute, the + name must be a subdomain of the zone's origin. If ``zone.relativize`` + is ``True``, then the name will be relativized. + + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + + *create*, a ``bool``. If true, the node will be created if it does + not exist. + + Returns a ``dns.rrset.RRset`` or ``None``. + """ + + try: + rrset = self.find_rrset(name, rdtype, covers) + except KeyError: + rrset = None + return rrset + + def iterate_rdatasets( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: + """Return a generator which yields (name, rdataset) tuples for + all rdatasets in the zone which have the specified *rdtype* + and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, + then all rdatasets will be matched. + + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + """ + + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + for name, node in self.items(): + for rds in node: + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): + yield (name, rds) + + def iterate_rdatas( + self, + rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.ANY, + covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, + ) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]: + """Return a generator which yields (name, ttl, rdata) tuples for + all rdatas in the zone which have the specified *rdtype* + and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, + then all rdatas will be matched. + + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. + + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. + Usually this value is ``dns.rdatatype.NONE``, but if the + rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, + then the covers value will be the rdata type the SIG/RRSIG + covers. The library treats the SIG and RRSIG types as if they + were a family of types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). + This makes RRSIGs much easier to work with than if RRSIGs + covering different rdata types were aggregated into a single + RRSIG rdataset. + """ + + rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) + for name, node in self.items(): + for rds in node: + if rdtype == dns.rdatatype.ANY or ( + rds.rdtype == rdtype and rds.covers == covers + ): + for rdata in rds: + yield (name, rds.ttl, rdata) + + def to_file( + self, + f: Any, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> None: + """Write a zone to a file. + + *f*, a file or `str`. If *f* is a string, it is treated + as the name of a file to open. + + *sorted*, a ``bool``. If True, the default, then the file + will be written with the names sorted in DNSSEC order from + least to greatest. Otherwise the names will be written in + whatever order they happen to have in the zone's dictionary. + + *relativize*, a ``bool``. If True, the default, then domain + names in the output will be relativized to the zone's origin + if possible. + + *nl*, a ``str`` or None. The end of line string. If not + ``None``, the output will use the platform's native + end-of-line marker (i.e. LF on POSIX, CRLF on Windows). + + *want_comments*, a ``bool``. If ``True``, emit end-of-line comments + as part of writing the file. If ``False``, the default, do not + emit them. + + *want_origin*, a ``bool``. If ``True``, emit a $ORIGIN line at + the start of the file. If ``False``, the default, do not emit + one. + """ + + if isinstance(f, str): + cm: contextlib.AbstractContextManager = open(f, "wb") + else: + cm = contextlib.nullcontext(f) + with cm as f: + # must be in this way, f.encoding may contain None, or even + # attribute may not be there + file_enc = getattr(f, "encoding", None) + if file_enc is None: + file_enc = "utf-8" + + if nl is None: + # binary mode, '\n' is not enough + nl_b = os.linesep.encode(file_enc) + nl = "\n" + elif isinstance(nl, str): + nl_b = nl.encode(file_enc) + else: + nl_b = nl + nl = nl.decode() + + if want_origin: + assert self.origin is not None + l = "$ORIGIN " + self.origin.to_text() + l_b = l.encode(file_enc) + try: + f.write(l_b) + f.write(nl_b) + except TypeError: # textual mode + f.write(l) + f.write(nl) + + if sorted: + names = list(self.keys()) + names.sort() + else: + names = self.keys() + for n in names: + l = self[n].to_text( + n, + origin=self.origin, + relativize=relativize, + want_comments=want_comments, + ) + l_b = l.encode(file_enc) + + try: + f.write(l_b) + f.write(nl_b) + except TypeError: # textual mode + f.write(l) + f.write(nl) + + def to_text( + self, + sorted: bool = True, + relativize: bool = True, + nl: Optional[str] = None, + want_comments: bool = False, + want_origin: bool = False, + ) -> str: + """Return a zone's text as though it were written to a file. + + *sorted*, a ``bool``. If True, the default, then the file + will be written with the names sorted in DNSSEC order from + least to greatest. Otherwise the names will be written in + whatever order they happen to have in the zone's dictionary. + + *relativize*, a ``bool``. If True, the default, then domain + names in the output will be relativized to the zone's origin + if possible. + + *nl*, a ``str`` or None. The end of line string. If not + ``None``, the output will use the platform's native + end-of-line marker (i.e. LF on POSIX, CRLF on Windows). + + *want_comments*, a ``bool``. If ``True``, emit end-of-line comments + as part of writing the file. If ``False``, the default, do not + emit them. + + *want_origin*, a ``bool``. If ``True``, emit a $ORIGIN line at + the start of the output. If ``False``, the default, do not emit + one. + + Returns a ``str``. + """ + temp_buffer = io.StringIO() + self.to_file(temp_buffer, sorted, relativize, nl, want_comments, want_origin) + return_value = temp_buffer.getvalue() + temp_buffer.close() + return return_value + + def check_origin(self) -> None: + """Do some simple checking of the zone's origin. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + """ + if self.relativize: + name = dns.name.empty + else: + assert self.origin is not None + name = self.origin + if self.get_rdataset(name, dns.rdatatype.SOA) is None: + raise NoSOA + if self.get_rdataset(name, dns.rdatatype.NS) is None: + raise NoNS + + def get_soa( + self, txn: Optional[dns.transaction.Transaction] = None + ) -> dns.rdtypes.ANY.SOA.SOA: + """Get the zone SOA rdata. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Returns a ``dns.rdtypes.ANY.SOA.SOA`` Rdata. + """ + if self.relativize: + origin_name = dns.name.empty + else: + if self.origin is None: + # get_soa() has been called very early, and there must not be + # an SOA if there is no origin. + raise NoSOA + origin_name = self.origin + soa: Optional[dns.rdataset.Rdataset] + if txn: + soa = txn.get(origin_name, dns.rdatatype.SOA) + else: + soa = self.get_rdataset(origin_name, dns.rdatatype.SOA) + if soa is None: + raise NoSOA + return soa[0] + + def _compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> bytes: + hashinfo = _digest_hashers.get(hash_algorithm) + if not hashinfo: + raise UnsupportedDigestHashAlgorithm + if scheme != DigestScheme.SIMPLE: + raise UnsupportedDigestScheme + + if self.relativize: + origin_name = dns.name.empty + else: + assert self.origin is not None + origin_name = self.origin + hasher = hashinfo() + for name, node in sorted(self.items()): + rrnamebuf = name.to_digestable(self.origin) + for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)): + if name == origin_name and dns.rdatatype.ZONEMD in ( + rdataset.rdtype, + rdataset.covers, + ): + continue + rrfixed = struct.pack( + "!HHI", rdataset.rdtype, rdataset.rdclass, rdataset.ttl + ) + rdatas = [rdata.to_digestable(self.origin) for rdata in rdataset] + for rdata in sorted(rdatas): + rrlen = struct.pack("!H", len(rdata)) + hasher.update(rrnamebuf + rrfixed + rrlen + rdata) + return hasher.digest() + + def compute_digest( + self, + hash_algorithm: DigestHashAlgorithm, + scheme: DigestScheme = DigestScheme.SIMPLE, + ) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: + serial = self.get_soa().serial + digest = self._compute_digest(hash_algorithm, scheme) + return dns.rdtypes.ANY.ZONEMD.ZONEMD( + self.rdclass, dns.rdatatype.ZONEMD, serial, scheme, hash_algorithm, digest + ) + + def verify_digest( + self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD] = None + ) -> None: + digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]] + if zonemd: + digests = [zonemd] + else: + assert self.origin is not None + rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) + if rds is None: + raise NoDigest + digests = rds + for digest in digests: + try: + computed = self._compute_digest(digest.hash_algorithm, digest.scheme) + if computed == digest.digest: + return + except Exception: + pass + raise DigestVerificationFailure + + # TransactionManager methods + + def reader(self) -> "Transaction": + return Transaction(self, False, Version(self, 1, self.nodes, self.origin)) + + def writer(self, replacement: bool = False) -> "Transaction": + txn = Transaction(self, replacement) + txn._setup_version() + return txn + + def origin_information( + self, + ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + effective: Optional[dns.name.Name] + if self.relativize: + effective = dns.name.empty + else: + effective = self.origin + return (self.origin, self.relativize, effective) + + def get_class(self): + return self.rdclass + + # Transaction methods + + def _end_read(self, txn): + pass + + def _end_write(self, txn): + pass + + def _commit_version(self, _, version, origin): + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + + def _get_next_version_id(self): + # Versions are ephemeral and all have id 1 + return 1 + + +# These classes used to be in dns.versioned, but have moved here so we can use +# the copy-on-write transaction mechanism for both kinds of zones. In a +# regular zone, the version only exists during the transaction, and the nodes +# are regular dns.node.Nodes. + +# A node with a version id. + + +class VersionedNode(dns.node.Node): # lgtm[py/missing-equals] + __slots__ = ["id"] + + def __init__(self): + super().__init__() + # A proper id will get set by the Version + self.id = 0 + + +@dns.immutable.immutable +class ImmutableVersionedNode(VersionedNode): + def __init__(self, node): + super().__init__() + self.id = node.id + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: + raise TypeError("immutable") + + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: + raise TypeError("immutable") + + def is_immutable(self) -> bool: + return True + + +class Version: + def __init__( + self, + zone: Zone, + id: int, + nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None, + origin: Optional[dns.name.Name] = None, + ): + self.zone = zone + self.id = id + if nodes is not None: + self.nodes = nodes + else: + self.nodes = zone.map_factory() + self.origin = origin + + def _validate_name(self, name: dns.name.Name) -> dns.name.Name: + return _validate_name(name, self.origin, self.zone.relativize) + + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: + name = self._validate_name(name) + return self.nodes.get(name) + + def get_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> Optional[dns.rdataset.Rdataset]: + node = self.get_node(name) + if node is None: + return None + return node.get_rdataset(self.zone.rdclass, rdtype, covers) + + def keys(self): + return self.nodes.keys() + + def items(self): + return self.nodes.items() + + +class WritableVersion(Version): + def __init__(self, zone: Zone, replacement: bool = False): + # The zone._versions_lock must be held by our caller in a versioned + # zone. + id = zone._get_next_version_id() + super().__init__(zone, id) + if not replacement: + # We copy the map, because that gives us a simple and thread-safe + # way of doing versions, and we have a garbage collector to help + # us. We only make new node objects if we actually change the + # node. + self.nodes.update(zone.nodes) + # We have to copy the zone origin as it may be None in the first + # version, and we don't want to mutate the zone until we commit. + self.origin = zone.origin + self.changed: Set[dns.name.Name] = set() + + def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node: + name = self._validate_name(name) + node = self.nodes.get(name) + if node is None or name not in self.changed: + new_node = self.zone.node_factory() + if hasattr(new_node, "id"): + # We keep doing this for backwards compatibility, as earlier + # code used new_node.id != self.id for the "do we need to CoW?" + # test. Now we use the changed set as this works with both + # regular zones and versioned zones. + # + # We ignore the mypy error as this is safe but it doesn't see it. + new_node.id = self.id # type: ignore + if node is not None: + # moo! copy on write! + new_node.rdatasets.extend(node.rdatasets) + self.nodes[name] = new_node + self.changed.add(name) + return new_node + else: + return node + + def delete_node(self, name: dns.name.Name) -> None: + name = self._validate_name(name) + if name in self.nodes: + del self.nodes[name] + self.changed.add(name) + + def put_rdataset( + self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset + ) -> None: + node = self._maybe_cow(name) + node.replace_rdataset(rdataset) + + def delete_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> None: + node = self._maybe_cow(name) + node.delete_rdataset(self.zone.rdclass, rdtype, covers) + if len(node) == 0: + del self.nodes[name] + + +@dns.immutable.immutable +class ImmutableVersion(Version): + def __init__(self, version: WritableVersion): + # We tell super() that it's a replacement as we don't want it + # to copy the nodes, as we're about to do that with an + # immutable Dict. + super().__init__(version.zone, True) + # set the right id! + self.id = version.id + # keep the origin + self.origin = version.origin + # Make changed nodes immutable + for name in version.changed: + node = version.nodes.get(name) + # it might not exist if we deleted it in the version + if node: + version.nodes[name] = ImmutableVersionedNode(node) + # We're changing the type of the nodes dictionary here on purpose, so + # we ignore the mypy error. + self.nodes = dns.immutable.Dict( + version.nodes, True, self.zone.map_factory + ) # type: ignore + + +class Transaction(dns.transaction.Transaction): + def __init__(self, zone, replacement, version=None, make_immutable=False): + read_only = version is not None + super().__init__(zone, replacement, read_only) + self.version = version + self.make_immutable = make_immutable + + @property + def zone(self): + return self.manager + + def _setup_version(self): + assert self.version is None + factory = self.manager.writable_version_factory + if factory is None: + factory = WritableVersion + self.version = factory(self.zone, self.replacement) + + def _get_rdataset(self, name, rdtype, covers): + return self.version.get_rdataset(name, rdtype, covers) + + def _put_rdataset(self, name, rdataset): + assert not self.read_only + self.version.put_rdataset(name, rdataset) + + def _delete_name(self, name): + assert not self.read_only + self.version.delete_node(name) + + def _delete_rdataset(self, name, rdtype, covers): + assert not self.read_only + self.version.delete_rdataset(name, rdtype, covers) + + def _name_exists(self, name): + return self.version.get_node(name) is not None + + def _changed(self): + if self.read_only: + return False + else: + return len(self.version.changed) > 0 + + def _end_transaction(self, commit): + if self.read_only: + self.zone._end_read(self) + elif commit and len(self.version.changed) > 0: + if self.make_immutable: + factory = self.manager.immutable_version_factory + if factory is None: + factory = ImmutableVersion + version = factory(self.version) + else: + version = self.version + self.zone._commit_version(self, version, self.version.origin) + else: + # rollback + self.zone._end_write(self) + + def _set_origin(self, origin): + if self.version.origin is None: + self.version.origin = origin + + def _iterate_rdatasets(self): + for name, node in self.version.items(): + for rdataset in node: + yield (name, rdataset) + + def _iterate_names(self): + return self.version.keys() + + def _get_node(self, name): + return self.version.get_node(name) + + def _origin_information(self): + (absolute, relativize, effective) = self.manager.origin_information() + if absolute is None and self.version.origin is not None: + # No origin has been committed yet, but we've learned one as part of + # this txn. Use it. + absolute = self.version.origin + if relativize: + effective = dns.name.empty + else: + effective = absolute + return (absolute, relativize, effective) + + +def _from_text( + text: Any, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = False, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: + # See the comments for the public APIs from_text() and from_file() for + # details. + + # 'text' can also be a file, but we don't publish that fact + # since it's an implementation detail. The official file + # interface is from_file(). + + if filename is None: + filename = "" + zone = zone_factory(origin, rdclass, relativize=relativize) + with zone.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) + reader = dns.zonefile.Reader( + tok, + rdclass, + txn, + allow_include=allow_include, + allow_directives=allow_directives, + ) + try: + reader.read() + except dns.zonefile.UnknownOrigin: + # for backwards compatibility + raise dns.zone.UnknownOrigin + # Now that we're done reading, do some basic checking of the zone. + if check_origin: + zone.check_origin() + return zone + + +def from_text( + text: str, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = False, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: + """Build a zone object from a zone file format string. + + *text*, a ``str``, the zone file format input. + + *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin + of the zone; if not specified, the first ``$ORIGIN`` statement in the + zone file will determine the origin of the zone. + + *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is + class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *zone_factory*, the zone factory to use or ``None``. If ``None``, then + ``dns.zone.Zone`` will be used. The value may be any class or callable + that returns a subclass of ``dns.zone.Zone``. + + *filename*, a ``str`` or ``None``, the filename to emit when + describing where an error occurred; the default is ``''``. + + *allow_include*, a ``bool``. If ``True``, the default, then ``$INCLUDE`` + directives are permitted. If ``False``, then encoutering a ``$INCLUDE`` + will raise a ``SyntaxError`` exception. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *allow_directives*, a ``bool`` or an iterable of `str`. If ``True``, the default, + then directives are permitted, and the *allow_include* parameter controls whether + ``$INCLUDE`` is permitted. If ``False`` or an empty iterable, then no directive + processing is done and any directive-like text will be treated as a regular owner + name. If a non-empty iterable, then only the listed directives (including the + ``$``) are allowed. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Returns a subclass of ``dns.zone.Zone``. + """ + return _from_text( + text, + origin, + rdclass, + relativize, + zone_factory, + filename, + allow_include, + check_origin, + idna_codec, + allow_directives, + ) + + +def from_file( + f: Any, + origin: Optional[Union[dns.name.Name, str]] = None, + rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, + relativize: bool = True, + zone_factory: Any = Zone, + filename: Optional[str] = None, + allow_include: bool = True, + check_origin: bool = True, + idna_codec: Optional[dns.name.IDNACodec] = None, + allow_directives: Union[bool, Iterable[str]] = True, +) -> Zone: + """Read a zone file and build a zone object. + + *f*, a file or ``str``. If *f* is a string, it is treated + as the name of a file to open. + + *origin*, a ``dns.name.Name``, a ``str``, or ``None``. The origin + of the zone; if not specified, the first ``$ORIGIN`` statement in the + zone file will determine the origin of the zone. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *zone_factory*, the zone factory to use or ``None``. If ``None``, then + ``dns.zone.Zone`` will be used. The value may be any class or callable + that returns a subclass of ``dns.zone.Zone``. + + *filename*, a ``str`` or ``None``, the filename to emit when + describing where an error occurred; the default is ``''``. + + *allow_include*, a ``bool``. If ``True``, the default, then ``$INCLUDE`` + directives are permitted. If ``False``, then encoutering a ``$INCLUDE`` + will raise a ``SyntaxError`` exception. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + *allow_directives*, a ``bool`` or an iterable of `str`. If ``True``, the default, + then directives are permitted, and the *allow_include* parameter controls whether + ``$INCLUDE`` is permitted. If ``False`` or an empty iterable, then no directive + processing is done and any directive-like text will be treated as a regular owner + name. If a non-empty iterable, then only the listed directives (including the + ``$``) are allowed. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Returns a subclass of ``dns.zone.Zone``. + """ + + if isinstance(f, str): + if filename is None: + filename = f + cm: contextlib.AbstractContextManager = open(f) + else: + cm = contextlib.nullcontext(f) + with cm as f: + return _from_text( + f, + origin, + rdclass, + relativize, + zone_factory, + filename, + allow_include, + check_origin, + idna_codec, + allow_directives, + ) + assert False # make mypy happy lgtm[py/unreachable-statement] + + +def from_xfr( + xfr: Any, + zone_factory: Any = Zone, + relativize: bool = True, + check_origin: bool = True, +) -> Zone: + """Convert the output of a zone transfer generator into a zone object. + + *xfr*, a generator of ``dns.message.Message`` objects, typically + ``dns.query.xfr()``. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + It is essential that the relativize setting matches the one specified + to the generator. + + *check_origin*, a ``bool``. If ``True``, the default, then sanity + checks of the origin node will be made by calling the zone's + ``check_origin()`` method. + + Raises ``dns.zone.NoSOA`` if there is no SOA RRset. + + Raises ``dns.zone.NoNS`` if there is no NS RRset. + + Raises ``KeyError`` if there is no origin node. + + Raises ``ValueError`` if no messages are yielded by the generator. + + Returns a subclass of ``dns.zone.Zone``. + """ + + z = None + for r in xfr: + if z is None: + if relativize: + origin = r.origin + else: + origin = r.answer[0].name + rdclass = r.answer[0].rdclass + z = zone_factory(origin, rdclass, relativize=relativize) + for rrset in r.answer: + znode = z.nodes.get(rrset.name) + if not znode: + znode = z.node_factory() + z.nodes[rrset.name] = znode + zrds = znode.find_rdataset(rrset.rdclass, rrset.rdtype, rrset.covers, True) + zrds.update_ttl(rrset.ttl) + for rd in rrset: + zrds.add(rd) + if z is None: + raise ValueError("empty transfer") + if check_origin: + z.check_origin() + return z diff --git a/venv/Lib/site-packages/dns/zonefile.py b/venv/Lib/site-packages/dns/zonefile.py new file mode 100644 index 00000000..af064e73 --- /dev/null +++ b/venv/Lib/site-packages/dns/zonefile.py @@ -0,0 +1,746 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Zones.""" + +import re +import sys +from typing import Any, Iterable, List, Optional, Set, Tuple, Union + +import dns.exception +import dns.grange +import dns.name +import dns.node +import dns.rdata +import dns.rdataclass +import dns.rdatatype +import dns.rdtypes.ANY.SOA +import dns.rrset +import dns.tokenizer +import dns.transaction +import dns.ttl + + +class UnknownOrigin(dns.exception.DNSException): + """Unknown origin""" + + +class CNAMEAndOtherData(dns.exception.DNSException): + """A node has a CNAME and other data""" + + +def _check_cname_and_other_data(txn, name, rdataset): + rdataset_kind = dns.node.NodeKind.classify_rdataset(rdataset) + node = txn.get_node(name) + if node is None: + # empty nodes are neutral. + return + node_kind = node.classify() + if ( + node_kind == dns.node.NodeKind.CNAME + and rdataset_kind == dns.node.NodeKind.REGULAR + ): + raise CNAMEAndOtherData("rdataset type is not compatible with a CNAME node") + elif ( + node_kind == dns.node.NodeKind.REGULAR + and rdataset_kind == dns.node.NodeKind.CNAME + ): + raise CNAMEAndOtherData( + "CNAME rdataset is not compatible with a regular data node" + ) + # Otherwise at least one of the node and the rdataset is neutral, so + # adding the rdataset is ok + + +SavedStateType = Tuple[ + dns.tokenizer.Tokenizer, + Optional[dns.name.Name], # current_origin + Optional[dns.name.Name], # last_name + Optional[Any], # current_file + int, # last_ttl + bool, # last_ttl_known + int, # default_ttl + bool, +] # default_ttl_known + + +def _upper_dollarize(s): + s = s.upper() + if not s.startswith("$"): + s = "$" + s + return s + + +class Reader: + """Read a DNS zone file into a transaction.""" + + def __init__( + self, + tok: dns.tokenizer.Tokenizer, + rdclass: dns.rdataclass.RdataClass, + txn: dns.transaction.Transaction, + allow_include: bool = False, + allow_directives: Union[bool, Iterable[str]] = True, + force_name: Optional[dns.name.Name] = None, + force_ttl: Optional[int] = None, + force_rdclass: Optional[dns.rdataclass.RdataClass] = None, + force_rdtype: Optional[dns.rdatatype.RdataType] = None, + default_ttl: Optional[int] = None, + ): + self.tok = tok + (self.zone_origin, self.relativize, _) = txn.manager.origin_information() + self.current_origin = self.zone_origin + self.last_ttl = 0 + self.last_ttl_known = False + if force_ttl is not None: + default_ttl = force_ttl + if default_ttl is None: + self.default_ttl = 0 + self.default_ttl_known = False + else: + self.default_ttl = default_ttl + self.default_ttl_known = True + self.last_name = self.current_origin + self.zone_rdclass = rdclass + self.txn = txn + self.saved_state: List[SavedStateType] = [] + self.current_file: Optional[Any] = None + self.allowed_directives: Set[str] + if allow_directives is True: + self.allowed_directives = {"$GENERATE", "$ORIGIN", "$TTL"} + if allow_include: + self.allowed_directives.add("$INCLUDE") + elif allow_directives is False: + # allow_include was ignored in earlier releases if allow_directives was + # False, so we continue that. + self.allowed_directives = set() + else: + # Note that if directives are explicitly specified, then allow_include + # is ignored. + self.allowed_directives = set(_upper_dollarize(d) for d in allow_directives) + self.force_name = force_name + self.force_ttl = force_ttl + self.force_rdclass = force_rdclass + self.force_rdtype = force_rdtype + self.txn.check_put_rdataset(_check_cname_and_other_data) + + def _eat_line(self): + while 1: + token = self.tok.get() + if token.is_eol_or_eof(): + break + + def _get_identifier(self): + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + return token + + def _rr_line(self): + """Process one line from a DNS zone file.""" + token = None + # Name + if self.force_name is not None: + name = self.force_name + else: + if self.current_origin is None: + raise UnknownOrigin + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name(token, self.current_origin) + else: + token = self.tok.get() + if token.is_eol_or_eof(): + # treat leading WS followed by EOL/EOF as if they were EOL/EOF. + return + self.tok.unget(token) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + + # TTL + if self.force_ttl is not None: + ttl = self.force_ttl + self.last_ttl = ttl + self.last_ttl_known = True + else: + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + self.tok.unget(token) + + # Class + if self.force_rdclass is not None: + rdclass = self.force_rdclass + else: + token = self._get_identifier() + try: + rdclass = dns.rdataclass.from_text(token.value) + except dns.exception.SyntaxError: + raise + except Exception: + rdclass = self.zone_rdclass + self.tok.unget(token) + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + + if ttl is None: + # support for syntax + token = self._get_identifier() + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = None + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + self.tok.unget(token) + + # Type + if self.force_rdtype is not None: + rdtype = self.force_rdtype + else: + token = self._get_identifier() + try: + rdtype = dns.rdatatype.from_text(token.value) + except Exception: + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) + + try: + rd = dns.rdata.from_text( + rdclass, + rdtype, + self.tok, + self.current_origin, + self.relativize, + self.zone_origin, + ) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError( + "caught exception {}: {}".format(str(ty), str(va)) + ) + + if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: + # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default + # TTL from the SOA minttl if no $TTL statement is present before the + # SOA is parsed. + self.default_ttl = rd.minimum + self.default_ttl_known = True + if ttl is None: + # if we didn't have a TTL on the SOA, set it! + ttl = rd.minimum + + # TTL check. We had to wait until now to do this as the SOA RR's + # own TTL can be inferred from its minimum. + if ttl is None: + raise dns.exception.SyntaxError("Missing default TTL value") + + self.txn.add(name, ttl, rd) + + def _parse_modify(self, side: str) -> Tuple[str, str, int, int, str]: + # Here we catch everything in '{' '}' in a group so we can replace it + # with ''. + is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") + is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") + is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") + # Sometimes there are modifiers in the hostname. These come after + # the dollar sign. They are in the form: ${offset[,width[,base]]}. + # Make names + g1 = is_generate1.match(side) + if g1: + mod, sign, offset, width, base = g1.groups() + if sign == "": + sign = "+" + g2 = is_generate2.match(side) + if g2: + mod, sign, offset = g2.groups() + if sign == "": + sign = "+" + width = 0 + base = "d" + g3 = is_generate3.match(side) + if g3: + mod, sign, offset, width = g3.groups() + if sign == "": + sign = "+" + base = "d" + + if not (g1 or g2 or g3): + mod = "" + sign = "+" + offset = 0 + width = 0 + base = "d" + + offset = int(offset) + width = int(width) + + if sign not in ["+", "-"]: + raise dns.exception.SyntaxError("invalid offset sign %s" % sign) + if base not in ["d", "o", "x", "X", "n", "N"]: + raise dns.exception.SyntaxError("invalid type %s" % base) + + return mod, sign, offset, width, base + + def _generate_line(self): + # range lhs [ttl] [class] type rhs [ comment ] + """Process one line containing the GENERATE statement from a DNS + zone file.""" + if self.current_origin is None: + raise UnknownOrigin + + token = self.tok.get() + # Range (required) + try: + start, stop, step = dns.grange.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # lhs (required) + try: + lhs = token.value + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # TTL + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.ttl.BadTTL: + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + rdclass = self.zone_rdclass + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + # Type + try: + rdtype = dns.rdatatype.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value) + + # rhs (required) + rhs = token.value + + def _calculate_index(counter: int, offset_sign: str, offset: int) -> int: + """Calculate the index from the counter and offset.""" + if offset_sign == "-": + offset *= -1 + return counter + offset + + def _format_index(index: int, base: str, width: int) -> str: + """Format the index with the given base, and zero-fill it + to the given width.""" + if base in ["d", "o", "x", "X"]: + return format(index, base).zfill(width) + + # base can only be n or N here + hexa = _format_index(index, "x", width) + nibbles = ".".join(hexa[::-1])[:width] + if base == "N": + nibbles = nibbles.upper() + return nibbles + + lmod, lsign, loffset, lwidth, lbase = self._parse_modify(lhs) + rmod, rsign, roffset, rwidth, rbase = self._parse_modify(rhs) + for i in range(start, stop + 1, step): + # +1 because bind is inclusive and python is exclusive + + lindex = _calculate_index(i, lsign, loffset) + rindex = _calculate_index(i, rsign, roffset) + + lzfindex = _format_index(lindex, lbase, lwidth) + rzfindex = _format_index(rindex, rbase, rwidth) + + name = lhs.replace("$%s" % (lmod), lzfindex) + rdata = rhs.replace("$%s" % (rmod), rzfindex) + + self.last_name = dns.name.from_text( + name, self.current_origin, self.tok.idna_codec + ) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + + try: + rd = dns.rdata.from_text( + rdclass, + rdtype, + rdata, + self.current_origin, + self.relativize, + self.zone_origin, + ) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError( + "caught exception %s: %s" % (str(ty), str(va)) + ) + + self.txn.add(name, ttl, rd) + + def read(self) -> None: + """Read a DNS zone file and build a zone object. + + @raises dns.zone.NoSOA: No SOA RR was found at the zone origin + @raises dns.zone.NoNS: No NS RRset was found at the zone origin + """ + + try: + while 1: + token = self.tok.get(True, True) + if token.is_eof(): + if self.current_file is not None: + self.current_file.close() + if len(self.saved_state) > 0: + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) = self.saved_state.pop(-1) + continue + break + elif token.is_eol(): + continue + elif token.is_comment(): + self.tok.get_eol() + continue + elif token.value[0] == "$" and len(self.allowed_directives) > 0: + # Note that we only run directive processing code if at least + # one directive is allowed in order to be backwards compatible + c = token.value.upper() + if c not in self.allowed_directives: + raise dns.exception.SyntaxError( + f"zone file directive '{c}' is not allowed" + ) + if c == "$TTL": + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError("bad $TTL") + self.default_ttl = dns.ttl.from_text(token.value) + self.default_ttl_known = True + self.tok.get_eol() + elif c == "$ORIGIN": + self.current_origin = self.tok.get_name() + self.tok.get_eol() + if self.zone_origin is None: + self.zone_origin = self.current_origin + self.txn._set_origin(self.current_origin) + elif c == "$INCLUDE": + token = self.tok.get() + filename = token.value + token = self.tok.get() + new_origin: Optional[dns.name.Name] + if token.is_identifier(): + new_origin = dns.name.from_text( + token.value, self.current_origin, self.tok.idna_codec + ) + self.tok.get_eol() + elif not token.is_eol_or_eof(): + raise dns.exception.SyntaxError("bad origin in $INCLUDE") + else: + new_origin = self.current_origin + self.saved_state.append( + ( + self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known, + ) + ) + self.current_file = open(filename, "r") + self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) + self.current_origin = new_origin + elif c == "$GENERATE": + self._generate_line() + else: + raise dns.exception.SyntaxError( + f"Unknown zone file directive '{c}'" + ) + continue + self.tok.unget(token) + self._rr_line() + except dns.exception.SyntaxError as detail: + (filename, line_number) = self.tok.where() + if detail is None: + detail = "syntax error" + ex = dns.exception.SyntaxError( + "%s:%d: %s" % (filename, line_number, detail) + ) + tb = sys.exc_info()[2] + raise ex.with_traceback(tb) from None + + +class RRsetsReaderTransaction(dns.transaction.Transaction): + def __init__(self, manager, replacement, read_only): + assert not read_only + super().__init__(manager, replacement, read_only) + self.rdatasets = {} + + def _get_rdataset(self, name, rdtype, covers): + return self.rdatasets.get((name, rdtype, covers)) + + def _get_node(self, name): + rdatasets = [] + for (rdataset_name, _, _), rdataset in self.rdatasets.items(): + if name == rdataset_name: + rdatasets.append(rdataset) + if len(rdatasets) == 0: + return None + node = dns.node.Node() + node.rdatasets = rdatasets + return node + + def _put_rdataset(self, name, rdataset): + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + + def _delete_name(self, name): + # First remove any changes involving the name + remove = [] + for key in self.rdatasets: + if key[0] == name: + remove.append(key) + if len(remove) > 0: + for key in remove: + del self.rdatasets[key] + + def _delete_rdataset(self, name, rdtype, covers): + try: + del self.rdatasets[(name, rdtype, covers)] + except KeyError: + pass + + def _name_exists(self, name): + for n, _, _ in self.rdatasets: + if n == name: + return True + return False + + def _changed(self): + return len(self.rdatasets) > 0 + + def _end_transaction(self, commit): + if commit and self._changed(): + rrsets = [] + for (name, _, _), rdataset in self.rdatasets.items(): + rrset = dns.rrset.RRset( + name, rdataset.rdclass, rdataset.rdtype, rdataset.covers + ) + rrset.update(rdataset) + rrsets.append(rrset) + self.manager.set_rrsets(rrsets) + + def _set_origin(self, origin): + pass + + def _iterate_rdatasets(self): + raise NotImplementedError # pragma: no cover + + def _iterate_names(self): + raise NotImplementedError # pragma: no cover + + +class RRSetsReaderManager(dns.transaction.TransactionManager): + def __init__( + self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN + ): + self.origin = origin + self.relativize = relativize + self.rdclass = rdclass + self.rrsets = [] + + def reader(self): # pragma: no cover + raise NotImplementedError + + def writer(self, replacement=False): + assert replacement is True + return RRsetsReaderTransaction(self, True, False) + + def get_class(self): + return self.rdclass + + def origin_information(self): + if self.relativize: + effective = dns.name.empty + else: + effective = self.origin + return (self.origin, self.relativize, effective) + + def set_rrsets(self, rrsets): + self.rrsets = rrsets + + +def read_rrsets( + text: Any, + name: Optional[Union[dns.name.Name, str]] = None, + ttl: Optional[int] = None, + rdclass: Optional[Union[dns.rdataclass.RdataClass, str]] = dns.rdataclass.IN, + default_rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, + rdtype: Optional[Union[dns.rdatatype.RdataType, str]] = None, + default_ttl: Optional[Union[int, str]] = None, + idna_codec: Optional[dns.name.IDNACodec] = None, + origin: Optional[Union[dns.name.Name, str]] = dns.name.root, + relativize: bool = False, +) -> List[dns.rrset.RRset]: + """Read one or more rrsets from the specified text, possibly subject + to restrictions. + + *text*, a file object or a string, is the input to process. + + *name*, a string, ``dns.name.Name``, or ``None``, is the owner name of + the rrset. If not ``None``, then the owner name is "forced", and the + input must not specify an owner name. If ``None``, then any owner names + are allowed and must be present in the input. + + *ttl*, an ``int``, string, or None. If not ``None``, the the TTL is + forced to be the specified value and the input must not specify a TTL. + If ``None``, then a TTL may be specified in the input. If it is not + specified, then the *default_ttl* will be used. + + *rdclass*, a ``dns.rdataclass.RdataClass``, string, or ``None``. If + not ``None``, then the class is forced to the specified value, and the + input must not specify a class. If ``None``, then the input may specify + a class that matches *default_rdclass*. Note that it is not possible to + return rrsets with differing classes; specifying ``None`` for the class + simply allows the user to optionally type a class as that may be convenient + when cutting and pasting. + + *default_rdclass*, a ``dns.rdataclass.RdataClass`` or string. The class + of the returned rrsets. + + *rdtype*, a ``dns.rdatatype.RdataType``, string, or ``None``. If not + ``None``, then the type is forced to the specified value, and the + input must not specify a type. If ``None``, then a type must be present + for each RR. + + *default_ttl*, an ``int``, string, or ``None``. If not ``None``, then if + the TTL is not forced and is not specified, then this value will be used. + if ``None``, then if the TTL is not forced an error will occur if the TTL + is not specified. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. Note that codecs only apply to the owner name; dnspython does + not do IDNA for names in rdata, as there is no IDNA zonefile format. + + *origin*, a string, ``dns.name.Name``, or ``None``, is the origin for any + relative names in the input, and also the origin to relativize to if + *relativize* is ``True``. + + *relativize*, a bool. If ``True``, names are relativized to the *origin*; + if ``False`` then any relative names in the input are made absolute by + appending the *origin*. + """ + if isinstance(origin, str): + origin = dns.name.from_text(origin, dns.name.root, idna_codec) + if isinstance(name, str): + name = dns.name.from_text(name, origin, idna_codec) + if isinstance(ttl, str): + ttl = dns.ttl.from_text(ttl) + if isinstance(default_ttl, str): + default_ttl = dns.ttl.from_text(default_ttl) + if rdclass is not None: + rdclass = dns.rdataclass.RdataClass.make(rdclass) + else: + rdclass = None + default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + if rdtype is not None: + rdtype = dns.rdatatype.RdataType.make(rdtype) + else: + rdtype = None + manager = RRSetsReaderManager(origin, relativize, default_rdclass) + with manager.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, "", idna_codec=idna_codec) + reader = Reader( + tok, + default_rdclass, + txn, + allow_directives=False, + force_name=name, + force_ttl=ttl, + force_rdclass=rdclass, + force_rdtype=rdtype, + default_ttl=default_ttl, + ) + reader.read() + return manager.rrsets diff --git a/venv/Lib/site-packages/dns/zonetypes.py b/venv/Lib/site-packages/dns/zonetypes.py new file mode 100644 index 00000000..195ee2ec --- /dev/null +++ b/venv/Lib/site-packages/dns/zonetypes.py @@ -0,0 +1,37 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Common zone-related types.""" + +# This is a separate file to avoid import circularity between dns.zone and +# the implementation of the ZONEMD type. + +import hashlib + +import dns.enum + + +class DigestScheme(dns.enum.IntEnum): + """ZONEMD Scheme""" + + SIMPLE = 1 + + @classmethod + def _maximum(cls): + return 255 + + +class DigestHashAlgorithm(dns.enum.IntEnum): + """ZONEMD Hash Algorithm""" + + SHA384 = 1 + SHA512 = 2 + + @classmethod + def _maximum(cls): + return 255 + + +_digest_hashers = { + DigestHashAlgorithm.SHA384: hashlib.sha384, + DigestHashAlgorithm.SHA512: hashlib.sha512, +} diff --git a/venv/Lib/site-packages/dnspython-2.6.1.dist-info/INSTALLER b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/venv/Lib/site-packages/dnspython-2.6.1.dist-info/METADATA b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/METADATA new file mode 100644 index 00000000..129184e3 --- /dev/null +++ b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/METADATA @@ -0,0 +1,147 @@ +Metadata-Version: 2.1 +Name: dnspython +Version: 2.6.1 +Summary: DNS toolkit +Project-URL: homepage, https://www.dnspython.org +Project-URL: repository, https://github.com/rthalley/dnspython.git +Project-URL: documentation, https://dnspython.readthedocs.io/en/stable/ +Project-URL: issues, https://github.com/rthalley/dnspython/issues +Author-email: Bob Halley +License-Expression: ISC +License-File: LICENSE +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: System Administrators +Classifier: License :: OSI Approved :: ISC License (ISCL) +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Internet :: Name Service (DNS) +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.8 +Provides-Extra: dev +Requires-Dist: black>=23.1.0; extra == 'dev' +Requires-Dist: coverage>=7.0; extra == 'dev' +Requires-Dist: flake8>=7; extra == 'dev' +Requires-Dist: mypy>=1.8; extra == 'dev' +Requires-Dist: pylint>=3; extra == 'dev' +Requires-Dist: pytest-cov>=4.1.0; extra == 'dev' +Requires-Dist: pytest>=7.4; extra == 'dev' +Requires-Dist: sphinx>=7.2.0; extra == 'dev' +Requires-Dist: twine>=4.0.0; extra == 'dev' +Requires-Dist: wheel>=0.42.0; extra == 'dev' +Provides-Extra: dnssec +Requires-Dist: cryptography>=41; extra == 'dnssec' +Provides-Extra: doh +Requires-Dist: h2>=4.1.0; extra == 'doh' +Requires-Dist: httpcore>=1.0.0; extra == 'doh' +Requires-Dist: httpx>=0.26.0; extra == 'doh' +Provides-Extra: doq +Requires-Dist: aioquic>=0.9.25; extra == 'doq' +Provides-Extra: idna +Requires-Dist: idna>=3.6; extra == 'idna' +Provides-Extra: trio +Requires-Dist: trio>=0.23; extra == 'trio' +Provides-Extra: wmi +Requires-Dist: wmi>=1.5.1; extra == 'wmi' +Description-Content-Type: text/markdown + +# dnspython + +[![Build Status](https://github.com/rthalley/dnspython/actions/workflows/python-package.yml/badge.svg)](https://github.com/rthalley/dnspython/actions/) +[![Documentation Status](https://readthedocs.org/projects/dnspython/badge/?version=latest)](https://dnspython.readthedocs.io/en/latest/?badge=latest) +[![PyPI version](https://badge.fury.io/py/dnspython.svg)](https://badge.fury.io/py/dnspython) +[![License: ISC](https://img.shields.io/badge/License-ISC-brightgreen.svg)](https://opensource.org/licenses/ISC) +[![Coverage](https://codecov.io/github/rthalley/dnspython/coverage.svg?branch=master)](https://codecov.io/github/rthalley/dnspython) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + +## INTRODUCTION + +dnspython is a DNS toolkit for Python. It supports almost all record types. It +can be used for queries, zone transfers, and dynamic updates. It supports TSIG +authenticated messages and EDNS0. + +dnspython provides both high and low level access to DNS. The high level classes +perform queries for data of a given name, type, and class, and return an answer +set. The low level classes allow direct manipulation of DNS zones, messages, +names, and records. + +To see a few of the ways dnspython can be used, look in the `examples/` +directory. + +dnspython is a utility to work with DNS, `/etc/hosts` is thus not used. For +simple forward DNS lookups, it's better to use `socket.getaddrinfo()` or +`socket.gethostbyname()`. + +dnspython originated at Nominum where it was developed +to facilitate the testing of DNS software. + +## ABOUT THIS RELEASE + +This is dnspython 2.6.1. +Please read +[What's New](https://dnspython.readthedocs.io/en/stable/whatsnew.html) for +information about the changes in this release. + +## INSTALLATION + +* Many distributions have dnspython packaged for you, so you should + check there first. +* To use a wheel downloaded from PyPi, run: + + pip install dnspython + +* To install from the source code, go into the top-level of the source code + and run: + +``` + pip install --upgrade pip build + python -m build + pip install dist/*.whl +``` + +* To install the latest from the master branch, run `pip install git+https://github.com/rthalley/dnspython.git` + +Dnspython's default installation does not depend on any modules other than +those in the Python standard library. To use some features, additional modules +must be installed. For convenience, pip options are defined for the +requirements. + +If you want to use DNS-over-HTTPS, run +`pip install dnspython[doh]`. + +If you want to use DNSSEC functionality, run +`pip install dnspython[dnssec]`. + +If you want to use internationalized domain names (IDNA) +functionality, you must run +`pip install dnspython[idna]` + +If you want to use the Trio asynchronous I/O package, run +`pip install dnspython[trio]`. + +If you want to use WMI on Windows to determine the active DNS settings +instead of the default registry scanning method, run +`pip install dnspython[wmi]`. + +If you want to try the experimental DNS-over-QUIC code, run +`pip install dnspython[doq]`. + +Note that you can install any combination of the above, e.g.: +`pip install dnspython[doh,dnssec,idna]` + +### Notices + +Python 2.x support ended with the release of 1.16.0. Dnspython 2.0.0 through +2.2.x support Python 3.6 and later. For dnspython 2.3.x, the minimum +supported Python version is 3.7, and for 2.4.x the minimum supported verison is 3.8. +We plan to align future support with the lifetime of the Python 3 versions. + +Documentation has moved to +[dnspython.readthedocs.io](https://dnspython.readthedocs.io). diff --git a/venv/Lib/site-packages/dnspython-2.6.1.dist-info/RECORD b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/RECORD new file mode 100644 index 00000000..c6529618 --- /dev/null +++ b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/RECORD @@ -0,0 +1,290 @@ +dns/__init__.py,sha256=YJZtDG14Idw5ui3h1nWooSwPM9gsxQgB8M0GBZ3aly0,1663 +dns/__pycache__/__init__.cpython-312.pyc,, +dns/__pycache__/_asyncbackend.cpython-312.pyc,, +dns/__pycache__/_asyncio_backend.cpython-312.pyc,, +dns/__pycache__/_ddr.cpython-312.pyc,, +dns/__pycache__/_features.cpython-312.pyc,, +dns/__pycache__/_immutable_ctx.cpython-312.pyc,, +dns/__pycache__/_trio_backend.cpython-312.pyc,, +dns/__pycache__/asyncbackend.cpython-312.pyc,, +dns/__pycache__/asyncquery.cpython-312.pyc,, +dns/__pycache__/asyncresolver.cpython-312.pyc,, +dns/__pycache__/dnssec.cpython-312.pyc,, +dns/__pycache__/dnssectypes.cpython-312.pyc,, +dns/__pycache__/e164.cpython-312.pyc,, +dns/__pycache__/edns.cpython-312.pyc,, +dns/__pycache__/entropy.cpython-312.pyc,, +dns/__pycache__/enum.cpython-312.pyc,, +dns/__pycache__/exception.cpython-312.pyc,, +dns/__pycache__/flags.cpython-312.pyc,, +dns/__pycache__/grange.cpython-312.pyc,, +dns/__pycache__/immutable.cpython-312.pyc,, +dns/__pycache__/inet.cpython-312.pyc,, +dns/__pycache__/ipv4.cpython-312.pyc,, +dns/__pycache__/ipv6.cpython-312.pyc,, +dns/__pycache__/message.cpython-312.pyc,, +dns/__pycache__/name.cpython-312.pyc,, +dns/__pycache__/namedict.cpython-312.pyc,, +dns/__pycache__/nameserver.cpython-312.pyc,, +dns/__pycache__/node.cpython-312.pyc,, +dns/__pycache__/opcode.cpython-312.pyc,, +dns/__pycache__/query.cpython-312.pyc,, +dns/__pycache__/rcode.cpython-312.pyc,, +dns/__pycache__/rdata.cpython-312.pyc,, +dns/__pycache__/rdataclass.cpython-312.pyc,, +dns/__pycache__/rdataset.cpython-312.pyc,, +dns/__pycache__/rdatatype.cpython-312.pyc,, +dns/__pycache__/renderer.cpython-312.pyc,, +dns/__pycache__/resolver.cpython-312.pyc,, +dns/__pycache__/reversename.cpython-312.pyc,, +dns/__pycache__/rrset.cpython-312.pyc,, +dns/__pycache__/serial.cpython-312.pyc,, +dns/__pycache__/set.cpython-312.pyc,, +dns/__pycache__/tokenizer.cpython-312.pyc,, +dns/__pycache__/transaction.cpython-312.pyc,, +dns/__pycache__/tsig.cpython-312.pyc,, +dns/__pycache__/tsigkeyring.cpython-312.pyc,, +dns/__pycache__/ttl.cpython-312.pyc,, +dns/__pycache__/update.cpython-312.pyc,, +dns/__pycache__/version.cpython-312.pyc,, +dns/__pycache__/versioned.cpython-312.pyc,, +dns/__pycache__/win32util.cpython-312.pyc,, +dns/__pycache__/wire.cpython-312.pyc,, +dns/__pycache__/xfr.cpython-312.pyc,, +dns/__pycache__/zone.cpython-312.pyc,, +dns/__pycache__/zonefile.cpython-312.pyc,, +dns/__pycache__/zonetypes.cpython-312.pyc,, +dns/_asyncbackend.py,sha256=Ny0kGesm9wbLBnt-0u-tANOKsxcYt2jbMuRoRz_JZUA,2360 +dns/_asyncio_backend.py,sha256=q58xPdqAOLmOYOux8GFRyiH-fSZ7jiwZF-Jg2vHjYSU,8971 +dns/_ddr.py,sha256=rHXKC8kncCTT9N4KBh1flicl79nyDjQ-DDvq30MJ3B8,5247 +dns/_features.py,sha256=MUeyfM_nMYAYkasGfbY7I_15JmwftaZjseuP1L43MT0,2384 +dns/_immutable_ctx.py,sha256=gtoCLMmdHXI23zt5lRSIS3A4Ca3jZJngebdoFFOtiwU,2459 +dns/_trio_backend.py,sha256=Vab_wR2CxDgy2Jz3iM_64FZmP_kMUN9j8LS4eNl-Oig,8269 +dns/asyncbackend.py,sha256=82fXTFls_m7F_ekQbgUGOkoBbs4BI-GBLDZAWNGUvJ0,2796 +dns/asyncquery.py,sha256=Q7u04mbbqCoe9VxsqRcsWTPxgH2Cx49eWWgi2wUyZHU,26850 +dns/asyncresolver.py,sha256=GD86dCyW9YGKs6SggWXwBKEXifW7Qdx4cEAGFKY6fA4,17852 +dns/dnssec.py,sha256=xyYW1cf6eeFNXROrEs1pyY4TgC8jlmUiiootaPbVjjY,40693 +dns/dnssecalgs/__init__.py,sha256=DcnGIbL6m-USPSiLWHSw511awB7dytlljvCOOmzchS0,4279 +dns/dnssecalgs/__pycache__/__init__.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/base.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/cryptography.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/dsa.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/ecdsa.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/eddsa.cpython-312.pyc,, +dns/dnssecalgs/__pycache__/rsa.cpython-312.pyc,, +dns/dnssecalgs/base.py,sha256=hsFHFr_eCYeDcI0eU6_WiLlOYL0GR4QJ__sXoMrIAfE,2446 +dns/dnssecalgs/cryptography.py,sha256=3uqMfRm-zCkJPOrxUqlu9CmdxIMy71dVor9eAHi0wZM,2425 +dns/dnssecalgs/dsa.py,sha256=hklh_HkT_ZffQBHQ7t6pKUStTH4x5nXlz8R9RUP72aY,3497 +dns/dnssecalgs/ecdsa.py,sha256=GWrJgEXAK08MCdbLk7LQcD2ajKqW_dbONWXh3wieLzw,3016 +dns/dnssecalgs/eddsa.py,sha256=9lQQZ92f2PiIhhylieInO-19aSTDQiyoY8X2kTkGlcs,1914 +dns/dnssecalgs/rsa.py,sha256=jWkhWKByylIo7Y9gAiiO8t8bowF8IZ0siVjgZpdhLSE,3555 +dns/dnssectypes.py,sha256=CyeuGTS_rM3zXr8wD9qMT9jkzvVfTY2JWckUcogG83E,1799 +dns/e164.py,sha256=EsK8cnOtOx7kQ0DmSwibcwkzp6efMWjbRiTyHZO8Q-M,3978 +dns/edns.py,sha256=d8QWhmRd6qlaGfO-tY6iDQZt9XUiyfJfKdjoGjvwOU4,15263 +dns/entropy.py,sha256=qkG8hXDLzrJS6R5My26iA59c0RhPwJNzuOhOCAZU5Bw,4242 +dns/enum.py,sha256=EepaunPKixTSrascy7iAe9UQEXXxP_MB5Gx4jUpHIhg,3691 +dns/exception.py,sha256=FphWy-JLRG06UUUq2VmUGwdPA1xWja_8YfrcffRFlQs,5957 +dns/flags.py,sha256=cQ3kTFyvcKiWHAxI5AwchNqxVOrsIrgJ6brgrH42Wq8,2750 +dns/grange.py,sha256=HA623Mv2mZDmOK_BZNDDakT0L6EHsMQU9lFFkE8dKr0,2148 +dns/immutable.py,sha256=InrtpKvPxl-74oYbzsyneZwAuX78hUqeG22f2aniZbk,2017 +dns/inet.py,sha256=j6jQs3K_ehVhDv-i4jwCKePr5HpEiSzvOXQ4uhgn1sU,5772 +dns/ipv4.py,sha256=qEUXtlqWDH_blicj6VMvyQhfX7-BF0gB_lWJliV-2FI,2552 +dns/ipv6.py,sha256=EyiF5T8t2oww9-W4ZA5Zk2GGnOjTy_uZ50CI7maed_8,6600 +dns/message.py,sha256=DyUtBHArPX-WGj_AtcngyIXZNpLppLZX-6q9TryL_wI,65993 +dns/name.py,sha256=eaR1wVR0rErnD3EPANquCuyqpbxy5VfFVhMenWlBPDE,42672 +dns/namedict.py,sha256=hJRYpKeQv6Bd2LaUOPV0L_a0eXEIuqgggPXaH4c3Tow,4000 +dns/nameserver.py,sha256=VkYRnX5wQ7RihAD6kYqidI_hb9NgKJSAE0GaYulNpHY,9909 +dns/node.py,sha256=NGZa0AUMq-CNledJ6wn1Rx6TFYc703cH2OraLysoNWM,12663 +dns/opcode.py,sha256=I6JyuFUL0msja_BYm6bzXHfbbfqUod_69Ss4xcv8xWQ,2730 +dns/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +dns/query.py,sha256=vB8C5u6HyjPWrEx9kUdTSg3kxrOoWbPGu7brC0eetIM,54832 +dns/quic/__init__.py,sha256=F6BybmRKnMGc4W8nX7K98PeyXiSwy1FHb_bJeA2lQSw,2202 +dns/quic/__pycache__/__init__.cpython-312.pyc,, +dns/quic/__pycache__/_asyncio.cpython-312.pyc,, +dns/quic/__pycache__/_common.cpython-312.pyc,, +dns/quic/__pycache__/_sync.cpython-312.pyc,, +dns/quic/__pycache__/_trio.cpython-312.pyc,, +dns/quic/_asyncio.py,sha256=vv4RR3Ol0Y1ZOj7rPAzXxy1UcWjPvhTGQvVkMidPs-o,8159 +dns/quic/_common.py,sha256=06TfauL2VciPYSfrL4gif1eR1rm-TRkQhS2Puuk5URU,7282 +dns/quic/_sync.py,sha256=kE0PRavzd27GPQ9UgYApXZ6SGSW2LwCt8k6XWUvrbVE,8133 +dns/quic/_trio.py,sha256=9zCCBtDs6GAtY_b8ck-A17QMiLZ0njjhVtfFT5qMP7s,7670 +dns/rcode.py,sha256=N6JjrIQjCdJy0boKIp8Hcky5tm__LSDscpDz3rE_sgU,4156 +dns/rdata.py,sha256=9cXM9Y9MK2hy9w5mYqmP-r7_aKjHosigfNn_SfqfGGw,29456 +dns/rdataclass.py,sha256=TK4W4ywB1L_X7EZqk2Gmwnu7vdQpolQF5DtQWyNk5xo,2984 +dns/rdataset.py,sha256=96gTaEIcYEL348VKtTOMAazXBVNtk7m0Xez0mF1eg4I,16756 +dns/rdatatype.py,sha256=gIdYZ0iHRlgiTEO-ftobUANmaAmjTnNc4JljMaP1OnQ,7339 +dns/rdtypes/ANY/AFSDB.py,sha256=k75wMwreF1DAfDymu4lHh16BUx7ulVP3PLeQBZnkurY,1661 +dns/rdtypes/ANY/AMTRELAY.py,sha256=19jfS61mT1CQT-8vf67ZylhDS9JVRVp4WCbFE-7l0jM,3381 +dns/rdtypes/ANY/AVC.py,sha256=SpsXYzlBirRWN0mGnQe0MdN6H8fvlgXPJX5PjOHnEak,1024 +dns/rdtypes/ANY/CAA.py,sha256=AHh59Is-4WiVWd26yovnPM3hXqKS-yx7IWfXSS0NZhE,2511 +dns/rdtypes/ANY/CDNSKEY.py,sha256=bJAdrBMsFHIJz8TF1AxZoNbdxVWBCRTG-bR_uR_r_G4,1225 +dns/rdtypes/ANY/CDS.py,sha256=Y9nIRUCAabztVLbxm2SXAdYapFemCOUuGh5JqroCDUs,1163 +dns/rdtypes/ANY/CERT.py,sha256=2Cu2LQM6-K4darqhHv1EM_blmpYpnrBIIX1GnL_rxKE,3533 +dns/rdtypes/ANY/CNAME.py,sha256=IHGGq2BDpeKUahTr1pvyBQgm0NGBI_vQ3Vs5mKTXO4w,1206 +dns/rdtypes/ANY/CSYNC.py,sha256=KkZ_rG6PfeL14il97nmJGWWmUGGS5o9nd2EqbJqOuYo,2439 +dns/rdtypes/ANY/DLV.py,sha256=J-pOrw5xXsDoaB9G0r6znlYXJtqtcqhsl1OXs6CPRU4,986 +dns/rdtypes/ANY/DNAME.py,sha256=yqXRtx4dAWwB4YCCv-qW6uaxeGhg2LPQ2uyKwWaMdXs,1150 +dns/rdtypes/ANY/DNSKEY.py,sha256=MD8HUVH5XXeAGOnFWg5aVz_w-2tXYwCeVXmzExhiIeQ,1223 +dns/rdtypes/ANY/DS.py,sha256=_gf8vk1O_uY8QXFjsfUw-bny-fm6e-QpCk3PT0JCyoM,995 +dns/rdtypes/ANY/EUI48.py,sha256=x0BkK0sY_tgzuCwfDYpw6tyuChHjjtbRpAgYhO0Y44o,1151 +dns/rdtypes/ANY/EUI64.py,sha256=1jCff2-SXHJLDnNDnMW8Cd_o-ok0P3x6zKy_bcCU5h4,1161 +dns/rdtypes/ANY/GPOS.py,sha256=pM3i6Tn4qwHWOGOuIuW9FENPlSXT_R4xsNJeGrrABc8,4433 +dns/rdtypes/ANY/HINFO.py,sha256=vYGCHGZmYOhtmxHlvPqrK7m4pBg3MSY5herBsKJTbKQ,2249 +dns/rdtypes/ANY/HIP.py,sha256=Ucrnndu3xDyHFB93AVUA3xW-r61GR50kpRHLyLacvZY,3228 +dns/rdtypes/ANY/ISDN.py,sha256=uymYB-ayZSBob6jQgXe4EefNB8-JMLW6VfxXn7ncwPg,2713 +dns/rdtypes/ANY/L32.py,sha256=TMz2kdGCd0siiQZyiocVDCSnvkOdjhUuYRFyf8o622M,1286 +dns/rdtypes/ANY/L64.py,sha256=sb2BjuPA0PQt67nEyT9rBt759C9e6lH71d3EJHGGnww,1592 +dns/rdtypes/ANY/LOC.py,sha256=hLkzgCxqEhg6fn5Uf-DJigKEIE6oavQ8rLpajp3HDLs,12024 +dns/rdtypes/ANY/LP.py,sha256=wTsKIjtK6vh66qZRLSsiE0k54GO8ieVBGZH8dzVvFnE,1338 +dns/rdtypes/ANY/MX.py,sha256=qQk83idY0-SbRMDmB15JOpJi7cSyiheF-ALUD0Ev19E,995 +dns/rdtypes/ANY/NID.py,sha256=N7Xx4kXf3yVAocTlCXQeJ3BtiQNPFPQVdL1iMuyl5W4,1544 +dns/rdtypes/ANY/NINFO.py,sha256=bdL_-6Bejb2EH-xwR1rfSr_9E3SDXLTAnov7x2924FI,1041 +dns/rdtypes/ANY/NS.py,sha256=ThfaPalUlhbyZyNyvBM3k-7onl3eJKq5wCORrOGtkMM,995 +dns/rdtypes/ANY/NSEC.py,sha256=6uRn1SxNuLRNumeoc76BkpECF8ztuqyaYviLjFe7FkQ,2475 +dns/rdtypes/ANY/NSEC3.py,sha256=696h-Zz30bmcT0n1rqoEtS5wqE6jIgsVGzaw5TfdGJo,4331 +dns/rdtypes/ANY/NSEC3PARAM.py,sha256=08p6NWS4DiLav1wOuPbxUxB9MtY2IPjfOMCtJwzzMuA,2635 +dns/rdtypes/ANY/OPENPGPKEY.py,sha256=Va0FGo_8vm1OeX62N5iDTWukAdLwrjTXIZeQ6oanE78,1851 +dns/rdtypes/ANY/OPT.py,sha256=W36RslT_Psp95OPUC70knumOYjKpaRHvGT27I-NV2qc,2561 +dns/rdtypes/ANY/PTR.py,sha256=5HcR1D77Otyk91vVY4tmqrfZfSxSXWyWvwIW-rIH5gc,997 +dns/rdtypes/ANY/RP.py,sha256=5Dgaava9mbLKr87XgbfKZPrunYPBaN8ejNzpmbW6r4s,2184 +dns/rdtypes/ANY/RRSIG.py,sha256=O8vwzS7ldfaj_x8DypvEGFsDSb7al-D7OEnprA3QQoo,4922 +dns/rdtypes/ANY/RT.py,sha256=2t9q3FZQ28iEyceeU25KU2Ur0T5JxELAu8BTwfOUgVw,1013 +dns/rdtypes/ANY/SMIMEA.py,sha256=6yjHuVDfIEodBU9wxbCGCDZ5cWYwyY6FCk-aq2VNU0s,222 +dns/rdtypes/ANY/SOA.py,sha256=Cn8yrag1YvrvwivQgWg-KXmOCaVQVdFHSkFF77w-CE0,3145 +dns/rdtypes/ANY/SPF.py,sha256=rA3Srs9ECQx-37lqm7Zf7aYmMpp_asv4tGS8_fSQ-CU,1022 +dns/rdtypes/ANY/SSHFP.py,sha256=l6TZH2R0kytiZGWez_g-Lq94o5a2xMuwLKwUwsPMx5w,2530 +dns/rdtypes/ANY/TKEY.py,sha256=HjJMIMl4Qb1Nt1JXS6iAymzd2nv_zdLWTt887PJU_5w,4931 +dns/rdtypes/ANY/TLSA.py,sha256=cytzebS3W7FFr9qeJ9gFSHq_bOwUk9aRVlXWHfnVrRs,218 +dns/rdtypes/ANY/TSIG.py,sha256=4fNQJSNWZXUKZejCciwQuUJtTw2g-YbPmqHrEj_pitg,4750 +dns/rdtypes/ANY/TXT.py,sha256=F1U9gIAhwXIV4UVT7CwOCEn_su6G1nJIdgWJsLktk20,1000 +dns/rdtypes/ANY/URI.py,sha256=dpcS8KwcJ2WJ7BkOp4CZYaUyRuw7U2S9GzvVwKUihQg,2921 +dns/rdtypes/ANY/X25.py,sha256=PxjYTKIuoq44LT2S2JHWOV8BOFD0ASqjq0S5VBeGkFM,1944 +dns/rdtypes/ANY/ZONEMD.py,sha256=JQicv69EvUxh4FCT7eZSLzzU5L5brw_dSM65Um2t5lQ,2393 +dns/rdtypes/ANY/__init__.py,sha256=Pox71HfsEnGGB1PGU44pwrrmjxPLQlA-IbX6nQRoA2M,1497 +dns/rdtypes/ANY/__pycache__/AFSDB.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/AMTRELAY.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/AVC.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CAA.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CDNSKEY.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CDS.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CERT.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CNAME.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/CSYNC.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/DLV.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/DNAME.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/DNSKEY.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/DS.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/EUI48.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/EUI64.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/GPOS.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/HINFO.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/HIP.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/ISDN.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/L32.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/L64.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/LOC.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/LP.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/MX.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NID.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NINFO.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NS.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NSEC.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NSEC3.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/NSEC3PARAM.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/OPENPGPKEY.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/OPT.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/PTR.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/RP.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/RRSIG.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/RT.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/SMIMEA.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/SOA.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/SPF.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/SSHFP.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/TKEY.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/TLSA.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/TSIG.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/TXT.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/URI.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/X25.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/ZONEMD.cpython-312.pyc,, +dns/rdtypes/ANY/__pycache__/__init__.cpython-312.pyc,, +dns/rdtypes/CH/A.py,sha256=3S3OhOkSc7_ZsZBVB4GhTS19LPrrZ-yQ8sAp957qEgI,2216 +dns/rdtypes/CH/__init__.py,sha256=GD9YeDKb9VBDo-J5rrChX1MWEGyQXuR9Htnbhg_iYLc,923 +dns/rdtypes/CH/__pycache__/A.cpython-312.pyc,, +dns/rdtypes/CH/__pycache__/__init__.cpython-312.pyc,, +dns/rdtypes/IN/A.py,sha256=FfFn3SqbpneL9Ky63COP50V2ZFxqS1ldCKJh39Enwug,1814 +dns/rdtypes/IN/AAAA.py,sha256=AxrOlYy-1TTTWeQypDKeXrDCrdHGor0EKCE4fxzSQGo,1820 +dns/rdtypes/IN/APL.py,sha256=ppyFwn0KYMdyDzphxd0BUhgTmZv0QnDMRLjzQQM793U,5097 +dns/rdtypes/IN/DHCID.py,sha256=zRUh_EOxUPVpJjWY5m7taX8q4Oz5K70785ZtKv5OTCU,1856 +dns/rdtypes/IN/HTTPS.py,sha256=P-IjwcvDQMmtoBgsDHglXF7KgLX73G6jEDqCKsnaGpQ,220 +dns/rdtypes/IN/IPSECKEY.py,sha256=RyIy9K0Yt0uJRjdr6cj5S95ELHHbl--0xV-Qq9O3QQk,3290 +dns/rdtypes/IN/KX.py,sha256=K1JwItL0n5G-YGFCjWeh0C9DyDD8G8VzicsBeQiNAv0,1013 +dns/rdtypes/IN/NAPTR.py,sha256=SaOK-0hIYImwLtb5Hqewi-e49ykJaQiLNvk8ZzNoG7Q,3750 +dns/rdtypes/IN/NSAP.py,sha256=3OUpPOSOxU8fcdi0Oe6Ex2ERXcQ-U3iNf6FftZMtNOw,2165 +dns/rdtypes/IN/NSAP_PTR.py,sha256=iTxlV6fr_Y9lqivLLncSHxEhmFqz5UEElDW3HMBtuCU,1015 +dns/rdtypes/IN/PX.py,sha256=vHDNN2rfLObuUKwpYDIvpPB482BqXlHA-ZQpQn9Sb_E,2756 +dns/rdtypes/IN/SRV.py,sha256=a0zGaUwzvih_a4Q9BViUTFs7NZaCqgl7mls3-KRVHm8,2769 +dns/rdtypes/IN/SVCB.py,sha256=HeFmi2v01F00Hott8FlvQ4R7aPxFmT7RF-gt45R5K_M,218 +dns/rdtypes/IN/WKS.py,sha256=kErSG5AO2qIuot_hkMHnQuZB1_uUzUirNdqBoCp97rk,3652 +dns/rdtypes/IN/__init__.py,sha256=HbI8aw9HWroI6SgEvl8Sx6FdkDswCCXMbSRuJy5o8LQ,1083 +dns/rdtypes/IN/__pycache__/A.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/AAAA.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/APL.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/DHCID.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/HTTPS.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/IPSECKEY.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/KX.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/NAPTR.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/NSAP.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/NSAP_PTR.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/PX.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/SRV.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/SVCB.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/WKS.cpython-312.pyc,, +dns/rdtypes/IN/__pycache__/__init__.cpython-312.pyc,, +dns/rdtypes/__init__.py,sha256=NYizfGglJfhqt_GMtSSXf7YQXIEHHCiJ_Y_qaLVeiOI,1073 +dns/rdtypes/__pycache__/__init__.cpython-312.pyc,, +dns/rdtypes/__pycache__/dnskeybase.cpython-312.pyc,, +dns/rdtypes/__pycache__/dsbase.cpython-312.pyc,, +dns/rdtypes/__pycache__/euibase.cpython-312.pyc,, +dns/rdtypes/__pycache__/mxbase.cpython-312.pyc,, +dns/rdtypes/__pycache__/nsbase.cpython-312.pyc,, +dns/rdtypes/__pycache__/svcbbase.cpython-312.pyc,, +dns/rdtypes/__pycache__/tlsabase.cpython-312.pyc,, +dns/rdtypes/__pycache__/txtbase.cpython-312.pyc,, +dns/rdtypes/__pycache__/util.cpython-312.pyc,, +dns/rdtypes/dnskeybase.py,sha256=FoDllfa9Pz2j2rf45VyUUYUsIt3kjjrwDy6LxrlPb5s,2856 +dns/rdtypes/dsbase.py,sha256=I85Aps1lBsiItdqGpsNY1O8icosfPtkWjiUn1J1lLUQ,3427 +dns/rdtypes/euibase.py,sha256=umN9A3VNw1TziAVtePvUses2jWPcynxINvjgyndPCdQ,2630 +dns/rdtypes/mxbase.py,sha256=DzjbiKoAAgpqbhwMBIFGA081jR5_doqGAq-kLvy2mns,3196 +dns/rdtypes/nsbase.py,sha256=tueXVV6E8lelebOmrmoOPq47eeRvOpsxHVXH4cOFxcs,2323 +dns/rdtypes/svcbbase.py,sha256=TQRT52m8F2NpSJsHUkTFS-hrkyhcIoAodW6bBHED4CY,16674 +dns/rdtypes/tlsabase.py,sha256=pIiWem6sF4IwyyKmyqx5xg55IG0w3K9r502Yx8PdziA,2596 +dns/rdtypes/txtbase.py,sha256=K4v2ulFu0DxPjxyf_Ul7YRjfBpUO-Ay_ChnR_Wx-ywA,3601 +dns/rdtypes/util.py,sha256=6AGQ-k3mLNlx4Ep_FiDABj1WVumUUGs3zQ6X-2iISec,9003 +dns/renderer.py,sha256=5THf1iKql2JPL2sKZt2-b4zqHKfk_vlx0FEfPtMJysY,11254 +dns/resolver.py,sha256=wagpUIu8Oh12O-zk48U30A6VQQOspjfibU4Ls2So-kM,73552 +dns/reversename.py,sha256=zoqXEbMZXm6R13nXbJHgTsf6L2C6uReODj6mqSHrTiE,3828 +dns/rrset.py,sha256=J-oQPEPJuKueLLiz1FN08P-ys9fjHhPWuwpDdrL4UTQ,9170 +dns/serial.py,sha256=-t5rPW-TcJwzBMfIJo7Tl-uDtaYtpqOfCVYx9dMaDCY,3606 +dns/set.py,sha256=Lr1qhyqywoobNkj9sAfdovoFy9vBfkz2eHdTCc7sZRs,9088 +dns/tokenizer.py,sha256=Dcc3lQgEIHCVZBuO6FaKWEojtPSd3EuaUC4vQA-spnk,23583 +dns/transaction.py,sha256=ZlnDT-V4W01J3cS501GaRLVhE9t1jZdnEZxPyZ0Cvg4,22636 +dns/tsig.py,sha256=I-Y-c3WMBX11bVioy5puFly2BhlpptUz82ikahxuh1c,11413 +dns/tsigkeyring.py,sha256=Z0xZemcU3XjZ9HlxBYv2E2PSuIhaFreqLDlD7HcmZDA,2633 +dns/ttl.py,sha256=fWFkw8qfk6saTp7lAPxZOuD3U3TRxVRvIpljQnG-01I,2979 +dns/update.py,sha256=y9d6LOO8xrUaH2UrZhy3ssnx8bJEsxqTArw5V8XqBRs,12243 +dns/version.py,sha256=sRMqE5tzPhXEzz-SEvdN82pP77xF_i1iELxaJN0roDE,1926 +dns/versioned.py,sha256=3YQj8mzGmZEsjnuVJJjcWopVmDKYLhEj4hEGTLEwzco,11765 +dns/win32util.py,sha256=NEjd5RXQU2aV1WsBMoIGZmXyqqKCxS4WYq9HqFQoVig,9107 +dns/wire.py,sha256=vy0SolgECbO1UXB4dnhXhDeFKOJT29nQxXvSfKOgA5s,2830 +dns/xfr.py,sha256=FKkKO-kSpyE1vHU5mnoPIP4YxiCl5gG7E5wOgY_4GO8,13273 +dns/zone.py,sha256=lLAarSxPtpx4Sw29OQ0ifPshD4QauGu8RnPh2dEropA,52086 +dns/zonefile.py,sha256=9pgkO0pV8Js53Oq9ZKOSbpFkGS5r_orU-25tmufGP9M,27929 +dns/zonetypes.py,sha256=HrQNZxZ_gWLWI9dskix71msi9wkYK5pgrBBbPb1T74Y,690 +dnspython-2.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +dnspython-2.6.1.dist-info/METADATA,sha256=2GJFv-NqkwIytog5VQe0wPtZKoS016uyYfG76lqftto,5808 +dnspython-2.6.1.dist-info/RECORD,, +dnspython-2.6.1.dist-info/WHEEL,sha256=TJPnKdtrSue7xZ_AVGkp9YXcvDrobsjBds1du3Nx6dc,87 +dnspython-2.6.1.dist-info/licenses/LICENSE,sha256=w-o_9WVLMpwZ07xfdIGvYjw93tSmFFWFSZ-EOtPXQc0,1526 diff --git a/venv/Lib/site-packages/dnspython-2.6.1.dist-info/WHEEL b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/WHEEL new file mode 100644 index 00000000..5998f3aa --- /dev/null +++ b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.21.1 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/venv/Lib/site-packages/dnspython-2.6.1.dist-info/licenses/LICENSE b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/licenses/LICENSE new file mode 100644 index 00000000..390a726d --- /dev/null +++ b/venv/Lib/site-packages/dnspython-2.6.1.dist-info/licenses/LICENSE @@ -0,0 +1,35 @@ +ISC License + +Copyright (C) Dnspython Contributors + +Permission to use, copy, modify, and/or distribute this software for +any purpose with or without fee is hereby granted, provided that the +above copyright notice and this permission notice appear in all +copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL +WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE +AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR +PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. + + + +Copyright (C) 2001-2017 Nominum, Inc. +Copyright (C) Google Inc. + +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose with or without fee is hereby granted, +provided that the above copyright notice and this permission notice +appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/venv/Lib/site-packages/gridfs/__init__.py b/venv/Lib/site-packages/gridfs/__init__.py new file mode 100644 index 00000000..8d01fefc --- /dev/null +++ b/venv/Lib/site-packages/gridfs/__init__.py @@ -0,0 +1,1000 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GridFS is a specification for storing large objects in Mongo. + +The :mod:`gridfs` package is an implementation of GridFS on top of +:mod:`pymongo`, exposing a file-like interface. + +.. seealso:: The MongoDB documentation on `gridfs `_. +""" +from __future__ import annotations + +from collections import abc +from typing import Any, Mapping, Optional, cast + +from bson.objectid import ObjectId +from gridfs.errors import NoFile +from gridfs.grid_file import ( + DEFAULT_CHUNK_SIZE, + GridIn, + GridOut, + GridOutCursor, + _clear_entity_type_registry, + _disallow_transactions, +) +from pymongo import ASCENDING, DESCENDING, _csot +from pymongo.client_session import ClientSession +from pymongo.collection import Collection +from pymongo.common import validate_string +from pymongo.database import Database +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import _ServerMode +from pymongo.write_concern import WriteConcern + +__all__ = [ + "GridFS", + "GridFSBucket", + "NoFile", + "DEFAULT_CHUNK_SIZE", + "GridIn", + "GridOut", + "GridOutCursor", +] + + +class GridFS: + """An instance of GridFS on top of a single Database.""" + + def __init__(self, database: Database, collection: str = "fs"): + """Create a new instance of :class:`GridFS`. + + Raises :class:`TypeError` if `database` is not an instance of + :class:`~pymongo.database.Database`. + + :param database: database to use + :param collection: root collection to use + + .. versionchanged:: 4.0 + Removed the `disable_md5` parameter. See + :ref:`removed-gridfs-checksum` for details. + + .. versionchanged:: 3.11 + Running a GridFS operation in a transaction now always raises an + error. GridFS does not support multi-document transactions. + + .. versionchanged:: 3.7 + Added the `disable_md5` parameter. + + .. versionchanged:: 3.1 + Indexes are only ensured on the first write to the DB. + + .. versionchanged:: 3.0 + `database` must use an acknowledged + :attr:`~pymongo.database.Database.write_concern` + + .. seealso:: The MongoDB documentation on `gridfs `_. + """ + if not isinstance(database, Database): + raise TypeError("database must be an instance of Database") + + database = _clear_entity_type_registry(database) + + if not database.write_concern.acknowledged: + raise ConfigurationError("database must use acknowledged write_concern") + + self.__collection = database[collection] + self.__files = self.__collection.files + self.__chunks = self.__collection.chunks + + def new_file(self, **kwargs: Any) -> GridIn: + """Create a new file in GridFS. + + Returns a new :class:`~gridfs.grid_file.GridIn` instance to + which data can be written. Any keyword arguments will be + passed through to :meth:`~gridfs.grid_file.GridIn`. + + If the ``"_id"`` of the file is manually specified, it must + not already exist in GridFS. Otherwise + :class:`~gridfs.errors.FileExists` is raised. + + :param kwargs: keyword arguments for file creation + """ + return GridIn(self.__collection, **kwargs) + + def put(self, data: Any, **kwargs: Any) -> Any: + """Put data in GridFS as a new file. + + Equivalent to doing:: + + with fs.new_file(**kwargs) as f: + f.write(data) + + `data` can be either an instance of :class:`bytes` or a file-like + object providing a :meth:`read` method. If an `encoding` keyword + argument is passed, `data` can also be a :class:`str` instance, which + will be encoded as `encoding` before being written. Any keyword + arguments will be passed through to the created file - see + :meth:`~gridfs.grid_file.GridIn` for possible arguments. Returns the + ``"_id"`` of the created file. + + If the ``"_id"`` of the file is manually specified, it must + not already exist in GridFS. Otherwise + :class:`~gridfs.errors.FileExists` is raised. + + :param data: data to be written as a file. + :param kwargs: keyword arguments for file creation + + .. versionchanged:: 3.0 + w=0 writes to GridFS are now prohibited. + """ + with GridIn(self.__collection, **kwargs) as grid_file: + grid_file.write(data) + return grid_file._id + + def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut: + """Get a file from GridFS by ``"_id"``. + + Returns an instance of :class:`~gridfs.grid_file.GridOut`, + which provides a file-like interface for reading. + + :param file_id: ``"_id"`` of the file to get + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + gout = GridOut(self.__collection, file_id, session=session) + + # Raise NoFile now, instead of on first attribute access. + gout._ensure_file() + return gout + + def get_version( + self, + filename: Optional[str] = None, + version: Optional[int] = -1, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> GridOut: + """Get a file from GridFS by ``"filename"`` or metadata fields. + + Returns a version of the file in GridFS whose filename matches + `filename` and whose metadata fields match the supplied keyword + arguments, as an instance of :class:`~gridfs.grid_file.GridOut`. + + Version numbering is a convenience atop the GridFS API provided + by MongoDB. If more than one file matches the query (either by + `filename` alone, by metadata fields, or by a combination of + both), then version ``-1`` will be the most recently uploaded + matching file, ``-2`` the second most recently + uploaded, etc. Version ``0`` will be the first version + uploaded, ``1`` the second version, etc. So if three versions + have been uploaded, then version ``0`` is the same as version + ``-3``, version ``1`` is the same as version ``-2``, and + version ``2`` is the same as version ``-1``. + + Raises :class:`~gridfs.errors.NoFile` if no such version of + that file exists. + + :param filename: ``"filename"`` of the file to get, or `None` + :param version: version of the file to get (defaults + to -1, the most recent version uploaded) + :param session: a + :class:`~pymongo.client_session.ClientSession` + :param kwargs: find files by custom metadata. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.1 + ``get_version`` no longer ensures indexes. + """ + query = kwargs + if filename is not None: + query["filename"] = filename + + _disallow_transactions(session) + cursor = self.__files.find(query, session=session) + if version is None: + version = -1 + if version < 0: + skip = abs(version) - 1 + cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING) + else: + cursor.limit(-1).skip(version).sort("uploadDate", ASCENDING) + try: + doc = next(cursor) + return GridOut(self.__collection, file_document=doc, session=session) + except StopIteration: + raise NoFile("no version %d for filename %r" % (version, filename)) from None + + def get_last_version( + self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any + ) -> GridOut: + """Get the most recent version of a file in GridFS by ``"filename"`` + or metadata fields. + + Equivalent to calling :meth:`get_version` with the default + `version` (``-1``). + + :param filename: ``"filename"`` of the file to get, or `None` + :param session: a + :class:`~pymongo.client_session.ClientSession` + :param kwargs: find files by custom metadata. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + return self.get_version(filename=filename, session=session, **kwargs) + + # TODO add optional safe mode for chunk removal? + def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + """Delete a file from GridFS by ``"_id"``. + + Deletes all data belonging to the file with ``"_id"``: + `file_id`. + + .. warning:: Any processes/threads reading from the file while + this method is executing will likely see an invalid/corrupt + file. Care should be taken to avoid concurrent reads to a file + while it is being deleted. + + .. note:: Deletes of non-existent files are considered successful + since the end result is the same: no file with that _id remains. + + :param file_id: ``"_id"`` of the file to delete + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.1 + ``delete`` no longer ensures indexes. + """ + _disallow_transactions(session) + self.__files.delete_one({"_id": file_id}, session=session) + self.__chunks.delete_many({"files_id": file_id}, session=session) + + def list(self, session: Optional[ClientSession] = None) -> list[str]: + """List the names of all files stored in this instance of + :class:`GridFS`. + + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.1 + ``list`` no longer ensures indexes. + """ + _disallow_transactions(session) + # With an index, distinct includes documents with no filename + # as None. + return [ + name for name in self.__files.distinct("filename", session=session) if name is not None + ] + + def find_one( + self, + filter: Optional[Any] = None, + session: Optional[ClientSession] = None, + *args: Any, + **kwargs: Any, + ) -> Optional[GridOut]: + """Get a single file from gridfs. + + All arguments to :meth:`find` are also valid arguments for + :meth:`find_one`, although any `limit` argument will be + ignored. Returns a single :class:`~gridfs.grid_file.GridOut`, + or ``None`` if no matching file is found. For example: + + .. code-block: python + + file = fs.find_one({"filename": "lisa.txt"}) + + :param filter: a dictionary specifying + the query to be performing OR any other type to be used as + the value for a query for ``"_id"`` in the file collection. + :param args: any additional positional arguments are + the same as the arguments to :meth:`find`. + :param session: a + :class:`~pymongo.client_session.ClientSession` + :param kwargs: any additional keyword arguments + are the same as the arguments to :meth:`find`. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + if filter is not None and not isinstance(filter, abc.Mapping): + filter = {"_id": filter} + + _disallow_transactions(session) + for f in self.find(filter, *args, session=session, **kwargs): + return f + + return None + + def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: + """Query GridFS for files. + + Returns a cursor that iterates across files matching + arbitrary queries on the files collection. Can be combined + with other modifiers for additional control. For example:: + + for grid_out in fs.find({"filename": "lisa.txt"}, + no_cursor_timeout=True): + data = grid_out.read() + + would iterate through all versions of "lisa.txt" stored in GridFS. + Note that setting no_cursor_timeout to True may be important to + prevent the cursor from timing out during long multi-file processing + work. + + As another example, the call:: + + most_recent_three = fs.find().sort("uploadDate", -1).limit(3) + + would return a cursor to the three most recently uploaded files + in GridFS. + + Follows a similar interface to + :meth:`~pymongo.collection.Collection.find` + in :class:`~pymongo.collection.Collection`. + + If a :class:`~pymongo.client_session.ClientSession` is passed to + :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances + are associated with that session. + + :param filter: A query document that selects which files + to include in the result set. Can be an empty document to include + all files. + :param skip: the number of files to omit (from + the start of the result set) when returning the results + :param limit: the maximum number of results to + return + :param no_cursor_timeout: if False (the default), any + returned cursor is closed by the server after 10 minutes of + inactivity. If set to True, the returned cursor will never + time out on the server. Care should be taken to ensure that + cursors with no_cursor_timeout turned on are properly closed. + :param sort: a list of (key, direction) pairs + specifying the sort order for this query. See + :meth:`~pymongo.cursor.Cursor.sort` for details. + + Raises :class:`TypeError` if any of the arguments are of + improper type. Returns an instance of + :class:`~gridfs.grid_file.GridOutCursor` + corresponding to this query. + + .. versionchanged:: 3.0 + Removed the read_preference, tag_sets, and + secondary_acceptable_latency_ms options. + .. versionadded:: 2.7 + .. seealso:: The MongoDB documentation on `find `_. + """ + return GridOutCursor(self.__collection, *args, **kwargs) + + def exists( + self, + document_or_id: Optional[Any] = None, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> bool: + """Check if a file exists in this instance of :class:`GridFS`. + + The file to check for can be specified by the value of its + ``_id`` key, or by passing in a query document. A query + document can be passed in as dictionary, or by using keyword + arguments. Thus, the following three calls are equivalent: + + >>> fs.exists(file_id) + >>> fs.exists({"_id": file_id}) + >>> fs.exists(_id=file_id) + + As are the following two calls: + + >>> fs.exists({"filename": "mike.txt"}) + >>> fs.exists(filename="mike.txt") + + And the following two: + + >>> fs.exists({"foo": {"$gt": 12}}) + >>> fs.exists(foo={"$gt": 12}) + + Returns ``True`` if a matching file exists, ``False`` + otherwise. Calls to :meth:`exists` will not automatically + create appropriate indexes; application developers should be + sure to create indexes if needed and as appropriate. + + :param document_or_id: query document, or _id of the + document to check for + :param session: a + :class:`~pymongo.client_session.ClientSession` + :param kwargs: keyword arguments are used as a + query document, if they're present. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + _disallow_transactions(session) + if kwargs: + f = self.__files.find_one(kwargs, ["_id"], session=session) + else: + f = self.__files.find_one(document_or_id, ["_id"], session=session) + + return f is not None + + +class GridFSBucket: + """An instance of GridFS on top of a single Database.""" + + def __init__( + self, + db: Database, + bucket_name: str = "fs", + chunk_size_bytes: int = DEFAULT_CHUNK_SIZE, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + ) -> None: + """Create a new instance of :class:`GridFSBucket`. + + Raises :exc:`TypeError` if `database` is not an instance of + :class:`~pymongo.database.Database`. + + Raises :exc:`~pymongo.errors.ConfigurationError` if `write_concern` + is not acknowledged. + + :param database: database to use. + :param bucket_name: The name of the bucket. Defaults to 'fs'. + :param chunk_size_bytes: The chunk size in bytes. Defaults + to 255KB. + :param write_concern: The + :class:`~pymongo.write_concern.WriteConcern` to use. If ``None`` + (the default) db.write_concern is used. + :param read_preference: The read preference to use. If + ``None`` (the default) db.read_preference is used. + + .. versionchanged:: 4.0 + Removed the `disable_md5` parameter. See + :ref:`removed-gridfs-checksum` for details. + + .. versionchanged:: 3.11 + Running a GridFSBucket operation in a transaction now always raises + an error. GridFSBucket does not support multi-document transactions. + + .. versionchanged:: 3.7 + Added the `disable_md5` parameter. + + .. versionadded:: 3.1 + + .. seealso:: The MongoDB documentation on `gridfs `_. + """ + if not isinstance(db, Database): + raise TypeError("database must be an instance of Database") + + db = _clear_entity_type_registry(db) + + wtc = write_concern if write_concern is not None else db.write_concern + if not wtc.acknowledged: + raise ConfigurationError("write concern must be acknowledged") + + self._bucket_name = bucket_name + self._collection = db[bucket_name] + self._chunks: Collection = self._collection.chunks.with_options( + write_concern=write_concern, read_preference=read_preference + ) + + self._files: Collection = self._collection.files.with_options( + write_concern=write_concern, read_preference=read_preference + ) + + self._chunk_size_bytes = chunk_size_bytes + self._timeout = db.client.options.timeout + + def open_upload_stream( + self, + filename: str, + chunk_size_bytes: Optional[int] = None, + metadata: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + ) -> GridIn: + """Opens a Stream that the application can write the contents of the + file to. + + The user must specify the filename, and can choose to add any + additional information in the metadata field of the file document or + modify the chunk size. + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + with fs.open_upload_stream( + "test_file", chunk_size_bytes=4, + metadata={"contentType": "text/plain"}) as grid_in: + grid_in.write("data I want to store!") + # uploaded on close + + Returns an instance of :class:`~gridfs.grid_file.GridIn`. + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + Raises :exc:`~ValueError` if `filename` is not a string. + + :param filename: The name of the file to upload. + :param chunk_size_bytes` (options): The number of bytes per chunk of this + file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`. + :param metadata: User data for the 'metadata' field of the + files collection document. If not provided the metadata field will + be omitted from the files collection document. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + validate_string("filename", filename) + + opts = { + "filename": filename, + "chunk_size": ( + chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes + ), + } + if metadata is not None: + opts["metadata"] = metadata + + return GridIn(self._collection, session=session, **opts) + + def open_upload_stream_with_id( + self, + file_id: Any, + filename: str, + chunk_size_bytes: Optional[int] = None, + metadata: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + ) -> GridIn: + """Opens a Stream that the application can write the contents of the + file to. + + The user must specify the file id and filename, and can choose to add + any additional information in the metadata field of the file document + or modify the chunk size. + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + with fs.open_upload_stream_with_id( + ObjectId(), + "test_file", + chunk_size_bytes=4, + metadata={"contentType": "text/plain"}) as grid_in: + grid_in.write("data I want to store!") + # uploaded on close + + Returns an instance of :class:`~gridfs.grid_file.GridIn`. + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + Raises :exc:`~ValueError` if `filename` is not a string. + + :param file_id: The id to use for this file. The id must not have + already been used for another file. + :param filename: The name of the file to upload. + :param chunk_size_bytes` (options): The number of bytes per chunk of this + file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`. + :param metadata: User data for the 'metadata' field of the + files collection document. If not provided the metadata field will + be omitted from the files collection document. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + validate_string("filename", filename) + + opts = { + "_id": file_id, + "filename": filename, + "chunk_size": ( + chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes + ), + } + if metadata is not None: + opts["metadata"] = metadata + + return GridIn(self._collection, session=session, **opts) + + @_csot.apply + def upload_from_stream( + self, + filename: str, + source: Any, + chunk_size_bytes: Optional[int] = None, + metadata: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + ) -> ObjectId: + """Uploads a user file to a GridFS bucket. + + Reads the contents of the user file from `source` and uploads + it to the file `filename`. Source can be a string or file-like object. + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + file_id = fs.upload_from_stream( + "test_file", + "data I want to store!", + chunk_size_bytes=4, + metadata={"contentType": "text/plain"}) + + Returns the _id of the uploaded file. + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + Raises :exc:`~ValueError` if `filename` is not a string. + + :param filename: The name of the file to upload. + :param source: The source stream of the content to be uploaded. Must be + a file-like object that implements :meth:`read` or a string. + :param chunk_size_bytes` (options): The number of bytes per chunk of this + file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`. + :param metadata: User data for the 'metadata' field of the + files collection document. If not provided the metadata field will + be omitted from the files collection document. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + with self.open_upload_stream(filename, chunk_size_bytes, metadata, session=session) as gin: + gin.write(source) + + return cast(ObjectId, gin._id) + + @_csot.apply + def upload_from_stream_with_id( + self, + file_id: Any, + filename: str, + source: Any, + chunk_size_bytes: Optional[int] = None, + metadata: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + ) -> None: + """Uploads a user file to a GridFS bucket with a custom file id. + + Reads the contents of the user file from `source` and uploads + it to the file `filename`. Source can be a string or file-like object. + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + file_id = fs.upload_from_stream( + ObjectId(), + "test_file", + "data I want to store!", + chunk_size_bytes=4, + metadata={"contentType": "text/plain"}) + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + Raises :exc:`~ValueError` if `filename` is not a string. + + :param file_id: The id to use for this file. The id must not have + already been used for another file. + :param filename: The name of the file to upload. + :param source: The source stream of the content to be uploaded. Must be + a file-like object that implements :meth:`read` or a string. + :param chunk_size_bytes` (options): The number of bytes per chunk of this + file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`. + :param metadata: User data for the 'metadata' field of the + files collection document. If not provided the metadata field will + be omitted from the files collection document. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + with self.open_upload_stream_with_id( + file_id, filename, chunk_size_bytes, metadata, session=session + ) as gin: + gin.write(source) + + def open_download_stream( + self, file_id: Any, session: Optional[ClientSession] = None + ) -> GridOut: + """Opens a Stream from which the application can read the contents of + the stored file specified by file_id. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + # get _id of file to read. + file_id = fs.upload_from_stream("test_file", "data I want to store!") + grid_out = fs.open_download_stream(file_id) + contents = grid_out.read() + + Returns an instance of :class:`~gridfs.grid_file.GridOut`. + + Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. + + :param file_id: The _id of the file to be downloaded. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + gout = GridOut(self._collection, file_id, session=session) + + # Raise NoFile now, instead of on first attribute access. + gout._ensure_file() + return gout + + @_csot.apply + def download_to_stream( + self, file_id: Any, destination: Any, session: Optional[ClientSession] = None + ) -> None: + """Downloads the contents of the stored file specified by file_id and + writes the contents to `destination`. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + # Get _id of file to read + file_id = fs.upload_from_stream("test_file", "data I want to store!") + # Get file to write to + file = open('myfile','wb+') + fs.download_to_stream(file_id, file) + file.seek(0) + contents = file.read() + + Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. + + :param file_id: The _id of the file to be downloaded. + :param destination: a file-like object implementing :meth:`write`. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + with self.open_download_stream(file_id, session=session) as gout: + while True: + chunk = gout.readchunk() + if not len(chunk): + break + destination.write(chunk) + + @_csot.apply + def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + """Given an file_id, delete this stored file's files collection document + and associated chunks from a GridFS bucket. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + # Get _id of file to delete + file_id = fs.upload_from_stream("test_file", "data I want to store!") + fs.delete(file_id) + + Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. + + :param file_id: The _id of the file to be deleted. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + _disallow_transactions(session) + res = self._files.delete_one({"_id": file_id}, session=session) + self._chunks.delete_many({"files_id": file_id}, session=session) + if not res.deleted_count: + raise NoFile("no file could be deleted because none matched %s" % file_id) + + def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: + """Find and return the files collection documents that match ``filter`` + + Returns a cursor that iterates across files matching + arbitrary queries on the files collection. Can be combined + with other modifiers for additional control. + + For example:: + + for grid_data in fs.find({"filename": "lisa.txt"}, + no_cursor_timeout=True): + data = grid_data.read() + + would iterate through all versions of "lisa.txt" stored in GridFS. + Note that setting no_cursor_timeout to True may be important to + prevent the cursor from timing out during long multi-file processing + work. + + As another example, the call:: + + most_recent_three = fs.find().sort("uploadDate", -1).limit(3) + + would return a cursor to the three most recently uploaded files + in GridFS. + + Follows a similar interface to + :meth:`~pymongo.collection.Collection.find` + in :class:`~pymongo.collection.Collection`. + + If a :class:`~pymongo.client_session.ClientSession` is passed to + :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances + are associated with that session. + + :param filter: Search query. + :param batch_size: The number of documents to return per + batch. + :param limit: The maximum number of documents to return. + :param no_cursor_timeout: The server normally times out idle + cursors after an inactivity period (10 minutes) to prevent excess + memory use. Set this option to True prevent that. + :param skip: The number of documents to skip before + returning. + :param sort: The order by which to sort results. Defaults to + None. + """ + return GridOutCursor(self._collection, *args, **kwargs) + + def open_download_stream_by_name( + self, filename: str, revision: int = -1, session: Optional[ClientSession] = None + ) -> GridOut: + """Opens a Stream from which the application can read the contents of + `filename` and optional `revision`. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + grid_out = fs.open_download_stream_by_name("test_file") + contents = grid_out.read() + + Returns an instance of :class:`~gridfs.grid_file.GridOut`. + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + + Raises :exc:`~ValueError` filename is not a string. + + :param filename: The name of the file to read from. + :param revision: Which revision (documents with the same + filename and different uploadDate) of the file to retrieve. + Defaults to -1 (the most recent revision). + :param session: a + :class:`~pymongo.client_session.ClientSession` + + :Note: Revision numbers are defined as follows: + + - 0 = the original stored file + - 1 = the first revision + - 2 = the second revision + - etc... + - -2 = the second most recent revision + - -1 = the most recent revision + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + validate_string("filename", filename) + query = {"filename": filename} + _disallow_transactions(session) + cursor = self._files.find(query, session=session) + if revision < 0: + skip = abs(revision) - 1 + cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING) + else: + cursor.limit(-1).skip(revision).sort("uploadDate", ASCENDING) + try: + grid_file = next(cursor) + return GridOut(self._collection, file_document=grid_file, session=session) + except StopIteration: + raise NoFile("no version %d for filename %r" % (revision, filename)) from None + + @_csot.apply + def download_to_stream_by_name( + self, + filename: str, + destination: Any, + revision: int = -1, + session: Optional[ClientSession] = None, + ) -> None: + """Write the contents of `filename` (with optional `revision`) to + `destination`. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + # Get file to write to + file = open('myfile','wb') + fs.download_to_stream_by_name("test_file", file) + + Raises :exc:`~gridfs.errors.NoFile` if no such version of + that file exists. + + Raises :exc:`~ValueError` if `filename` is not a string. + + :param filename: The name of the file to read from. + :param destination: A file-like object that implements :meth:`write`. + :param revision: Which revision (documents with the same + filename and different uploadDate) of the file to retrieve. + Defaults to -1 (the most recent revision). + :param session: a + :class:`~pymongo.client_session.ClientSession` + + :Note: Revision numbers are defined as follows: + + - 0 = the original stored file + - 1 = the first revision + - 2 = the second revision + - etc... + - -2 = the second most recent revision + - -1 = the most recent revision + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + with self.open_download_stream_by_name(filename, revision, session=session) as gout: + while True: + chunk = gout.readchunk() + if not len(chunk): + break + destination.write(chunk) + + def rename( + self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None + ) -> None: + """Renames the stored file with the specified file_id. + + For example:: + + my_db = MongoClient().test + fs = GridFSBucket(my_db) + # Get _id of file to rename + file_id = fs.upload_from_stream("test_file", "data I want to store!") + fs.rename(file_id, "new_test_name") + + Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. + + :param file_id: The _id of the file to be renamed. + :param new_filename: The new name of the file. + :param session: a + :class:`~pymongo.client_session.ClientSession` + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + _disallow_transactions(session) + result = self._files.update_one( + {"_id": file_id}, {"$set": {"filename": new_filename}}, session=session + ) + if not result.matched_count: + raise NoFile( + "no files could be renamed %r because none " + "matched file_id %i" % (new_filename, file_id) + ) diff --git a/venv/Lib/site-packages/gridfs/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/gridfs/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..8cdfcd07 Binary files /dev/null and b/venv/Lib/site-packages/gridfs/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/gridfs/__pycache__/errors.cpython-312.pyc b/venv/Lib/site-packages/gridfs/__pycache__/errors.cpython-312.pyc new file mode 100644 index 00000000..f0b27447 Binary files /dev/null and b/venv/Lib/site-packages/gridfs/__pycache__/errors.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/gridfs/__pycache__/grid_file.cpython-312.pyc b/venv/Lib/site-packages/gridfs/__pycache__/grid_file.cpython-312.pyc new file mode 100644 index 00000000..d3ee67b7 Binary files /dev/null and b/venv/Lib/site-packages/gridfs/__pycache__/grid_file.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/gridfs/errors.py b/venv/Lib/site-packages/gridfs/errors.py new file mode 100644 index 00000000..e8c02cef --- /dev/null +++ b/venv/Lib/site-packages/gridfs/errors.py @@ -0,0 +1,34 @@ +# Copyright 2009-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by the :mod:`gridfs` package""" +from __future__ import annotations + +from pymongo.errors import PyMongoError + + +class GridFSError(PyMongoError): + """Base class for all GridFS exceptions.""" + + +class CorruptGridFile(GridFSError): + """Raised when a file in :class:`~gridfs.GridFS` is malformed.""" + + +class NoFile(GridFSError): + """Raised when trying to read from a non-existent file.""" + + +class FileExists(GridFSError): + """Raised when trying to create a file that already exists.""" diff --git a/venv/Lib/site-packages/gridfs/grid_file.py b/venv/Lib/site-packages/gridfs/grid_file.py new file mode 100644 index 00000000..ac72c144 --- /dev/null +++ b/venv/Lib/site-packages/gridfs/grid_file.py @@ -0,0 +1,964 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing files stored in GridFS.""" +from __future__ import annotations + +import datetime +import io +import math +import os +import warnings +from typing import Any, Iterable, Mapping, NoReturn, Optional + +from bson.int64 import Int64 +from bson.objectid import ObjectId +from gridfs.errors import CorruptGridFile, FileExists, NoFile +from pymongo import ASCENDING +from pymongo.client_session import ClientSession +from pymongo.collection import Collection +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.cursor import Cursor +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + CursorNotFound, + DuplicateKeyError, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers import _check_write_command_response +from pymongo.read_preferences import ReadPreference + +_SEEK_SET = os.SEEK_SET +_SEEK_CUR = os.SEEK_CUR +_SEEK_END = os.SEEK_END + +EMPTY = b"" +NEWLN = b"\n" + +"""Default chunk size, in bytes.""" +# Slightly under a power of 2, to work well with server's record allocations. +DEFAULT_CHUNK_SIZE = 255 * 1024 +# The number of chunked bytes to buffer before calling insert_many. +_UPLOAD_BUFFER_SIZE = MAX_MESSAGE_SIZE +# The number of chunk documents to buffer before calling insert_many. +_UPLOAD_BUFFER_CHUNKS = 100000 +# Rough BSON overhead of a chunk document not including the chunk data itself. +# Essentially len(encode({"_id": ObjectId(), "files_id": ObjectId(), "n": 1, "data": ""})) +_CHUNK_OVERHEAD = 60 + +_C_INDEX: dict[str, Any] = {"files_id": ASCENDING, "n": ASCENDING} +_F_INDEX: dict[str, Any] = {"filename": ASCENDING, "uploadDate": ASCENDING} + + +def _grid_in_property( + field_name: str, + docstring: str, + read_only: Optional[bool] = False, + closed_only: Optional[bool] = False, +) -> Any: + """Create a GridIn property.""" + warn_str = "" + if docstring.startswith("DEPRECATED,"): + warn_str = ( + f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0" + ) + + def getter(self: Any) -> Any: + if warn_str: + warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning) + if closed_only and not self._closed: + raise AttributeError("can only get %r on a closed file" % field_name) + # Protect against PHP-237 + if field_name == "length": + return self._file.get(field_name, 0) + return self._file.get(field_name, None) + + def setter(self: Any, value: Any) -> Any: + if warn_str: + warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning) + if self._closed: + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}}) + self._file[field_name] = value + + if read_only: + docstring += "\n\nThis attribute is read-only." + elif closed_only: + docstring = "{}\n\n{}".format( + docstring, + "This attribute is read-only and " + "can only be read after :meth:`close` " + "has been called.", + ) + + if not read_only and not closed_only: + return property(getter, setter, doc=docstring) + return property(getter, doc=docstring) + + +def _grid_out_property(field_name: str, docstring: str) -> Any: + """Create a GridOut property.""" + warn_str = "" + if docstring.startswith("DEPRECATED,"): + warn_str = ( + f"GridOut property '{field_name}' is deprecated and will be removed in PyMongo 5.0" + ) + + def getter(self: Any) -> Any: + if warn_str: + warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning) + self._ensure_file() + + # Protect against PHP-237 + if field_name == "length": + return self._file.get(field_name, 0) + return self._file.get(field_name, None) + + docstring += "\n\nThis attribute is read-only." + return property(getter, doc=docstring) + + +def _clear_entity_type_registry(entity: Any, **kwargs: Any) -> Any: + """Clear the given database/collection object's type registry.""" + codecopts = entity.codec_options.with_options(type_registry=None) + return entity.with_options(codec_options=codecopts, **kwargs) + + +def _disallow_transactions(session: Optional[ClientSession]) -> None: + if session and session.in_transaction: + raise InvalidOperation("GridFS does not support multi-document transactions") + + +class GridIn: + """Class to write data to GridFS.""" + + def __init__( + self, root_collection: Collection, session: Optional[ClientSession] = None, **kwargs: Any + ) -> None: + """Write a file to GridFS + + Application developers should generally not need to + instantiate this class directly - instead see the methods + provided by :class:`~gridfs.GridFS`. + + Raises :class:`TypeError` if `root_collection` is not an + instance of :class:`~pymongo.collection.Collection`. + + Any of the file level options specified in the `GridFS Spec + `_ may be passed as + keyword arguments. Any additional keyword arguments will be + set as additional fields on the file document. Valid keyword + arguments include: + + - ``"_id"``: unique ID for this file (default: + :class:`~bson.objectid.ObjectId`) - this ``"_id"`` must + not have already been used for another file + + - ``"filename"``: human name for the file + + - ``"contentType"`` or ``"content_type"``: valid mime-type + for the file + + - ``"chunkSize"`` or ``"chunk_size"``: size of each of the + chunks, in bytes (default: 255 kb) + + - ``"encoding"``: encoding used for this file. Any :class:`str` + that is written to the file will be converted to :class:`bytes`. + + :param root_collection: root collection to write to + :param session: a + :class:`~pymongo.client_session.ClientSession` to use for all + commands + :param kwargs: Any: file level options (see above) + + .. versionchanged:: 4.0 + Removed the `disable_md5` parameter. See + :ref:`removed-gridfs-checksum` for details. + + .. versionchanged:: 3.7 + Added the `disable_md5` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.0 + `root_collection` must use an acknowledged + :attr:`~pymongo.collection.Collection.write_concern` + """ + if not isinstance(root_collection, Collection): + raise TypeError("root_collection must be an instance of Collection") + + if not root_collection.write_concern.acknowledged: + raise ConfigurationError("root_collection must use acknowledged write_concern") + _disallow_transactions(session) + + # Handle alternative naming + if "content_type" in kwargs: + kwargs["contentType"] = kwargs.pop("content_type") + if "chunk_size" in kwargs: + kwargs["chunkSize"] = kwargs.pop("chunk_size") + + coll = _clear_entity_type_registry(root_collection, read_preference=ReadPreference.PRIMARY) + + # Defaults + kwargs["_id"] = kwargs.get("_id", ObjectId()) + kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE) + object.__setattr__(self, "_session", session) + object.__setattr__(self, "_coll", coll) + object.__setattr__(self, "_chunks", coll.chunks) + object.__setattr__(self, "_file", kwargs) + object.__setattr__(self, "_buffer", io.BytesIO()) + object.__setattr__(self, "_position", 0) + object.__setattr__(self, "_chunk_number", 0) + object.__setattr__(self, "_closed", False) + object.__setattr__(self, "_ensured_index", False) + object.__setattr__(self, "_buffered_docs", []) + object.__setattr__(self, "_buffered_docs_size", 0) + + def __create_index(self, collection: Collection, index_key: Any, unique: bool) -> None: + doc = collection.find_one(projection={"_id": 1}, session=self._session) + if doc is None: + try: + index_keys = [ + index_spec["key"] + for index_spec in collection.list_indexes(session=self._session) + ] + except OperationFailure: + index_keys = [] + if index_key not in index_keys: + collection.create_index(index_key.items(), unique=unique, session=self._session) + + def __ensure_indexes(self) -> None: + if not object.__getattribute__(self, "_ensured_index"): + _disallow_transactions(self._session) + self.__create_index(self._coll.files, _F_INDEX, False) + self.__create_index(self._coll.chunks, _C_INDEX, True) + object.__setattr__(self, "_ensured_index", True) + + def abort(self) -> None: + """Remove all chunks/files that may have been uploaded and close.""" + self._coll.chunks.delete_many({"files_id": self._file["_id"]}, session=self._session) + self._coll.files.delete_one({"_id": self._file["_id"]}, session=self._session) + object.__setattr__(self, "_closed", True) + + @property + def closed(self) -> bool: + """Is this file closed?""" + return self._closed + + _id: Any = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True) + filename: Optional[str] = _grid_in_property("filename", "Name of this file.") + name: Optional[str] = _grid_in_property("filename", "Alias for `filename`.") + content_type: Optional[str] = _grid_in_property( + "contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file." + ) + length: int = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True) + chunk_size: int = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True) + upload_date: datetime.datetime = _grid_in_property( + "uploadDate", "Date that this file was uploaded.", closed_only=True + ) + md5: Optional[str] = _grid_in_property( + "md5", + "DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.", + closed_only=True, + ) + + _buffer: io.BytesIO + _closed: bool + _buffered_docs: list[dict[str, Any]] + _buffered_docs_size: int + + def __getattr__(self, name: str) -> Any: + if name in self._file: + return self._file[name] + raise AttributeError("GridIn object has no attribute '%s'" % name) + + def __setattr__(self, name: str, value: Any) -> None: + # For properties of this instance like _buffer, or descriptors set on + # the class like filename, use regular __setattr__ + if name in self.__dict__ or name in self.__class__.__dict__: + object.__setattr__(self, name, value) + else: + # All other attributes are part of the document in db.fs.files. + # Store them to be sent to server on close() or if closed, send + # them now. + self._file[name] = value + if self._closed: + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) + + def __flush_data(self, data: Any, force: bool = False) -> None: + """Flush `data` to a chunk.""" + self.__ensure_indexes() + assert len(data) <= self.chunk_size + if data: + self._buffered_docs.append( + {"files_id": self._file["_id"], "n": self._chunk_number, "data": data} + ) + self._buffered_docs_size += len(data) + _CHUNK_OVERHEAD + if not self._buffered_docs: + return + # Limit to 100,000 chunks or 32MB (+1 chunk) of data. + if ( + force + or self._buffered_docs_size >= _UPLOAD_BUFFER_SIZE + or len(self._buffered_docs) >= _UPLOAD_BUFFER_CHUNKS + ): + try: + self._chunks.insert_many(self._buffered_docs, session=self._session) + except BulkWriteError as exc: + # For backwards compatibility, raise an insert_one style exception. + write_errors = exc.details["writeErrors"] + for err in write_errors: + if err.get("code") in (11000, 11001, 12582): # Duplicate key errors + self._raise_file_exists(self._file["_id"]) + result = {"writeErrors": write_errors} + wces = exc.details["writeConcernErrors"] + if wces: + result["writeConcernError"] = wces[-1] + _check_write_command_response(result) + raise + self._buffered_docs = [] + self._buffered_docs_size = 0 + self._chunk_number += 1 + self._position += len(data) + + def __flush_buffer(self, force: bool = False) -> None: + """Flush the buffer contents out to a chunk.""" + self.__flush_data(self._buffer.getvalue(), force=force) + self._buffer.close() + self._buffer = io.BytesIO() + + def __flush(self) -> Any: + """Flush the file to the database.""" + try: + self.__flush_buffer(force=True) + # The GridFS spec says length SHOULD be an Int64. + self._file["length"] = Int64(self._position) + self._file["uploadDate"] = datetime.datetime.now(tz=datetime.timezone.utc) + + return self._coll.files.insert_one(self._file, session=self._session) + except DuplicateKeyError: + self._raise_file_exists(self._id) + + def _raise_file_exists(self, file_id: Any) -> NoReturn: + """Raise a FileExists exception for the given file_id.""" + raise FileExists("file with _id %r already exists" % file_id) + + def close(self) -> None: + """Flush the file and close it. + + A closed file cannot be written any more. Calling + :meth:`close` more than once is allowed. + """ + if not self._closed: + self.__flush() + object.__setattr__(self, "_closed", True) + + def read(self, size: int = -1) -> NoReturn: + raise io.UnsupportedOperation("read") + + def readable(self) -> bool: + return False + + def seekable(self) -> bool: + return False + + def write(self, data: Any) -> None: + """Write data to the file. There is no return value. + + `data` can be either a string of bytes or a file-like object + (implementing :meth:`read`). If the file has an + :attr:`encoding` attribute, `data` can also be a + :class:`str` instance, which will be encoded as + :attr:`encoding` before being written. + + Due to buffering, the data may not actually be written to the + database until the :meth:`close` method is called. Raises + :class:`ValueError` if this file is already closed. Raises + :class:`TypeError` if `data` is not an instance of + :class:`bytes`, a file-like object, or an instance of :class:`str`. + Unicode data is only allowed if the file has an :attr:`encoding` + attribute. + + :param data: string of bytes or file-like object to be written + to the file + """ + if self._closed: + raise ValueError("cannot write to a closed file") + + try: + # file-like + read = data.read + except AttributeError: + # string + if not isinstance(data, (str, bytes)): + raise TypeError("can only write strings or file-like objects") from None + if isinstance(data, str): + try: + data = data.encode(self.encoding) + except AttributeError: + raise TypeError( + "must specify an encoding for file in order to write str" + ) from None + read = io.BytesIO(data).read + + if self._buffer.tell() > 0: + # Make sure to flush only when _buffer is complete + space = self.chunk_size - self._buffer.tell() + if space: + try: + to_write = read(space) + except BaseException: + self.abort() + raise + self._buffer.write(to_write) + if len(to_write) < space: + return # EOF or incomplete + self.__flush_buffer() + to_write = read(self.chunk_size) + while to_write and len(to_write) == self.chunk_size: + self.__flush_data(to_write) + to_write = read(self.chunk_size) + self._buffer.write(to_write) + + def writelines(self, sequence: Iterable[Any]) -> None: + """Write a sequence of strings to the file. + + Does not add separators. + """ + for line in sequence: + self.write(line) + + def writeable(self) -> bool: + return True + + def __enter__(self) -> GridIn: + """Support for the context manager protocol.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: + """Support for the context manager protocol. + + Close the file if no exceptions occur and allow exceptions to propagate. + """ + if exc_type is None: + # No exceptions happened. + self.close() + else: + # Something happened, at minimum mark as closed. + object.__setattr__(self, "_closed", True) + + # propagate exceptions + return False + + +class GridOut(io.IOBase): + """Class to read data out of GridFS.""" + + def __init__( + self, + root_collection: Collection, + file_id: Optional[int] = None, + file_document: Optional[Any] = None, + session: Optional[ClientSession] = None, + ) -> None: + """Read a file from GridFS + + Application developers should generally not need to + instantiate this class directly - instead see the methods + provided by :class:`~gridfs.GridFS`. + + Either `file_id` or `file_document` must be specified, + `file_document` will be given priority if present. Raises + :class:`TypeError` if `root_collection` is not an instance of + :class:`~pymongo.collection.Collection`. + + :param root_collection: root collection to read from + :param file_id: value of ``"_id"`` for the file to read + :param file_document: file document from + `root_collection.files` + :param session: a + :class:`~pymongo.client_session.ClientSession` to use for all + commands + + .. versionchanged:: 3.8 + For better performance and to better follow the GridFS spec, + :class:`GridOut` now uses a single cursor to read all the chunks in + the file. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.0 + Creating a GridOut does not immediately retrieve the file metadata + from the server. Metadata is fetched when first needed. + """ + if not isinstance(root_collection, Collection): + raise TypeError("root_collection must be an instance of Collection") + _disallow_transactions(session) + + root_collection = _clear_entity_type_registry(root_collection) + + super().__init__() + + self.__chunks = root_collection.chunks + self.__files = root_collection.files + self.__file_id = file_id + self.__buffer = EMPTY + # Start position within the current buffered chunk. + self.__buffer_pos = 0 + self.__chunk_iter = None + # Position within the total file. + self.__position = 0 + self._file = file_document + self._session = session + + _id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.") + filename: str = _grid_out_property("filename", "Name of this file.") + name: str = _grid_out_property("filename", "Alias for `filename`.") + content_type: Optional[str] = _grid_out_property( + "contentType", "DEPRECATED, will be removed in PyMongo 5.0. Mime-type for this file." + ) + length: int = _grid_out_property("length", "Length (in bytes) of this file.") + chunk_size: int = _grid_out_property("chunkSize", "Chunk size for this file.") + upload_date: datetime.datetime = _grid_out_property( + "uploadDate", "Date that this file was first uploaded." + ) + aliases: Optional[list[str]] = _grid_out_property( + "aliases", "DEPRECATED, will be removed in PyMongo 5.0. List of aliases for this file." + ) + metadata: Optional[Mapping[str, Any]] = _grid_out_property( + "metadata", "Metadata attached to this file." + ) + md5: Optional[str] = _grid_out_property( + "md5", + "DEPRECATED, will be removed in PyMongo 5.0. MD5 of the contents of this file if an md5 sum was created.", + ) + + _file: Any + __chunk_iter: Any + + def _ensure_file(self) -> None: + if not self._file: + _disallow_transactions(self._session) + self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session) + if not self._file: + raise NoFile( + f"no file in gridfs collection {self.__files!r} with _id {self.__file_id!r}" + ) + + def __getattr__(self, name: str) -> Any: + self._ensure_file() + if name in self._file: + return self._file[name] + raise AttributeError("GridOut object has no attribute '%s'" % name) + + def readable(self) -> bool: + return True + + def readchunk(self) -> bytes: + """Reads a chunk at a time. If the current position is within a + chunk the remainder of the chunk is returned. + """ + received = len(self.__buffer) - self.__buffer_pos + chunk_data = EMPTY + chunk_size = int(self.chunk_size) + + if received > 0: + chunk_data = self.__buffer[self.__buffer_pos :] + elif self.__position < int(self.length): + chunk_number = int((received + self.__position) / chunk_size) + if self.__chunk_iter is None: + self.__chunk_iter = _GridOutChunkIterator( + self, self.__chunks, self._session, chunk_number + ) + + chunk = self.__chunk_iter.next() + chunk_data = chunk["data"][self.__position % chunk_size :] + + if not chunk_data: + raise CorruptGridFile("truncated chunk") + + self.__position += len(chunk_data) + self.__buffer = EMPTY + self.__buffer_pos = 0 + return chunk_data + + def _read_size_or_line(self, size: int = -1, line: bool = False) -> bytes: + """Internal read() and readline() helper.""" + self._ensure_file() + remainder = int(self.length) - self.__position + if size < 0 or size > remainder: + size = remainder + + if size == 0: + return EMPTY + + received = 0 + data = [] + while received < size: + needed = size - received + if self.__buffer: + # Optimization: Read the buffer with zero byte copies. + buf = self.__buffer + chunk_start = self.__buffer_pos + chunk_data = memoryview(buf)[self.__buffer_pos :] + self.__buffer = EMPTY + self.__buffer_pos = 0 + self.__position += len(chunk_data) + else: + buf = self.readchunk() + chunk_start = 0 + chunk_data = memoryview(buf) + if line: + pos = buf.find(NEWLN, chunk_start, chunk_start + needed) - chunk_start + if pos >= 0: + # Decrease size to exit the loop. + size = received + pos + 1 + needed = pos + 1 + if len(chunk_data) > needed: + data.append(chunk_data[:needed]) + # Optimization: Save the buffer with zero byte copies. + self.__buffer = buf + self.__buffer_pos = chunk_start + needed + self.__position -= len(self.__buffer) - self.__buffer_pos + else: + data.append(chunk_data) + received += len(chunk_data) + + # Detect extra chunks after reading the entire file. + if size == remainder and self.__chunk_iter: + try: + self.__chunk_iter.next() + except StopIteration: + pass + + return b"".join(data) + + def read(self, size: int = -1) -> bytes: + """Read at most `size` bytes from the file (less if there + isn't enough data). + + The bytes are returned as an instance of :class:`bytes` + If `size` is negative or omitted all data is read. + + :param size: the number of bytes to read + + .. versionchanged:: 3.8 + This method now only checks for extra chunks after reading the + entire file. Previously, this method would check for extra chunks + on every call. + """ + return self._read_size_or_line(size=size) + + def readline(self, size: int = -1) -> bytes: # type: ignore[override] + """Read one line or up to `size` bytes from the file. + + :param size: the maximum number of bytes to read + """ + return self._read_size_or_line(size=size, line=True) + + def tell(self) -> int: + """Return the current position of this file.""" + return self.__position + + def seek(self, pos: int, whence: int = _SEEK_SET) -> int: + """Set the current position of this file. + + :param pos: the position (or offset if using relative + positioning) to seek to + :param whence: where to seek + from. :attr:`os.SEEK_SET` (``0``) for absolute file + positioning, :attr:`os.SEEK_CUR` (``1``) to seek relative + to the current position, :attr:`os.SEEK_END` (``2``) to + seek relative to the file's end. + + .. versionchanged:: 4.1 + The method now returns the new position in the file, to + conform to the behavior of :meth:`io.IOBase.seek`. + """ + if whence == _SEEK_SET: + new_pos = pos + elif whence == _SEEK_CUR: + new_pos = self.__position + pos + elif whence == _SEEK_END: + new_pos = int(self.length) + pos + else: + raise OSError(22, "Invalid value for `whence`") + + if new_pos < 0: + raise OSError(22, "Invalid value for `pos` - must be positive") + + # Optimization, continue using the same buffer and chunk iterator. + if new_pos == self.__position: + return new_pos + + self.__position = new_pos + self.__buffer = EMPTY + self.__buffer_pos = 0 + if self.__chunk_iter: + self.__chunk_iter.close() + self.__chunk_iter = None + return new_pos + + def seekable(self) -> bool: + return True + + def __iter__(self) -> GridOut: + """Return an iterator over all of this file's data. + + The iterator will return lines (delimited by ``b'\\n'``) of + :class:`bytes`. This can be useful when serving files + using a webserver that handles such an iterator efficiently. + + .. versionchanged:: 3.8 + The iterator now raises :class:`CorruptGridFile` when encountering + any truncated, missing, or extra chunk in a file. The previous + behavior was to only raise :class:`CorruptGridFile` on a missing + chunk. + + .. versionchanged:: 4.0 + The iterator now iterates over *lines* in the file, instead + of chunks, to conform to the base class :py:class:`io.IOBase`. + Use :meth:`GridOut.readchunk` to read chunk by chunk instead + of line by line. + """ + return self + + def close(self) -> None: + """Make GridOut more generically file-like.""" + if self.__chunk_iter: + self.__chunk_iter.close() + self.__chunk_iter = None + super().close() + + def write(self, value: Any) -> NoReturn: + raise io.UnsupportedOperation("write") + + def writelines(self, lines: Any) -> NoReturn: + raise io.UnsupportedOperation("writelines") + + def writable(self) -> bool: + return False + + def __enter__(self) -> GridOut: + """Makes it possible to use :class:`GridOut` files + with the context manager protocol. + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: + """Makes it possible to use :class:`GridOut` files + with the context manager protocol. + """ + self.close() + return False + + def fileno(self) -> NoReturn: + raise io.UnsupportedOperation("fileno") + + def flush(self) -> None: + # GridOut is read-only, so flush does nothing. + pass + + def isatty(self) -> bool: + return False + + def truncate(self, size: Optional[int] = None) -> NoReturn: + # See https://docs.python.org/3/library/io.html#io.IOBase.writable + # for why truncate has to raise. + raise io.UnsupportedOperation("truncate") + + # Override IOBase.__del__ otherwise it will lead to __getattr__ on + # __IOBase_closed which calls _ensure_file and potentially performs I/O. + # We cannot do I/O in __del__ since it can lead to a deadlock. + def __del__(self) -> None: + pass + + +class _GridOutChunkIterator: + """Iterates over a file's chunks using a single cursor. + + Raises CorruptGridFile when encountering any truncated, missing, or extra + chunk in a file. + """ + + def __init__( + self, + grid_out: GridOut, + chunks: Collection, + session: Optional[ClientSession], + next_chunk: Any, + ) -> None: + self._id = grid_out._id + self._chunk_size = int(grid_out.chunk_size) + self._length = int(grid_out.length) + self._chunks = chunks + self._session = session + self._next_chunk = next_chunk + self._num_chunks = math.ceil(float(self._length) / self._chunk_size) + self._cursor = None + + _cursor: Optional[Cursor] + + def expected_chunk_length(self, chunk_n: int) -> int: + if chunk_n < self._num_chunks - 1: + return self._chunk_size + return self._length - (self._chunk_size * (self._num_chunks - 1)) + + def __iter__(self) -> _GridOutChunkIterator: + return self + + def _create_cursor(self) -> None: + filter = {"files_id": self._id} + if self._next_chunk > 0: + filter["n"] = {"$gte": self._next_chunk} + _disallow_transactions(self._session) + self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session) + + def _next_with_retry(self) -> Mapping[str, Any]: + """Return the next chunk and retry once on CursorNotFound. + + We retry on CursorNotFound to maintain backwards compatibility in + cases where two calls to read occur more than 10 minutes apart (the + server's default cursor timeout). + """ + if self._cursor is None: + self._create_cursor() + assert self._cursor is not None + try: + return self._cursor.next() + except CursorNotFound: + self._cursor.close() + self._create_cursor() + return self._cursor.next() + + def next(self) -> Mapping[str, Any]: + try: + chunk = self._next_with_retry() + except StopIteration: + if self._next_chunk >= self._num_chunks: + raise + raise CorruptGridFile("no chunk #%d" % self._next_chunk) from None + + if chunk["n"] != self._next_chunk: + self.close() + raise CorruptGridFile( + "Missing chunk: expected chunk #%d but found " + "chunk with n=%d" % (self._next_chunk, chunk["n"]) + ) + + if chunk["n"] >= self._num_chunks: + # According to spec, ignore extra chunks if they are empty. + if len(chunk["data"]): + self.close() + raise CorruptGridFile( + "Extra chunk found: expected %d chunks but found " + "chunk with n=%d" % (self._num_chunks, chunk["n"]) + ) + + expected_length = self.expected_chunk_length(chunk["n"]) + if len(chunk["data"]) != expected_length: + self.close() + raise CorruptGridFile( + "truncated chunk #%d: expected chunk length to be %d but " + "found chunk with length %d" % (chunk["n"], expected_length, len(chunk["data"])) + ) + + self._next_chunk += 1 + return chunk + + __next__ = next + + def close(self) -> None: + if self._cursor: + self._cursor.close() + self._cursor = None + + +class GridOutIterator: + def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession): + self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0) + + def __iter__(self) -> GridOutIterator: + return self + + def next(self) -> bytes: + chunk = self.__chunk_iter.next() + return bytes(chunk["data"]) + + __next__ = next + + +class GridOutCursor(Cursor): + """A cursor / iterator for returning GridOut objects as the result + of an arbitrary query against the GridFS files collection. + """ + + def __init__( + self, + collection: Collection, + filter: Optional[Mapping[str, Any]] = None, + skip: int = 0, + limit: int = 0, + no_cursor_timeout: bool = False, + sort: Optional[Any] = None, + batch_size: int = 0, + session: Optional[ClientSession] = None, + ) -> None: + """Create a new cursor, similar to the normal + :class:`~pymongo.cursor.Cursor`. + + Should not be called directly by application developers - see + the :class:`~gridfs.GridFS` method :meth:`~gridfs.GridFS.find` instead. + + .. versionadded 2.7 + + .. seealso:: The MongoDB documentation on `cursors `_. + """ + _disallow_transactions(session) + collection = _clear_entity_type_registry(collection) + + # Hold on to the base "fs" collection to create GridOut objects later. + self.__root_collection = collection + + super().__init__( + collection.files, + filter, + skip=skip, + limit=limit, + no_cursor_timeout=no_cursor_timeout, + sort=sort, + batch_size=batch_size, + session=session, + ) + + def next(self) -> GridOut: + """Get next GridOut object from cursor.""" + _disallow_transactions(self.session) + next_file = super().next() + return GridOut(self.__root_collection, file_document=next_file, session=self.session) + + __next__ = next + + def add_option(self, *args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError("Method does not exist for GridOutCursor") + + def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError("Method does not exist for GridOutCursor") + + def _clone_base(self, session: Optional[ClientSession]) -> GridOutCursor: + """Creates an empty GridOutCursor for information to be copied into.""" + return GridOutCursor(self.__root_collection, session=session) diff --git a/venv/Lib/site-packages/gridfs/py.typed b/venv/Lib/site-packages/gridfs/py.typed new file mode 100644 index 00000000..0f405706 --- /dev/null +++ b/venv/Lib/site-packages/gridfs/py.typed @@ -0,0 +1,2 @@ +# PEP-561 Support File. +# "Package maintainers who wish to support type checking of their code MUST add a marker file named py.typed to their package supporting typing". diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/INSTALLER b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/LICENSE b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/METADATA b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/METADATA new file mode 100644 index 00000000..da000892 --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/METADATA @@ -0,0 +1,485 @@ +Metadata-Version: 2.1 +Name: pymongo +Version: 4.7.2 +Summary: Python driver for MongoDB +Author: The MongoDB Python Team +License: Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Project-URL: Homepage, https://www.mongodb.org +Project-URL: Documentation, https://pymongo.readthedocs.io +Project-URL: Source, https://github.com/mongodb/mongo-python-driver +Project-URL: Tracker, https://jira.mongodb.org/projects/PYTHON/issues +Keywords: bson,gridfs,mongo,mongodb,pymongo +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Database +Classifier: Typing :: Typed +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: dnspython <3.0.0,>=1.16.0 +Provides-Extra: aws +Requires-Dist: pymongo-auth-aws <2.0.0,>=1.1.0 ; extra == 'aws' +Provides-Extra: encryption +Requires-Dist: pymongo-auth-aws <2.0.0,>=1.1.0 ; extra == 'encryption' +Requires-Dist: pymongocrypt <2.0.0,>=1.6.0 ; extra == 'encryption' +Requires-Dist: certifi ; (os_name == "nt" or sys_platform == "darwin") and extra == 'encryption' +Provides-Extra: gssapi +Requires-Dist: pykerberos ; (os_name != "nt") and extra == 'gssapi' +Requires-Dist: winkerberos >=0.5.0 ; (os_name == "nt") and extra == 'gssapi' +Provides-Extra: ocsp +Requires-Dist: pyopenssl >=17.2.0 ; extra == 'ocsp' +Requires-Dist: requests <3.0.0 ; extra == 'ocsp' +Requires-Dist: cryptography >=2.5 ; extra == 'ocsp' +Requires-Dist: service-identity >=18.1.0 ; extra == 'ocsp' +Requires-Dist: certifi ; (os_name == "nt" or sys_platform == "darwin") and extra == 'ocsp' +Provides-Extra: snappy +Requires-Dist: python-snappy ; extra == 'snappy' +Provides-Extra: srv +Provides-Extra: test +Requires-Dist: pytest >=7 ; extra == 'test' +Provides-Extra: tls +Provides-Extra: zstd +Requires-Dist: zstandard ; extra == 'zstd' + +# PyMongo + +[![PyPI Version](https://img.shields.io/pypi/v/pymongo)](https://pypi.org/project/pymongo) +[![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo) +[![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo) +[![Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/?badge=stable) + +## About + +The PyMongo distribution contains tools for interacting with MongoDB +database from Python. The `bson` package is an implementation of the +[BSON format](http://bsonspec.org) for Python. The `pymongo` package is +a native Python driver for MongoDB. The `gridfs` package is a +[gridfs](https://github.com/mongodb/specifications/blob/master/source/gridfs/gridfs-spec.rst/) +implementation on top of `pymongo`. + +PyMongo supports MongoDB 3.6, 4.0, 4.2, 4.4, 5.0, 6.0, and 7.0. + +## Support / Feedback + +For issues with, questions about, or feedback for PyMongo, please look +into our [support channels](https://support.mongodb.com/welcome). Please +do not email any of the PyMongo developers directly with issues or +questions - you're more likely to get an answer on +[StackOverflow](https://stackoverflow.com/questions/tagged/mongodb) +(using a "mongodb" tag). + +## Bugs / Feature Requests + +Think you've found a bug? Want to see a new feature in PyMongo? Please +open a case in our issue management tool, JIRA: + +- [Create an account and login](https://jira.mongodb.org). +- Navigate to [the PYTHON + project](https://jira.mongodb.org/browse/PYTHON). +- Click **Create Issue** - Please provide as much information as + possible about the issue type and how to reproduce it. + +Bug reports in JIRA for all driver projects (i.e. PYTHON, CSHARP, JAVA) +and the Core Server (i.e. SERVER) project are **public**. + +### How To Ask For Help + +Please include all of the following information when opening an issue: + +- Detailed steps to reproduce the problem, including full traceback, + if possible. + +- The exact python version used, with patch level: + +```bash +python -c "import sys; print(sys.version)" +``` + +- The exact version of PyMongo used, with patch level: + +```bash +python -c "import pymongo; print(pymongo.version); print(pymongo.has_c())" +``` + +- The operating system and version (e.g. Windows 7, OSX 10.8, ...) + +- Web framework or asynchronous network library used, if any, with + version (e.g. Django 1.7, mod_wsgi 4.3.0, gevent 1.0.1, Tornado + 4.0.2, ...) + +### Security Vulnerabilities + +If you've identified a security vulnerability in a driver or any other +MongoDB project, please report it according to the [instructions +here](https://www.mongodb.com/docs/manual/tutorial/create-a-vulnerability-report/). + +## Installation + +PyMongo can be installed with [pip](http://pypi.python.org/pypi/pip): + +```bash +python -m pip install pymongo +``` + +Or `easy_install` from [setuptools](http://pypi.python.org/pypi/setuptools): + +```bash +python -m easy_install pymongo +``` + +You can also download the project source and do: + +```bash +pip install . +``` + +Do **not** install the "bson" package from pypi. PyMongo comes with +its own bson package; running "pip install bson" installs a third-party +package that is incompatible with PyMongo. + +## Dependencies + +PyMongo supports CPython 3.7+ and PyPy3.7+. + +Required dependencies: + +Support for `mongodb+srv://` URIs requires [dnspython](https://pypi.python.org/pypi/dnspython) + +Optional dependencies: + +GSSAPI authentication requires +[pykerberos](https://pypi.python.org/pypi/pykerberos) on Unix or +[WinKerberos](https://pypi.python.org/pypi/winkerberos) on Windows. The +correct dependency can be installed automatically along with PyMongo: + +```bash +python -m pip install "pymongo[gssapi]" +``` + +MONGODB-AWS authentication requires +[pymongo-auth-aws](https://pypi.org/project/pymongo-auth-aws/): + +```bash +python -m pip install "pymongo[aws]" +``` + +OCSP (Online Certificate Status Protocol) requires +[PyOpenSSL](https://pypi.org/project/pyOpenSSL/), +[requests](https://pypi.org/project/requests/), +[service_identity](https://pypi.org/project/service_identity/) and may +require [certifi](https://pypi.python.org/pypi/certifi): + +```bash +python -m pip install "pymongo[ocsp]" +``` + +Wire protocol compression with snappy requires +[python-snappy](https://pypi.org/project/python-snappy): + +```bash +python -m pip install "pymongo[snappy]" +``` + +Wire protocol compression with zstandard requires +[zstandard](https://pypi.org/project/zstandard): + +```bash +python -m pip install "pymongo[zstd]" +``` + +Client-Side Field Level Encryption requires +[pymongocrypt](https://pypi.org/project/pymongocrypt/) and +[pymongo-auth-aws](https://pypi.org/project/pymongo-auth-aws/): + +```bash +python -m pip install "pymongo[encryption]" +``` +You can install all dependencies automatically with the following +command: + +```bash +python -m pip install "pymongo[gssapi,aws,ocsp,snappy,zstd,encryption]" +``` + +Additional dependencies are: + +- (to generate documentation or run tests) + [tox](https://tox.wiki/en/latest/index.html) + +## Examples + +Here's a basic example (for more see the *examples* section of the +docs): + +```pycon +>>> import pymongo +>>> client = pymongo.MongoClient("localhost", 27017) +>>> db = client.test +>>> db.name +'test' +>>> db.my_collection +Collection(Database(MongoClient('localhost', 27017), 'test'), 'my_collection') +>>> db.my_collection.insert_one({"x": 10}).inserted_id +ObjectId('4aba15ebe23f6b53b0000000') +>>> db.my_collection.insert_one({"x": 8}).inserted_id +ObjectId('4aba160ee23f6b543e000000') +>>> db.my_collection.insert_one({"x": 11}).inserted_id +ObjectId('4aba160ee23f6b543e000002') +>>> db.my_collection.find_one() +{'x': 10, '_id': ObjectId('4aba15ebe23f6b53b0000000')} +>>> for item in db.my_collection.find(): +... print(item["x"]) +... +10 +8 +11 +>>> db.my_collection.create_index("x") +'x_1' +>>> for item in db.my_collection.find().sort("x", pymongo.ASCENDING): +... print(item["x"]) +... +8 +10 +11 +>>> [item["x"] for item in db.my_collection.find().limit(2).skip(1)] +[8, 11] +``` + +## Documentation + +Documentation is available at +[pymongo.readthedocs.io](https://pymongo.readthedocs.io/en/stable/). + +Documentation can be generated by running **tox -m doc**. Generated +documentation can be found in the `doc/build/html/` directory. + +## Learning Resources + +- MongoDB Learn - [Python +courses](https://learn.mongodb.com/catalog?labels=%5B%22Language%22%5D&values=%5B%22Python%22%5D). +- [Python Articles on Developer +Center](https://www.mongodb.com/developer/languages/python/). + +## Testing + +The easiest way to run the tests is to run **tox -m test** in the root +of the distribution. For example, + +```bash +tox -e test +``` diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/RECORD b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/RECORD new file mode 100644 index 00000000..b13ab34c --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/RECORD @@ -0,0 +1,194 @@ +bson/__init__.py,sha256=hHfukHTrBEVUPnw0Qxs2F-7IpRD64Ita0MYYBMyv8nM,51326 +bson/__pycache__/__init__.cpython-312.pyc,, +bson/__pycache__/_helpers.cpython-312.pyc,, +bson/__pycache__/binary.cpython-312.pyc,, +bson/__pycache__/code.cpython-312.pyc,, +bson/__pycache__/codec_options.cpython-312.pyc,, +bson/__pycache__/datetime_ms.cpython-312.pyc,, +bson/__pycache__/dbref.cpython-312.pyc,, +bson/__pycache__/decimal128.cpython-312.pyc,, +bson/__pycache__/errors.cpython-312.pyc,, +bson/__pycache__/int64.cpython-312.pyc,, +bson/__pycache__/json_util.cpython-312.pyc,, +bson/__pycache__/max_key.cpython-312.pyc,, +bson/__pycache__/min_key.cpython-312.pyc,, +bson/__pycache__/objectid.cpython-312.pyc,, +bson/__pycache__/raw_bson.cpython-312.pyc,, +bson/__pycache__/regex.cpython-312.pyc,, +bson/__pycache__/son.cpython-312.pyc,, +bson/__pycache__/timestamp.cpython-312.pyc,, +bson/__pycache__/typings.cpython-312.pyc,, +bson/__pycache__/tz_util.cpython-312.pyc,, +bson/_cbson.cp312-win_amd64.pyd,sha256=s3QJ5f3_AzDJSJvJUbREExnjRuHRpv2NmszSuF_3dGk,46080 +bson/_cbsonmodule.c,sha256=BFegVXVaYhj1HnZ61GPGFjtAbNsFpl9iRVLU51jTfjE,105766 +bson/_cbsonmodule.h,sha256=va4oA4jASv494cr68ewG6kXDWFX4C6wVm5dUryCtSVU,8263 +bson/_helpers.py,sha256=MqOBDkyRx78lR1RYii4kNZcifMNPIEYceWSKIaU7f-w,1409 +bson/binary.py,sha256=H68Zpp9wxQkUShNOG3z3RDIj7PR7JbelOSHHw1Jfojg,12703 +bson/bson-endian.h,sha256=RoU1Fkefn_pUMdmCIG13o-hoIwFtj2qZnopOzJgXZM8,6806 +bson/buffer.c,sha256=b_r58Ua6ad6nTuQh9WCWm9-0SUY3__QCZ4P4aus9x_Y,4607 +bson/buffer.h,sha256=bLqJy7Jxdl_TeJUljzB0Up2fuBvyvUH1CbhpoU8NBok,1879 +bson/code.py,sha256=_Q1_tzdHdWzp3QWT-53xmYX9j4tmA4yj0uxGoXsVZk4,3533 +bson/codec_options.py,sha256=rrYBl0gHlaTFlybPCSOcboFG-wJyzsR4U3NKCWqmIQQ,20156 +bson/datetime_ms.py,sha256=Fow-BeU3LMmL8W-bhMMpDwv8BzRdtZQ5CyTOXhShOVY,6732 +bson/dbref.py,sha256=SSnerWSEhsgRnwz95ydZi6kvkMizGseSmFaZS8YPrLQ,4861 +bson/decimal128.py,sha256=F8MxNstyb0hMelanANkbvaBqViZ0smhw9DUBczzpkFM,10525 +bson/errors.py,sha256=buM2qPmur-9_rLKM_xkXVGJw0RkryGSUWhjbKHKhTfI,1205 +bson/int64.py,sha256=uJO22QL4tW0eX8m3lPnpVbZ2NhdodOQlEdd7kSf6NwY,1217 +bson/json_util.py,sha256=LFwl1apeEuZ8mcqZ3U_sZj-VbPI0QaSVMzim7EuVHa8,43794 +bson/max_key.py,sha256=XxkXM0i2eRlIqb8Wt8Z2bJZBd-lY8KKSWaS2aotb5kY,1560 +bson/min_key.py,sha256=35AkN147moOGGi4eMyE8zPeIqQ0aeKCaPvFhk4KAhpw,1560 +bson/objectid.py,sha256=L7WPv1WBbSaixdYG2lGxwieK-MjaD9j5fMc8pTEc2a0,9419 +bson/py.typed,sha256=SEaNgPmH3E8kUVMaKTOYBxODVTUDutfGVGupZE0IkZQ,172 +bson/raw_bson.py,sha256=b0vyQTv98LnulIfu-ArW08q6ZMHXpUtgAYpbsGBfH5o,7497 +bson/regex.py,sha256=eP0mvQi1G4XkV2UwSa8iDC4xxHQjk1yyipihd7GkQj8,4721 +bson/son.py,sha256=GwLUl_4SryoRXnP_WqN1SX4ItSItbSpBj8k9-3JV2Dk,6722 +bson/time64.c,sha256=HBaC09Oz721fxpeFTHm2xR9opxadFDih2LCkPtuXsm8,22308 +bson/time64.h,sha256=RaXMBNtMoFQyaJUijGIXkbM5oLkgGL4l2VSfZGEkhEQ,1628 +bson/time64_config.h,sha256=jOirHsEcXTlAaGF8r2iY5Tgrq0TN3xEtG45uVk8NCmY,1760 +bson/time64_limits.h,sha256=UfzyW78wagp1puS6YjKBDFmPI4gdQGZceHfxMTguqRI,1587 +bson/timestamp.py,sha256=T1LmQDvbTWZMAv8RZgoWpHghCRzRbzVYmKOEc3BtJ0Q,4356 +bson/typings.py,sha256=0zJOM3KQ7oDhbOZwVTosWEfysNIDWITO8Pwjk8Mh4Ek,1169 +bson/tz_util.py,sha256=qlv0MvZox__cI0armrEAaKqDIaf5YIcLeviQA28-VlE,1814 +gridfs/__init__.py,sha256=vW_U2MBgttj4RvXP-XS05sdrYIcRXAV8eap5NoIEqd0,38774 +gridfs/__pycache__/__init__.cpython-312.pyc,, +gridfs/__pycache__/errors.cpython-312.pyc,, +gridfs/__pycache__/grid_file.cpython-312.pyc,, +gridfs/errors.py,sha256=YIIhfEMrQV9SuAS-nNfxQtCxHf_4Bj4_yQ52uL67vaw,1125 +gridfs/grid_file.py,sha256=psYUwIhBYPEqCJc6cI3wo6gJ911zbiuMTe5N7i3MVXk,36921 +gridfs/py.typed,sha256=SEaNgPmH3E8kUVMaKTOYBxODVTUDutfGVGupZE0IkZQ,172 +pymongo-4.7.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +pymongo-4.7.2.dist-info/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558 +pymongo-4.7.2.dist-info/METADATA,sha256=uBVZsvS9irxPFB8GqtxPbdPkqdYgFqLzY7CSu7QDnnA,22681 +pymongo-4.7.2.dist-info/RECORD,, +pymongo-4.7.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pymongo-4.7.2.dist-info/WHEEL,sha256=fZWyj_84lK0cA-ZNCsdwhbJl0OTrpWkxInEn424qrSs,102 +pymongo-4.7.2.dist-info/top_level.txt,sha256=OinVojDdOfo1Dsp-NRfrZdp6gcJJ4bPRq61vSg5vyAs,20 +pymongo/__init__.py,sha256=vOV6Mo8Xo3y_ATcLsPEqk1c_iWh_0zVKbH4jfOV65TA,5402 +pymongo/__pycache__/__init__.cpython-312.pyc,, +pymongo/__pycache__/_azure_helpers.cpython-312.pyc,, +pymongo/__pycache__/_csot.cpython-312.pyc,, +pymongo/__pycache__/_gcp_helpers.cpython-312.pyc,, +pymongo/__pycache__/_lazy_import.cpython-312.pyc,, +pymongo/__pycache__/_version.cpython-312.pyc,, +pymongo/__pycache__/aggregation.cpython-312.pyc,, +pymongo/__pycache__/auth.cpython-312.pyc,, +pymongo/__pycache__/auth_aws.cpython-312.pyc,, +pymongo/__pycache__/auth_oidc.cpython-312.pyc,, +pymongo/__pycache__/bulk.cpython-312.pyc,, +pymongo/__pycache__/change_stream.cpython-312.pyc,, +pymongo/__pycache__/client_options.cpython-312.pyc,, +pymongo/__pycache__/client_session.cpython-312.pyc,, +pymongo/__pycache__/collation.cpython-312.pyc,, +pymongo/__pycache__/collection.cpython-312.pyc,, +pymongo/__pycache__/command_cursor.cpython-312.pyc,, +pymongo/__pycache__/common.cpython-312.pyc,, +pymongo/__pycache__/compression_support.cpython-312.pyc,, +pymongo/__pycache__/cursor.cpython-312.pyc,, +pymongo/__pycache__/daemon.cpython-312.pyc,, +pymongo/__pycache__/database.cpython-312.pyc,, +pymongo/__pycache__/driver_info.cpython-312.pyc,, +pymongo/__pycache__/encryption.cpython-312.pyc,, +pymongo/__pycache__/encryption_options.cpython-312.pyc,, +pymongo/__pycache__/errors.cpython-312.pyc,, +pymongo/__pycache__/event_loggers.cpython-312.pyc,, +pymongo/__pycache__/hello.cpython-312.pyc,, +pymongo/__pycache__/helpers.cpython-312.pyc,, +pymongo/__pycache__/lock.cpython-312.pyc,, +pymongo/__pycache__/logger.cpython-312.pyc,, +pymongo/__pycache__/max_staleness_selectors.cpython-312.pyc,, +pymongo/__pycache__/message.cpython-312.pyc,, +pymongo/__pycache__/mongo_client.cpython-312.pyc,, +pymongo/__pycache__/monitor.cpython-312.pyc,, +pymongo/__pycache__/monitoring.cpython-312.pyc,, +pymongo/__pycache__/network.cpython-312.pyc,, +pymongo/__pycache__/ocsp_cache.cpython-312.pyc,, +pymongo/__pycache__/ocsp_support.cpython-312.pyc,, +pymongo/__pycache__/operations.cpython-312.pyc,, +pymongo/__pycache__/periodic_executor.cpython-312.pyc,, +pymongo/__pycache__/pool.cpython-312.pyc,, +pymongo/__pycache__/pyopenssl_context.cpython-312.pyc,, +pymongo/__pycache__/read_concern.cpython-312.pyc,, +pymongo/__pycache__/read_preferences.cpython-312.pyc,, +pymongo/__pycache__/response.cpython-312.pyc,, +pymongo/__pycache__/results.cpython-312.pyc,, +pymongo/__pycache__/saslprep.cpython-312.pyc,, +pymongo/__pycache__/server.cpython-312.pyc,, +pymongo/__pycache__/server_api.cpython-312.pyc,, +pymongo/__pycache__/server_description.cpython-312.pyc,, +pymongo/__pycache__/server_selectors.cpython-312.pyc,, +pymongo/__pycache__/server_type.cpython-312.pyc,, +pymongo/__pycache__/settings.cpython-312.pyc,, +pymongo/__pycache__/socket_checker.cpython-312.pyc,, +pymongo/__pycache__/srv_resolver.cpython-312.pyc,, +pymongo/__pycache__/ssl_context.cpython-312.pyc,, +pymongo/__pycache__/ssl_support.cpython-312.pyc,, +pymongo/__pycache__/topology.cpython-312.pyc,, +pymongo/__pycache__/topology_description.cpython-312.pyc,, +pymongo/__pycache__/typings.cpython-312.pyc,, +pymongo/__pycache__/uri_parser.cpython-312.pyc,, +pymongo/__pycache__/write_concern.cpython-312.pyc,, +pymongo/_azure_helpers.py,sha256=XtnVfYnNlKpzrGGU3k7t1GzIOaKa9JYG46W9O8c9VNk,2062 +pymongo/_cmessage.cp312-win_amd64.pyd,sha256=J4LvTujxC5Uv9QbvSkCJmhgyy21ZhXIm5tU6XVrStTA,56832 +pymongo/_cmessagemodule.c,sha256=y0-IuLinu01Ob43cAATu95HT9Q1lNT2Dzc-v3y_GEz4,33527 +pymongo/_csot.py,sha256=R2qx1vm2486VLBDXoIHZ6kJbissKiVf9FOLncIpSh4s,4810 +pymongo/_gcp_helpers.py,sha256=N2nVDeOJKPiKuUOD4GD8gU6rZfU6eRG15Z5v18-uB48,1487 +pymongo/_lazy_import.py,sha256=BAKpfey3JukKjFoMWpGORlNtmPbfgNmn7nqd5Rfffgg,1596 +pymongo/_version.py,sha256=d1i5IPn8OYjtaympWfdkeHnpW92jPO9EUxF3bWHFRZ0,1031 +pymongo/aggregation.py,sha256=R1fsDf_qVHYsyyTsryZnoE-knLm7U3ndlcL8q_iBPKw,9588 +pymongo/auth.py,sha256=-5JnaEvBOro9IwuZA3rc3W7dpLuCBhVZiipJKiJ2_yg,24973 +pymongo/auth_aws.py,sha256=OYDzwWO3wgw6LwbKBW4QggMmW9e2SpcaGKSuyJA-lnM,3923 +pymongo/auth_oidc.py,sha256=OgyWbtNkRmPXJnFz4iYyCnNMv3wTfc8BEvQnRa1qN0k,14427 +pymongo/bulk.py,sha256=ZNL96EC7MB_AIPjiwkYSCXIqVeiJ-JPC_MbLBDMb8bQ,21840 +pymongo/change_stream.py,sha256=BzCkg2JtGCeS4AuIfPS1QMXXPqKRlU3aY1Rqb6_Jrz0,19186 +pymongo/client_options.py,sha256=8OmbOU4HDiPGsxflouttHQggiSmFyreYHN8PUK7vmAU,12788 +pymongo/client_session.py,sha256=AJ6pPgjZS717XlIYbNkeCqxBjEff4zf07fYkvudmbt4,45652 +pymongo/collation.py,sha256=qEfmvs6u5SdkiAB_0UndlSeownqd-mIpDg3OVtFDxdA,8129 +pymongo/collection.py,sha256=RrTTVsIXfmdznrtvzLY8qaDKphLPwBrswKmIKbhjt2U,143913 +pymongo/command_cursor.py,sha256=EeKLTkmfhULY7TzByMJVIRqin2mCur3GjRmqCHN_Ht0,14737 +pymongo/common.py,sha256=-Z6zZAPeO78jGMlDFIdMx_DMZEPmxvvUw7gYABbEuJc,38365 +pymongo/compression_support.py,sha256=6dF0HGN3BcTyqO6X1N_mfW2m716wVXM7x93xbSjisJU,5378 +pymongo/cursor.py,sha256=F4ptj3gIxg9op4xTb38YEnoVKuwIa0m5u0rhwqW-AIQ,51762 +pymongo/daemon.py,sha256=CKFmDOP47KU9LxazaN3rv6YN2bN_mnNB3qWx1db_ELE,6039 +pymongo/database.py,sha256=oiaiN7aNsKcWfn2scvbrwH2oX3ioy6hzOM3SRQc85TM,57266 +pymongo/driver_info.py,sha256=yc4DzOrIzSD9-HGvvzy6wWqQ1oxaapk5cH4y3sxL51g,1747 +pymongo/encryption.py,sha256=vgKjPRNET9nKTlemqzhX1H_f-6jAweW_jU9ccTr2cK8,45695 +pymongo/encryption_options.py,sha256=HJaohNu8FGOkNQxFjiFGOmBOH_na6ARFNMAn0nJ8y7o,13277 +pymongo/errors.py,sha256=2ntDbN7SV5bbLH5WwEDL4pLbtuWCujmF5yspyEhqav0,12034 +pymongo/event_loggers.py,sha256=TLA8rLQdvZ5SmaSFQGZnRY7CM-6TuDjnjYf8HIjmwEU,9356 +pymongo/hello.py,sha256=zWh_dwooFhgIYSz4rmGCoflUF1jqyKbUTUBkHACsd7w,6936 +pymongo/helpers.py,sha256=XjTdc_3IT41boldhUi0CShXhqMY1goXY3ksfz4aSoY4,11983 +pymongo/lock.py,sha256=atG82ip9uANlz9ZNU9quQoaVQuQTOOd17JWDgJbBBAU,1317 +pymongo/logger.py,sha256=F1_MXxc4qApoMNVZZeBWcS3mI8qMaALfvNO49oa2U5w,6536 +pymongo/max_staleness_selectors.py,sha256=NIjaxeoXSI-I-MTw_AGm3SscgdUe7VTMJBDCsHsVXjE,4795 +pymongo/message.py,sha256=Ab-1RJCHklpKIwyJsONjQ3f-RvR8bTAYGgrTDRMCMns,60690 +pymongo/mongo_client.py,sha256=MhIs4JRiwoeyD1IZf21zaWr8UH0O99rcIWXaz3_-a1s,111914 +pymongo/monitor.py,sha256=3r-opKIYC9KndGxOuBAWLFeDnnpAPCcvhnWHRpwBnys,18125 +pymongo/monitoring.py,sha256=ZttTuK3M3_RblVF-SQKen5fP0E9jAKwqHH0KmvX4PV8,66499 +pymongo/network.py,sha256=TOS1vw1MYgldse0DKGSSDAtFovD184L1HeWXAAFMmeI,16523 +pymongo/ocsp_cache.py,sha256=auQMBkffkewFDCRjJv9SQ8tvb1iLcu3H17ewr7DbVfc,3946 +pymongo/ocsp_support.py,sha256=Wr2qDb1h-_lPTGOBfaxiXCu3QBs3MSwIkH8w8dI8X5A,18233 +pymongo/operations.py,sha256=zr28SZQjPGrCtm33tpI6HSOKTfh67NZDGHs0FirGwEI,22543 +pymongo/periodic_executor.py,sha256=_u_coW3VpfV6IdptWCJMdOPYlGFxh7RNakXvhWvbTUE,6905 +pymongo/pool.py,sha256=bll23VpjGi1Z0UVfPn2Wywp_72I3RVibuD3oJsMzrrA,84845 +pymongo/py.typed,sha256=SEaNgPmH3E8kUVMaKTOYBxODVTUDutfGVGupZE0IkZQ,172 +pymongo/pyopenssl_context.py,sha256=xsL8JbrzpVJMjqN5rYsP6jUh-BBb2RyzjuaE-6e_tM4,17050 +pymongo/read_concern.py,sha256=U68-UKvYXIl7l9t-EgyKuCbUpslLsxvkl3kxJjWMiBM,2488 +pymongo/read_preferences.py,sha256=GQxqa0T337wkYGN0eIAc4S3C-LVxPKUvTWy8p81ATUA,21995 +pymongo/response.py,sha256=Q1uF4jtcVM1AgPNQnjYV7Bn9qvjzL2dZIGAayrrjw4E,4424 +pymongo/results.py,sha256=SmIJ9pZIhJ_qpHXCghVWl7DSmLhrKDtulYm3-JVoQMU,8757 +pymongo/saslprep.py,sha256=VIfmL_K5ml0BDZ1vSb7DrnGoqAzFKdWHAjTzvm69UlI,4499 +pymongo/server.py,sha256=jw6sOb4rUE8uGOy7fZzwH7ZX4MCSgBM0zUM0gtd7RyY,13194 +pymongo/server_api.py,sha256=4t_Ob9AFZnAx6EPsTyXC8iWlUtZrHcRMXvi8weMjfpM,6248 +pymongo/server_description.py,sha256=LROXcelRbnwXnw17e97tMRPvVVpG81AQ6ASteSuLmV0,9905 +pymongo/server_selectors.py,sha256=hTD9FsvpPXy5bZtm0ACt52TOUtTTlsvDHqHSkEVsPc8,6252 +pymongo/server_type.py,sha256=Iz7XtKaIyxm2TbXYMBqFD7FkEnxvo3p-IOUFqYh4iSc,956 +pymongo/settings.py,sha256=oIHdlJNLnKHImqChYLjpZTdkY5qcWHg8I6VyHA8DhUI,6249 +pymongo/socket_checker.py,sha256=mjM1ImCkdMl6BEho5q5sXsadoI1fIKTZJODseIfY0cA,4329 +pymongo/srv_resolver.py,sha256=cwo9MRNzuvFqTE-r1kAuIMxsxCikCCTeYdYu6X6sbzs,4956 +pymongo/ssl_context.py,sha256=FWsu0TZjZ8_0LQ3F1L7Hp1ifGEdmnhZvJq5j0VicuXI,1465 +pymongo/ssl_support.py,sha256=20oziqw2n2V31nqpAzpmvbyNDH4I_IVurHk_MaiRfxw,4013 +pymongo/topology.py,sha256=ECjaNLsVz9GEzgR6Gpdli-tSOVrKmpLRTGlvDVK6NFw,42423 +pymongo/topology_description.py,sha256=md8xx_GJTNqVif46ONejw36DTl7Yqt16SZABuIyM-JM,27542 +pymongo/typings.py,sha256=22e-oAWcZCUkgpbHUTeXQoQ9MF35zHQD5X1ztsH_Hck,1580 +pymongo/uri_parser.py,sha256=bx9F8KVcPnu1B8XGih09bjWrFYFVH_IfqhrvYfXkMFM,24300 +pymongo/write_concern.py,sha256=nRG1kVZmGgdjw8ZIYaYvciElsc6vQL61RQo2FBhRfHA,5441 diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/REQUESTED b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/WHEEL b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/WHEEL new file mode 100644 index 00000000..8e45f0d7 --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.43.0) +Root-Is-Purelib: false +Tag: cp312-cp312-win_amd64 + diff --git a/venv/Lib/site-packages/pymongo-4.7.2.dist-info/top_level.txt b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/top_level.txt new file mode 100644 index 00000000..7b660e26 --- /dev/null +++ b/venv/Lib/site-packages/pymongo-4.7.2.dist-info/top_level.txt @@ -0,0 +1,3 @@ +bson +gridfs +pymongo diff --git a/venv/Lib/site-packages/pymongo/__init__.py b/venv/Lib/site-packages/pymongo/__init__.py new file mode 100644 index 00000000..758bb33a --- /dev/null +++ b/venv/Lib/site-packages/pymongo/__init__.py @@ -0,0 +1,176 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python driver for MongoDB.""" +from __future__ import annotations + +from typing import ContextManager, Optional + +__all__ = [ + "ASCENDING", + "DESCENDING", + "GEO2D", + "GEOSPHERE", + "HASHED", + "TEXT", + "version_tuple", + "get_version_string", + "__version__", + "version", + "ReturnDocument", + "MAX_SUPPORTED_WIRE_VERSION", + "MIN_SUPPORTED_WIRE_VERSION", + "CursorType", + "MongoClient", + "DeleteMany", + "DeleteOne", + "IndexModel", + "InsertOne", + "ReplaceOne", + "UpdateMany", + "UpdateOne", + "ReadPreference", + "WriteConcern", + "has_c", + "timeout", +] + +ASCENDING = 1 +"""Ascending sort order.""" +DESCENDING = -1 +"""Descending sort order.""" + +GEO2D = "2d" +"""Index specifier for a 2-dimensional `geospatial index`_. + +.. _geospatial index: http://mongodb.com/docs/manual/core/2d/ +""" + +GEOSPHERE = "2dsphere" +"""Index specifier for a `spherical geospatial index`_. + +.. versionadded:: 2.5 + +.. _spherical geospatial index: http://mongodb.com/docs/manual/core/2dsphere/ +""" + +HASHED = "hashed" +"""Index specifier for a `hashed index`_. + +.. versionadded:: 2.5 + +.. _hashed index: http://mongodb.com/docs/manual/core/index-hashed/ +""" + +TEXT = "text" +"""Index specifier for a `text index`_. + +.. seealso:: MongoDB's `Atlas Search + `_ which offers more advanced + text search functionality. + +.. versionadded:: 2.7.1 + +.. _text index: http://mongodb.com/docs/manual/core/index-text/ +""" + +from pymongo import _csot +from pymongo._version import __version__, get_version_string, version_tuple +from pymongo.collection import ReturnDocument +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION +from pymongo.cursor import CursorType +from pymongo.mongo_client import MongoClient +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, +) +from pymongo.read_preferences import ReadPreference +from pymongo.write_concern import WriteConcern + +version = __version__ +"""Current version of PyMongo.""" + + +def has_c() -> bool: + """Is the C extension installed?""" + try: + from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401 + + return True + except ImportError: + return False + + +def timeout(seconds: Optional[float]) -> ContextManager[None]: + """**(Provisional)** Apply the given timeout for a block of operations. + + .. note:: :func:`~pymongo.timeout` is currently provisional. Backwards + incompatible changes may occur before becoming officially supported. + + Use :func:`~pymongo.timeout` in a with-statement:: + + with pymongo.timeout(5): + client.db.coll.insert_one({}) + client.db.coll2.insert_one({}) + + When the with-statement is entered, a deadline is set for the entire + block. When that deadline is exceeded, any blocking pymongo operation + will raise a timeout exception. For example:: + + try: + with pymongo.timeout(5): + client.db.coll.insert_one({}) + time.sleep(5) + # The deadline has now expired, the next operation will raise + # a timeout exception. + client.db.coll2.insert_one({}) + except PyMongoError as exc: + if exc.timeout: + print(f"block timed out: {exc!r}") + else: + print(f"failed with non-timeout error: {exc!r}") + + When nesting :func:`~pymongo.timeout`, the nested deadline is capped by + the outer deadline. The deadline can only be shortened, not extended. + When exiting the block, the previous deadline is restored:: + + with pymongo.timeout(5): + coll.find_one() # Uses the 5 second deadline. + with pymongo.timeout(3): + coll.find_one() # Uses the 3 second deadline. + coll.find_one() # Uses the original 5 second deadline. + with pymongo.timeout(10): + coll.find_one() # Still uses the original 5 second deadline. + coll.find_one() # Uses the original 5 second deadline. + + :param seconds: A non-negative floating point number expressing seconds, or None. + + :raises: :py:class:`ValueError`: When `seconds` is negative. + + See :ref:`timeout-example` for more examples. + + .. versionadded:: 4.2 + """ + if not isinstance(seconds, (int, float, type(None))): + raise TypeError("timeout must be None, an int, or a float") + if seconds and seconds < 0: + raise ValueError("timeout cannot be negative") + if seconds is not None: + seconds = float(seconds) + return _csot._TimeoutContext(seconds) diff --git a/venv/Lib/site-packages/pymongo/__pycache__/__init__.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..556af1ee Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/__init__.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/_azure_helpers.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/_azure_helpers.cpython-312.pyc new file mode 100644 index 00000000..249b22fc Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/_azure_helpers.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/_csot.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/_csot.cpython-312.pyc new file mode 100644 index 00000000..54c3428d Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/_csot.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/_gcp_helpers.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/_gcp_helpers.cpython-312.pyc new file mode 100644 index 00000000..c5ebe883 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/_gcp_helpers.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/_lazy_import.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/_lazy_import.cpython-312.pyc new file mode 100644 index 00000000..2e706e4a Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/_lazy_import.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/_version.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/_version.cpython-312.pyc new file mode 100644 index 00000000..81032932 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/_version.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/aggregation.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/aggregation.cpython-312.pyc new file mode 100644 index 00000000..b454b465 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/aggregation.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/auth.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/auth.cpython-312.pyc new file mode 100644 index 00000000..b5dddd4a Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/auth.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/auth_aws.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/auth_aws.cpython-312.pyc new file mode 100644 index 00000000..ee2b1964 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/auth_aws.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/auth_oidc.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/auth_oidc.cpython-312.pyc new file mode 100644 index 00000000..f6d4f31b Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/auth_oidc.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/bulk.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/bulk.cpython-312.pyc new file mode 100644 index 00000000..0c558e1b Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/bulk.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/change_stream.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/change_stream.cpython-312.pyc new file mode 100644 index 00000000..fb2e78c3 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/change_stream.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/client_options.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/client_options.cpython-312.pyc new file mode 100644 index 00000000..f92fd7a3 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/client_options.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/client_session.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/client_session.cpython-312.pyc new file mode 100644 index 00000000..4c08d191 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/client_session.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/collation.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/collation.cpython-312.pyc new file mode 100644 index 00000000..e22a4845 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/collation.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/collection.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/collection.cpython-312.pyc new file mode 100644 index 00000000..b977eafe Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/collection.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/command_cursor.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/command_cursor.cpython-312.pyc new file mode 100644 index 00000000..6740abc1 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/command_cursor.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/common.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/common.cpython-312.pyc new file mode 100644 index 00000000..04882b74 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/common.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/compression_support.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/compression_support.cpython-312.pyc new file mode 100644 index 00000000..48ec575d Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/compression_support.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/cursor.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/cursor.cpython-312.pyc new file mode 100644 index 00000000..5b6548e0 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/cursor.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/daemon.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/daemon.cpython-312.pyc new file mode 100644 index 00000000..93c9edd2 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/daemon.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/database.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/database.cpython-312.pyc new file mode 100644 index 00000000..e1bb1f3b Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/database.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/driver_info.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/driver_info.cpython-312.pyc new file mode 100644 index 00000000..deedf504 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/driver_info.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/encryption.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/encryption.cpython-312.pyc new file mode 100644 index 00000000..7d0eeab9 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/encryption.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/encryption_options.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/encryption_options.cpython-312.pyc new file mode 100644 index 00000000..e0404718 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/encryption_options.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/errors.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/errors.cpython-312.pyc new file mode 100644 index 00000000..0d3fbe50 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/errors.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/event_loggers.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/event_loggers.cpython-312.pyc new file mode 100644 index 00000000..1d1c3e52 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/event_loggers.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/hello.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/hello.cpython-312.pyc new file mode 100644 index 00000000..52ba083f Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/hello.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/helpers.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/helpers.cpython-312.pyc new file mode 100644 index 00000000..beb73b3a Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/helpers.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/lock.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/lock.cpython-312.pyc new file mode 100644 index 00000000..ebfdf17e Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/lock.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/logger.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/logger.cpython-312.pyc new file mode 100644 index 00000000..19898027 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/logger.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/max_staleness_selectors.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/max_staleness_selectors.cpython-312.pyc new file mode 100644 index 00000000..0b5383fa Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/max_staleness_selectors.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/message.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/message.cpython-312.pyc new file mode 100644 index 00000000..ed70efa7 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/message.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/mongo_client.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/mongo_client.cpython-312.pyc new file mode 100644 index 00000000..641ce1b9 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/mongo_client.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/monitor.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/monitor.cpython-312.pyc new file mode 100644 index 00000000..168c61ad Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/monitor.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/monitoring.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/monitoring.cpython-312.pyc new file mode 100644 index 00000000..48209924 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/monitoring.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/network.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/network.cpython-312.pyc new file mode 100644 index 00000000..43277a11 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/network.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/ocsp_cache.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/ocsp_cache.cpython-312.pyc new file mode 100644 index 00000000..548c80e3 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/ocsp_cache.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/ocsp_support.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/ocsp_support.cpython-312.pyc new file mode 100644 index 00000000..348d0bf3 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/ocsp_support.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/operations.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/operations.cpython-312.pyc new file mode 100644 index 00000000..1aec0ea6 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/operations.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/periodic_executor.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/periodic_executor.cpython-312.pyc new file mode 100644 index 00000000..39702d07 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/periodic_executor.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/pool.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/pool.cpython-312.pyc new file mode 100644 index 00000000..4e170e91 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/pool.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/pyopenssl_context.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/pyopenssl_context.cpython-312.pyc new file mode 100644 index 00000000..d5447bb4 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/pyopenssl_context.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/read_concern.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/read_concern.cpython-312.pyc new file mode 100644 index 00000000..d895b615 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/read_concern.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/read_preferences.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/read_preferences.cpython-312.pyc new file mode 100644 index 00000000..5aa08606 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/read_preferences.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/response.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/response.cpython-312.pyc new file mode 100644 index 00000000..d44a0f67 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/response.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/results.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/results.cpython-312.pyc new file mode 100644 index 00000000..20764e4b Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/results.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/saslprep.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/saslprep.cpython-312.pyc new file mode 100644 index 00000000..f7ac0ab2 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/saslprep.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/server.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/server.cpython-312.pyc new file mode 100644 index 00000000..31960c19 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/server.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/server_api.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/server_api.cpython-312.pyc new file mode 100644 index 00000000..4d1e3929 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/server_api.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/server_description.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/server_description.cpython-312.pyc new file mode 100644 index 00000000..054f7ed0 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/server_description.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/server_selectors.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/server_selectors.cpython-312.pyc new file mode 100644 index 00000000..e67b8fe8 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/server_selectors.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/server_type.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/server_type.cpython-312.pyc new file mode 100644 index 00000000..82a50e2d Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/server_type.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/settings.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/settings.cpython-312.pyc new file mode 100644 index 00000000..beeadf5a Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/settings.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/socket_checker.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/socket_checker.cpython-312.pyc new file mode 100644 index 00000000..b523d61c Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/socket_checker.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/srv_resolver.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/srv_resolver.cpython-312.pyc new file mode 100644 index 00000000..4ecb5d4b Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/srv_resolver.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/ssl_context.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/ssl_context.cpython-312.pyc new file mode 100644 index 00000000..540c4eee Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/ssl_context.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/ssl_support.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/ssl_support.cpython-312.pyc new file mode 100644 index 00000000..0be9d3b8 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/ssl_support.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/topology.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/topology.cpython-312.pyc new file mode 100644 index 00000000..00473085 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/topology.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/topology_description.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/topology_description.cpython-312.pyc new file mode 100644 index 00000000..68eb4fee Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/topology_description.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/typings.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/typings.cpython-312.pyc new file mode 100644 index 00000000..45897eac Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/typings.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/uri_parser.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/uri_parser.cpython-312.pyc new file mode 100644 index 00000000..efd57248 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/uri_parser.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/__pycache__/write_concern.cpython-312.pyc b/venv/Lib/site-packages/pymongo/__pycache__/write_concern.cpython-312.pyc new file mode 100644 index 00000000..d09598a4 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/__pycache__/write_concern.cpython-312.pyc differ diff --git a/venv/Lib/site-packages/pymongo/_azure_helpers.py b/venv/Lib/site-packages/pymongo/_azure_helpers.py new file mode 100644 index 00000000..704c561c --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_azure_helpers.py @@ -0,0 +1,57 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Azure helpers.""" +from __future__ import annotations + +import json +from typing import Any, Optional + + +def _get_azure_response( + resource: str, client_id: Optional[str] = None, timeout: float = 5 +) -> dict[str, Any]: + # Deferred import to save overall import time. + from urllib.request import Request, urlopen + + url = "http://169.254.169.254/metadata/identity/oauth2/token" + url += "?api-version=2018-02-01" + url += f"&resource={resource}" + if client_id: + url += f"&client_id={client_id}" + headers = {"Metadata": "true", "Accept": "application/json"} + request = Request(url, headers=headers) # noqa: S310 + try: + with urlopen(request, timeout=timeout) as response: # noqa: S310 + status = response.status + body = response.read().decode("utf8") + except Exception as e: + msg = "Failed to acquire IMDS access token: %s" % e + raise ValueError(msg) from None + + if status != 200: + msg = "Failed to acquire IMDS access token." + raise ValueError(msg) + try: + data = json.loads(body) + except Exception: + raise ValueError("Azure IMDS response must be in JSON format.") from None + + for key in ["access_token", "expires_in"]: + if not data.get(key): + msg = "Azure IMDS response must contain %s, but was %s." + msg = msg % (key, body) + raise ValueError(msg) + + return data diff --git a/venv/Lib/site-packages/pymongo/_cmessage.cp312-win_amd64.pyd b/venv/Lib/site-packages/pymongo/_cmessage.cp312-win_amd64.pyd new file mode 100644 index 00000000..ab0666d0 Binary files /dev/null and b/venv/Lib/site-packages/pymongo/_cmessage.cp312-win_amd64.pyd differ diff --git a/venv/Lib/site-packages/pymongo/_cmessagemodule.c b/venv/Lib/site-packages/pymongo/_cmessagemodule.c new file mode 100644 index 00000000..f95b9493 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_cmessagemodule.c @@ -0,0 +1,1046 @@ +/* + * Copyright 2009-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This file contains C implementations of some of the functions + * needed by the message module. If possible, these implementations + * should be used to speed up message creation. + */ + +#define PY_SSIZE_T_CLEAN +#include "Python.h" + +#include "_cbsonmodule.h" +#include "buffer.h" + +struct module_state { + PyObject* _cbson; + PyObject* _max_bson_size_str; + PyObject* _max_message_size_str; + PyObject* _max_write_batch_size_str; + PyObject* _max_split_size_str; +}; + +/* See comments about module initialization in _cbsonmodule.c */ +#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) + +#define DOC_TOO_LARGE_FMT "BSON document too large (%d bytes)" \ + " - the connected server supports" \ + " BSON document sizes up to %ld bytes." + +/* Get an error class from the pymongo.errors module. + * + * Returns a new ref */ +static PyObject* _error(char* name) { + PyObject* error; + PyObject* errors = PyImport_ImportModule("pymongo.errors"); + if (!errors) { + return NULL; + } + error = PyObject_GetAttrString(errors, name); + Py_DECREF(errors); + return error; +} + +/* The same as buffer_write_bytes except that it also validates + * "size" will fit in an int. + * Returns 0 on failure */ +static int buffer_write_bytes_ssize_t(buffer_t buffer, const char* data, Py_ssize_t size) { + int downsize = _downcast_and_check(size, 0); + if (size == -1) { + return 0; + } + return buffer_write_bytes(buffer, data, downsize); +} + +static PyObject* _cbson_query_message(PyObject* self, PyObject* args) { + /* NOTE just using a random number as the request_id */ + int request_id = rand(); + unsigned int flags; + char* collection_name = NULL; + Py_ssize_t collection_name_length; + int begin, cur_size, max_size = 0; + int num_to_skip; + int num_to_return; + PyObject* query; + PyObject* field_selector; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer = NULL; + int length_location, message_length; + PyObject* result = NULL; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + if (!(PyArg_ParseTuple(args, "Iet#iiOOO", + &flags, + "utf-8", + &collection_name, + &collection_name_length, + &num_to_skip, &num_to_return, + &query, &field_selector, + &options_obj) && + convert_codec_options(state->_cbson, options_obj, &options))) { + return NULL; + } + buffer = pymongo_buffer_new(); + if (!buffer) { + goto fail; + } + + // save space for message length + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + goto fail; + } + + if (!buffer_write_int32(buffer, (int32_t)request_id) || + !buffer_write_bytes(buffer, "\x00\x00\x00\x00\xd4\x07\x00\x00", 8) || + !buffer_write_int32(buffer, (int32_t)flags) || + !buffer_write_bytes_ssize_t(buffer, collection_name, + collection_name_length + 1) || + !buffer_write_int32(buffer, (int32_t)num_to_skip) || + !buffer_write_int32(buffer, (int32_t)num_to_return)) { + goto fail; + } + + begin = pymongo_buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, query, 0, &options, 1)) { + goto fail; + } + + max_size = pymongo_buffer_get_position(buffer) - begin; + + if (field_selector != Py_None) { + begin = pymongo_buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, field_selector, 0, + &options, 1)) { + goto fail; + } + cur_size = pymongo_buffer_get_position(buffer) - begin; + max_size = (cur_size > max_size) ? cur_size : max_size; + } + + message_length = pymongo_buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + + /* objectify buffer */ + result = Py_BuildValue("iy#i", request_id, + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer), + max_size); +fail: + PyMem_Free(collection_name); + destroy_codec_options(&options); + if (buffer) { + pymongo_buffer_free(buffer); + } + return result; +} + +static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) { + /* NOTE just using a random number as the request_id */ + int request_id = rand(); + char* collection_name = NULL; + Py_ssize_t collection_name_length; + int num_to_return; + long long cursor_id; + buffer_t buffer = NULL; + int length_location, message_length; + PyObject* result = NULL; + + if (!PyArg_ParseTuple(args, "et#iL", + "utf-8", + &collection_name, + &collection_name_length, + &num_to_return, + &cursor_id)) { + return NULL; + } + buffer = pymongo_buffer_new(); + if (!buffer) { + goto fail; + } + + // save space for message length + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + goto fail; + } + if (!buffer_write_int32(buffer, (int32_t)request_id) || + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" + "\xd5\x07\x00\x00" + "\x00\x00\x00\x00", 12) || + !buffer_write_bytes_ssize_t(buffer, + collection_name, + collection_name_length + 1) || + !buffer_write_int32(buffer, (int32_t)num_to_return) || + !buffer_write_int64(buffer, (int64_t)cursor_id)) { + goto fail; + } + + message_length = pymongo_buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + + /* objectify buffer */ + result = Py_BuildValue("iy#", request_id, + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer)); +fail: + PyMem_Free(collection_name); + if (buffer) { + pymongo_buffer_free(buffer); + } + return result; +} + +/* + * NOTE this method handles multiple documents in a type one payload but + * it does not perform batch splitting and the total message size is + * only checked *after* generating the entire message. + */ +static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) { + /* NOTE just using a random number as the request_id */ + int request_id = rand(); + unsigned int flags; + PyObject* command; + char* identifier = NULL; + Py_ssize_t identifier_length = 0; + PyObject* docs; + PyObject* doc; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer = NULL; + int length_location, message_length; + int total_size = 0; + int max_doc_size = 0; + PyObject* result = NULL; + PyObject* iterator = NULL; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + /*flags, command, identifier, docs, opts*/ + if (!(PyArg_ParseTuple(args, "IOet#OO", + &flags, + &command, + "utf-8", + &identifier, + &identifier_length, + &docs, + &options_obj) && + convert_codec_options(state->_cbson, options_obj, &options))) { + return NULL; + } + buffer = pymongo_buffer_new(); + if (!buffer) { + goto fail; + } + + // save space for message length + length_location = pymongo_buffer_save_space(buffer, 4); + if (length_location == -1) { + goto fail; + } + if (!buffer_write_int32(buffer, (int32_t)request_id) || + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* responseTo */ + "\xdd\x07\x00\x00" /* 2013 */, 8)) { + goto fail; + } + + if (!buffer_write_int32(buffer, (int32_t)flags) || + !buffer_write_bytes(buffer, "\x00", 1) /* Payload type 0 */) { + goto fail; + } + total_size = write_dict(state->_cbson, buffer, command, 0, + &options, 1); + if (!total_size) { + goto fail; + } + + if (identifier_length) { + int payload_one_length_location, payload_length; + /* Payload type 1 */ + if (!buffer_write_bytes(buffer, "\x01", 1)) { + goto fail; + } + /* save space for payload 0 length */ + payload_one_length_location = pymongo_buffer_save_space(buffer, 4); + /* C string identifier */ + if (!buffer_write_bytes_ssize_t(buffer, identifier, identifier_length + 1)) { + goto fail; + } + iterator = PyObject_GetIter(docs); + if (iterator == NULL) { + goto fail; + } + while ((doc = PyIter_Next(iterator)) != NULL) { + int encoded_doc_size = write_dict( + state->_cbson, buffer, doc, 0, &options, 1); + if (!encoded_doc_size) { + Py_CLEAR(doc); + goto fail; + } + if (encoded_doc_size > max_doc_size) { + max_doc_size = encoded_doc_size; + } + Py_CLEAR(doc); + } + + payload_length = pymongo_buffer_get_position(buffer) - payload_one_length_location; + buffer_write_int32_at_position( + buffer, payload_one_length_location, (int32_t)payload_length); + total_size += payload_length; + } + + message_length = pymongo_buffer_get_position(buffer) - length_location; + buffer_write_int32_at_position( + buffer, length_location, (int32_t)message_length); + + /* objectify buffer */ + result = Py_BuildValue("iy#ii", request_id, + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer), + total_size, + max_doc_size); +fail: + Py_XDECREF(iterator); + if (buffer) { + pymongo_buffer_free(buffer); + } + PyMem_Free(identifier); + destroy_codec_options(&options); + return result; +} + + +static void +_set_document_too_large(int size, long max) { + PyObject* DocumentTooLarge = _error("DocumentTooLarge"); + if (DocumentTooLarge) { + PyObject* error = PyUnicode_FromFormat(DOC_TOO_LARGE_FMT, size, max); + if (error) { + PyErr_SetObject(DocumentTooLarge, error); + Py_DECREF(error); + } + Py_DECREF(DocumentTooLarge); + } +} + +#define _INSERT 0 +#define _UPDATE 1 +#define _DELETE 2 + +/* OP_MSG ----------------------------------------------- */ + +static int +_batched_op_msg( + unsigned char op, unsigned char ack, + PyObject* command, PyObject* docs, PyObject* ctx, + PyObject* to_publish, codec_options_t options, + buffer_t buffer, struct module_state *state) { + + long max_bson_size; + long max_write_batch_size; + long max_message_size; + int idx = 0; + int size_location; + int position; + int length; + PyObject* max_bson_size_obj = NULL; + PyObject* max_write_batch_size_obj = NULL; + PyObject* max_message_size_obj = NULL; + PyObject* doc = NULL; + PyObject* iterator = NULL; + char* flags = ack ? "\x00\x00\x00\x00" : "\x02\x00\x00\x00"; + + max_bson_size_obj = PyObject_GetAttr(ctx, state->_max_bson_size_str); + max_bson_size = PyLong_AsLong(max_bson_size_obj); + Py_XDECREF(max_bson_size_obj); + if (max_bson_size == -1) { + return 0; + } + + max_write_batch_size_obj = PyObject_GetAttr(ctx, state->_max_write_batch_size_str); + max_write_batch_size = PyLong_AsLong(max_write_batch_size_obj); + Py_XDECREF(max_write_batch_size_obj); + if (max_write_batch_size == -1) { + return 0; + } + + max_message_size_obj = PyObject_GetAttr(ctx, state->_max_message_size_str); + max_message_size = PyLong_AsLong(max_message_size_obj); + Py_XDECREF(max_message_size_obj); + if (max_message_size == -1) { + return 0; + } + + if (!buffer_write_bytes(buffer, flags, 4)) { + return 0; + } + /* Type 0 Section */ + if (!buffer_write_bytes(buffer, "\x00", 1)) { + return 0; + } + if (!write_dict(state->_cbson, buffer, command, 0, + &options, 0)) { + return 0; + } + + /* Type 1 Section */ + if (!buffer_write_bytes(buffer, "\x01", 1)) { + return 0; + } + /* Save space for size */ + size_location = pymongo_buffer_save_space(buffer, 4); + if (size_location == -1) { + return 0; + } + + switch (op) { + case _INSERT: + { + if (!buffer_write_bytes(buffer, "documents\x00", 10)) + goto fail; + break; + } + case _UPDATE: + { + if (!buffer_write_bytes(buffer, "updates\x00", 8)) + goto fail; + break; + } + case _DELETE: + { + if (!buffer_write_bytes(buffer, "deletes\x00", 8)) + goto fail; + break; + } + default: + { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "Unknown command"); + Py_DECREF(InvalidOperation); + } + return 0; + } + } + + iterator = PyObject_GetIter(docs); + if (iterator == NULL) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "input is not iterable"); + Py_DECREF(InvalidOperation); + } + return 0; + } + while ((doc = PyIter_Next(iterator)) != NULL) { + int cur_doc_begin = pymongo_buffer_get_position(buffer); + int cur_size; + int doc_too_large = 0; + int unacked_doc_too_large = 0; + if (!write_dict(state->_cbson, buffer, doc, 0, &options, 1)) { + goto fail; + } + cur_size = pymongo_buffer_get_position(buffer) - cur_doc_begin; + + /* Does the first document exceed max_message_size? */ + doc_too_large = (idx == 0 && (pymongo_buffer_get_position(buffer) > max_message_size)); + /* When OP_MSG is used unacknowledged we have to check + * document size client side or applications won't be notified. + * Otherwise we let the server deal with documents that are too large + * since ordered=False causes those documents to be skipped instead of + * halting the bulk write operation. + * */ + unacked_doc_too_large = (!ack && cur_size > max_bson_size); + if (doc_too_large || unacked_doc_too_large) { + if (op == _INSERT) { + _set_document_too_large(cur_size, max_bson_size); + } else { + PyObject* DocumentTooLarge = _error("DocumentTooLarge"); + if (DocumentTooLarge) { + /* + * There's nothing intelligent we can say + * about size for update and delete. + */ + PyErr_Format( + DocumentTooLarge, + "%s command document too large", + (op == _UPDATE) ? "update": "delete"); + Py_DECREF(DocumentTooLarge); + } + } + goto fail; + } + /* We have enough data, return this batch. */ + if (pymongo_buffer_get_position(buffer) > max_message_size) { + /* + * Roll the existing buffer back to the beginning + * of the last document encoded. + */ + pymongo_buffer_update_position(buffer, cur_doc_begin); + Py_CLEAR(doc); + break; + } + if (PyList_Append(to_publish, doc) < 0) { + goto fail; + } + Py_CLEAR(doc); + idx += 1; + /* We have enough documents, return this batch. */ + if (idx == max_write_batch_size) { + break; + } + } + Py_CLEAR(iterator); + + if (PyErr_Occurred()) { + goto fail; + } + + position = pymongo_buffer_get_position(buffer); + length = position - size_location; + buffer_write_int32_at_position(buffer, size_location, (int32_t)length); + return 1; + +fail: + Py_XDECREF(doc); + Py_XDECREF(iterator); + return 0; +} + +static PyObject* +_cbson_encode_batched_op_msg(PyObject* self, PyObject* args) { + unsigned char op; + unsigned char ack; + PyObject* command; + PyObject* docs; + PyObject* ctx = NULL; + PyObject* to_publish = NULL; + PyObject* result = NULL; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + if (!(PyArg_ParseTuple(args, "bOObOO", + &op, &command, &docs, &ack, + &options_obj, &ctx) && + convert_codec_options(state->_cbson, options_obj, &options))) { + return NULL; + } + if (!(buffer = pymongo_buffer_new())) { + destroy_codec_options(&options); + return NULL; + } + if (!(to_publish = PyList_New(0))) { + goto fail; + } + + if (!_batched_op_msg( + op, + ack, + command, + docs, + ctx, + to_publish, + options, + buffer, + state)) { + goto fail; + } + + result = Py_BuildValue("y#O", + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer), + to_publish); +fail: + destroy_codec_options(&options); + pymongo_buffer_free(buffer); + Py_XDECREF(to_publish); + return result; +} + +static PyObject* +_cbson_batched_op_msg(PyObject* self, PyObject* args) { + unsigned char op; + unsigned char ack; + int request_id; + int position; + PyObject* command; + PyObject* docs; + PyObject* ctx = NULL; + PyObject* to_publish = NULL; + PyObject* result = NULL; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + if (!(PyArg_ParseTuple(args, "bOObOO", + &op, &command, &docs, &ack, + &options_obj, &ctx) && + convert_codec_options(state->_cbson, options_obj, &options))) { + return NULL; + } + if (!(buffer = pymongo_buffer_new())) { + destroy_codec_options(&options); + return NULL; + } + /* Save space for message length and request id */ + if ((pymongo_buffer_save_space(buffer, 8)) == -1) { + goto fail; + } + if (!buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* responseTo */ + "\xdd\x07\x00\x00", /* opcode */ + 8)) { + goto fail; + } + if (!(to_publish = PyList_New(0))) { + goto fail; + } + + if (!_batched_op_msg( + op, + ack, + command, + docs, + ctx, + to_publish, + options, + buffer, + state)) { + goto fail; + } + + request_id = rand(); + position = pymongo_buffer_get_position(buffer); + buffer_write_int32_at_position(buffer, 0, (int32_t)position); + buffer_write_int32_at_position(buffer, 4, (int32_t)request_id); + result = Py_BuildValue("iy#O", request_id, + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer), + to_publish); +fail: + destroy_codec_options(&options); + pymongo_buffer_free(buffer); + Py_XDECREF(to_publish); + return result; +} + +/* End OP_MSG -------------------------------------------- */ + +static int +_batched_write_command( + char* ns, Py_ssize_t ns_len, unsigned char op, + PyObject* command, PyObject* docs, PyObject* ctx, + PyObject* to_publish, codec_options_t options, + buffer_t buffer, struct module_state *state) { + + long max_bson_size; + long max_cmd_size; + long max_write_batch_size; + long max_split_size; + int idx = 0; + int cmd_len_loc; + int lst_len_loc; + int position; + int length; + PyObject* max_bson_size_obj = NULL; + PyObject* max_write_batch_size_obj = NULL; + PyObject* max_split_size_obj = NULL; + PyObject* doc = NULL; + PyObject* iterator = NULL; + + max_bson_size_obj = PyObject_GetAttr(ctx, state->_max_bson_size_str); + max_bson_size = PyLong_AsLong(max_bson_size_obj); + Py_XDECREF(max_bson_size_obj); + if (max_bson_size == -1) { + return 0; + } + /* + * Max BSON object size + 16k - 2 bytes for ending NUL bytes + * XXX: This should come from the server - SERVER-10643 + */ + max_cmd_size = max_bson_size + 16382; + + max_write_batch_size_obj = PyObject_GetAttr(ctx, state->_max_write_batch_size_str); + max_write_batch_size = PyLong_AsLong(max_write_batch_size_obj); + Py_XDECREF(max_write_batch_size_obj); + if (max_write_batch_size == -1) { + return 0; + } + + // max_split_size is the size at which to perform a batch split. + // Normally this this value is equal to max_bson_size (16MiB). However, + // when auto encryption is enabled max_split_size is reduced to 2MiB. + max_split_size_obj = PyObject_GetAttr(ctx, state->_max_split_size_str); + max_split_size = PyLong_AsLong(max_split_size_obj); + Py_XDECREF(max_split_size_obj); + if (max_split_size == -1) { + return 0; + } + + if (!buffer_write_bytes(buffer, + "\x00\x00\x00\x00", /* flags */ + 4) || + !buffer_write_bytes_ssize_t(buffer, ns, ns_len + 1) || /* namespace */ + !buffer_write_bytes(buffer, + "\x00\x00\x00\x00" /* skip */ + "\xFF\xFF\xFF\xFF", /* limit (-1) */ + 8)) { + return 0; + } + + /* Position of command document length */ + cmd_len_loc = pymongo_buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, command, 0, + &options, 0)) { + return 0; + } + + /* Write type byte for array */ + *(pymongo_buffer_get_buffer(buffer) + (pymongo_buffer_get_position(buffer) - 1)) = 0x4; + + switch (op) { + case _INSERT: + { + if (!buffer_write_bytes(buffer, "documents\x00", 10)) + goto fail; + break; + } + case _UPDATE: + { + if (!buffer_write_bytes(buffer, "updates\x00", 8)) + goto fail; + break; + } + case _DELETE: + { + if (!buffer_write_bytes(buffer, "deletes\x00", 8)) + goto fail; + break; + } + default: + { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "Unknown command"); + Py_DECREF(InvalidOperation); + } + return 0; + } + } + + /* Save space for list document */ + lst_len_loc = pymongo_buffer_save_space(buffer, 4); + if (lst_len_loc == -1) { + return 0; + } + + iterator = PyObject_GetIter(docs); + if (iterator == NULL) { + PyObject* InvalidOperation = _error("InvalidOperation"); + if (InvalidOperation) { + PyErr_SetString(InvalidOperation, "input is not iterable"); + Py_DECREF(InvalidOperation); + } + return 0; + } + while ((doc = PyIter_Next(iterator)) != NULL) { + int sub_doc_begin = pymongo_buffer_get_position(buffer); + int cur_doc_begin; + int cur_size; + int enough_data = 0; + char key[BUF_SIZE]; + int res = LL2STR(key, (long long)idx); + if (res == -1) { + return 0; + } + if (!buffer_write_bytes(buffer, "\x03", 1) || + !buffer_write_bytes(buffer, key, (int)strlen(key) + 1)) { + goto fail; + } + cur_doc_begin = pymongo_buffer_get_position(buffer); + if (!write_dict(state->_cbson, buffer, doc, 0, &options, 1)) { + goto fail; + } + + /* We have enough data, return this batch. + * max_cmd_size accounts for the two trailing null bytes. + */ + cur_size = pymongo_buffer_get_position(buffer) - cur_doc_begin; + /* This single document is too large for the command. */ + if (cur_size > max_cmd_size) { + if (op == _INSERT) { + _set_document_too_large(cur_size, max_bson_size); + } else { + PyObject* DocumentTooLarge = _error("DocumentTooLarge"); + if (DocumentTooLarge) { + /* + * There's nothing intelligent we can say + * about size for update and delete. + */ + PyErr_Format( + DocumentTooLarge, + "%s command document too large", + (op == _UPDATE) ? "update": "delete"); + Py_DECREF(DocumentTooLarge); + } + } + goto fail; + } + enough_data = (idx >= 1 && + (pymongo_buffer_get_position(buffer) > max_split_size)); + if (enough_data) { + /* + * Roll the existing buffer back to the beginning + * of the last document encoded. + */ + pymongo_buffer_update_position(buffer, sub_doc_begin); + Py_CLEAR(doc); + break; + } + if (PyList_Append(to_publish, doc) < 0) { + goto fail; + } + Py_CLEAR(doc); + idx += 1; + /* We have enough documents, return this batch. */ + if (idx == max_write_batch_size) { + break; + } + } + Py_CLEAR(iterator); + + if (PyErr_Occurred()) { + goto fail; + } + + if (!buffer_write_bytes(buffer, "\x00\x00", 2)) { + goto fail; + } + + position = pymongo_buffer_get_position(buffer); + length = position - lst_len_loc - 1; + buffer_write_int32_at_position(buffer, lst_len_loc, (int32_t)length); + length = position - cmd_len_loc; + buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length); + return 1; + +fail: + Py_XDECREF(doc); + Py_XDECREF(iterator); + return 0; +} + +static PyObject* +_cbson_encode_batched_write_command(PyObject* self, PyObject* args) { + char *ns = NULL; + unsigned char op; + Py_ssize_t ns_len; + PyObject* command; + PyObject* docs; + PyObject* ctx = NULL; + PyObject* to_publish = NULL; + PyObject* result = NULL; + PyObject* options_obj; + codec_options_t options; + buffer_t buffer; + struct module_state *state = GETSTATE(self); + if (!state) { + return NULL; + } + + if (!(PyArg_ParseTuple(args, "et#bOOOO", "utf-8", + &ns, &ns_len, &op, &command, &docs, + &options_obj, &ctx) && + convert_codec_options(state->_cbson, options_obj, &options))) { + return NULL; + } + if (!(buffer = pymongo_buffer_new())) { + PyMem_Free(ns); + destroy_codec_options(&options); + return NULL; + } + if (!(to_publish = PyList_New(0))) { + goto fail; + } + + if (!_batched_write_command( + ns, + ns_len, + op, + command, + docs, + ctx, + to_publish, + options, + buffer, + state)) { + goto fail; + } + + result = Py_BuildValue("y#O", + pymongo_buffer_get_buffer(buffer), + (Py_ssize_t)pymongo_buffer_get_position(buffer), + to_publish); +fail: + PyMem_Free(ns); + destroy_codec_options(&options); + pymongo_buffer_free(buffer); + Py_XDECREF(to_publish); + return result; +} + +static PyMethodDef _CMessageMethods[] = { + {"_query_message", _cbson_query_message, METH_VARARGS, + "create a query message to be sent to MongoDB"}, + {"_get_more_message", _cbson_get_more_message, METH_VARARGS, + "create a get more message to be sent to MongoDB"}, + {"_op_msg", _cbson_op_msg, METH_VARARGS, + "create an OP_MSG message to be sent to MongoDB"}, + {"_encode_batched_write_command", _cbson_encode_batched_write_command, METH_VARARGS, + "Encode the next batched insert, update, or delete command"}, + {"_batched_op_msg", _cbson_batched_op_msg, METH_VARARGS, + "Create the next batched insert, update, or delete using OP_MSG"}, + {"_encode_batched_op_msg", _cbson_encode_batched_op_msg, METH_VARARGS, + "Encode the next batched insert, update, or delete using OP_MSG"}, + {NULL, NULL, 0, NULL} +}; + +#define INITERROR return -1; +static int _cmessage_traverse(PyObject *m, visitproc visit, void *arg) { + struct module_state *state = GETSTATE(m); + if (!state) { + return 0; + } + Py_VISIT(state->_cbson); + Py_VISIT(state->_max_bson_size_str); + Py_VISIT(state->_max_message_size_str); + Py_VISIT(state->_max_split_size_str); + Py_VISIT(state->_max_write_batch_size_str); + return 0; +} + +static int _cmessage_clear(PyObject *m) { + struct module_state *state = GETSTATE(m); + if (!state) { + return 0; + } + Py_CLEAR(state->_cbson); + Py_CLEAR(state->_max_bson_size_str); + Py_CLEAR(state->_max_message_size_str); + Py_CLEAR(state->_max_split_size_str); + Py_CLEAR(state->_max_write_batch_size_str); + return 0; +} + +/* Multi-phase extension module initialization code. + * See https://peps.python.org/pep-0489/. +*/ +static int +_cmessage_exec(PyObject *m) +{ + PyObject *_cbson = NULL; + PyObject *c_api_object = NULL; + struct module_state* state = NULL; + + /* Store a reference to the _cbson module since it's needed to call some + * of its functions + */ + _cbson = PyImport_ImportModule("bson._cbson"); + if (_cbson == NULL) { + goto fail; + } + + /* Import C API of _cbson + * The header file accesses _cbson_API to call the functions + */ + c_api_object = PyObject_GetAttrString(_cbson, "_C_API"); + if (c_api_object == NULL) { + goto fail; + } + _cbson_API = (void **)PyCapsule_GetPointer(c_api_object, "_cbson._C_API"); + if (_cbson_API == NULL) { + goto fail; + } + + state = GETSTATE(m); + if (state == NULL) { + goto fail; + } + state->_cbson = _cbson; + if (!((state->_max_bson_size_str = PyUnicode_FromString("max_bson_size")) && + (state->_max_message_size_str = PyUnicode_FromString("max_message_size")) && + (state->_max_write_batch_size_str = PyUnicode_FromString("max_write_batch_size")) && + (state->_max_split_size_str = PyUnicode_FromString("max_split_size")))) { + goto fail; + } + + Py_DECREF(c_api_object); + return 0; + +fail: + Py_XDECREF(m); + Py_XDECREF(c_api_object); + Py_XDECREF(_cbson); + INITERROR; +} + + +static PyModuleDef_Slot _cmessage_slots[] = { + {Py_mod_exec, _cmessage_exec}, +#ifdef Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED + {Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED}, +#endif + {0, NULL}, +}; + + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_cmessage", + NULL, + sizeof(struct module_state), + _CMessageMethods, + _cmessage_slots, + _cmessage_traverse, + _cmessage_clear, + NULL +}; + +PyMODINIT_FUNC +PyInit__cmessage(void) +{ + return PyModuleDef_Init(&moduledef); +} diff --git a/venv/Lib/site-packages/pymongo/_csot.py b/venv/Lib/site-packages/pymongo/_csot.py new file mode 100644 index 00000000..194cbad4 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_csot.py @@ -0,0 +1,153 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Internal helpers for CSOT.""" + +from __future__ import annotations + +import functools +import time +from collections import deque +from contextlib import AbstractContextManager +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast + +if TYPE_CHECKING: + from pymongo.write_concern import WriteConcern + +TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) +RTT: ContextVar[float] = ContextVar("RTT", default=0.0) +DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf")) + + +def get_timeout() -> Optional[float]: + return TIMEOUT.get(None) + + +def get_rtt() -> float: + return RTT.get() + + +def get_deadline() -> float: + return DEADLINE.get() + + +def set_rtt(rtt: float) -> None: + RTT.set(rtt) + + +def remaining() -> Optional[float]: + if not get_timeout(): + return None + return DEADLINE.get() - time.monotonic() + + +def clamp_remaining(max_timeout: float) -> float: + """Return the remaining timeout clamped to a max value.""" + timeout = remaining() + if timeout is None: + return max_timeout + return min(timeout, max_timeout) + + +class _TimeoutContext(AbstractContextManager): + """Internal timeout context manager. + + Use :func:`pymongo.timeout` instead:: + + with pymongo.timeout(0.5): + client.test.test.insert_one({}) + """ + + def __init__(self, timeout: Optional[float]): + self._timeout = timeout + self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None + + def __enter__(self) -> _TimeoutContext: + timeout_token = TIMEOUT.set(self._timeout) + prev_deadline = DEADLINE.get() + next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") + deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) + rtt_token = RTT.set(0.0) + self._tokens = (timeout_token, deadline_token, rtt_token) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._tokens: + timeout_token, deadline_token, rtt_token = self._tokens + TIMEOUT.reset(timeout_token) + DEADLINE.reset(deadline_token) + RTT.reset(rtt_token) + + +# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories +F = TypeVar("F", bound=Callable[..., Any]) + + +def apply(func: F) -> F: + """Apply the client's timeoutMS to this operation.""" + + @functools.wraps(func) + def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if get_timeout() is None: + timeout = self._timeout + if timeout is not None: + with _TimeoutContext(timeout): + return func(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return cast(F, csot_wrapper) + + +def apply_write_concern( + cmd: MutableMapping[str, Any], write_concern: Optional[WriteConcern] +) -> None: + """Apply the given write concern to a command.""" + if not write_concern or write_concern.is_server_default: + return + wc = write_concern.document + if get_timeout() is not None: + wc.pop("wtimeout", None) + if wc: + cmd["writeConcern"] = wc + + +_MAX_RTT_SAMPLES: int = 10 +_MIN_RTT_SAMPLES: int = 2 + + +class MovingMinimum: + """Tracks a minimum RTT within the last 10 RTT samples.""" + + samples: Deque[float] + + def __init__(self) -> None: + self.samples = deque(maxlen=_MAX_RTT_SAMPLES) + + def add_sample(self, sample: float) -> None: + if sample < 0: + # Likely system time change while waiting for hello response + # and not using time.monotonic. Ignore it, the next one will + # probably be valid. + return + self.samples.append(sample) + + def get(self) -> float: + """Get the min, or 0.0 if there aren't enough samples yet.""" + if len(self.samples) >= _MIN_RTT_SAMPLES: + return min(self.samples) + return 0.0 + + def reset(self) -> None: + self.samples.clear() diff --git a/venv/Lib/site-packages/pymongo/_gcp_helpers.py b/venv/Lib/site-packages/pymongo/_gcp_helpers.py new file mode 100644 index 00000000..46f02ba1 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_gcp_helpers.py @@ -0,0 +1,39 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GCP helpers.""" +from __future__ import annotations + +from typing import Any +from urllib.request import Request, urlopen + + +def _get_gcp_response(resource: str, timeout: float = 5) -> dict[str, Any]: + url = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity" + url += f"?audience={resource}" + headers = {"Metadata-Flavor": "Google"} + request = Request(url, headers=headers) # noqa: S310 + try: + with urlopen(request, timeout=timeout) as response: # noqa: S310 + status = response.status + body = response.read().decode("utf8") + except Exception as e: + msg = "Failed to acquire IMDS access token: %s" % e + raise ValueError(msg) from None + + if status != 200: + msg = "Failed to acquire IMDS access token." + raise ValueError(msg) + + return dict(access_token=body) diff --git a/venv/Lib/site-packages/pymongo/_lazy_import.py b/venv/Lib/site-packages/pymongo/_lazy_import.py new file mode 100644 index 00000000..888339d0 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_lazy_import.py @@ -0,0 +1,43 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +from __future__ import annotations + +import importlib.util +import sys +from types import ModuleType + + +def lazy_import(name: str) -> ModuleType: + """Lazily import a module by name + + From https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + """ + # Workaround for PYTHON-4424. + if "__compiled__" in globals(): + return importlib.import_module(name) + try: + spec = importlib.util.find_spec(name) + except ValueError: + # Note: this cannot be ModuleNotFoundError, see PYTHON-4424. + raise ImportError(name=name) from None + if spec is None: + # Note: this cannot be ModuleNotFoundError, see PYTHON-4424. + raise ImportError(name=name) + assert spec is not None + loader = importlib.util.LazyLoader(spec.loader) # type:ignore[arg-type] + spec.loader = loader + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + loader.exec_module(module) + return module diff --git a/venv/Lib/site-packages/pymongo/_version.py b/venv/Lib/site-packages/pymongo/_version.py new file mode 100644 index 00000000..65caa084 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/_version.py @@ -0,0 +1,30 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Current version of PyMongo.""" +from __future__ import annotations + +from typing import Tuple, Union + +version_tuple: Tuple[Union[int, str], ...] = (4, 7, 2) + + +def get_version_string() -> str: + if isinstance(version_tuple[-1], str): + return ".".join(map(str, version_tuple[:-1])) + version_tuple[-1] + return ".".join(map(str, version_tuple)) + + +__version__: str = get_version_string() +version = __version__ diff --git a/venv/Lib/site-packages/pymongo/aggregation.py b/venv/Lib/site-packages/pymongo/aggregation.py new file mode 100644 index 00000000..574db10a --- /dev/null +++ b/venv/Lib/site-packages/pymongo/aggregation.py @@ -0,0 +1,255 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Perform aggregation operations on a collection or database.""" +from __future__ import annotations + +from collections.abc import Callable, Mapping, MutableMapping +from typing import TYPE_CHECKING, Any, Optional, Union + +from pymongo import common +from pymongo.collation import validate_collation_or_none +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + from pymongo.command_cursor import CommandCursor + from pymongo.database import Database + from pymongo.pool import Connection + from pymongo.read_preferences import _ServerMode + from pymongo.server import Server + from pymongo.typings import _DocumentType, _Pipeline + + +class _AggregationCommand: + """The internal abstract base class for aggregation cursors. + + Should not be called directly by application developers. Use + :meth:`pymongo.collection.Collection.aggregate`, or + :meth:`pymongo.database.Database.aggregate` instead. + """ + + def __init__( + self, + target: Union[Database, Collection], + cursor_class: type[CommandCursor], + pipeline: _Pipeline, + options: MutableMapping[str, Any], + explicit_session: bool, + let: Optional[Mapping[str, Any]] = None, + user_fields: Optional[MutableMapping[str, Any]] = None, + result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, + comment: Any = None, + ) -> None: + if "explain" in options: + raise ConfigurationError( + "The explain option is not supported. Use Database.command instead." + ) + + self._target = target + + pipeline = common.validate_list("pipeline", pipeline) + self._pipeline = pipeline + self._performs_write = False + if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]): + self._performs_write = True + + common.validate_is_mapping("options", options) + if let is not None: + common.validate_is_mapping("let", let) + options["let"] = let + if comment is not None: + options["comment"] = comment + + self._options = options + + # This is the batchSize that will be used for setting the initial + # batchSize for the cursor, as well as the subsequent getMores. + self._batch_size = common.validate_non_negative_integer_or_none( + "batchSize", self._options.pop("batchSize", None) + ) + + # If the cursor option is already specified, avoid overriding it. + self._options.setdefault("cursor", {}) + # If the pipeline performs a write, we ignore the initial batchSize + # since the server doesn't return results in this case. + if self._batch_size is not None and not self._performs_write: + self._options["cursor"]["batchSize"] = self._batch_size + + self._cursor_class = cursor_class + self._explicit_session = explicit_session + self._user_fields = user_fields + self._result_processor = result_processor + + self._collation = validate_collation_or_none(options.pop("collation", None)) + + self._max_await_time_ms = options.pop("maxAwaitTimeMS", None) + self._write_preference: Optional[_AggWritePref] = None + + @property + def _aggregation_target(self) -> Union[str, int]: + """The argument to pass to the aggregate command.""" + raise NotImplementedError + + @property + def _cursor_namespace(self) -> str: + """The namespace in which the aggregate command is run.""" + raise NotImplementedError + + def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection: + """The Collection used for the aggregate command cursor.""" + raise NotImplementedError + + @property + def _database(self) -> Database: + """The database against which the aggregation command is run.""" + raise NotImplementedError + + def get_read_preference( + self, session: Optional[ClientSession] + ) -> Union[_AggWritePref, _ServerMode]: + if self._write_preference: + return self._write_preference + pref = self._target._read_preference_for(session) + if self._performs_write and pref != ReadPreference.PRIMARY: + self._write_preference = pref = _AggWritePref(pref) # type: ignore[assignment] + return pref + + def get_cursor( + self, + session: Optional[ClientSession], + server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> CommandCursor[_DocumentType]: + # Serialize command. + cmd = {"aggregate": self._aggregation_target, "pipeline": self._pipeline} + cmd.update(self._options) + + # Apply this target's read concern if: + # readConcern has not been specified as a kwarg and either + # - server version is >= 4.2 or + # - server version is >= 3.2 and pipeline doesn't use $out + if ("readConcern" not in cmd) and ( + not self._performs_write or (conn.max_wire_version >= 8) + ): + read_concern = self._target.read_concern + else: + read_concern = None + + # Apply this target's write concern if: + # writeConcern has not been specified as a kwarg and pipeline doesn't + # perform a write operation + if "writeConcern" not in cmd and self._performs_write: + write_concern = self._target._write_concern_for(session) + else: + write_concern = None + + # Run command. + result = conn.command( + self._database.name, + cmd, + read_preference, + self._target.codec_options, + parse_write_concern_error=True, + read_concern=read_concern, + write_concern=write_concern, + collation=self._collation, + session=session, + client=self._database.client, + user_fields=self._user_fields, + ) + + if self._result_processor: + self._result_processor(result, conn) + + # Extract cursor from result or mock/fake one if necessary. + if "cursor" in result: + cursor = result["cursor"] + else: + # Unacknowledged $out/$merge write. Fake a cursor. + cursor = { + "id": 0, + "firstBatch": result.get("result", []), + "ns": self._cursor_namespace, + } + + # Create and return cursor instance. + cmd_cursor = self._cursor_class( + self._cursor_collection(cursor), + cursor, + conn.address, + batch_size=self._batch_size or 0, + max_await_time_ms=self._max_await_time_ms, + session=session, + explicit_session=self._explicit_session, + comment=self._options.get("comment"), + ) + cmd_cursor._maybe_pin_connection(conn) + return cmd_cursor + + +class _CollectionAggregationCommand(_AggregationCommand): + _target: Collection + + @property + def _aggregation_target(self) -> str: + return self._target.name + + @property + def _cursor_namespace(self) -> str: + return self._target.full_name + + def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection: + """The Collection used for the aggregate command cursor.""" + return self._target + + @property + def _database(self) -> Database: + return self._target.database + + +class _CollectionRawAggregationCommand(_CollectionAggregationCommand): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # For raw-batches, we set the initial batchSize for the cursor to 0. + if not self._performs_write: + self._options["cursor"]["batchSize"] = 0 + + +class _DatabaseAggregationCommand(_AggregationCommand): + _target: Database + + @property + def _aggregation_target(self) -> int: + return 1 + + @property + def _cursor_namespace(self) -> str: + return f"{self._target.name}.$cmd.aggregate" + + @property + def _database(self) -> Database: + return self._target + + def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection: + """The Collection used for the aggregate command cursor.""" + # Collection level aggregate may not always return the "ns" field + # according to our MockupDB tests. Let's handle that case for db level + # aggregate too by defaulting to the .$cmd.aggregate namespace. + _, collname = cursor.get("ns", self._cursor_namespace).split(".", 1) + return self._database[collname] diff --git a/venv/Lib/site-packages/pymongo/auth.py b/venv/Lib/site-packages/pymongo/auth.py new file mode 100644 index 00000000..8bc4145a --- /dev/null +++ b/venv/Lib/site-packages/pymongo/auth.py @@ -0,0 +1,656 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Authentication helpers.""" +from __future__ import annotations + +import functools +import hashlib +import hmac +import os +import socket +import typing +from base64 import standard_b64decode, standard_b64encode +from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + MutableMapping, + Optional, + cast, +) +from urllib.parse import quote + +from bson.binary import Binary +from pymongo.auth_aws import _authenticate_aws +from pymongo.auth_oidc import ( + _authenticate_oidc, + _get_authenticator, + _OIDCAzureCallback, + _OIDCGCPCallback, + _OIDCProperties, + _OIDCTestCallback, +) +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.saslprep import saslprep + +if TYPE_CHECKING: + from pymongo.hello import Hello + from pymongo.pool import Connection + +HAVE_KERBEROS = True +_USE_PRINCIPAL = False +try: + import winkerberos as kerberos # type:ignore[import] + + if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5): + _USE_PRINCIPAL = True +except ImportError: + try: + import kerberos # type:ignore[import] + except ImportError: + HAVE_KERBEROS = False + + +MECHANISMS = frozenset( + [ + "GSSAPI", + "MONGODB-CR", + "MONGODB-OIDC", + "MONGODB-X509", + "MONGODB-AWS", + "PLAIN", + "SCRAM-SHA-1", + "SCRAM-SHA-256", + "DEFAULT", + ] +) +"""The authentication mechanisms supported by PyMongo.""" + + +class _Cache: + __slots__ = ("data",) + + _hash_val = hash("_Cache") + + def __init__(self) -> None: + self.data = None + + def __eq__(self, other: object) -> bool: + # Two instances must always compare equal. + if isinstance(other, _Cache): + return True + return NotImplemented + + def __ne__(self, other: object) -> bool: + if isinstance(other, _Cache): + return False + return NotImplemented + + def __hash__(self) -> int: + return self._hash_val + + +MongoCredential = namedtuple( + "MongoCredential", + ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], +) +"""A hashable namedtuple of values used for authentication.""" + + +GSSAPIProperties = namedtuple( + "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] +) +"""Mechanism properties for GSSAPI authentication.""" + + +_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) +"""Mechanism properties for MONGODB-AWS authentication.""" + + +def _build_credentials_tuple( + mech: str, + source: Optional[str], + user: str, + passwd: str, + extra: Mapping[str, Any], + database: Optional[str], +) -> MongoCredential: + """Build and return a mechanism specific credentials tuple.""" + if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: + raise ConfigurationError(f"{mech} requires a username.") + if mech == "GSSAPI": + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for GSSAPI") + properties = extra.get("authmechanismproperties", {}) + service_name = properties.get("SERVICE_NAME", "mongodb") + canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) + service_realm = properties.get("SERVICE_REALM") + props = GSSAPIProperties( + service_name=service_name, + canonicalize_host_name=canonicalize, + service_realm=service_realm, + ) + # Source is always $external. + return MongoCredential(mech, "$external", user, passwd, props, None) + elif mech == "MONGODB-X509": + if passwd is not None: + raise ConfigurationError("Passwords are not supported by MONGODB-X509") + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for MONGODB-X509") + # Source is always $external, user can be None. + return MongoCredential(mech, "$external", user, None, None, None) + elif mech == "MONGODB-AWS": + if user is not None and passwd is None: + raise ConfigurationError("username without a password is not supported by MONGODB-AWS") + if source is not None and source != "$external": + raise ConfigurationError( + "authentication source must be $external or None for MONGODB-AWS" + ) + + properties = extra.get("authmechanismproperties", {}) + aws_session_token = properties.get("AWS_SESSION_TOKEN") + aws_props = _AWSProperties(aws_session_token=aws_session_token) + # user can be None for temporary link-local EC2 credentials. + return MongoCredential(mech, "$external", user, passwd, aws_props, None) + elif mech == "MONGODB-OIDC": + properties = extra.get("authmechanismproperties", {}) + callback = properties.get("OIDC_CALLBACK") + human_callback = properties.get("OIDC_HUMAN_CALLBACK") + environ = properties.get("ENVIRONMENT") + token_resource = properties.get("TOKEN_RESOURCE", "") + default_allowed = [ + "*.mongodb.net", + "*.mongodb-dev.net", + "*.mongodb-qa.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + ] + allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) + msg = ( + "authentication with MONGODB-OIDC requires providing either a callback or a environment" + ) + if passwd is not None: + msg = "password is not supported by MONGODB-OIDC" + raise ConfigurationError(msg) + if callback or human_callback: + if environ is not None: + raise ConfigurationError(msg) + if callback and human_callback: + msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" + raise ConfigurationError(msg) + elif environ is not None: + if environ == "test": + if user is not None: + msg = "test environment for MONGODB-OIDC does not support username" + raise ConfigurationError(msg) + callback = _OIDCTestCallback() + elif environ == "azure": + passwd = None + if not token_resource: + raise ConfigurationError( + "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCAzureCallback(token_resource) + elif environ == "gcp": + passwd = None + if not token_resource: + raise ConfigurationError( + "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCGCPCallback(token_resource) + else: + raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") + else: + raise ConfigurationError(msg) + + oidc_props = _OIDCProperties( + callback=callback, + human_callback=human_callback, + environment=environ, + allowed_hosts=allowed_hosts, + token_resource=token_resource, + username=user, + ) + return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) + + elif mech == "PLAIN": + source_database = source or database or "$external" + return MongoCredential(mech, source_database, user, passwd, None, None) + else: + source_database = source or database or "admin" + if passwd is None: + raise ConfigurationError("A password is required.") + return MongoCredential(mech, source_database, user, passwd, None, _Cache()) + + +def _xor(fir: bytes, sec: bytes) -> bytes: + """XOR two byte strings together.""" + return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) + + +def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: + """Split a scram response into key, value pairs.""" + return dict( + typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) + for item in response.split(b",") + ) + + +def _authenticate_scram_start( + credentials: MongoCredential, mechanism: str +) -> tuple[bytes, bytes, MutableMapping[str, Any]]: + username = credentials.username + user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") + nonce = standard_b64encode(os.urandom(32)) + first_bare = b"n=" + user + b",r=" + nonce + + cmd = { + "saslStart": 1, + "mechanism": mechanism, + "payload": Binary(b"n,," + first_bare), + "autoAuthorize": 1, + "options": {"skipEmptyExchange": True}, + } + return nonce, first_bare, cmd + + +def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None: + """Authenticate using SCRAM.""" + username = credentials.username + if mechanism == "SCRAM-SHA-256": + digest = "sha256" + digestmod = hashlib.sha256 + data = saslprep(credentials.password).encode("utf-8") + else: + digest = "sha1" + digestmod = hashlib.sha1 + data = _password_digest(username, credentials.password).encode("utf-8") + source = credentials.source + cache = credentials.cache + + # Make local + _hmac = hmac.HMAC + + ctx = conn.auth_ctx + if ctx and ctx.speculate_succeeded(): + assert isinstance(ctx, _ScramContext) + assert ctx.scram_data is not None + nonce, first_bare = ctx.scram_data + res = ctx.speculative_authenticate + else: + nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) + res = conn.command(source, cmd) + + assert res is not None + server_first = res["payload"] + parsed = _parse_scram_response(server_first) + iterations = int(parsed[b"i"]) + if iterations < 4096: + raise OperationFailure("Server returned an invalid iteration count.") + salt = parsed[b"s"] + rnonce = parsed[b"r"] + if not rnonce.startswith(nonce): + raise OperationFailure("Server returned an invalid nonce.") + + without_proof = b"c=biws,r=" + rnonce + if cache.data: + client_key, server_key, csalt, citerations = cache.data + else: + client_key, server_key, csalt, citerations = None, None, None, None + + # Salt and / or iterations could change for a number of different + # reasons. Either changing invalidates the cache. + if not client_key or salt != csalt or iterations != citerations: + salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations) + client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() + server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() + cache.data = (client_key, server_key, salt, iterations) + stored_key = digestmod(client_key).digest() + auth_msg = b",".join((first_bare, server_first, without_proof)) + client_sig = _hmac(stored_key, auth_msg, digestmod).digest() + client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig)) + client_final = b",".join((without_proof, client_proof)) + + server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest()) + + cmd = { + "saslContinue": 1, + "conversationId": res["conversationId"], + "payload": Binary(client_final), + } + res = conn.command(source, cmd) + + parsed = _parse_scram_response(res["payload"]) + if not hmac.compare_digest(parsed[b"v"], server_sig): + raise OperationFailure("Server returned an invalid signature.") + + # A third empty challenge may be required if the server does not support + # skipEmptyExchange: SERVER-44857. + if not res["done"]: + cmd = { + "saslContinue": 1, + "conversationId": res["conversationId"], + "payload": Binary(b""), + } + res = conn.command(source, cmd) + if not res["done"]: + raise OperationFailure("SASL conversation failed to complete.") + + +def _password_digest(username: str, password: str) -> str: + """Get a password digest to use for authentication.""" + if not isinstance(password, str): + raise TypeError("password must be an instance of str") + if len(password) == 0: + raise ValueError("password can't be empty") + if not isinstance(username, str): + raise TypeError("username must be an instance of str") + + md5hash = hashlib.md5() # noqa: S324 + data = f"{username}:mongo:{password}" + md5hash.update(data.encode("utf-8")) + return md5hash.hexdigest() + + +def _auth_key(nonce: str, username: str, password: str) -> str: + """Get an auth key to use for authentication.""" + digest = _password_digest(username, password) + md5hash = hashlib.md5() # noqa: S324 + data = f"{nonce}{username}{digest}" + md5hash.update(data.encode("utf-8")) + return md5hash.hexdigest() + + +def _canonicalize_hostname(hostname: str) -> str: + """Canonicalize hostname following MIT-krb5 behavior.""" + # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 + af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( + hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME + )[0] + + try: + name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD) + except socket.gaierror: + return canonname.lower() + + return name[0].lower() + + +def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None: + """Authenticate using GSSAPI.""" + if not HAVE_KERBEROS: + raise ConfigurationError( + 'The "kerberos" module must be installed to use GSSAPI authentication.' + ) + + try: + username = credentials.username + password = credentials.password + props = credentials.mechanism_properties + # Starting here and continuing through the while loop below - establish + # the security context. See RFC 4752, Section 3.1, first paragraph. + host = conn.address[0] + if props.canonicalize_host_name: + host = _canonicalize_hostname(host) + service = props.service_name + "@" + host + if props.service_realm is not None: + service = service + "@" + props.service_realm + + if password is not None: + if _USE_PRINCIPAL: + # Note that, though we use unquote_plus for unquoting URI + # options, we use quote here. Microsoft's UrlUnescape (used + # by WinKerberos) doesn't support +. + principal = ":".join((quote(username), quote(password))) + result, ctx = kerberos.authGSSClientInit( + service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG + ) + else: + if "@" in username: + user, domain = username.split("@", 1) + else: + user, domain = username, None + result, ctx = kerberos.authGSSClientInit( + service, + gssflags=kerberos.GSS_C_MUTUAL_FLAG, + user=user, + domain=domain, + password=password, + ) + else: + result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG) + + if result != kerberos.AUTH_GSS_COMPLETE: + raise OperationFailure("Kerberos context failed to initialize.") + + try: + # pykerberos uses a weird mix of exceptions and return values + # to indicate errors. + # 0 == continue, 1 == complete, -1 == error + # Only authGSSClientStep can return 0. + if kerberos.authGSSClientStep(ctx, "") != 0: + raise OperationFailure("Unknown kerberos failure in step function.") + + # Start a SASL conversation with mongod/s + # Note: pykerberos deals with base64 encoded byte strings. + # Since mongo accepts base64 strings as the payload we don't + # have to use bson.binary.Binary. + payload = kerberos.authGSSClientResponse(ctx) + cmd = { + "saslStart": 1, + "mechanism": "GSSAPI", + "payload": payload, + "autoAuthorize": 1, + } + response = conn.command("$external", cmd) + + # Limit how many times we loop to catch protocol / library issues + for _ in range(10): + result = kerberos.authGSSClientStep(ctx, str(response["payload"])) + if result == -1: + raise OperationFailure("Unknown kerberos failure in step function.") + + payload = kerberos.authGSSClientResponse(ctx) or "" + + cmd = { + "saslContinue": 1, + "conversationId": response["conversationId"], + "payload": payload, + } + response = conn.command("$external", cmd) + + if result == kerberos.AUTH_GSS_COMPLETE: + break + else: + raise OperationFailure("Kerberos authentication failed to complete.") + + # Once the security context is established actually authenticate. + # See RFC 4752, Section 3.1, last two paragraphs. + if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1: + raise OperationFailure("Unknown kerberos failure during GSS_Unwrap step.") + + if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1: + raise OperationFailure("Unknown kerberos failure during GSS_Wrap step.") + + payload = kerberos.authGSSClientResponse(ctx) + cmd = { + "saslContinue": 1, + "conversationId": response["conversationId"], + "payload": payload, + } + conn.command("$external", cmd) + + finally: + kerberos.authGSSClientClean(ctx) + + except kerberos.KrbError as exc: + raise OperationFailure(str(exc)) from None + + +def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None: + """Authenticate using SASL PLAIN (RFC 4616)""" + source = credentials.source + username = credentials.username + password = credentials.password + payload = (f"\x00{username}\x00{password}").encode() + cmd = { + "saslStart": 1, + "mechanism": "PLAIN", + "payload": Binary(payload), + "autoAuthorize": 1, + } + conn.command(source, cmd) + + +def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None: + """Authenticate using MONGODB-X509.""" + ctx = conn.auth_ctx + if ctx and ctx.speculate_succeeded(): + # MONGODB-X509 is done after the speculative auth step. + return + + cmd = _X509Context(credentials, conn.address).speculate_command() + conn.command("$external", cmd) + + +def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None: + """Authenticate using MONGODB-CR.""" + source = credentials.source + username = credentials.username + password = credentials.password + # Get a nonce + response = conn.command(source, {"getnonce": 1}) + nonce = response["nonce"] + key = _auth_key(nonce, username, password) + + # Actually authenticate + query = {"authenticate": 1, "user": username, "nonce": nonce, "key": key} + conn.command(source, query) + + +def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None: + if conn.max_wire_version >= 7: + if conn.negotiated_mechs: + mechs = conn.negotiated_mechs + else: + source = credentials.source + cmd = conn.hello_cmd() + cmd["saslSupportedMechs"] = source + "." + credentials.username + mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", []) + if "SCRAM-SHA-256" in mechs: + return _authenticate_scram(credentials, conn, "SCRAM-SHA-256") + else: + return _authenticate_scram(credentials, conn, "SCRAM-SHA-1") + else: + return _authenticate_scram(credentials, conn, "SCRAM-SHA-1") + + +_AUTH_MAP: Mapping[str, Callable[..., None]] = { + "GSSAPI": _authenticate_gssapi, + "MONGODB-CR": _authenticate_mongo_cr, + "MONGODB-X509": _authenticate_x509, + "MONGODB-AWS": _authenticate_aws, + "MONGODB-OIDC": _authenticate_oidc, # type:ignore[dict-item] + "PLAIN": _authenticate_plain, + "SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"), + "SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"), + "DEFAULT": _authenticate_default, +} + + +class _AuthContext: + def __init__(self, credentials: MongoCredential, address: tuple[str, int]) -> None: + self.credentials = credentials + self.speculative_authenticate: Optional[Mapping[str, Any]] = None + self.address = address + + @staticmethod + def from_credentials( + creds: MongoCredential, address: tuple[str, int] + ) -> Optional[_AuthContext]: + spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) + if spec_cls: + return cast(_AuthContext, spec_cls(creds, address)) + return None + + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: + raise NotImplementedError + + def parse_response(self, hello: Hello[Mapping[str, Any]]) -> None: + self.speculative_authenticate = hello.speculative_authenticate + + def speculate_succeeded(self) -> bool: + return bool(self.speculative_authenticate) + + +class _ScramContext(_AuthContext): + def __init__( + self, credentials: MongoCredential, address: tuple[str, int], mechanism: str + ) -> None: + super().__init__(credentials, address) + self.scram_data: Optional[tuple[bytes, bytes]] = None + self.mechanism = mechanism + + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: + nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism) + # The 'db' field is included only on the speculative command. + cmd["db"] = self.credentials.source + # Save for later use. + self.scram_data = (nonce, first_bare) + return cmd + + +class _X509Context(_AuthContext): + def speculate_command(self) -> MutableMapping[str, Any]: + cmd = {"authenticate": 1, "mechanism": "MONGODB-X509"} + if self.credentials.username is not None: + cmd["user"] = self.credentials.username + return cmd + + +class _OIDCContext(_AuthContext): + def speculate_command(self) -> Optional[MutableMapping[str, Any]]: + authenticator = _get_authenticator(self.credentials, self.address) + cmd = authenticator.get_spec_auth_cmd() + if cmd is None: + return None + cmd["db"] = self.credentials.source + return cmd + + +_SPECULATIVE_AUTH_MAP: Mapping[str, Any] = { + "MONGODB-X509": _X509Context, + "SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"), + "SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), + "MONGODB-OIDC": _OIDCContext, + "DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), +} + + +def authenticate( + credentials: MongoCredential, conn: Connection, reauthenticate: bool = False +) -> None: + """Authenticate connection.""" + mechanism = credentials.mechanism + auth_func = _AUTH_MAP[mechanism] + if mechanism == "MONGODB-OIDC": + _authenticate_oidc(credentials, conn, reauthenticate) + else: + auth_func(credentials, conn) diff --git a/venv/Lib/site-packages/pymongo/auth_aws.py b/venv/Lib/site-packages/pymongo/auth_aws.py new file mode 100644 index 00000000..0d253cea --- /dev/null +++ b/venv/Lib/site-packages/pymongo/auth_aws.py @@ -0,0 +1,106 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MONGODB-AWS Authentication helpers.""" +from __future__ import annotations + +from pymongo._lazy_import import lazy_import + +try: + pymongo_auth_aws = lazy_import("pymongo_auth_aws") + _HAVE_MONGODB_AWS = True +except ImportError: + _HAVE_MONGODB_AWS = False + + +from typing import TYPE_CHECKING, Any, Mapping, Type + +import bson +from bson.binary import Binary +from pymongo.errors import ConfigurationError, OperationFailure + +if TYPE_CHECKING: + from bson.typings import _ReadableBuffer + from pymongo.auth import MongoCredential + from pymongo.pool import Connection + + +def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: + """Authenticate using MONGODB-AWS.""" + if not _HAVE_MONGODB_AWS: + raise ConfigurationError( + "MONGODB-AWS authentication requires pymongo-auth-aws: " + "install with: python -m pip install 'pymongo[aws]'" + ) + + # Delayed import. + from pymongo_auth_aws.auth import ( # type:ignore[import] + set_cached_credentials, + set_use_cached_credentials, + ) + + set_use_cached_credentials(True) + + if conn.max_wire_version < 9: + raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") + + class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore + # Dependency injection: + def binary_type(self) -> Type[Binary]: + """Return the bson.binary.Binary type.""" + return Binary + + def bson_encode(self, doc: Mapping[str, Any]) -> bytes: + """Encode a dictionary to BSON.""" + return bson.encode(doc) + + def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]: + """Decode BSON to a dictionary.""" + return bson.decode(data) + + try: + ctx = AwsSaslContext( + pymongo_auth_aws.AwsCredential( + credentials.username, + credentials.password, + credentials.mechanism_properties.aws_session_token, + ) + ) + client_payload = ctx.step(None) + client_first = {"saslStart": 1, "mechanism": "MONGODB-AWS", "payload": client_payload} + server_first = conn.command("$external", client_first) + res = server_first + # Limit how many times we loop to catch protocol / library issues + for _ in range(10): + client_payload = ctx.step(res["payload"]) + cmd = { + "saslContinue": 1, + "conversationId": server_first["conversationId"], + "payload": client_payload, + } + res = conn.command("$external", cmd) + if res["done"]: + # SASL complete. + break + except pymongo_auth_aws.PyMongoAuthAwsError as exc: + # Clear the cached credentials if we hit a failure in auth. + set_cached_credentials(None) + # Convert to OperationFailure and include pymongo-auth-aws version. + raise OperationFailure( + f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})" + ) from None + except Exception: + # Clear the cached credentials if we hit a failure in auth. + set_cached_credentials(None) + raise diff --git a/venv/Lib/site-packages/pymongo/auth_oidc.py b/venv/Lib/site-packages/pymongo/auth_oidc.py new file mode 100644 index 00000000..bfe2340f --- /dev/null +++ b/venv/Lib/site-packages/pymongo/auth_oidc.py @@ -0,0 +1,365 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MONGODB-OIDC Authentication helpers.""" +from __future__ import annotations + +import abc +import os +import threading +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union +from urllib.parse import quote + +import bson +from bson.binary import Binary +from pymongo._azure_helpers import _get_azure_response +from pymongo._csot import remaining +from pymongo._gcp_helpers import _get_gcp_response +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE + +if TYPE_CHECKING: + from pymongo.auth import MongoCredential + from pymongo.pool import Connection + + +@dataclass +class OIDCIdPInfo: + issuer: str + clientId: Optional[str] = field(default=None) + requestScopes: Optional[list[str]] = field(default=None) + + +@dataclass +class OIDCCallbackContext: + timeout_seconds: float + username: str + version: int + refresh_token: Optional[str] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) + + +@dataclass +class OIDCCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + refresh_token: Optional[str] = field(default=None) + + +class OIDCCallback(abc.ABC): + """A base class for defining OIDC callbacks.""" + + @abc.abstractmethod + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + """Convert the given BSON value into our own type.""" + + +@dataclass +class _OIDCProperties: + callback: Optional[OIDCCallback] = field(default=None) + human_callback: Optional[OIDCCallback] = field(default=None) + environment: Optional[str] = field(default=None) + allowed_hosts: list[str] = field(default_factory=list) + token_resource: Optional[str] = field(default=None) + username: str = "" + + +"""Mechanism properties for MONGODB-OIDC authentication.""" + +TOKEN_BUFFER_MINUTES = 5 +HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 +CALLBACK_VERSION = 1 +MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 +TIME_BETWEEN_CALLS_SECONDS = 0.1 + + +def _get_authenticator( + credentials: MongoCredential, address: tuple[str, int] +) -> _OIDCAuthenticator: + if credentials.cache.data: + return credentials.cache.data + + # Extract values. + principal_name = credentials.username + properties = credentials.mechanism_properties + + # Validate that the address is allowed. + if not properties.environment: + found = False + allowed_hosts = properties.allowed_hosts + for patt in allowed_hosts: + if patt == address[0]: + found = True + elif patt.startswith("*.") and address[0].endswith(patt[1:]): + found = True + if not found: + raise ConfigurationError( + f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" + ) + + # Get or create the cache data. + credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties) + return credentials.cache.data + + +class _OIDCTestCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("OIDC_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAzureCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) + return OIDCCallbackResult( + access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] + ) + + +class _OIDCGCPCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_gcp_response(self.token_resource, context.timeout_seconds) + return OIDCCallbackResult(access_token=resp["access_token"]) + + +@dataclass +class _OIDCAuthenticator: + username: str + properties: _OIDCProperties + refresh_token: Optional[str] = field(default=None) + access_token: Optional[str] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) + token_gen_id: int = field(default=0) + lock: threading.Lock = field(default_factory=threading.Lock) + last_call_time: float = field(default=0) + + def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + """Handle a reauthenticate from the server.""" + # Invalidate the token for the connection. + self._invalidate(conn) + # Call the appropriate auth logic for the callback type. + if self.properties.callback: + return self._authenticate_machine(conn) + return self._authenticate_human(conn) + + def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + """Handle an initial authenticate request.""" + # First handle speculative auth. + # If it succeeded, we are done. + ctx = conn.auth_ctx + if ctx and ctx.speculate_succeeded(): + resp = ctx.speculative_authenticate + if resp and resp["done"]: + conn.oidc_token_gen_id = self.token_gen_id + return resp + + # If spec auth failed, call the appropriate auth logic for the callback type. + # We cannot assume that the token is invalid, because a proxy may have been + # involved that stripped the speculative auth information. + if self.properties.callback: + return self._authenticate_machine(conn) + return self._authenticate_human(conn) + + def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]: + """Get the appropriate speculative auth command.""" + if not self.access_token: + return None + return self._get_start_command({"jwt": self.access_token}) + + def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]: + # If there is a cached access token, try to authenticate with it. If + # authentication fails with error code 18, invalidate the access token, + # fetch a new access token, and try to authenticate again. If authentication + # fails for any other reason, raise the error to the user. + if self.access_token: + try: + return self._sasl_start_jwt(conn) + except OperationFailure as e: + if self._is_auth_error(e): + return self._authenticate_machine(conn) + raise + return self._sasl_start_jwt(conn) + + def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]: + # If we have a cached access token, try a JwtStepRequest. + # authentication fails with error code 18, invalidate the access token, + # and try to authenticate again. If authentication fails for any other + # reason, raise the error to the user. + if self.access_token: + try: + return self._sasl_start_jwt(conn) + except OperationFailure as e: + if self._is_auth_error(e): + return self._authenticate_human(conn) + raise + + # If we have a cached refresh token, try a JwtStepRequest with that. + # If authentication fails with error code 18, invalidate the access and + # refresh tokens, and try to authenticate again. If authentication fails for + # any other reason, raise the error to the user. + if self.refresh_token: + try: + return self._sasl_start_jwt(conn) + except OperationFailure as e: + if self._is_auth_error(e): + self.refresh_token = None + return self._authenticate_human(conn) + raise + + # Start a new Two-Step SASL conversation. + # Run a PrincipalStepRequest to get the IdpInfo. + cmd = self._get_start_command(None) + start_resp = self._run_command(conn, cmd) + # Attempt to authenticate with a JwtStepRequest. + return self._sasl_continue_jwt(conn, start_resp) + + def _get_access_token(self) -> Optional[str]: + properties = self.properties + cb: Union[None, OIDCCallback] + resp: OIDCCallbackResult + + is_human = properties.human_callback is not None + if is_human and self.idp_info is None: + return None + + if properties.callback: + cb = properties.callback + if properties.human_callback: + cb = properties.human_callback + + prev_token = self.access_token + if prev_token: + return prev_token + + if cb is None and not prev_token: + return None + + if not prev_token and cb is not None: + with self.lock: + # See if the token was changed while we were waiting for the + # lock. + new_token = self.access_token + if new_token != prev_token: + return new_token + + # Ensure that we are waiting a min time between callback invocations. + delta = time.time() - self.last_call_time + if delta < TIME_BETWEEN_CALLS_SECONDS: + time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta) + self.last_call_time = time.time() + + if is_human: + timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS + assert self.idp_info is not None + else: + timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS) + context = OIDCCallbackContext( + timeout_seconds=timeout, + version=CALLBACK_VERSION, + refresh_token=self.refresh_token, + idp_info=self.idp_info, + username=self.properties.username, + ) + resp = cb.fetch(context) + if not isinstance(resp, OIDCCallbackResult): + raise ValueError("Callback result must be of type OIDCCallbackResult") + self.refresh_token = resp.refresh_token + self.access_token = resp.access_token + self.token_gen_id += 1 + + return self.access_token + + def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]: + try: + return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] + except OperationFailure as e: + if self._is_auth_error(e): + self._invalidate(conn) + raise + + def _is_auth_error(self, err: Exception) -> bool: + if not isinstance(err, OperationFailure): + return False + return err.code == _AUTHENTICATION_FAILURE_CODE + + def _invalidate(self, conn: Connection) -> None: + # Ignore the invalidation if a token gen id is given and is less than our + # current token gen id. + token_gen_id = conn.oidc_token_gen_id or 0 + if token_gen_id is not None and token_gen_id < self.token_gen_id: + return + self.access_token = None + + def _sasl_continue_jwt( + self, conn: Connection, start_resp: Mapping[str, Any] + ) -> Mapping[str, Any]: + self.access_token = None + self.refresh_token = None + start_payload: dict = bson.decode(start_resp["payload"]) + if "issuer" in start_payload: + self.idp_info = OIDCIdPInfo(**start_payload) + access_token = self._get_access_token() + conn.oidc_token_gen_id = self.token_gen_id + cmd = self._get_continue_command({"jwt": access_token}, start_resp) + return self._run_command(conn, cmd) + + def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]: + access_token = self._get_access_token() + conn.oidc_token_gen_id = self.token_gen_id + cmd = self._get_start_command({"jwt": access_token}) + return self._run_command(conn, cmd) + + def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]: + if payload is None: + principal_name = self.username + if principal_name: + payload = {"n": principal_name} + else: + payload = {} + bin_payload = Binary(bson.encode(payload)) + return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload} + + def _get_continue_command( + self, payload: Mapping[str, Any], start_resp: Mapping[str, Any] + ) -> MutableMapping[str, Any]: + bin_payload = Binary(bson.encode(payload)) + return { + "saslContinue": 1, + "payload": bin_payload, + "conversationId": start_resp["conversationId"], + } + + +def _authenticate_oidc( + credentials: MongoCredential, conn: Connection, reauthenticate: bool +) -> Optional[Mapping[str, Any]]: + """Authenticate using MONGODB-OIDC.""" + authenticator = _get_authenticator(credentials, conn.address) + if reauthenticate: + return authenticator.reauthenticate(conn) + else: + return authenticator.authenticate(conn) diff --git a/venv/Lib/site-packages/pymongo/bulk.py b/venv/Lib/site-packages/pymongo/bulk.py new file mode 100644 index 00000000..e1c46105 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/bulk.py @@ -0,0 +1,595 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The bulk write operations interface. + +.. versionadded:: 2.7 +""" +from __future__ import annotations + +import copy +from collections.abc import MutableMapping +from itertools import islice +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + Mapping, + NoReturn, + Optional, + Type, + Union, +) + +from bson.objectid import ObjectId +from bson.raw_bson import RawBSONDocument +from pymongo import _csot, common +from pymongo.client_session import ClientSession, _validate_session_write_concern +from pymongo.common import ( + validate_is_document_type, + validate_ok_for_replace, + validate_ok_for_update, +) +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.message import ( + _DELETE, + _INSERT, + _UPDATE, + _BulkWriteContext, + _EncryptedBulkWriteContext, + _randint, +) +from pymongo.read_preferences import ReadPreference +from pymongo.write_concern import WriteConcern + +if TYPE_CHECKING: + from pymongo.collection import Collection + from pymongo.pool import Connection + from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline + +_DELETE_ALL: int = 0 +_DELETE_ONE: int = 1 + +# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err +_BAD_VALUE: int = 2 +_UNKNOWN_ERROR: int = 8 +_WRITE_CONCERN_ERROR: int = 64 + +_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") + + +class _Run: + """Represents a batch of write operations.""" + + def __init__(self, op_type: int) -> None: + """Initialize a new Run object.""" + self.op_type: int = op_type + self.index_map: list[int] = [] + self.ops: list[Any] = [] + self.idx_offset: int = 0 + + def index(self, idx: int) -> int: + """Get the original index of an operation in this run. + + :param idx: The Run index that maps to the original index. + """ + return self.index_map[idx] + + def add(self, original_index: int, operation: Any) -> None: + """Add an operation to this Run instance. + + :param original_index: The original index of this operation + within a larger bulk operation. + :param operation: The operation document. + """ + self.index_map.append(original_index) + self.ops.append(operation) + + +def _merge_command( + run: _Run, + full_result: MutableMapping[str, Any], + offset: int, + result: Mapping[str, Any], +) -> None: + """Merge a write command result into the full bulk result.""" + affected = result.get("n", 0) + + if run.op_type == _INSERT: + full_result["nInserted"] += affected + + elif run.op_type == _DELETE: + full_result["nRemoved"] += affected + + elif run.op_type == _UPDATE: + upserted = result.get("upserted") + if upserted: + n_upserted = len(upserted) + for doc in upserted: + doc["index"] = run.index(doc["index"] + offset) + full_result["upserted"].extend(upserted) + full_result["nUpserted"] += n_upserted + full_result["nMatched"] += affected - n_upserted + else: + full_result["nMatched"] += affected + full_result["nModified"] += result["nModified"] + + write_errors = result.get("writeErrors") + if write_errors: + for doc in write_errors: + # Leave the server response intact for APM. + replacement = doc.copy() + idx = doc["index"] + offset + replacement["index"] = run.index(idx) + # Add the failed operation to the error document. + replacement["op"] = run.ops[idx] + full_result["writeErrors"].append(replacement) + + wce = _get_wce_doc(result) + if wce: + full_result["writeConcernErrors"].append(wce) + + +def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: + """Raise a BulkWriteError from the full bulk api result.""" + # retryWrites on MMAPv1 should raise an actionable error. + if full_result["writeErrors"]: + full_result["writeErrors"].sort(key=lambda error: error["index"]) + err = full_result["writeErrors"][0] + code = err["code"] + msg = err["errmsg"] + if code == 20 and msg.startswith("Transaction numbers"): + errmsg = ( + "This MongoDB deployment does not support " + "retryable writes. Please add retryWrites=false " + "to your connection string." + ) + raise OperationFailure(errmsg, code, full_result) + raise BulkWriteError(full_result) + + +class _Bulk: + """The private guts of the bulk write API.""" + + def __init__( + self, + collection: Collection[_DocumentType], + ordered: bool, + bypass_document_validation: bool, + comment: Optional[str] = None, + let: Optional[Any] = None, + ) -> None: + """Initialize a _Bulk instance.""" + self.collection = collection.with_options( + codec_options=collection.codec_options._replace( + unicode_decode_error_handler="replace", document_class=dict + ) + ) + self.let = let + if self.let is not None: + common.validate_is_document_type("let", self.let) + self.comment: Optional[str] = comment + self.ordered = ordered + self.ops: list[tuple[int, Mapping[str, Any]]] = [] + self.executed = False + self.bypass_doc_val = bypass_document_validation + self.uses_collation = False + self.uses_array_filters = False + self.uses_hint_update = False + self.uses_hint_delete = False + self.is_retryable = True + self.retrying = False + self.started_retryable_write = False + # Extra state so that we know where to pick up on a retry attempt. + self.current_run = None + self.next_run = None + + @property + def bulk_ctx_class(self) -> Type[_BulkWriteContext]: + encrypter = self.collection.database.client._encrypter + if encrypter and not encrypter._bypass_auto_encryption: + return _EncryptedBulkWriteContext + else: + return _BulkWriteContext + + def add_insert(self, document: _DocumentOut) -> None: + """Add an insert document to the list of ops.""" + validate_is_document_type("document", document) + # Generate ObjectId client side. + if not (isinstance(document, RawBSONDocument) or "_id" in document): + document["_id"] = ObjectId() + self.ops.append((_INSERT, document)) + + def add_update( + self, + selector: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + multi: bool = False, + upsert: bool = False, + collation: Optional[Mapping[str, Any]] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + """Create an update document and add it to the list of ops.""" + validate_ok_for_update(update) + cmd: dict[str, Any] = dict( # noqa: C406 + [("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)] + ) + if collation is not None: + self.uses_collation = True + cmd["collation"] = collation + if array_filters is not None: + self.uses_array_filters = True + cmd["arrayFilters"] = array_filters + if hint is not None: + self.uses_hint_update = True + cmd["hint"] = hint + if multi: + # A bulk_write containing an update_many is not retryable. + self.is_retryable = False + self.ops.append((_UPDATE, cmd)) + + def add_replace( + self, + selector: Mapping[str, Any], + replacement: Mapping[str, Any], + upsert: bool = False, + collation: Optional[Mapping[str, Any]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + """Create a replace document and add it to the list of ops.""" + validate_ok_for_replace(replacement) + cmd = {"q": selector, "u": replacement, "multi": False, "upsert": upsert} + if collation is not None: + self.uses_collation = True + cmd["collation"] = collation + if hint is not None: + self.uses_hint_update = True + cmd["hint"] = hint + self.ops.append((_UPDATE, cmd)) + + def add_delete( + self, + selector: Mapping[str, Any], + limit: int, + collation: Optional[Mapping[str, Any]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + """Create a delete document and add it to the list of ops.""" + cmd = {"q": selector, "limit": limit} + if collation is not None: + self.uses_collation = True + cmd["collation"] = collation + if hint is not None: + self.uses_hint_delete = True + cmd["hint"] = hint + if limit == _DELETE_ALL: + # A bulk_write containing a delete_many is not retryable. + self.is_retryable = False + self.ops.append((_DELETE, cmd)) + + def gen_ordered(self) -> Iterator[Optional[_Run]]: + """Generate batches of operations, batched by type of + operation, in the order **provided**. + """ + run = None + for idx, (op_type, operation) in enumerate(self.ops): + if run is None: + run = _Run(op_type) + elif run.op_type != op_type: + yield run + run = _Run(op_type) + run.add(idx, operation) + yield run + + def gen_unordered(self) -> Iterator[_Run]: + """Generate batches of operations, batched by type of + operation, in arbitrary order. + """ + operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] + for idx, (op_type, operation) in enumerate(self.ops): + operations[op_type].add(idx, operation) + + for run in operations: + if run.ops: + yield run + + def _execute_command( + self, + generator: Iterator[Any], + write_concern: WriteConcern, + session: Optional[ClientSession], + conn: Connection, + op_id: int, + retryable: bool, + full_result: MutableMapping[str, Any], + final_write_concern: Optional[WriteConcern] = None, + ) -> None: + db_name = self.collection.database.name + client = self.collection.database.client + listeners = client._event_listeners + + if not self.current_run: + self.current_run = next(generator) + self.next_run = None + run = self.current_run + + # Connection.command validates the session, but we use + # Connection.write_command + conn.validate_session(client, session) + last_run = False + + while run: + if not self.retrying: + self.next_run = next(generator, None) + if self.next_run is None: + last_run = True + + cmd_name = _COMMANDS[run.op_type] + bwc = self.bulk_ctx_class( + db_name, + cmd_name, + conn, + op_id, + listeners, + session, + run.op_type, + self.collection.codec_options, + ) + + while run.idx_offset < len(run.ops): + # If this is the last possible operation, use the + # final write concern. + if last_run and (len(run.ops) - run.idx_offset) == 1: + write_concern = final_write_concern or write_concern + + cmd = {cmd_name: self.collection.name, "ordered": self.ordered} + if self.comment: + cmd["comment"] = self.comment + _csot.apply_write_concern(cmd, write_concern) + if self.bypass_doc_val: + cmd["bypassDocumentValidation"] = True + if self.let is not None and run.op_type in (_DELETE, _UPDATE): + cmd["let"] = self.let + if session: + # Start a new retryable write unless one was already + # started for this command. + if retryable and not self.started_retryable_write: + session._start_retryable_write() + self.started_retryable_write = True + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + conn.send_cluster_time(cmd, session, client) + conn.add_server_api(cmd) + # CSOT: apply timeout before encoding the command. + conn.apply_timeout(client, cmd) + ops = islice(run.ops, run.idx_offset, None) + + # Run as many ops as possible in one command. + if write_concern.acknowledged: + result, to_send = bwc.execute(cmd, ops, client) + + # Retryable writeConcernErrors halt the execution of this run. + wce = result.get("writeConcernError", {}) + if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: + # Synthesize the full bulk result without modifying the + # current one because this write operation may be retried. + full = copy.deepcopy(full_result) + _merge_command(run, full, run.idx_offset, result) + _raise_bulk_write_error(full) + + _merge_command(run, full_result, run.idx_offset, result) + + # We're no longer in a retry once a command succeeds. + self.retrying = False + self.started_retryable_write = False + + if self.ordered and "writeErrors" in result: + break + else: + to_send = bwc.execute_unack(cmd, ops, client) + + run.idx_offset += len(to_send) + + # We're supposed to continue if errors are + # at the write concern level (e.g. wtimeout) + if self.ordered and full_result["writeErrors"]: + break + # Reset our state + self.current_run = run = self.next_run + + def execute_command( + self, + generator: Iterator[Any], + write_concern: WriteConcern, + session: Optional[ClientSession], + operation: str, + ) -> dict[str, Any]: + """Execute using write commands.""" + # nModified is only reported for write commands, not legacy ops. + full_result = { + "writeErrors": [], + "writeConcernErrors": [], + "nInserted": 0, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + op_id = _randint() + + def retryable_bulk( + session: Optional[ClientSession], conn: Connection, retryable: bool + ) -> None: + self._execute_command( + generator, + write_concern, + session, + conn, + op_id, + retryable, + full_result, + ) + + client = self.collection.database.client + client._retryable_write( + self.is_retryable, + retryable_bulk, + session, + operation, + bulk=self, + operation_id=op_id, + ) + + if full_result["writeErrors"] or full_result["writeConcernErrors"]: + _raise_bulk_write_error(full_result) + return full_result + + def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" + db_name = self.collection.database.name + client = self.collection.database.client + listeners = client._event_listeners + op_id = _randint() + + if not self.current_run: + self.current_run = next(generator) + run = self.current_run + + while run: + cmd_name = _COMMANDS[run.op_type] + bwc = self.bulk_ctx_class( + db_name, + cmd_name, + conn, + op_id, + listeners, + None, + run.op_type, + self.collection.codec_options, + ) + + while run.idx_offset < len(run.ops): + cmd = { + cmd_name: self.collection.name, + "ordered": False, + "writeConcern": {"w": 0}, + } + conn.add_server_api(cmd) + ops = islice(run.ops, run.idx_offset, None) + # Run as many ops as possible. + to_send = bwc.execute_unack(cmd, ops, client) + run.idx_offset += len(to_send) + self.current_run = run = next(generator, None) + + def execute_command_no_results( + self, + conn: Connection, + generator: Iterator[Any], + write_concern: WriteConcern, + ) -> None: + """Execute write commands with OP_MSG and w=0 WriteConcern, ordered.""" + full_result = { + "writeErrors": [], + "writeConcernErrors": [], + "nInserted": 0, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + # Ordered bulk writes have to be acknowledged so that we stop + # processing at the first error, even when the application + # specified unacknowledged writeConcern. + initial_write_concern = WriteConcern() + op_id = _randint() + try: + self._execute_command( + generator, + initial_write_concern, + None, + conn, + op_id, + False, + full_result, + write_concern, + ) + except OperationFailure: + pass + + def execute_no_results( + self, + conn: Connection, + generator: Iterator[Any], + write_concern: WriteConcern, + ) -> None: + """Execute all operations, returning no results (w=0).""" + if self.uses_collation: + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") + if self.uses_array_filters: + raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") + # Guard against unsupported unacknowledged writes. + unack = write_concern and not write_concern.acknowledged + if unack and self.uses_hint_delete and conn.max_wire_version < 9: + raise ConfigurationError( + "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." + ) + if unack and self.uses_hint_update and conn.max_wire_version < 8: + raise ConfigurationError( + "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." + ) + # Cannot have both unacknowledged writes and bypass document validation. + if self.bypass_doc_val: + raise OperationFailure( + "Cannot set bypass_document_validation with unacknowledged write concern" + ) + + if self.ordered: + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator) + + def execute( + self, + write_concern: WriteConcern, + session: Optional[ClientSession], + operation: str, + ) -> Any: + """Execute operations.""" + if not self.ops: + raise InvalidOperation("No operations to execute") + if self.executed: + raise InvalidOperation("Bulk operations can only be executed once.") + self.executed = True + write_concern = write_concern or self.collection.write_concern + session = _validate_session_write_concern(session, write_concern) + + if self.ordered: + generator = self.gen_ordered() + else: + generator = self.gen_unordered() + + client = self.collection.database.client + if not write_concern.acknowledged: + with client._conn_for_writes(session, operation) as connection: + self.execute_no_results(connection, generator, write_concern) + return None + else: + return self.execute_command(generator, write_concern, session, operation) diff --git a/venv/Lib/site-packages/pymongo/change_stream.py b/venv/Lib/site-packages/pymongo/change_stream.py new file mode 100644 index 00000000..dc2f6bf2 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/change_stream.py @@ -0,0 +1,490 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Watch changes on a collection, a database, or the entire cluster.""" +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Type, Union + +from bson import CodecOptions, _bson_to_dict +from bson.raw_bson import RawBSONDocument +from bson.timestamp import Timestamp +from pymongo import _csot, common +from pymongo.aggregation import ( + _AggregationCommand, + _CollectionAggregationCommand, + _DatabaseAggregationCommand, +) +from pymongo.collation import validate_collation_or_none +from pymongo.command_cursor import CommandCursor +from pymongo.errors import ( + ConnectionFailure, + CursorNotFound, + InvalidOperation, + OperationFailure, + PyMongoError, +) +from pymongo.operations import _Op +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline + +# The change streams spec considers the following server errors from the +# getMore command non-resumable. All other getMore errors are resumable. +_RESUMABLE_GETMORE_ERRORS = frozenset( + [ + 6, # HostUnreachable + 7, # HostNotFound + 89, # NetworkTimeout + 91, # ShutdownInProgress + 189, # PrimarySteppedDown + 262, # ExceededTimeLimit + 9001, # SocketException + 10107, # NotWritablePrimary + 11600, # InterruptedAtShutdown + 11602, # InterruptedDueToReplStateChange + 13435, # NotPrimaryNoSecondaryOk + 13436, # NotPrimaryOrSecondary + 63, # StaleShardVersion + 150, # StaleEpoch + 13388, # StaleConfig + 234, # RetryChangeStream + 133, # FailedToSatisfyReadPreference + ] +) + + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + from pymongo.database import Database + from pymongo.mongo_client import MongoClient + from pymongo.pool import Connection + + +def _resumable(exc: PyMongoError) -> bool: + """Return True if given a resumable change stream error.""" + if isinstance(exc, (ConnectionFailure, CursorNotFound)): + return True + if isinstance(exc, OperationFailure): + if exc._max_wire_version is None: + return False + return ( + exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError") + ) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS) + return False + + +class ChangeStream(Generic[_DocumentType]): + """The internal abstract base class for change stream cursors. + + Should not be called directly by application developers. Use + :meth:`pymongo.collection.Collection.watch`, + :meth:`pymongo.database.Database.watch`, or + :meth:`pymongo.mongo_client.MongoClient.watch` instead. + + .. versionadded:: 3.6 + .. seealso:: The MongoDB documentation on `changeStreams `_. + """ + + def __init__( + self, + target: Union[ + MongoClient[_DocumentType], Database[_DocumentType], Collection[_DocumentType] + ], + pipeline: Optional[_Pipeline], + full_document: Optional[str], + resume_after: Optional[Mapping[str, Any]], + max_await_time_ms: Optional[int], + batch_size: Optional[int], + collation: Optional[_CollationIn], + start_at_operation_time: Optional[Timestamp], + session: Optional[ClientSession], + start_after: Optional[Mapping[str, Any]], + comment: Optional[Any] = None, + full_document_before_change: Optional[str] = None, + show_expanded_events: Optional[bool] = None, + ) -> None: + if pipeline is None: + pipeline = [] + pipeline = common.validate_list("pipeline", pipeline) + common.validate_string_or_none("full_document", full_document) + validate_collation_or_none(collation) + common.validate_non_negative_integer_or_none("batchSize", batch_size) + + self._decode_custom = False + self._orig_codec_options: CodecOptions[_DocumentType] = target.codec_options + if target.codec_options.type_registry._decoder_map: + self._decode_custom = True + # Keep the type registry so that we support encoding custom types + # in the pipeline. + self._target = target.with_options( # type: ignore + codec_options=target.codec_options.with_options(document_class=RawBSONDocument) + ) + else: + self._target = target + + self._pipeline = copy.deepcopy(pipeline) + self._full_document = full_document + self._full_document_before_change = full_document_before_change + self._uses_start_after = start_after is not None + self._uses_resume_after = resume_after is not None + self._resume_token = copy.deepcopy(start_after or resume_after) + self._max_await_time_ms = max_await_time_ms + self._batch_size = batch_size + self._collation = collation + self._start_at_operation_time = start_at_operation_time + self._session = session + self._comment = comment + self._closed = False + self._timeout = self._target._timeout + self._show_expanded_events = show_expanded_events + # Initialize cursor. + self._cursor = self._create_cursor() + + @property + def _aggregation_command_class(self) -> Type[_AggregationCommand]: + """The aggregation command class to be used.""" + raise NotImplementedError + + @property + def _client(self) -> MongoClient: + """The client against which the aggregation commands for + this ChangeStream will be run. + """ + raise NotImplementedError + + def _change_stream_options(self) -> dict[str, Any]: + """Return the options dict for the $changeStream pipeline stage.""" + options: dict[str, Any] = {} + if self._full_document is not None: + options["fullDocument"] = self._full_document + + if self._full_document_before_change is not None: + options["fullDocumentBeforeChange"] = self._full_document_before_change + + resume_token = self.resume_token + if resume_token is not None: + if self._uses_start_after: + options["startAfter"] = resume_token + else: + options["resumeAfter"] = resume_token + + if self._start_at_operation_time is not None: + options["startAtOperationTime"] = self._start_at_operation_time + + if self._show_expanded_events: + options["showExpandedEvents"] = self._show_expanded_events + + return options + + def _command_options(self) -> dict[str, Any]: + """Return the options dict for the aggregation command.""" + options = {} + if self._max_await_time_ms is not None: + options["maxAwaitTimeMS"] = self._max_await_time_ms + if self._batch_size is not None: + options["batchSize"] = self._batch_size + return options + + def _aggregation_pipeline(self) -> list[dict[str, Any]]: + """Return the full aggregation pipeline for this ChangeStream.""" + options = self._change_stream_options() + full_pipeline: list = [{"$changeStream": options}] + full_pipeline.extend(self._pipeline) + return full_pipeline + + def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: + """Callback that caches the postBatchResumeToken or + startAtOperationTime from a changeStream aggregate command response + containing an empty batch of change documents. + + This is implemented as a callback because we need access to the wire + version in order to determine whether to cache this value. + """ + if not result["cursor"]["firstBatch"]: + if "postBatchResumeToken" in result["cursor"]: + self._resume_token = result["cursor"]["postBatchResumeToken"] + elif ( + self._start_at_operation_time is None + and self._uses_resume_after is False + and self._uses_start_after is False + and conn.max_wire_version >= 7 + ): + self._start_at_operation_time = result.get("operationTime") + # PYTHON-2181: informative error on missing operationTime. + if self._start_at_operation_time is None: + raise OperationFailure( + "Expected field 'operationTime' missing from command " + f"response : {result!r}" + ) + + def _run_aggregation_cmd( + self, session: Optional[ClientSession], explicit_session: bool + ) -> CommandCursor: + """Run the full aggregation pipeline for this ChangeStream and return + the corresponding CommandCursor. + """ + cmd = self._aggregation_command_class( + self._target, + CommandCursor, + self._aggregation_pipeline(), + self._command_options(), + explicit_session, + result_processor=self._process_result, + comment=self._comment, + ) + return self._client._retryable_read( + cmd.get_cursor, + self._target._read_preference_for(session), + session, + operation=_Op.AGGREGATE, + ) + + def _create_cursor(self) -> CommandCursor: + with self._client._tmp_session(self._session, close=False) as s: + return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None) + + def _resume(self) -> None: + """Reestablish this change stream after a resumable error.""" + try: + self._cursor.close() + except PyMongoError: + pass + self._cursor = self._create_cursor() + + def close(self) -> None: + """Close this ChangeStream.""" + self._closed = True + self._cursor.close() + + def __iter__(self) -> ChangeStream[_DocumentType]: + return self + + @property + def resume_token(self) -> Optional[Mapping[str, Any]]: + """The cached resume token that will be used to resume after the most + recently returned change. + + .. versionadded:: 3.9 + """ + return copy.deepcopy(self._resume_token) + + @_csot.apply + def next(self) -> _DocumentType: + """Advance the cursor. + + This method blocks until the next change document is returned or an + unrecoverable error is raised. This method is used when iterating over + all changes in the cursor. For example:: + + try: + resume_token = None + pipeline = [{'$match': {'operationType': 'insert'}}] + with db.collection.watch(pipeline) as stream: + for insert_change in stream: + print(insert_change) + resume_token = stream.resume_token + except pymongo.errors.PyMongoError: + # The ChangeStream encountered an unrecoverable error or the + # resume attempt failed to recreate the cursor. + if resume_token is None: + # There is no usable resume token because there was a + # failure during ChangeStream initialization. + logging.error('...') + else: + # Use the interrupted ChangeStream's resume token to create + # a new ChangeStream. The new stream will continue from the + # last seen insert change without missing any events. + with db.collection.watch( + pipeline, resume_after=resume_token) as stream: + for insert_change in stream: + print(insert_change) + + Raises :exc:`StopIteration` if this ChangeStream is closed. + """ + while self.alive: + doc = self.try_next() + if doc is not None: + return doc + + raise StopIteration + + __next__ = next + + @property + def alive(self) -> bool: + """Does this cursor have the potential to return more data? + + .. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise + :exc:`StopIteration` and :meth:`try_next` can return ``None``. + + .. versionadded:: 3.8 + """ + return not self._closed + + @_csot.apply + def try_next(self) -> Optional[_DocumentType]: + """Advance the cursor without blocking indefinitely. + + This method returns the next change document without waiting + indefinitely for the next change. For example:: + + with db.collection.watch() as stream: + while stream.alive: + change = stream.try_next() + # Note that the ChangeStream's resume token may be updated + # even when no changes are returned. + print("Current resume token: %r" % (stream.resume_token,)) + if change is not None: + print("Change document: %r" % (change,)) + continue + # We end up here when there are no recent changes. + # Sleep for a while before trying again to avoid flooding + # the server with getMore requests when no changes are + # available. + time.sleep(10) + + If no change document is cached locally then this method runs a single + getMore command. If the getMore yields any documents, the next + document is returned, otherwise, if the getMore returns no documents + (because there have been no changes) then ``None`` is returned. + + :return: The next change document or ``None`` when no document is available + after running a single getMore or when the cursor is closed. + + .. versionadded:: 3.8 + """ + if not self._closed and not self._cursor.alive: + self._resume() + + # Attempt to get the next change with at most one getMore and at most + # one resume attempt. + try: + try: + change = self._cursor._try_next(True) + except PyMongoError as exc: + if not _resumable(exc): + raise + self._resume() + change = self._cursor._try_next(False) + except PyMongoError as exc: + # Close the stream after a fatal error. + if not _resumable(exc) and not exc.timeout: + self.close() + raise + except Exception: + self.close() + raise + + # Check if the cursor was invalidated. + if not self._cursor.alive: + self._closed = True + + # If no changes are available. + if change is None: + # We have either iterated over all documents in the cursor, + # OR the most-recently returned batch is empty. In either case, + # update the cached resume token with the postBatchResumeToken if + # one was returned. We also clear the startAtOperationTime. + if self._cursor._post_batch_resume_token is not None: + self._resume_token = self._cursor._post_batch_resume_token + self._start_at_operation_time = None + return change + + # Else, changes are available. + try: + resume_token = change["_id"] + except KeyError: + self.close() + raise InvalidOperation( + "Cannot provide resume functionality when the resume token is missing." + ) from None + + # If this is the last change document from the current batch, cache the + # postBatchResumeToken. + if not self._cursor._has_next() and self._cursor._post_batch_resume_token: + resume_token = self._cursor._post_batch_resume_token + + # Hereafter, don't use startAfter; instead use resumeAfter. + self._uses_start_after = False + self._uses_resume_after = True + + # Cache the resume token and clear startAtOperationTime. + self._resume_token = resume_token + self._start_at_operation_time = None + + if self._decode_custom: + return _bson_to_dict(change.raw, self._orig_codec_options) + return change + + def __enter__(self) -> ChangeStream[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + +class CollectionChangeStream(ChangeStream[_DocumentType]): + """A change stream that watches changes on a single collection. + + Should not be called directly by application developers. Use + helper method :meth:`pymongo.collection.Collection.watch` instead. + + .. versionadded:: 3.7 + """ + + _target: Collection[_DocumentType] + + @property + def _aggregation_command_class(self) -> Type[_CollectionAggregationCommand]: + return _CollectionAggregationCommand + + @property + def _client(self) -> MongoClient[_DocumentType]: + return self._target.database.client + + +class DatabaseChangeStream(ChangeStream[_DocumentType]): + """A change stream that watches changes on all collections in a database. + + Should not be called directly by application developers. Use + helper method :meth:`pymongo.database.Database.watch` instead. + + .. versionadded:: 3.7 + """ + + _target: Database[_DocumentType] + + @property + def _aggregation_command_class(self) -> Type[_DatabaseAggregationCommand]: + return _DatabaseAggregationCommand + + @property + def _client(self) -> MongoClient[_DocumentType]: + return self._target.client + + +class ClusterChangeStream(DatabaseChangeStream[_DocumentType]): + """A change stream that watches changes on all collections in the cluster. + + Should not be called directly by application developers. Use + helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead. + + .. versionadded:: 3.7 + """ + + def _change_stream_options(self) -> dict[str, Any]: + options = super()._change_stream_options() + options["allChangesForCluster"] = True + return options diff --git a/venv/Lib/site-packages/pymongo/client_options.py b/venv/Lib/site-packages/pymongo/client_options.py new file mode 100644 index 00000000..60332605 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/client_options.py @@ -0,0 +1,330 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools to parse mongo client options.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast + +from bson.codec_options import _parse_codec_options +from pymongo import common +from pymongo.auth import MongoCredential, _build_credentials_tuple +from pymongo.compression_support import CompressionSettings +from pymongo.errors import ConfigurationError +from pymongo.monitoring import _EventListener, _EventListeners +from pymongo.pool import PoolOptions +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ( + _ServerMode, + make_read_preference, + read_pref_mode_from_name, +) +from pymongo.server_selectors import any_server_selector +from pymongo.ssl_support import get_ssl_context +from pymongo.write_concern import WriteConcern, validate_boolean + +if TYPE_CHECKING: + from bson.codec_options import CodecOptions + from pymongo.encryption_options import AutoEncryptionOpts + from pymongo.pyopenssl_context import SSLContext + from pymongo.topology_description import _ServerSelector + + +def _parse_credentials( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> Optional[MongoCredential]: + """Parse authentication credentials.""" + mechanism = options.get("authmechanism", "DEFAULT" if username else None) + source = options.get("authsource") + if username or mechanism: + return _build_credentials_tuple(mechanism, source, username, password, options, database) + return None + + +def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: + """Parse read preference options.""" + if "read_preference" in options: + return options["read_preference"] + + name = options.get("readpreference", "primary") + mode = read_pref_mode_from_name(name) + tags = options.get("readpreferencetags") + max_staleness = options.get("maxstalenessseconds", -1) + return make_read_preference(mode, tags, max_staleness) + + +def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: + """Parse write concern options.""" + concern = options.get("w") + wtimeout = options.get("wtimeoutms") + j = options.get("journal") + fsync = options.get("fsync") + return WriteConcern(concern, wtimeout, j, fsync) + + +def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: + """Parse read concern options.""" + concern = options.get("readconcernlevel") + return ReadConcern(concern) + + +def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: + """Parse ssl options.""" + use_tls = options.get("tls") + if use_tls is not None: + validate_boolean("tls", use_tls) + + certfile = options.get("tlscertificatekeyfile") + passphrase = options.get("tlscertificatekeyfilepassword") + ca_certs = options.get("tlscafile") + crlfile = options.get("tlscrlfile") + allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) + allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) + disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) + + enabled_tls_opts = [] + for opt in ( + "tlscertificatekeyfile", + "tlscertificatekeyfilepassword", + "tlscafile", + "tlscrlfile", + ): + # Any non-null value of these options implies tls=True. + if opt in options and options[opt]: + enabled_tls_opts.append(opt) + for opt in ( + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", + ): + # A value of False for these options implies tls=True. + if opt in options and not options[opt]: + enabled_tls_opts.append(opt) + + if enabled_tls_opts: + if use_tls is None: + # Implicitly enable TLS when one of the tls* options is set. + use_tls = True + elif not use_tls: + # Error since tls is explicitly disabled but a tls option is set. + raise ConfigurationError( + "TLS has not been enabled but the " + "following tls parameters have been set: " + "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) + ) + + if use_tls: + ctx = get_ssl_context( + certfile, + passphrase, + ca_certs, + crlfile, + allow_invalid_certificates, + allow_invalid_hostnames, + disable_ocsp_endpoint_check, + ) + return ctx, allow_invalid_hostnames + return None, allow_invalid_hostnames + + +def _parse_pool_options( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> PoolOptions: + """Parse connection pool options.""" + credentials = _parse_credentials(username, password, database, options) + max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) + min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) + max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) + if max_pool_size is not None and min_pool_size > max_pool_size: + raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") + connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) + socket_timeout = options.get("sockettimeoutms") + wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) + event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) + appname = options.get("appname") + driver = options.get("driver") + server_api = options.get("server_api") + compression_settings = CompressionSettings( + options.get("compressors", []), options.get("zlibcompressionlevel", -1) + ) + ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) + load_balanced = options.get("loadbalanced") + max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) + return PoolOptions( + max_pool_size, + min_pool_size, + max_idle_time_seconds, + connect_timeout, + socket_timeout, + wait_queue_timeout, + ssl_context, + tls_allow_invalid_hostnames, + _EventListeners(event_listeners), + appname, + driver, + compression_settings, + max_connecting=max_connecting, + server_api=server_api, + load_balanced=load_balanced, + credentials=credentials, + ) + + +class ClientOptions: + """Read only configuration options for a MongoClient. + + Should not be instantiated directly by application developers. Access + a client's options via :attr:`pymongo.mongo_client.MongoClient.options` + instead. + """ + + def __init__( + self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] + ): + self.__options = options + self.__codec_options = _parse_codec_options(options) + self.__direct_connection = options.get("directconnection") + self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) + # self.__server_selection_timeout is in seconds. Must use full name for + # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. + self.__server_selection_timeout = options.get( + "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT + ) + self.__pool_options = _parse_pool_options(username, password, database, options) + self.__read_preference = _parse_read_preference(options) + self.__replica_set_name = options.get("replicaset") + self.__write_concern = _parse_write_concern(options) + self.__read_concern = _parse_read_concern(options) + self.__connect = options.get("connect") + self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) + self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) + self.__retry_reads = options.get("retryreads", common.RETRY_READS) + self.__server_selector = options.get("server_selector", any_server_selector) + self.__auto_encryption_opts = options.get("auto_encryption_opts") + self.__load_balanced = options.get("loadbalanced") + self.__timeout = options.get("timeoutms") + self.__server_monitoring_mode = options.get( + "servermonitoringmode", common.SERVER_MONITORING_MODE + ) + + @property + def _options(self) -> Mapping[str, Any]: + """The original options used to create this ClientOptions.""" + return self.__options + + @property + def connect(self) -> Optional[bool]: + """Whether to begin discovering a MongoDB topology automatically.""" + return self.__connect + + @property + def codec_options(self) -> CodecOptions: + """A :class:`~bson.codec_options.CodecOptions` instance.""" + return self.__codec_options + + @property + def direct_connection(self) -> Optional[bool]: + """Whether to connect to the deployment in 'Single' topology.""" + return self.__direct_connection + + @property + def local_threshold_ms(self) -> int: + """The local threshold for this instance.""" + return self.__local_threshold_ms + + @property + def server_selection_timeout(self) -> int: + """The server selection timeout for this instance in seconds.""" + return self.__server_selection_timeout + + @property + def server_selector(self) -> _ServerSelector: + return self.__server_selector + + @property + def heartbeat_frequency(self) -> int: + """The monitoring frequency in seconds.""" + return self.__heartbeat_frequency + + @property + def pool_options(self) -> PoolOptions: + """A :class:`~pymongo.pool.PoolOptions` instance.""" + return self.__pool_options + + @property + def read_preference(self) -> _ServerMode: + """A read preference instance.""" + return self.__read_preference + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self.__replica_set_name + + @property + def write_concern(self) -> WriteConcern: + """A :class:`~pymongo.write_concern.WriteConcern` instance.""" + return self.__write_concern + + @property + def read_concern(self) -> ReadConcern: + """A :class:`~pymongo.read_concern.ReadConcern` instance.""" + return self.__read_concern + + @property + def timeout(self) -> Optional[float]: + """The configured timeoutMS converted to seconds, or None. + + .. versionadded:: 4.2 + """ + return self.__timeout + + @property + def retry_writes(self) -> bool: + """If this instance should retry supported write operations.""" + return self.__retry_writes + + @property + def retry_reads(self) -> bool: + """If this instance should retry supported read operations.""" + return self.__retry_reads + + @property + def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: + """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" + return self.__auto_encryption_opts + + @property + def load_balanced(self) -> Optional[bool]: + """True if the client was configured to connect to a load balancer.""" + return self.__load_balanced + + @property + def event_listeners(self) -> list[_EventListeners]: + """The event listeners registered for this client. + + See :mod:`~pymongo.monitoring` for details. + + .. versionadded:: 4.0 + """ + assert self.__pool_options._event_listeners is not None + return self.__pool_options._event_listeners.event_listeners() + + @property + def server_monitoring_mode(self) -> str: + """The configured serverMonitoringMode option. + + .. versionadded:: 4.5 + """ + return self.__server_monitoring_mode diff --git a/venv/Lib/site-packages/pymongo/client_session.py b/venv/Lib/site-packages/pymongo/client_session.py new file mode 100644 index 00000000..3efc624c --- /dev/null +++ b/venv/Lib/site-packages/pymongo/client_session.py @@ -0,0 +1,1155 @@ +# Copyright 2017 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logical sessions for ordering sequential operations. + +.. versionadded:: 3.6 + +Causally Consistent Reads +========================= + +.. code-block:: python + + with client.start_session(causal_consistency=True) as session: + collection = client.db.collection + collection.update_one({"_id": 1}, {"$set": {"x": 10}}, session=session) + secondary_c = collection.with_options(read_preference=ReadPreference.SECONDARY) + + # A secondary read waits for replication of the write. + secondary_c.find_one({"_id": 1}, session=session) + +If `causal_consistency` is True (the default), read operations that use +the session are causally after previous read and write operations. Using a +causally consistent session, an application can read its own writes and is +guaranteed monotonic reads, even when reading from replica set secondaries. + +.. seealso:: The MongoDB documentation on `causal-consistency `_. + +.. _transactions-ref: + +Transactions +============ + +.. versionadded:: 3.7 + +MongoDB 4.0 adds support for transactions on replica set primaries. A +transaction is associated with a :class:`ClientSession`. To start a transaction +on a session, use :meth:`ClientSession.start_transaction` in a with-statement. +Then, execute an operation within the transaction by passing the session to the +operation: + +.. code-block:: python + + orders = client.db.orders + inventory = client.db.inventory + with client.start_session() as session: + with session.start_transaction(): + orders.insert_one({"sku": "abc123", "qty": 100}, session=session) + inventory.update_one( + {"sku": "abc123", "qty": {"$gte": 100}}, + {"$inc": {"qty": -100}}, + session=session, + ) + +Upon normal completion of ``with session.start_transaction()`` block, the +transaction automatically calls :meth:`ClientSession.commit_transaction`. +If the block exits with an exception, the transaction automatically calls +:meth:`ClientSession.abort_transaction`. + +In general, multi-document transactions only support read/write (CRUD) +operations on existing collections. However, MongoDB 4.4 adds support for +creating collections and indexes with some limitations, including an +insert operation that would result in the creation of a new collection. +For a complete description of all the supported and unsupported operations +see the `MongoDB server's documentation for transactions +`_. + +A session may only have a single active transaction at a time, multiple +transactions on the same session can be executed in sequence. + +Sharded Transactions +^^^^^^^^^^^^^^^^^^^^ + +.. versionadded:: 3.9 + +PyMongo 3.9 adds support for transactions on sharded clusters running MongoDB +>=4.2. Sharded transactions have the same API as replica set transactions. +When running a transaction against a sharded cluster, the session is +pinned to the mongos server selected for the first operation in the +transaction. All subsequent operations that are part of the same transaction +are routed to the same mongos server. When the transaction is completed, by +running either commitTransaction or abortTransaction, the session is unpinned. + +.. seealso:: The MongoDB documentation on `transactions `_. + +.. _snapshot-reads-ref: + +Snapshot Reads +============== + +.. versionadded:: 3.12 + +MongoDB 5.0 adds support for snapshot reads. Snapshot reads are requested by +passing the ``snapshot`` option to +:meth:`~pymongo.mongo_client.MongoClient.start_session`. +If ``snapshot`` is True, all read operations that use this session read data +from the same snapshot timestamp. The server chooses the latest +majority-committed snapshot timestamp when executing the first read operation +using the session. Subsequent reads on this session read from the same +snapshot timestamp. Snapshot reads are also supported when reading from +replica set secondaries. + +.. code-block:: python + + # Each read using this session reads data from the same point in time. + with client.start_session(snapshot=True) as session: + order = orders.find_one({"sku": "abc123"}, session=session) + inventory = inventory.find_one({"sku": "abc123"}, session=session) + +Snapshot Reads Limitations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Snapshot reads sessions are incompatible with ``causal_consistency=True``. +Only the following read operations are supported in a snapshot reads session: + +- :meth:`~pymongo.collection.Collection.find` +- :meth:`~pymongo.collection.Collection.find_one` +- :meth:`~pymongo.collection.Collection.aggregate` +- :meth:`~pymongo.collection.Collection.count_documents` +- :meth:`~pymongo.collection.Collection.distinct` (on unsharded collections) + +Classes +======= +""" + +from __future__ import annotations + +import collections +import time +import uuid +from collections.abc import Mapping as _Mapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Mapping, + MutableMapping, + NoReturn, + Optional, + Type, + TypeVar, +) + +from bson.binary import Binary +from bson.int64 import Int64 +from bson.timestamp import Timestamp +from pymongo import _csot +from pymongo.cursor import _ConnectionManager +from pymongo.errors import ( + ConfigurationError, + ConnectionFailure, + InvalidOperation, + OperationFailure, + PyMongoError, + WTimeoutError, +) +from pymongo.helpers import _RETRYABLE_ERROR_CODES +from pymongo.operations import _Op +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_type import SERVER_TYPE +from pymongo.write_concern import WriteConcern + +if TYPE_CHECKING: + from types import TracebackType + + from pymongo.pool import Connection + from pymongo.server import Server + from pymongo.typings import ClusterTime, _Address + + +class SessionOptions: + """Options for a new :class:`ClientSession`. + + :param causal_consistency: If True, read operations are causally + ordered within the session. Defaults to True when the ``snapshot`` + option is ``False``. + :param default_transaction_options: The default + TransactionOptions to use for transactions started on this session. + :param snapshot: If True, then all reads performed using this + session will read from the same snapshot. This option is incompatible + with ``causal_consistency=True``. Defaults to ``False``. + + .. versionchanged:: 3.12 + Added the ``snapshot`` parameter. + """ + + def __init__( + self, + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional[TransactionOptions] = None, + snapshot: Optional[bool] = False, + ) -> None: + if snapshot: + if causal_consistency: + raise ConfigurationError("snapshot reads do not support causal_consistency=True") + causal_consistency = False + elif causal_consistency is None: + causal_consistency = True + self._causal_consistency = causal_consistency + if default_transaction_options is not None: + if not isinstance(default_transaction_options, TransactionOptions): + raise TypeError( + "default_transaction_options must be an instance of " + "pymongo.client_session.TransactionOptions, not: {!r}".format( + default_transaction_options + ) + ) + self._default_transaction_options = default_transaction_options + self._snapshot = snapshot + + @property + def causal_consistency(self) -> bool: + """Whether causal consistency is configured.""" + return self._causal_consistency + + @property + def default_transaction_options(self) -> Optional[TransactionOptions]: + """The default TransactionOptions to use for transactions started on + this session. + + .. versionadded:: 3.7 + """ + return self._default_transaction_options + + @property + def snapshot(self) -> Optional[bool]: + """Whether snapshot reads are configured. + + .. versionadded:: 3.12 + """ + return self._snapshot + + +class TransactionOptions: + """Options for :meth:`ClientSession.start_transaction`. + + :param read_concern: The + :class:`~pymongo.read_concern.ReadConcern` to use for this transaction. + If ``None`` (the default) the :attr:`read_preference` of + the :class:`MongoClient` is used. + :param write_concern: The + :class:`~pymongo.write_concern.WriteConcern` to use for this + transaction. If ``None`` (the default) the :attr:`read_preference` of + the :class:`MongoClient` is used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` + for options. Transactions which read must use + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + :param max_commit_time_ms: The maximum amount of time to allow a + single commitTransaction command to run. This option is an alias for + maxTimeMS option on the commitTransaction command. If ``None`` (the + default) maxTimeMS is not used. + + .. versionchanged:: 3.9 + Added the ``max_commit_time_ms`` option. + + .. versionadded:: 3.7 + """ + + def __init__( + self, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None, + ) -> None: + self._read_concern = read_concern + self._write_concern = write_concern + self._read_preference = read_preference + self._max_commit_time_ms = max_commit_time_ms + if read_concern is not None: + if not isinstance(read_concern, ReadConcern): + raise TypeError( + "read_concern must be an instance of " + f"pymongo.read_concern.ReadConcern, not: {read_concern!r}" + ) + if write_concern is not None: + if not isinstance(write_concern, WriteConcern): + raise TypeError( + "write_concern must be an instance of " + f"pymongo.write_concern.WriteConcern, not: {write_concern!r}" + ) + if not write_concern.acknowledged: + raise ConfigurationError( + "transactions do not support unacknowledged write concern" + f": {write_concern!r}" + ) + if read_preference is not None: + if not isinstance(read_preference, _ServerMode): + raise TypeError( + f"{read_preference!r} is not valid for read_preference. See " + "pymongo.read_preferences for valid " + "options." + ) + if max_commit_time_ms is not None: + if not isinstance(max_commit_time_ms, int): + raise TypeError("max_commit_time_ms must be an integer or None") + + @property + def read_concern(self) -> Optional[ReadConcern]: + """This transaction's :class:`~pymongo.read_concern.ReadConcern`.""" + return self._read_concern + + @property + def write_concern(self) -> Optional[WriteConcern]: + """This transaction's :class:`~pymongo.write_concern.WriteConcern`.""" + return self._write_concern + + @property + def read_preference(self) -> Optional[_ServerMode]: + """This transaction's :class:`~pymongo.read_preferences.ReadPreference`.""" + return self._read_preference + + @property + def max_commit_time_ms(self) -> Optional[int]: + """The maxTimeMS to use when running a commitTransaction command. + + .. versionadded:: 3.9 + """ + return self._max_commit_time_ms + + +def _validate_session_write_concern( + session: Optional[ClientSession], write_concern: Optional[WriteConcern] +) -> Optional[ClientSession]: + """Validate that an explicit session is not used with an unack'ed write. + + Returns the session to use for the next operation. + """ + if session: + if write_concern is not None and not write_concern.acknowledged: + # For unacknowledged writes without an explicit session, + # drivers SHOULD NOT use an implicit session. If a driver + # creates an implicit session for unacknowledged writes + # without an explicit session, the driver MUST NOT send the + # session ID. + if session._implicit: + return None + else: + raise ConfigurationError( + "Explicit sessions are incompatible with " + f"unacknowledged write concern: {write_concern!r}" + ) + return session + + +class _TransactionContext: + """Internal transaction context manager for start_transaction.""" + + def __init__(self, session: ClientSession): + self.__session = session + + def __enter__(self) -> _TransactionContext: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if self.__session.in_transaction: + if exc_val is None: + self.__session.commit_transaction() + else: + self.__session.abort_transaction() + + +class _TxnState: + NONE = 1 + STARTING = 2 + IN_PROGRESS = 3 + COMMITTED = 4 + COMMITTED_EMPTY = 5 + ABORTED = 6 + + +class _Transaction: + """Internal class to hold transaction information in a ClientSession.""" + + def __init__(self, opts: Optional[TransactionOptions], client: MongoClient): + self.opts = opts + self.state = _TxnState.NONE + self.sharded = False + self.pinned_address: Optional[_Address] = None + self.conn_mgr: Optional[_ConnectionManager] = None + self.recovery_token = None + self.attempt = 0 + self.client = client + + def active(self) -> bool: + return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) + + def starting(self) -> bool: + return self.state == _TxnState.STARTING + + @property + def pinned_conn(self) -> Optional[Connection]: + if self.active() and self.conn_mgr: + return self.conn_mgr.conn + return None + + def pin(self, server: Server, conn: Connection) -> None: + self.sharded = True + self.pinned_address = server.description.address + if server.description.server_type == SERVER_TYPE.LoadBalancer: + conn.pin_txn() + self.conn_mgr = _ConnectionManager(conn, False) + + def unpin(self) -> None: + self.pinned_address = None + if self.conn_mgr: + self.conn_mgr.close() + self.conn_mgr = None + + def reset(self) -> None: + self.unpin() + self.state = _TxnState.NONE + self.sharded = False + self.recovery_token = None + self.attempt = 0 + + def __del__(self) -> None: + if self.conn_mgr: + # Reuse the cursor closing machinery to return the socket to the + # pool soon. + self.client._close_cursor_soon(0, None, self.conn_mgr) + self.conn_mgr = None + + +def _reraise_with_unknown_commit(exc: Any) -> NoReturn: + """Re-raise an exception with the UnknownTransactionCommitResult label.""" + exc._add_error_label("UnknownTransactionCommitResult") + raise + + +def _max_time_expired_error(exc: PyMongoError) -> bool: + """Return true if exc is a MaxTimeMSExpired error.""" + return isinstance(exc, OperationFailure) and exc.code == 50 + + +# From the transactions spec, all the retryable writes errors plus +# WriteConcernFailed. +_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( + [ + 64, # WriteConcernFailed + 50, # MaxTimeMSExpired + ] +) + +# From the Convenient API for Transactions spec, with_transaction must +# halt retries after 120 seconds. +# This limit is non-configurable and was chosen to be twice the 60 second +# default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. +_WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 + + +def _within_time_limit(start_time: float) -> bool: + """Are we within the with_transaction retry limit?""" + return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + + +_T = TypeVar("_T") + +if TYPE_CHECKING: + from pymongo.mongo_client import MongoClient + + +class ClientSession: + """A session for ordering sequential operations. + + :class:`ClientSession` instances are **not thread-safe or fork-safe**. + They can only be used by one thread or process at a time. A single + :class:`ClientSession` cannot be used to run multiple operations + concurrently. + + Should not be initialized directly by application developers - to create a + :class:`ClientSession`, call + :meth:`~pymongo.mongo_client.MongoClient.start_session`. + """ + + def __init__( + self, + client: MongoClient, + server_session: Any, + options: SessionOptions, + implicit: bool, + ) -> None: + # A MongoClient, a _ServerSession, a SessionOptions, and a set. + self._client: MongoClient = client + self._server_session = server_session + self._options = options + self._cluster_time: Optional[Mapping[str, Any]] = None + self._operation_time: Optional[Timestamp] = None + self._snapshot_time = None + # Is this an implicitly created session? + self._implicit = implicit + self._transaction = _Transaction(None, client) + + def end_session(self) -> None: + """Finish this session. If a transaction has started, abort it. + + It is an error to use the session after the session has ended. + """ + self._end_session(lock=True) + + def _end_session(self, lock: bool) -> None: + if self._server_session is not None: + try: + if self.in_transaction: + self.abort_transaction() + # It's possible we're still pinned here when the transaction + # is in the committed state when the session is discarded. + self._unpin() + finally: + self._client._return_server_session(self._server_session, lock) + self._server_session = None + + def _check_ended(self) -> None: + if self._server_session is None: + raise InvalidOperation("Cannot use ended session") + + def __enter__(self) -> ClientSession: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self._end_session(lock=True) + + @property + def client(self) -> MongoClient: + """The :class:`~pymongo.mongo_client.MongoClient` this session was + created from. + """ + return self._client + + @property + def options(self) -> SessionOptions: + """The :class:`SessionOptions` this session was created with.""" + return self._options + + @property + def session_id(self) -> Mapping[str, Any]: + """A BSON document, the opaque server session identifier.""" + self._check_ended() + self._materialize(self._client.topology_description.logical_session_timeout_minutes) + return self._server_session.session_id + + @property + def _transaction_id(self) -> Int64: + """The current transaction id for the underlying server session.""" + self._materialize(self._client.topology_description.logical_session_timeout_minutes) + return self._server_session.transaction_id + + @property + def cluster_time(self) -> Optional[ClusterTime]: + """The cluster time returned by the last operation executed + in this session. + """ + return self._cluster_time + + @property + def operation_time(self) -> Optional[Timestamp]: + """The operation time returned by the last operation executed + in this session. + """ + return self._operation_time + + def _inherit_option(self, name: str, val: _T) -> _T: + """Return the inherited TransactionOption value.""" + if val: + return val + txn_opts = self.options.default_transaction_options + parent_val = txn_opts and getattr(txn_opts, name) + if parent_val: + return parent_val + return getattr(self.client, name) + + def with_transaction( + self, + callback: Callable[[ClientSession], _T], + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None, + ) -> _T: + """Execute a callback in a transaction. + + This method starts a transaction on this session, executes ``callback`` + once, and then commits the transaction. For example:: + + def callback(session): + orders = session.client.db.orders + inventory = session.client.db.inventory + orders.insert_one({"sku": "abc123", "qty": 100}, session=session) + inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}}, + {"$inc": {"qty": -100}}, session=session) + + with client.start_session() as session: + session.with_transaction(callback) + + To pass arbitrary arguments to the ``callback``, wrap your callable + with a ``lambda`` like this:: + + def callback(session, custom_arg, custom_kwarg=None): + # Transaction operations... + + with client.start_session() as session: + session.with_transaction( + lambda s: callback(s, "custom_arg", custom_kwarg=1)) + + In the event of an exception, ``with_transaction`` may retry the commit + or the entire transaction, therefore ``callback`` may be invoked + multiple times by a single call to ``with_transaction``. Developers + should be mindful of this possibility when writing a ``callback`` that + modifies application state or has any other side-effects. + Note that even when the ``callback`` is invoked multiple times, + ``with_transaction`` ensures that the transaction will be committed + at-most-once on the server. + + The ``callback`` should not attempt to start new transactions, but + should simply run operations meant to be contained within a + transaction. The ``callback`` should also not commit the transaction; + this is handled automatically by ``with_transaction``. If the + ``callback`` does commit or abort the transaction without error, + however, ``with_transaction`` will return without taking further + action. + + :class:`ClientSession` instances are **not thread-safe or fork-safe**. + Consequently, the ``callback`` must not attempt to execute multiple + operations concurrently. + + When ``callback`` raises an exception, ``with_transaction`` + automatically aborts the current transaction. When ``callback`` or + :meth:`~ClientSession.commit_transaction` raises an exception that + includes the ``"TransientTransactionError"`` error label, + ``with_transaction`` starts a new transaction and re-executes + the ``callback``. + + When :meth:`~ClientSession.commit_transaction` raises an exception with + the ``"UnknownTransactionCommitResult"`` error label, + ``with_transaction`` retries the commit until the result of the + transaction is known. + + This method will cease retrying after 120 seconds has elapsed. This + timeout is not configurable and any exception raised by the + ``callback`` or by :meth:`ClientSession.commit_transaction` after the + timeout is reached will be re-raised. Applications that desire a + different timeout duration should not use this method. + + :param callback: The callable ``callback`` to run inside a transaction. + The callable must accept a single argument, this session. Note, + under certain error conditions the callback may be run multiple + times. + :param read_concern: The + :class:`~pymongo.read_concern.ReadConcern` to use for this + transaction. + :param write_concern: The + :class:`~pymongo.write_concern.WriteConcern` to use for this + transaction. + :param read_preference: The read preference to use for this + transaction. If ``None`` (the default) the :attr:`read_preference` + of this :class:`Database` is used. See + :mod:`~pymongo.read_preferences` for options. + + :return: The return value of the ``callback``. + + .. versionadded:: 3.9 + """ + start_time = time.monotonic() + while True: + self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) + try: + ret = callback(self) + except Exception as exc: + if self.in_transaction: + self.abort_transaction() + if ( + isinstance(exc, PyMongoError) + and exc.has_error_label("TransientTransactionError") + and _within_time_limit(start_time) + ): + # Retry the entire transaction. + continue + raise + + if not self.in_transaction: + # Assume callback intentionally ended the transaction. + return ret + + while True: + try: + self.commit_transaction() + except PyMongoError as exc: + if ( + exc.has_error_label("UnknownTransactionCommitResult") + and _within_time_limit(start_time) + and not _max_time_expired_error(exc) + ): + # Retry the commit. + continue + + if exc.has_error_label("TransientTransactionError") and _within_time_limit( + start_time + ): + # Retry the entire transaction. + break + raise + + # Commit succeeded. + return ret + + def start_transaction( + self, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + read_preference: Optional[_ServerMode] = None, + max_commit_time_ms: Optional[int] = None, + ) -> ContextManager: + """Start a multi-statement transaction. + + Takes the same arguments as :class:`TransactionOptions`. + + .. versionchanged:: 3.9 + Added the ``max_commit_time_ms`` option. + + .. versionadded:: 3.7 + """ + self._check_ended() + + if self.options.snapshot: + raise InvalidOperation("Transactions are not supported in snapshot sessions") + + if self.in_transaction: + raise InvalidOperation("Transaction already in progress") + + read_concern = self._inherit_option("read_concern", read_concern) + write_concern = self._inherit_option("write_concern", write_concern) + read_preference = self._inherit_option("read_preference", read_preference) + if max_commit_time_ms is None: + opts = self.options.default_transaction_options + if opts: + max_commit_time_ms = opts.max_commit_time_ms + + self._transaction.opts = TransactionOptions( + read_concern, write_concern, read_preference, max_commit_time_ms + ) + self._transaction.reset() + self._transaction.state = _TxnState.STARTING + self._start_retryable_write() + return _TransactionContext(self) + + def commit_transaction(self) -> None: + """Commit a multi-statement transaction. + + .. versionadded:: 3.7 + """ + self._check_ended() + state = self._transaction.state + if state is _TxnState.NONE: + raise InvalidOperation("No transaction started") + elif state in (_TxnState.STARTING, _TxnState.COMMITTED_EMPTY): + # Server transaction was never started, no need to send a command. + self._transaction.state = _TxnState.COMMITTED_EMPTY + return + elif state is _TxnState.ABORTED: + raise InvalidOperation("Cannot call commitTransaction after calling abortTransaction") + elif state is _TxnState.COMMITTED: + # We're explicitly retrying the commit, move the state back to + # "in progress" so that in_transaction returns true. + self._transaction.state = _TxnState.IN_PROGRESS + + try: + self._finish_transaction_with_retry("commitTransaction") + except ConnectionFailure as exc: + # We do not know if the commit was successfully applied on the + # server or if it satisfied the provided write concern, set the + # unknown commit error label. + exc._remove_error_label("TransientTransactionError") + _reraise_with_unknown_commit(exc) + except WTimeoutError as exc: + # We do not know if the commit has satisfied the provided write + # concern, add the unknown commit error label. + _reraise_with_unknown_commit(exc) + except OperationFailure as exc: + if exc.code not in _UNKNOWN_COMMIT_ERROR_CODES: + # The server reports errorLabels in the case. + raise + # We do not know if the commit was successfully applied on the + # server or if it satisfied the provided write concern, set the + # unknown commit error label. + _reraise_with_unknown_commit(exc) + finally: + self._transaction.state = _TxnState.COMMITTED + + def abort_transaction(self) -> None: + """Abort a multi-statement transaction. + + .. versionadded:: 3.7 + """ + self._check_ended() + + state = self._transaction.state + if state is _TxnState.NONE: + raise InvalidOperation("No transaction started") + elif state is _TxnState.STARTING: + # Server transaction was never started, no need to send a command. + self._transaction.state = _TxnState.ABORTED + return + elif state is _TxnState.ABORTED: + raise InvalidOperation("Cannot call abortTransaction twice") + elif state in (_TxnState.COMMITTED, _TxnState.COMMITTED_EMPTY): + raise InvalidOperation("Cannot call abortTransaction after calling commitTransaction") + + try: + self._finish_transaction_with_retry("abortTransaction") + except (OperationFailure, ConnectionFailure): + # The transactions spec says to ignore abortTransaction errors. + pass + finally: + self._transaction.state = _TxnState.ABORTED + self._unpin() + + def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: + """Run commit or abort with one retry after any retryable error. + + :param command_name: Either "commitTransaction" or "abortTransaction". + """ + + def func( + _session: Optional[ClientSession], conn: Connection, _retryable: bool + ) -> dict[str, Any]: + return self._finish_transaction(conn, command_name) + + return self._client._retry_internal(func, self, None, retryable=True, operation=_Op.ABORT) + + def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: + self._transaction.attempt += 1 + opts = self._transaction.opts + assert opts + wc = opts.write_concern + cmd = {command_name: 1} + if command_name == "commitTransaction": + if opts.max_commit_time_ms and _csot.get_timeout() is None: + cmd["maxTimeMS"] = opts.max_commit_time_ms + + # Transaction spec says that after the initial commit attempt, + # subsequent commitTransaction commands should be upgraded to use + # w:"majority" and set a default value of 10 seconds for wtimeout. + if self._transaction.attempt > 1: + assert wc + wc_doc = wc.document + wc_doc["w"] = "majority" + wc_doc.setdefault("wtimeout", 10000) + wc = WriteConcern(**wc_doc) + + if self._transaction.recovery_token: + cmd["recoveryToken"] = self._transaction.recovery_token + + return self._client.admin._command( + conn, cmd, session=self, write_concern=wc, parse_write_concern_error=True + ) + + def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: + """Internal cluster time helper.""" + if self._cluster_time is None: + self._cluster_time = cluster_time + elif cluster_time is not None: + if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]: + self._cluster_time = cluster_time + + def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: + """Update the cluster time for this session. + + :param cluster_time: The + :data:`~pymongo.client_session.ClientSession.cluster_time` from + another `ClientSession` instance. + """ + if not isinstance(cluster_time, _Mapping): + raise TypeError("cluster_time must be a subclass of collections.Mapping") + if not isinstance(cluster_time.get("clusterTime"), Timestamp): + raise ValueError("Invalid cluster_time") + self._advance_cluster_time(cluster_time) + + def _advance_operation_time(self, operation_time: Optional[Timestamp]) -> None: + """Internal operation time helper.""" + if self._operation_time is None: + self._operation_time = operation_time + elif operation_time is not None: + if operation_time > self._operation_time: + self._operation_time = operation_time + + def advance_operation_time(self, operation_time: Timestamp) -> None: + """Update the operation time for this session. + + :param operation_time: The + :data:`~pymongo.client_session.ClientSession.operation_time` from + another `ClientSession` instance. + """ + if not isinstance(operation_time, Timestamp): + raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") + self._advance_operation_time(operation_time) + + def _process_response(self, reply: Mapping[str, Any]) -> None: + """Process a response to a command that was run with this session.""" + self._advance_cluster_time(reply.get("$clusterTime")) + self._advance_operation_time(reply.get("operationTime")) + if self._options.snapshot and self._snapshot_time is None: + if "cursor" in reply: + ct = reply["cursor"].get("atClusterTime") + else: + ct = reply.get("atClusterTime") + self._snapshot_time = ct + if self.in_transaction and self._transaction.sharded: + recovery_token = reply.get("recoveryToken") + if recovery_token: + self._transaction.recovery_token = recovery_token + + @property + def has_ended(self) -> bool: + """True if this session is finished.""" + return self._server_session is None + + @property + def in_transaction(self) -> bool: + """True if this session has an active multi-statement transaction. + + .. versionadded:: 3.10 + """ + return self._transaction.active() + + @property + def _starting_transaction(self) -> bool: + """True if this session is starting a multi-statement transaction.""" + return self._transaction.starting() + + @property + def _pinned_address(self) -> Optional[_Address]: + """The mongos address this transaction was created on.""" + if self._transaction.active(): + return self._transaction.pinned_address + return None + + @property + def _pinned_connection(self) -> Optional[Connection]: + """The connection this transaction was started on.""" + return self._transaction.pinned_conn + + def _pin(self, server: Server, conn: Connection) -> None: + """Pin this session to the given Server or to the given connection.""" + self._transaction.pin(server, conn) + + def _unpin(self) -> None: + """Unpin this session from any pinned Server.""" + self._transaction.unpin() + + def _txn_read_preference(self) -> Optional[_ServerMode]: + """Return read preference of this transaction or None.""" + if self.in_transaction: + assert self._transaction.opts + return self._transaction.opts.read_preference + return None + + def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: + if isinstance(self._server_session, _EmptyServerSession): + old = self._server_session + self._server_session = self._client._topology.get_server_session( + logical_session_timeout_minutes + ) + if old.started_retryable_write: + self._server_session.inc_transaction_id() + + def _apply_to( + self, + command: MutableMapping[str, Any], + is_retryable: bool, + read_preference: _ServerMode, + conn: Connection, + ) -> None: + if not conn.supports_sessions: + if not self._implicit: + raise ConfigurationError("Sessions are not supported by this MongoDB deployment") + return + self._check_ended() + self._materialize(conn.logical_session_timeout_minutes) + if self.options.snapshot: + self._update_read_concern(command, conn) + + self._server_session.last_use = time.monotonic() + command["lsid"] = self._server_session.session_id + + if is_retryable: + command["txnNumber"] = self._server_session.transaction_id + return + + if self.in_transaction: + if read_preference != ReadPreference.PRIMARY: + raise InvalidOperation( + f"read preference in a transaction must be primary, not: {read_preference!r}" + ) + + if self._transaction.state == _TxnState.STARTING: + # First command begins a new transaction. + self._transaction.state = _TxnState.IN_PROGRESS + command["startTransaction"] = True + + assert self._transaction.opts + if self._transaction.opts.read_concern: + rc = self._transaction.opts.read_concern.document + if rc: + command["readConcern"] = rc + self._update_read_concern(command, conn) + + command["txnNumber"] = self._server_session.transaction_id + command["autocommit"] = False + + def _start_retryable_write(self) -> None: + self._check_ended() + self._server_session.inc_transaction_id() + + def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None: + if self.options.causal_consistency and self.operation_time is not None: + cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time + if self.options.snapshot: + if conn.max_wire_version < 13: + raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later") + rc = cmd.setdefault("readConcern", {}) + rc["level"] = "snapshot" + if self._snapshot_time is not None: + rc["atClusterTime"] = self._snapshot_time + + def __copy__(self) -> NoReturn: + raise TypeError("A ClientSession cannot be copied, create a new session instead") + + +class _EmptyServerSession: + __slots__ = "dirty", "started_retryable_write" + + def __init__(self) -> None: + self.dirty = False + self.started_retryable_write = False + + def mark_dirty(self) -> None: + self.dirty = True + + def inc_transaction_id(self) -> None: + self.started_retryable_write = True + + +class _ServerSession: + def __init__(self, generation: int): + # Ensure id is type 4, regardless of CodecOptions.uuid_representation. + self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)} + self.last_use = time.monotonic() + self._transaction_id = 0 + self.dirty = False + self.generation = generation + + def mark_dirty(self) -> None: + """Mark this session as dirty. + + A server session is marked dirty when a command fails with a network + error. Dirty sessions are later discarded from the server session pool. + """ + self.dirty = True + + def timed_out(self, session_timeout_minutes: Optional[int]) -> bool: + if session_timeout_minutes is None: + return False + + idle_seconds = time.monotonic() - self.last_use + + # Timed out if we have less than a minute to live. + return idle_seconds > (session_timeout_minutes - 1) * 60 + + @property + def transaction_id(self) -> Int64: + """Positive 64-bit integer.""" + return Int64(self._transaction_id) + + def inc_transaction_id(self) -> None: + self._transaction_id += 1 + + +class _ServerSessionPool(collections.deque): + """Pool of _ServerSession objects. + + This class is not thread-safe, access it while holding the Topology lock. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.generation = 0 + + def reset(self) -> None: + self.generation += 1 + self.clear() + + def pop_all(self) -> list[_ServerSession]: + ids = [] + while self: + ids.append(self.pop().session_id) + return ids + + def get_server_session(self, session_timeout_minutes: Optional[int]) -> _ServerSession: + # Although the Driver Sessions Spec says we only clear stale sessions + # in return_server_session, PyMongo can't take a lock when returning + # sessions from a __del__ method (like in Cursor.__die), so it can't + # clear stale sessions there. In case many sessions were returned via + # __del__, check for stale sessions here too. + self._clear_stale(session_timeout_minutes) + + # The most recently used sessions are on the left. + while self: + s = self.popleft() + if not s.timed_out(session_timeout_minutes): + return s + + return _ServerSession(self.generation) + + def return_server_session( + self, server_session: _ServerSession, session_timeout_minutes: Optional[int] + ) -> None: + if session_timeout_minutes is not None: + self._clear_stale(session_timeout_minutes) + if server_session.timed_out(session_timeout_minutes): + return + self.return_server_session_no_lock(server_session) + + def return_server_session_no_lock(self, server_session: _ServerSession) -> None: + # Discard sessions from an old pool to avoid duplicate sessions in the + # child process after a fork. + if server_session.generation == self.generation and not server_session.dirty: + self.appendleft(server_session) + + def _clear_stale(self, session_timeout_minutes: Optional[int]) -> None: + # Clear stale sessions. The least recently used are on the right. + while self: + if self[-1].timed_out(session_timeout_minutes): + self.pop() + else: + # The remaining sessions also haven't timed out. + break diff --git a/venv/Lib/site-packages/pymongo/collation.py b/venv/Lib/site-packages/pymongo/collation.py new file mode 100644 index 00000000..971628f4 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/collation.py @@ -0,0 +1,224 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with `collations`_. + +.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ +""" +from __future__ import annotations + +from typing import Any, Mapping, Optional, Union + +from pymongo import common +from pymongo.write_concern import validate_boolean + + +class CollationStrength: + """ + An enum that defines values for `strength` on a + :class:`~pymongo.collation.Collation`. + """ + + PRIMARY = 1 + """Differentiate base (unadorned) characters.""" + + SECONDARY = 2 + """Differentiate character accents.""" + + TERTIARY = 3 + """Differentiate character case.""" + + QUATERNARY = 4 + """Differentiate words with and without punctuation.""" + + IDENTICAL = 5 + """Differentiate unicode code point (characters are exactly identical).""" + + +class CollationAlternate: + """ + An enum that defines values for `alternate` on a + :class:`~pymongo.collation.Collation`. + """ + + NON_IGNORABLE = "non-ignorable" + """Spaces and punctuation are treated as base characters.""" + + SHIFTED = "shifted" + """Spaces and punctuation are *not* considered base characters. + + Spaces and punctuation are distinguished regardless when the + :class:`~pymongo.collation.Collation` strength is at least + :data:`~pymongo.collation.CollationStrength.QUATERNARY`. + + """ + + +class CollationMaxVariable: + """ + An enum that defines values for `max_variable` on a + :class:`~pymongo.collation.Collation`. + """ + + PUNCT = "punct" + """Both punctuation and spaces are ignored.""" + + SPACE = "space" + """Spaces alone are ignored.""" + + +class CollationCaseFirst: + """ + An enum that defines values for `case_first` on a + :class:`~pymongo.collation.Collation`. + """ + + UPPER = "upper" + """Sort uppercase characters first.""" + + LOWER = "lower" + """Sort lowercase characters first.""" + + OFF = "off" + """Default for locale or collation strength.""" + + +class Collation: + """Collation + + :param locale: (string) The locale of the collation. This should be a string + that identifies an `ICU locale ID` exactly. For example, ``en_US`` is + valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB + documentation for a list of supported locales. + :param caseLevel: (optional) If ``True``, turn on case sensitivity if + `strength` is 1 or 2 (case sensitivity is implied if `strength` is + greater than 2). Defaults to ``False``. + :param caseFirst: (optional) Specify that either uppercase or lowercase + characters take precedence. Must be one of the following values: + + * :data:`~CollationCaseFirst.UPPER` + * :data:`~CollationCaseFirst.LOWER` + * :data:`~CollationCaseFirst.OFF` (the default) + + :param strength: Specify the comparison strength. This is also + known as the ICU comparison level. This must be one of the following + values: + + * :data:`~CollationStrength.PRIMARY` + * :data:`~CollationStrength.SECONDARY` + * :data:`~CollationStrength.TERTIARY` (the default) + * :data:`~CollationStrength.QUATERNARY` + * :data:`~CollationStrength.IDENTICAL` + + Each successive level builds upon the previous. For example, a + `strength` of :data:`~CollationStrength.SECONDARY` differentiates + characters based both on the unadorned base character and its accents. + + :param numericOrdering: If ``True``, order numbers numerically + instead of in collation order (defaults to ``False``). + :param alternate: Specify whether spaces and punctuation are + considered base characters. This must be one of the following values: + + * :data:`~CollationAlternate.NON_IGNORABLE` (the default) + * :data:`~CollationAlternate.SHIFTED` + + :param maxVariable: When `alternate` is + :data:`~CollationAlternate.SHIFTED`, this option specifies what + characters may be ignored. This must be one of the following values: + + * :data:`~CollationMaxVariable.PUNCT` (the default) + * :data:`~CollationMaxVariable.SPACE` + + :param normalization: If ``True``, normalizes text into Unicode + NFD. Defaults to ``False``. + :param backwards: If ``True``, accents on characters are + considered from the back of the word to the front, as it is done in some + French dictionary ordering traditions. Defaults to ``False``. + :param kwargs: Keyword arguments supplying any additional options + to be sent with this Collation object. + + .. versionadded: 3.4 + + """ + + __slots__ = ("__document",) + + def __init__( + self, + locale: str, + caseLevel: Optional[bool] = None, + caseFirst: Optional[str] = None, + strength: Optional[int] = None, + numericOrdering: Optional[bool] = None, + alternate: Optional[str] = None, + maxVariable: Optional[str] = None, + normalization: Optional[bool] = None, + backwards: Optional[bool] = None, + **kwargs: Any, + ) -> None: + locale = common.validate_string("locale", locale) + self.__document: dict[str, Any] = {"locale": locale} + if caseLevel is not None: + self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) + if caseFirst is not None: + self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) + if strength is not None: + self.__document["strength"] = common.validate_integer("strength", strength) + if numericOrdering is not None: + self.__document["numericOrdering"] = validate_boolean( + "numericOrdering", numericOrdering + ) + if alternate is not None: + self.__document["alternate"] = common.validate_string("alternate", alternate) + if maxVariable is not None: + self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) + if normalization is not None: + self.__document["normalization"] = validate_boolean("normalization", normalization) + if backwards is not None: + self.__document["backwards"] = validate_boolean("backwards", backwards) + self.__document.update(kwargs) + + @property + def document(self) -> dict[str, Any]: + """The document representation of this collation. + + .. note:: + :class:`Collation` is immutable. Mutating the value of + :attr:`document` does not mutate this :class:`Collation`. + """ + return self.__document.copy() + + def __repr__(self) -> str: + document = self.document + return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Collation): + return self.document == other.document + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +def validate_collation_or_none( + value: Optional[Union[Mapping[str, Any], Collation]] +) -> Optional[dict[str, Any]]: + if value is None: + return None + if isinstance(value, Collation): + return value.document + if isinstance(value, dict): + return value + raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/venv/Lib/site-packages/pymongo/collection.py b/venv/Lib/site-packages/pymongo/collection.py new file mode 100644 index 00000000..ddfe9f1d --- /dev/null +++ b/venv/Lib/site-packages/pymongo/collection.py @@ -0,0 +1,3483 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection level utilities for Mongo.""" +from __future__ import annotations + +from collections import abc +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Generic, + Iterable, + Iterator, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) + +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions +from bson.objectid import ObjectId +from bson.raw_bson import RawBSONDocument +from bson.son import SON +from bson.timestamp import Timestamp +from pymongo import ASCENDING, _csot, common, helpers, message +from pymongo.aggregation import ( + _CollectionAggregationCommand, + _CollectionRawAggregationCommand, +) +from pymongo.bulk import _Bulk +from pymongo.change_stream import CollectionChangeStream +from pymongo.collation import validate_collation_or_none +from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor +from pymongo.common import _ecoc_coll_name, _esc_coll_name +from pymongo.cursor import Cursor, RawBatchCursor +from pymongo.errors import ( + ConfigurationError, + InvalidName, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers import _check_write_command_response +from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + InsertOne, + ReplaceOne, + SearchIndexModel, + UpdateMany, + UpdateOne, + _IndexKeyHint, + _IndexList, + _Op, +) +from pymongo.read_concern import DEFAULT_READ_CONCERN, ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.results import ( + BulkWriteResult, + DeleteResult, + InsertManyResult, + InsertOneResult, + UpdateResult, +) +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean + +T = TypeVar("T") + +_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1} + + +_WriteOp = Union[ + InsertOne[_DocumentType], + DeleteOne, + DeleteMany, + ReplaceOne[_DocumentType], + UpdateOne, + UpdateMany, +] + + +class ReturnDocument: + """An enum used with + :meth:`~pymongo.collection.Collection.find_one_and_replace` and + :meth:`~pymongo.collection.Collection.find_one_and_update`. + """ + + BEFORE = False + """Return the original document before it was updated/replaced, or + ``None`` if no document matches the query. + """ + AFTER = True + """Return the updated/replaced or inserted document.""" + + +if TYPE_CHECKING: + from pymongo.aggregation import _AggregationCommand + from pymongo.client_session import ClientSession + from pymongo.collation import Collation + from pymongo.database import Database + from pymongo.pool import Connection + from pymongo.server import Server + + +class Collection(common.BaseObject, Generic[_DocumentType]): + """A Mongo collection.""" + + def __init__( + self, + database: Database[_DocumentType], + name: str, + create: Optional[bool] = False, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> None: + """Get / create a Mongo collection. + + Raises :class:`TypeError` if `name` is not an instance of + :class:`str`. Raises :class:`~pymongo.errors.InvalidName` if `name` is + not a valid collection name. Any additional keyword arguments will be used + as options passed to the create command. See + :meth:`~pymongo.database.Database.create_collection` for valid + options. + + If `create` is ``True``, `collation` is specified, or any additional + keyword arguments are present, a ``create`` command will be + sent, using ``session`` if specified. Otherwise, a ``create`` command + will not be sent and the collection will be created implicitly on first + use. The optional ``session`` argument is *only* used for the ``create`` + command, it is not associated with the collection afterward. + + :param database: the database to get a collection from + :param name: the name of the collection to get + :param create: if ``True``, force collection + creation even without options being set + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) database.codec_options is used. + :param read_preference: The read preference to use. If + ``None`` (the default) database.read_preference is used. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) database.write_concern is used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) database.read_concern is used. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. If a collation is provided, + it will be passed to the create collection command. + :param session: a + :class:`~pymongo.client_session.ClientSession` that is used with + the create collection command + :param kwargs: additional keyword arguments will + be passed as options for the create collection command + + .. versionchanged:: 4.2 + Added the ``clusteredIndex`` and ``encryptedFields`` parameters. + + .. versionchanged:: 4.0 + Removed the reindex, map_reduce, inline_map_reduce, + parallel_scan, initialize_unordered_bulk_op, + initialize_ordered_bulk_op, group, count, insert, save, + update, remove, find_and_modify, and ensure_index methods. See the + :ref:`pymongo4-migration-guide`. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.4 + Support the `collation` option. + + .. versionchanged:: 3.2 + Added the read_concern option. + + .. versionchanged:: 3.0 + Added the codec_options, read_preference, and write_concern options. + Removed the uuid_subtype attribute. + :class:`~pymongo.collection.Collection` no longer returns an + instance of :class:`~pymongo.collection.Collection` for attribute + names with leading underscores. You must use dict-style lookups + instead:: + + collection['__my_collection__'] + + Not: + + collection.__my_collection__ + + .. seealso:: The MongoDB documentation on `collections `_. + """ + super().__init__( + codec_options or database.codec_options, + read_preference or database.read_preference, + write_concern or database.write_concern, + read_concern or database.read_concern, + ) + if not isinstance(name, str): + raise TypeError("name must be an instance of str") + + if not name or ".." in name: + raise InvalidName("collection names cannot be empty") + if "$" in name and not (name.startswith(("oplog.$main", "$cmd"))): + raise InvalidName("collection names must not contain '$': %r" % name) + if name[0] == "." or name[-1] == ".": + raise InvalidName("collection names must not start or end with '.': %r" % name) + if "\x00" in name: + raise InvalidName("collection names must not contain the null character") + collation = validate_collation_or_none(kwargs.pop("collation", None)) + + self.__database: Database[_DocumentType] = database + self.__name = name + self.__full_name = f"{self.__database.name}.{self.__name}" + self.__write_response_codec_options = self.codec_options._replace( + unicode_decode_error_handler="replace", document_class=dict + ) + self._timeout = database.client.options.timeout + encrypted_fields = kwargs.pop("encryptedFields", None) + if create or kwargs or collation: + if encrypted_fields: + common.validate_is_mapping("encrypted_fields", encrypted_fields) + opts = {"clusteredIndex": {"key": {"_id": 1}, "unique": True}} + self.__create( + _esc_coll_name(encrypted_fields, name), opts, None, session, qev2_required=True + ) + self.__create(_ecoc_coll_name(encrypted_fields, name), opts, None, session) + self.__create(name, kwargs, collation, session, encrypted_fields=encrypted_fields) + self.create_index([("__safeContent__", ASCENDING)], session) + else: + self.__create(name, kwargs, collation, session) + + def _conn_for_writes( + self, session: Optional[ClientSession], operation: str + ) -> ContextManager[Connection]: + return self.__database.client._conn_for_writes(session, operation) + + def _command( + self, + conn: Connection, + command: MutableMapping[str, Any], + read_preference: Optional[_ServerMode] = None, + codec_options: Optional[CodecOptions] = None, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + collation: Optional[_CollationIn] = None, + session: Optional[ClientSession] = None, + retryable_write: bool = False, + user_fields: Optional[Any] = None, + ) -> Mapping[str, Any]: + """Internal command helper. + + :param conn` - A Connection instance. + :param command` - The command itself, as a :class:`~bson.son.SON` instance. + :param read_preference` (optional) - The read preference to use. + :param codec_options` (optional) - An instance of + :class:`~bson.codec_options.CodecOptions`. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param read_concern` (optional) - An instance of + :class:`~pymongo.read_concern.ReadConcern`. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. + :param collation` (optional) - An instance of + :class:`~pymongo.collation.Collation`. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param retryable_write: True if this command is a retryable + write. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + + :return: The result document. + """ + with self.__database.client._tmp_session(session) as s: + return conn.command( + self.__database.name, + command, + read_preference or self._read_preference_for(session), + codec_options or self.codec_options, + check, + allowable_errors, + read_concern=read_concern, + write_concern=write_concern, + parse_write_concern_error=True, + collation=collation, + session=s, + client=self.__database.client, + retryable_write=retryable_write, + user_fields=user_fields, + ) + + def __create( + self, + name: str, + options: MutableMapping[str, Any], + collation: Optional[_CollationIn], + session: Optional[ClientSession], + encrypted_fields: Optional[Mapping[str, Any]] = None, + qev2_required: bool = False, + ) -> None: + """Sends a create command with the given options.""" + cmd: dict[str, Any] = {"create": name} + if encrypted_fields: + cmd["encryptedFields"] = encrypted_fields + + if options: + if "size" in options: + options["size"] = float(options["size"]) + cmd.update(options) + with self._conn_for_writes(session, operation=_Op.CREATE) as conn: + if qev2_required and conn.max_wire_version < 21: + raise ConfigurationError( + "Driver support of Queryable Encryption is incompatible with server. " + "Upgrade server to use Queryable Encryption. " + f"Got maxWireVersion {conn.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" + ) + + self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + write_concern=self._write_concern_for(session), + collation=collation, + session=session, + ) + + def __getattr__(self, name: str) -> Collection[_DocumentType]: + """Get a sub-collection of this collection by name. + + Raises InvalidName if an invalid collection name is used. + + :param name: the name of the collection to get + """ + if name.startswith("_"): + full_name = f"{self.__name}.{name}" + raise AttributeError( + f"Collection has no attribute {name!r}. To access the {full_name}" + f" collection, use database['{full_name}']." + ) + return self.__getitem__(name) + + def __getitem__(self, name: str) -> Collection[_DocumentType]: + return Collection( + self.__database, + f"{self.__name}.{name}", + False, + self.codec_options, + self.read_preference, + self.write_concern, + self.read_concern, + ) + + def __repr__(self) -> str: + return f"Collection({self.__database!r}, {self.__name!r})" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Collection): + return self.__database == other.database and self.__name == other.name + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + return hash((self.__database, self.__name)) + + def __bool__(self) -> NoReturn: + raise NotImplementedError( + "Collection objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: collection is not None" + ) + + @property + def full_name(self) -> str: + """The full name of this :class:`Collection`. + + The full name is of the form `database_name.collection_name`. + """ + return self.__full_name + + @property + def name(self) -> str: + """The name of this :class:`Collection`.""" + return self.__name + + @property + def database(self) -> Database[_DocumentType]: + """The :class:`~pymongo.database.Database` that this + :class:`Collection` is a part of. + """ + return self.__database + + def with_options( + self, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> Collection[_DocumentType]: + """Get a clone of this collection changing the specified settings. + + >>> coll1.read_preference + Primary() + >>> from pymongo import ReadPreference + >>> coll2 = coll1.with_options(read_preference=ReadPreference.SECONDARY) + >>> coll1.read_preference + Primary() + >>> coll2.read_preference + Secondary(tag_sets=None) + + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`Collection` + is used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`Collection` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`Collection` + is used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`Collection` + is used. + """ + return Collection( + self.__database, + self.__name, + False, + codec_options or self.codec_options, + read_preference or self.read_preference, + write_concern or self.write_concern, + read_concern or self.read_concern, + ) + + @_csot.apply + def bulk_write( + self, + requests: Sequence[_WriteOp[_DocumentType]], + ordered: bool = True, + bypass_document_validation: bool = False, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + let: Optional[Mapping] = None, + ) -> BulkWriteResult: + """Send a batch of write operations to the server. + + Requests are passed as a list of write operation instances ( + :class:`~pymongo.operations.InsertOne`, + :class:`~pymongo.operations.UpdateOne`, + :class:`~pymongo.operations.UpdateMany`, + :class:`~pymongo.operations.ReplaceOne`, + :class:`~pymongo.operations.DeleteOne`, or + :class:`~pymongo.operations.DeleteMany`). + + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634ef')} + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634f0')} + >>> # DeleteMany, UpdateOne, and UpdateMany are also available. + ... + >>> from pymongo import InsertOne, DeleteOne, ReplaceOne + >>> requests = [InsertOne({'y': 1}), DeleteOne({'x': 1}), + ... ReplaceOne({'w': 1}, {'z': 1}, upsert=True)] + >>> result = db.test.bulk_write(requests) + >>> result.inserted_count + 1 + >>> result.deleted_count + 1 + >>> result.modified_count + 0 + >>> result.upserted_ids + {2: ObjectId('54f62ee28891e756a6e1abd5')} + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634f0')} + {'y': 1, '_id': ObjectId('54f62ee2fba5226811f634f1')} + {'z': 1, '_id': ObjectId('54f62ee28891e756a6e1abd5')} + + :param requests: A list of write operations (see examples above). + :param ordered: If ``True`` (the default) requests will be + performed on the server serially, in the order provided. If an error + occurs all remaining operations are aborted. If ``False`` requests + will be performed on the server in arbitrary order, possibly in + parallel, and all operations will be attempted. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + + :return: An instance of :class:`~pymongo.results.BulkWriteResult`. + + .. seealso:: :ref:`writes-and-ids` + + .. note:: `bypass_document_validation` requires server version + **>= 3.2** + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + Added ``let`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.2 + Added bypass_document_validation support + + .. versionadded:: 3.0 + """ + common.validate_list("requests", requests) + + blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) + for request in requests: + try: + request._add_to_bulk(blk) + except AttributeError: + raise TypeError(f"{request!r} is not a valid request") from None + + write_concern = self._write_concern_for(session) + bulk_api_result = blk.execute(write_concern, session, _Op.INSERT) + if bulk_api_result is not None: + return BulkWriteResult(bulk_api_result, True) + return BulkWriteResult({}, False) + + def _insert_one( + self, + doc: Mapping[str, Any], + ordered: bool, + write_concern: WriteConcern, + op_id: Optional[int], + bypass_doc_val: bool, + session: Optional[ClientSession], + comment: Optional[Any] = None, + ) -> Any: + """Internal helper for inserting a single document.""" + write_concern = write_concern or self.write_concern + acknowledged = write_concern.acknowledged + command = {"insert": self.name, "ordered": ordered, "documents": [doc]} + if comment is not None: + command["comment"] = comment + + def _insert_command( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> None: + if bypass_doc_val: + command["bypassDocumentValidation"] = True + + result = conn.command( + self.__database.name, + command, + write_concern=write_concern, + codec_options=self.__write_response_codec_options, + session=session, + client=self.__database.client, + retryable_write=retryable_write, + ) + + _check_write_command_response(result) + + self.__database.client._retryable_write( + acknowledged, _insert_command, session, operation=_Op.INSERT + ) + + if not isinstance(doc, RawBSONDocument): + return doc.get("_id") + return None + + def insert_one( + self, + document: Union[_DocumentType, RawBSONDocument], + bypass_document_validation: bool = False, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + ) -> InsertOneResult: + """Insert a single document. + + >>> db.test.count_documents({'x': 1}) + 0 + >>> result = db.test.insert_one({'x': 1}) + >>> result.inserted_id + ObjectId('54f112defba522406c9cc208') + >>> db.test.find_one({'x': 1}) + {'x': 1, '_id': ObjectId('54f112defba522406c9cc208')} + + :param document: The document to insert. Must be a mutable mapping + type. If the document does not have an _id field one will be + added automatically. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + :return: - An instance of :class:`~pymongo.results.InsertOneResult`. + + .. seealso:: :ref:`writes-and-ids` + + .. note:: `bypass_document_validation` requires server version + **>= 3.2** + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.2 + Added bypass_document_validation support + + .. versionadded:: 3.0 + """ + common.validate_is_document_type("document", document) + if not (isinstance(document, RawBSONDocument) or "_id" in document): + document["_id"] = ObjectId() # type: ignore[index] + + write_concern = self._write_concern_for(session) + return InsertOneResult( + self._insert_one( + document, + ordered=True, + write_concern=write_concern, + op_id=None, + bypass_doc_val=bypass_document_validation, + session=session, + comment=comment, + ), + write_concern.acknowledged, + ) + + @_csot.apply + def insert_many( + self, + documents: Iterable[Union[_DocumentType, RawBSONDocument]], + ordered: bool = True, + bypass_document_validation: bool = False, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + ) -> InsertManyResult: + """Insert an iterable of documents. + + >>> db.test.count_documents({}) + 0 + >>> result = db.test.insert_many([{'x': i} for i in range(2)]) + >>> result.inserted_ids + [ObjectId('54f113fffba522406c9cc20e'), ObjectId('54f113fffba522406c9cc20f')] + >>> db.test.count_documents({}) + 2 + + :param documents: A iterable of documents to insert. + :param ordered: If ``True`` (the default) documents will be + inserted on the server serially, in the order provided. If an error + occurs all remaining inserts are aborted. If ``False``, documents + will be inserted on the server in arbitrary order, possibly in + parallel, and all document inserts will be attempted. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + :return: An instance of :class:`~pymongo.results.InsertManyResult`. + + .. seealso:: :ref:`writes-and-ids` + + .. note:: `bypass_document_validation` requires server version + **>= 3.2** + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.2 + Added bypass_document_validation support + + .. versionadded:: 3.0 + """ + if ( + not isinstance(documents, abc.Iterable) + or isinstance(documents, abc.Mapping) + or not documents + ): + raise TypeError("documents must be a non-empty list") + inserted_ids: list[ObjectId] = [] + + def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + """A generator that validates documents and handles _ids.""" + for document in documents: + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + yield (message._INSERT, document) + + write_concern = self._write_concern_for(session) + blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) + blk.ops = list(gen()) + blk.execute(write_concern, session, _Op.INSERT) + return InsertManyResult(inserted_ids, write_concern.acknowledged) + + def _update( + self, + conn: Connection, + criteria: Mapping[str, Any], + document: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + multi: bool = False, + write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, + ordered: bool = True, + bypass_doc_val: Optional[bool] = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + retryable_write: bool = False, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> Optional[Mapping[str, Any]]: + """Internal update / replace helper.""" + validate_boolean("upsert", upsert) + collation = validate_collation_or_none(collation) + write_concern = write_concern or self.write_concern + acknowledged = write_concern.acknowledged + update_doc: dict[str, Any] = { + "q": criteria, + "u": document, + "multi": multi, + "upsert": upsert, + } + if collation is not None: + if not acknowledged: + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") + else: + update_doc["collation"] = collation + if array_filters is not None: + if not acknowledged: + raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") + else: + update_doc["arrayFilters"] = array_filters + if hint is not None: + if not acknowledged and conn.max_wire_version < 8: + raise ConfigurationError( + "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." + ) + if not isinstance(hint, str): + hint = helpers._index_document(hint) + update_doc["hint"] = hint + command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} + if let is not None: + common.validate_is_mapping("let", let) + command["let"] = let + + if comment is not None: + command["comment"] = comment + # Update command. + if bypass_doc_val: + command["bypassDocumentValidation"] = True + + # The command result has to be published for APM unmodified + # so we make a shallow copy here before adding updatedExisting. + result = conn.command( + self.__database.name, + command, + write_concern=write_concern, + codec_options=self.__write_response_codec_options, + session=session, + client=self.__database.client, + retryable_write=retryable_write, + ).copy() + _check_write_command_response(result) + # Add the updatedExisting field for compatibility. + if result.get("n") and "upserted" not in result: + result["updatedExisting"] = True + else: + result["updatedExisting"] = False + # MongoDB >= 2.6.0 returns the upsert _id in an array + # element. Break it out for backward compatibility. + if "upserted" in result: + result["upserted"] = result["upserted"][0]["_id"] + + if not acknowledged: + return None + return result + + def _update_retryable( + self, + criteria: Mapping[str, Any], + document: Union[Mapping[str, Any], _Pipeline], + operation: str, + upsert: bool = False, + multi: bool = False, + write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, + ordered: bool = True, + bypass_doc_val: Optional[bool] = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> Optional[Mapping[str, Any]]: + """Internal update / replace helper.""" + + def _update( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> Optional[Mapping[str, Any]]: + return self._update( + conn, + criteria, + document, + upsert=upsert, + multi=multi, + write_concern=write_concern, + op_id=op_id, + ordered=ordered, + bypass_doc_val=bypass_doc_val, + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + retryable_write=retryable_write, + let=let, + comment=comment, + ) + + return self.__database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, + _update, + session, + operation, + ) + + def replace_one( + self, + filter: Mapping[str, Any], + replacement: Mapping[str, Any], + upsert: bool = False, + bypass_document_validation: bool = False, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> UpdateResult: + """Replace a single document matching the filter. + + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': ObjectId('54f4c5befba5220aa4d6dee7')} + >>> result = db.test.replace_one({'x': 1}, {'y': 1}) + >>> result.matched_count + 1 + >>> result.modified_count + 1 + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'y': 1, '_id': ObjectId('54f4c5befba5220aa4d6dee7')} + + The *upsert* option can be used to insert a new document if a matching + document does not exist. + + >>> result = db.test.replace_one({'x': 1}, {'x': 1}, True) + >>> result.matched_count + 0 + >>> result.modified_count + 0 + >>> result.upserted_id + ObjectId('54f11e5c8891e756a6e1abd4') + >>> db.test.find_one({'x': 1}) + {'x': 1, '_id': ObjectId('54f11e5c8891e756a6e1abd4')} + + :param filter: A query that matches the document to replace. + :param replacement: The new document. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + :return: - An instance of :class:`~pymongo.results.UpdateResult`. + + .. versionchanged:: 4.1 + Added ``let`` parameter. + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.6 + Added ``session`` parameter. + .. versionchanged:: 3.4 + Added the `collation` option. + .. versionchanged:: 3.2 + Added bypass_document_validation support. + + .. versionadded:: 3.0 + """ + common.validate_is_mapping("filter", filter) + common.validate_ok_for_replace(replacement) + if let is not None: + common.validate_is_mapping("let", let) + write_concern = self._write_concern_for(session) + return UpdateResult( + self._update_retryable( + filter, + replacement, + _Op.UPDATE, + upsert, + write_concern=write_concern, + bypass_doc_val=bypass_document_validation, + collation=collation, + hint=hint, + session=session, + let=let, + comment=comment, + ), + write_concern.acknowledged, + ) + + def update_one( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + bypass_document_validation: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> UpdateResult: + """Update a single document matching the filter. + + >>> for doc in db.test.find(): + ... print(doc) + ... + {'x': 1, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + >>> result = db.test.update_one({'x': 1}, {'$inc': {'x': 3}}) + >>> result.matched_count + 1 + >>> result.modified_count + 1 + >>> for doc in db.test.find(): + ... print(doc) + ... + {'x': 4, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + + If ``upsert=True`` and no documents match the filter, create a + new document based on the filter criteria and update modifications. + + >>> result = db.test.update_one({'x': -10}, {'$inc': {'x': 3}}, upsert=True) + >>> result.matched_count + 0 + >>> result.modified_count + 0 + >>> result.upserted_id + ObjectId('626a678eeaa80587d4bb3fb7') + >>> db.test.find_one(result.upserted_id) + {'_id': ObjectId('626a678eeaa80587d4bb3fb7'), 'x': -7} + + :param filter: A query that matches the document to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + + :return: - An instance of :class:`~pymongo.results.UpdateResult`. + + .. versionchanged:: 4.1 + Added ``let`` parameter. + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the ``update``. + .. versionchanged:: 3.6 + Added the ``array_filters`` and ``session`` parameters. + .. versionchanged:: 3.4 + Added the ``collation`` option. + .. versionchanged:: 3.2 + Added ``bypass_document_validation`` support. + + .. versionadded:: 3.0 + """ + common.validate_is_mapping("filter", filter) + common.validate_ok_for_update(update) + common.validate_list_or_none("array_filters", array_filters) + + write_concern = self._write_concern_for(session) + return UpdateResult( + self._update_retryable( + filter, + update, + _Op.UPDATE, + upsert, + write_concern=write_concern, + bypass_doc_val=bypass_document_validation, + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + let=let, + comment=comment, + ), + write_concern.acknowledged, + ) + + def update_many( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + bypass_document_validation: Optional[bool] = None, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> UpdateResult: + """Update one or more documents that match the filter. + + >>> for doc in db.test.find(): + ... print(doc) + ... + {'x': 1, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + >>> result = db.test.update_many({'x': 1}, {'$inc': {'x': 3}}) + >>> result.matched_count + 3 + >>> result.modified_count + 3 + >>> for doc in db.test.find(): + ... print(doc) + ... + {'x': 4, '_id': 0} + {'x': 4, '_id': 1} + {'x': 4, '_id': 2} + + :param filter: A query that matches the documents to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param bypass_document_validation: If ``True``, allows the + write to opt-out of document level validation. Default is + ``False``. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + + :return: - An instance of :class:`~pymongo.results.UpdateResult`. + + .. versionchanged:: 4.1 + Added ``let`` parameter. + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added ``array_filters`` and ``session`` parameters. + .. versionchanged:: 3.4 + Added the `collation` option. + .. versionchanged:: 3.2 + Added bypass_document_validation support. + + .. versionadded:: 3.0 + """ + common.validate_is_mapping("filter", filter) + common.validate_ok_for_update(update) + common.validate_list_or_none("array_filters", array_filters) + + write_concern = self._write_concern_for(session) + return UpdateResult( + self._update_retryable( + filter, + update, + _Op.UPDATE, + upsert, + multi=True, + write_concern=write_concern, + bypass_doc_val=bypass_document_validation, + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + let=let, + comment=comment, + ), + write_concern.acknowledged, + ) + + def drop( + self, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + encrypted_fields: Optional[Mapping[str, Any]] = None, + ) -> None: + """Alias for :meth:`~pymongo.database.Database.drop_collection`. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for + Queryable Encryption. + + The following two calls are equivalent: + + >>> db.foo.drop() + >>> db.drop_collection("foo") + + .. versionchanged:: 4.2 + Added ``encrypted_fields`` parameter. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.7 + :meth:`drop` now respects this :class:`Collection`'s :attr:`write_concern`. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + dbo = self.__database.client.get_database( + self.__database.name, + self.codec_options, + self.read_preference, + self.write_concern, + self.read_concern, + ) + dbo.drop_collection( + self.__name, session=session, comment=comment, encrypted_fields=encrypted_fields + ) + + def _delete( + self, + conn: Connection, + criteria: Mapping[str, Any], + multi: bool, + write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, + ordered: bool = True, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + retryable_write: bool = False, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> Mapping[str, Any]: + """Internal delete helper.""" + common.validate_is_mapping("filter", criteria) + write_concern = write_concern or self.write_concern + acknowledged = write_concern.acknowledged + delete_doc = {"q": criteria, "limit": int(not multi)} + collation = validate_collation_or_none(collation) + if collation is not None: + if not acknowledged: + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") + else: + delete_doc["collation"] = collation + if hint is not None: + if not acknowledged and conn.max_wire_version < 9: + raise ConfigurationError( + "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." + ) + if not isinstance(hint, str): + hint = helpers._index_document(hint) + delete_doc["hint"] = hint + command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]} + + if let is not None: + common.validate_is_document_type("let", let) + command["let"] = let + + if comment is not None: + command["comment"] = comment + + # Delete command. + result = conn.command( + self.__database.name, + command, + write_concern=write_concern, + codec_options=self.__write_response_codec_options, + session=session, + client=self.__database.client, + retryable_write=retryable_write, + ) + _check_write_command_response(result) + return result + + def _delete_retryable( + self, + criteria: Mapping[str, Any], + multi: bool, + write_concern: Optional[WriteConcern] = None, + op_id: Optional[int] = None, + ordered: bool = True, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> Mapping[str, Any]: + """Internal delete helper.""" + + def _delete( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> Mapping[str, Any]: + return self._delete( + conn, + criteria, + multi, + write_concern=write_concern, + op_id=op_id, + ordered=ordered, + collation=collation, + hint=hint, + session=session, + retryable_write=retryable_write, + let=let, + comment=comment, + ) + + return self.__database.client._retryable_write( + (write_concern or self.write_concern).acknowledged and not multi, + _delete, + session, + operation=_Op.DELETE, + ) + + def delete_one( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> DeleteResult: + """Delete a single document matching the filter. + + >>> db.test.count_documents({'x': 1}) + 3 + >>> result = db.test.delete_one({'x': 1}) + >>> result.deleted_count + 1 + >>> db.test.count_documents({'x': 1}) + 2 + + :param filter: A query that matches the document to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + + :return: - An instance of :class:`~pymongo.results.DeleteResult`. + + .. versionchanged:: 4.1 + Added ``let`` parameter. + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.6 + Added ``session`` parameter. + .. versionchanged:: 3.4 + Added the `collation` option. + .. versionadded:: 3.0 + """ + write_concern = self._write_concern_for(session) + return DeleteResult( + self._delete_retryable( + filter, + False, + write_concern=write_concern, + collation=collation, + hint=hint, + session=session, + let=let, + comment=comment, + ), + write_concern.acknowledged, + ) + + def delete_many( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + ) -> DeleteResult: + """Delete one or more documents matching the filter. + + >>> db.test.count_documents({'x': 1}) + 3 + >>> result = db.test.delete_many({'x': 1}) + >>> result.deleted_count + 3 + >>> db.test.count_documents({'x': 1}) + 0 + + :param filter: A query that matches the documents to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + + :return: - An instance of :class:`~pymongo.results.DeleteResult`. + + .. versionchanged:: 4.1 + Added ``let`` parameter. + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.6 + Added ``session`` parameter. + .. versionchanged:: 3.4 + Added the `collation` option. + .. versionadded:: 3.0 + """ + write_concern = self._write_concern_for(session) + return DeleteResult( + self._delete_retryable( + filter, + True, + write_concern=write_concern, + collation=collation, + hint=hint, + session=session, + let=let, + comment=comment, + ), + write_concern.acknowledged, + ) + + def find_one( + self, filter: Optional[Any] = None, *args: Any, **kwargs: Any + ) -> Optional[_DocumentType]: + """Get a single document from the database. + + All arguments to :meth:`find` are also valid arguments for + :meth:`find_one`, although any `limit` argument will be + ignored. Returns a single document, or ``None`` if no matching + document is found. + + The :meth:`find_one` method obeys the :attr:`read_preference` of + this :class:`Collection`. + + :param filter: a dictionary specifying + the query to be performed OR any other type to be used as + the value for a query for ``"_id"``. + + :param args: any additional positional arguments + are the same as the arguments to :meth:`find`. + + :param kwargs: any additional keyword arguments + are the same as the arguments to :meth:`find`. + + :: code-block: python + + >>> collection.find_one(max_time_ms=100) + + """ + if filter is not None and not isinstance(filter, abc.Mapping): + filter = {"_id": filter} + cursor = self.find(filter, *args, **kwargs) + for result in cursor.limit(-1): + return result + return None + + def find(self, *args: Any, **kwargs: Any) -> Cursor[_DocumentType]: + """Query the database. + + The `filter` argument is a query document that all results + must match. For example: + + >>> db.test.find({"hello": "world"}) + + only matches documents that have a key "hello" with value + "world". Matches can have other keys *in addition* to + "hello". The `projection` argument is used to specify a subset + of fields that should be included in the result documents. By + limiting results to a certain subset of fields you can cut + down on network traffic and decoding time. + + Raises :class:`TypeError` if any of the arguments are of + improper type. Returns an instance of + :class:`~pymongo.cursor.Cursor` corresponding to this query. + + The :meth:`find` method obeys the :attr:`read_preference` of + this :class:`Collection`. + + :param filter: A query document that selects which documents + to include in the result set. Can be an empty document to include + all documents. + :param projection: a list of field names that should be + returned in the result set or a dict specifying the fields + to include or exclude. If `projection` is a list "_id" will + always be returned. Use a dict to exclude fields from + the result (e.g. projection={'_id': False}). + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param skip: the number of documents to omit (from + the start of the result set) when returning the results + :param limit: the maximum number of results to + return. A limit of 0 (the default) is equivalent to setting no + limit. + :param no_cursor_timeout: if False (the default), any + returned cursor is closed by the server after 10 minutes of + inactivity. If set to True, the returned cursor will never + time out on the server. Care should be taken to ensure that + cursors with no_cursor_timeout turned on are properly closed. + :param cursor_type: the type of cursor to return. The valid + options are defined by :class:`~pymongo.cursor.CursorType`: + + - :attr:`~pymongo.cursor.CursorType.NON_TAILABLE` - the result of + this find call will return a standard cursor over the result set. + - :attr:`~pymongo.cursor.CursorType.TAILABLE` - the result of this + find call will be a tailable cursor - tailable cursors are only + for use with capped collections. They are not closed when the + last data is retrieved but are kept open and the cursor location + marks the final document position. If more data is received + iteration of the cursor will continue from the last document + received. For details, see the `tailable cursor documentation + `_. + - :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` - the result + of this find call will be a tailable cursor with the await flag + set. The server will wait for a few seconds after returning the + full result set so that it can capture and return additional data + added during the query. + - :attr:`~pymongo.cursor.CursorType.EXHAUST` - the result of this + find call will be an exhaust cursor. MongoDB will stream batched + results to the client without waiting for the client to request + each batch, reducing latency. See notes on compatibility below. + + :param sort: a list of (key, direction) pairs + specifying the sort order for this query. See + :meth:`~pymongo.cursor.Cursor.sort` for details. + :param allow_partial_results: if True, mongos will return + partial results if some shards are down instead of returning an + error. + :param oplog_replay: **DEPRECATED** - if True, set the + oplogReplay query flag. Default: False. + :param batch_size: Limits the number of documents returned in + a single batch. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param return_key: If True, return only the index keys in + each document. + :param show_record_id: If True, adds a field ``$recordId`` in + each document with the storage engine's internal record identifier. + :param snapshot: **DEPRECATED** - If True, prevents the + cursor from returning a document more than once because of an + intervening write operation. + :param hint: An index, in the same format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). Pass this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.hint` on the cursor to tell Mongo the + proper index to use for the query. + :param max_time_ms: Specifies a time limit for a query + operation. If the specified time is exceeded, the operation will be + aborted and :exc:`~pymongo.errors.ExecutionTimeout` is raised. Pass + this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.max_time_ms` on the cursor. + :param max_scan: **DEPRECATED** - The maximum number of + documents to scan. Pass this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.max_scan` on the cursor. + :param min: A list of field, limit pairs specifying the + inclusive lower bound for all keys of a specific index in order. + Pass this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.min` on the cursor. ``hint`` must + also be passed to ensure the query utilizes the correct index. + :param max: A list of field, limit pairs specifying the + exclusive upper bound for all keys of a specific index in order. + Pass this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.max` on the cursor. ``hint`` must + also be passed to ensure the query utilizes the correct index. + :param comment: A string to attach to the query to help + interpret and trace the operation in the server logs and in profile + data. Pass this as an alternative to calling + :meth:`~pymongo.cursor.Cursor.comment` on the cursor. + :param allow_disk_use: if True, MongoDB may use temporary + disk files to store data exceeding the system memory limit while + processing a blocking sort operation. The option has no effect if + MongoDB can satisfy the specified sort using an index, or if the + blocking sort requires less memory than the 100 MiB limit. This + option is only supported on MongoDB 4.4 and above. + + .. note:: There are a number of caveats to using + :attr:`~pymongo.cursor.CursorType.EXHAUST` as cursor_type: + + - The `limit` option can not be used with an exhaust cursor. + + - Exhaust cursors are not supported by mongos and can not be + used with a sharded cluster. + + - A :class:`~pymongo.cursor.Cursor` instance created with the + :attr:`~pymongo.cursor.CursorType.EXHAUST` cursor_type requires an + exclusive :class:`~socket.socket` connection to MongoDB. If the + :class:`~pymongo.cursor.Cursor` is discarded without being + completely iterated the underlying :class:`~socket.socket` + connection will be closed and discarded without being returned to + the connection pool. + + .. versionchanged:: 4.0 + Removed the ``modifiers`` option. + Empty projections (eg {} or []) are passed to the server as-is, + rather than the previous behavior which substituted in a + projection of ``{"_id": 1}``. This means that an empty projection + will now return the entire document, not just the ``"_id"`` field. + + .. versionchanged:: 3.11 + Added the ``allow_disk_use`` option. + Deprecated the ``oplog_replay`` option. Support for this option is + deprecated in MongoDB 4.4. The query engine now automatically + optimizes queries against the oplog without requiring this + option to be set. + + .. versionchanged:: 3.7 + Deprecated the ``snapshot`` option, which is deprecated in MongoDB + 3.6 and removed in MongoDB 4.0. + Deprecated the ``max_scan`` option. Support for this option is + deprecated in MongoDB 4.0. Use ``max_time_ms`` instead to limit + server-side execution time. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.5 + Added the options ``return_key``, ``show_record_id``, ``snapshot``, + ``hint``, ``max_time_ms``, ``max_scan``, ``min``, ``max``, and + ``comment``. + Deprecated the ``modifiers`` option. + + .. versionchanged:: 3.4 + Added support for the ``collation`` option. + + .. versionchanged:: 3.0 + Changed the parameter names ``spec``, ``fields``, ``timeout``, and + ``partial`` to ``filter``, ``projection``, ``no_cursor_timeout``, + and ``allow_partial_results`` respectively. + Added the ``cursor_type``, ``oplog_replay``, and ``modifiers`` + options. + Removed the ``network_timeout``, ``read_preference``, ``tag_sets``, + ``secondary_acceptable_latency_ms``, ``max_scan``, ``snapshot``, + ``tailable``, ``await_data``, ``exhaust``, ``as_class``, and + slave_okay parameters. + Removed ``compile_re`` option: PyMongo now always + represents BSON regular expressions as :class:`~bson.regex.Regex` + objects. Use :meth:`~bson.regex.Regex.try_compile` to attempt to + convert from a BSON regular expression to a Python regular + expression object. + Soft deprecated the ``manipulate`` option. + + .. seealso:: The MongoDB documentation on `find `_. + """ + return Cursor(self, *args, **kwargs) + + def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_DocumentType]: + """Query the database and retrieve batches of raw BSON. + + Similar to the :meth:`find` method but returns a + :class:`~pymongo.cursor.RawBatchCursor`. + + This example demonstrates how to work with raw batches, but in practice + raw batches should be passed to an external library that can decode + BSON into another data type, rather than used with PyMongo's + :mod:`bson` module. + + >>> import bson + >>> cursor = db.test.find_raw_batches() + >>> for batch in cursor: + ... print(bson.decode_all(batch)) + + .. note:: find_raw_batches does not support auto encryption. + + .. versionchanged:: 3.12 + Instead of ignoring the user-specified read concern, this method + now sends it to the server when connected to MongoDB 3.6+. + + Added session support. + + .. versionadded:: 3.6 + """ + # OP_MSG is required to support encryption. + if self.__database.client._encrypter: + raise InvalidOperation("find_raw_batches does not support auto encryption") + return RawBatchCursor(self, *args, **kwargs) + + def _count_cmd( + self, + session: Optional[ClientSession], + conn: Connection, + read_preference: Optional[_ServerMode], + cmd: dict[str, Any], + collation: Optional[Collation], + ) -> int: + """Internal count command helper.""" + # XXX: "ns missing" checks can be removed when we drop support for + # MongoDB 3.0, see SERVER-17051. + res = self._command( + conn, + cmd, + read_preference=read_preference, + allowable_errors=["ns missing"], + codec_options=self.__write_response_codec_options, + read_concern=self.read_concern, + collation=collation, + session=session, + ) + if res.get("errmsg", "") == "ns missing": + return 0 + return int(res["n"]) + + def _aggregate_one_result( + self, + conn: Connection, + read_preference: Optional[_ServerMode], + cmd: dict[str, Any], + collation: Optional[_CollationIn], + session: Optional[ClientSession], + ) -> Optional[Mapping[str, Any]]: + """Internal helper to run an aggregate that returns a single result.""" + result = self._command( + conn, + cmd, + read_preference, + allowable_errors=[26], # Ignore NamespaceNotFound. + codec_options=self.__write_response_codec_options, + read_concern=self.read_concern, + collation=collation, + session=session, + ) + # cursor will not be present for NamespaceNotFound errors. + if "cursor" not in result: + return None + batch = result["cursor"]["firstBatch"] + return batch[0] if batch else None + + def estimated_document_count(self, comment: Optional[Any] = None, **kwargs: Any) -> int: + """Get an estimate of the number of documents in this collection using + collection metadata. + + The :meth:`estimated_document_count` method is **not** supported in a + transaction. + + All optional parameters should be passed as keyword arguments + to this method. Valid options include: + + - `maxTimeMS` (int): The maximum amount of time to allow this + operation to run, in milliseconds. + + :param comment: A user-provided comment to attach to this + command. + :param kwargs: See list of options above. + + .. versionchanged:: 4.2 + This method now always uses the `count`_ command. Due to an oversight in versions + 5.0.0-5.0.8 of MongoDB, the count command was not included in V1 of the + :ref:`versioned-api-ref`. Users of the Stable API with estimated_document_count are + recommended to upgrade their server version to 5.0.9+ or set + :attr:`pymongo.server_api.ServerApi.strict` to ``False`` to avoid encountering errors. + + .. versionadded:: 3.7 + .. _count: https://mongodb.com/docs/manual/reference/command/count/ + """ + if "session" in kwargs: + raise ConfigurationError("estimated_document_count does not support sessions") + if comment is not None: + kwargs["comment"] = comment + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: Optional[_ServerMode], + ) -> int: + cmd: dict[str, Any] = {"count": self.__name} + cmd.update(kwargs) + return self._count_cmd(session, conn, read_preference, cmd, collation=None) + + return self._retryable_non_cursor_read(_cmd, None, operation=_Op.COUNT) + + def count_documents( + self, + filter: Mapping[str, Any], + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> int: + """Count the number of documents in this collection. + + .. note:: For a fast count of the total documents in a collection see + :meth:`estimated_document_count`. + + The :meth:`count_documents` method is supported in a transaction. + + All optional parameters should be passed as keyword arguments + to this method. Valid options include: + + - `skip` (int): The number of matching documents to skip before + returning results. + - `limit` (int): The maximum number of documents to count. Must be + a positive integer. If not provided, no limit is imposed. + - `maxTimeMS` (int): The maximum amount of time to allow this + operation to run, in milliseconds. + - `collation` (optional): An instance of + :class:`~pymongo.collation.Collation`. + - `hint` (string or list of tuples): The index to use. Specify either + the index name as a string or the index specification as a list of + tuples (e.g. [('a', pymongo.ASCENDING), ('b', pymongo.ASCENDING)]). + + The :meth:`count_documents` method obeys the :attr:`read_preference` of + this :class:`Collection`. + + .. note:: When migrating from :meth:`count` to :meth:`count_documents` + the following query operators must be replaced: + + +-------------+-------------------------------------+ + | Operator | Replacement | + +=============+=====================================+ + | $where | `$expr`_ | + +-------------+-------------------------------------+ + | $near | `$geoWithin`_ with `$center`_ | + +-------------+-------------------------------------+ + | $nearSphere | `$geoWithin`_ with `$centerSphere`_ | + +-------------+-------------------------------------+ + + :param filter: A query document that selects which documents + to count in the collection. Can be an empty document to count all + documents. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: See list of options above. + + + .. versionadded:: 3.7 + + .. _$expr: https://mongodb.com/docs/manual/reference/operator/query/expr/ + .. _$geoWithin: https://mongodb.com/docs/manual/reference/operator/query/geoWithin/ + .. _$center: https://mongodb.com/docs/manual/reference/operator/query/center/ + .. _$centerSphere: https://mongodb.com/docs/manual/reference/operator/query/centerSphere/ + """ + pipeline = [{"$match": filter}] + if "skip" in kwargs: + pipeline.append({"$skip": kwargs.pop("skip")}) + if "limit" in kwargs: + pipeline.append({"$limit": kwargs.pop("limit")}) + if comment is not None: + kwargs["comment"] = comment + pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) + cmd = {"aggregate": self.__name, "pipeline": pipeline, "cursor": {}} + if "hint" in kwargs and not isinstance(kwargs["hint"], str): + kwargs["hint"] = helpers._index_document(kwargs["hint"]) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + cmd.update(kwargs) + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: Optional[_ServerMode], + ) -> int: + result = self._aggregate_one_result(conn, read_preference, cmd, collation, session) + if not result: + return 0 + return result["n"] + + return self._retryable_non_cursor_read(_cmd, session, _Op.COUNT) + + def _retryable_non_cursor_read( + self, + func: Callable[[Optional[ClientSession], Server, Connection, Optional[_ServerMode]], T], + session: Optional[ClientSession], + operation: str, + ) -> T: + """Non-cursor read helper to handle implicit session creation.""" + client = self.__database.client + with client._tmp_session(session) as s: + return client._retryable_read(func, self._read_preference_for(s), s, operation) + + def create_indexes( + self, + indexes: Sequence[IndexModel], + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> list[str]: + """Create one or more indexes on this collection. + + >>> from pymongo import IndexModel, ASCENDING, DESCENDING + >>> index1 = IndexModel([("hello", DESCENDING), + ... ("world", ASCENDING)], name="hello_world") + >>> index2 = IndexModel([("goodbye", DESCENDING)]) + >>> db.test.create_indexes([index1, index2]) + ["hello_world", "goodbye_-1"] + + :param indexes: A list of :class:`~pymongo.operations.IndexModel` + instances. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the createIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + + + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + .. versionchanged:: 3.6 + Added ``session`` parameter. Added support for arbitrary keyword + arguments. + + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. + .. versionadded:: 3.0 + + .. _createIndexes: https://mongodb.com/docs/manual/reference/command/createIndexes/ + """ + common.validate_list("indexes", indexes) + if comment is not None: + kwargs["comment"] = comment + return self.__create_indexes(indexes, session, **kwargs) + + @_csot.apply + def __create_indexes( + self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any + ) -> list[str]: + """Internal createIndexes helper. + + :param indexes: A list of :class:`~pymongo.operations.IndexModel` + instances. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param kwargs: optional arguments to the createIndexes + command (like maxTimeMS) can be passed as keyword arguments. + """ + names = [] + with self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn: + supports_quorum = conn.max_wire_version >= 9 + + def gen_indexes() -> Iterator[Mapping[str, Any]]: + for index in indexes: + if not isinstance(index, IndexModel): + raise TypeError( + f"{index!r} is not an instance of pymongo.operations.IndexModel" + ) + document = index.document + names.append(document["name"]) + yield document + + cmd = {"createIndexes": self.name, "indexes": list(gen_indexes())} + cmd.update(kwargs) + if "commitQuorum" in kwargs and not supports_quorum: + raise ConfigurationError( + "Must be connected to MongoDB 4.4+ to use the " + "commitQuorum option for createIndexes" + ) + + self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + write_concern=self._write_concern_for(session), + session=session, + ) + return names + + def create_index( + self, + keys: _IndexKeyHint, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> str: + """Creates an index on this collection. + + Takes either a single key or a list containing (key, direction) pairs + or keys. If no direction is given, :data:`~pymongo.ASCENDING` will + be assumed. + The key(s) must be an instance of :class:`str` and the direction(s) must + be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, + :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). + + To create a single key ascending index on the key ``'mike'`` we just + use a string argument:: + + >>> my_collection.create_index("mike") + + For a compound index on ``'mike'`` descending and ``'eliot'`` + ascending we need to use a list of tuples:: + + >>> my_collection.create_index([("mike", pymongo.DESCENDING), + ... "eliot"]) + + All optional index creation parameters should be passed as + keyword arguments to this method. For example:: + + >>> my_collection.create_index([("mike", pymongo.DESCENDING)], + ... background=True) + + Valid options include, but are not limited to: + + - `name`: custom name to use for this index - if none is + given, a name will be generated. + - `unique`: if ``True``, creates a uniqueness constraint on the + index. + - `background`: if ``True``, this index should be created in the + background. + - `sparse`: if ``True``, omit from the index any documents that lack + the indexed field. + - `bucketSize`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index. + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index. + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + - `partialFilterExpression`: A document that specifies a filter for + a partial index. + - `collation` (optional): An instance of + :class:`~pymongo.collation.Collation`. + - `wildcardProjection`: Allows users to include or exclude specific + field paths from a `wildcard index`_ using the {"$**" : 1} key + pattern. Requires MongoDB >= 4.2. + - `hidden`: if ``True``, this index will be hidden from the query + planner and will not be evaluated as part of query plan + selection. Requires MongoDB >= 4.4. + + See the MongoDB documentation for a full list of supported options by + server version. + + .. warning:: `dropDups` is not supported by MongoDB 3.0 or newer. The + option is silently ignored by the server and unique index builds + using the option will fail if a duplicate value is detected. + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + :param keys: a single key or a list of (key, direction) + pairs specifying the index to create + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: any additional index creation + options (see the above list) should be passed as keyword + arguments. + + .. versionchanged:: 4.4 + Allow passing a list containing (key, direction) pairs + or keys for the ``keys`` parameter. + .. versionchanged:: 4.1 + Added ``comment`` parameter. + .. versionchanged:: 3.11 + Added the ``hidden`` option. + .. versionchanged:: 3.6 + Added ``session`` parameter. Added support for passing maxTimeMS + in kwargs. + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. Support the `collation` option. + .. versionchanged:: 3.2 + Added partialFilterExpression to support partial indexes. + .. versionchanged:: 3.0 + Renamed `key_or_list` to `keys`. Removed the `cache_for` option. + :meth:`create_index` no longer caches index names. Removed support + for the drop_dups and bucket_size aliases. + + .. seealso:: The MongoDB documentation on `indexes `_. + + .. _wildcard index: https://dochub.mongodb.org/core/index-wildcard/ + """ + cmd_options = {} + if "maxTimeMS" in kwargs: + cmd_options["maxTimeMS"] = kwargs.pop("maxTimeMS") + if comment is not None: + cmd_options["comment"] = comment + index = IndexModel(keys, **kwargs) + return self.__create_indexes([index], session, **cmd_options)[0] + + def drop_indexes( + self, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Drops all indexes on this collection. + + Can be used on non-existent collections or collections with no indexes. + Raises OperationFailure on an error. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the createIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + .. versionchanged:: 3.6 + Added ``session`` parameter. Added support for arbitrary keyword + arguments. + + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. + """ + if comment is not None: + kwargs["comment"] = comment + self.drop_index("*", session=session, **kwargs) + + @_csot.apply + def drop_index( + self, + index_or_name: _IndexKeyHint, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Drops the specified index on this collection. + + Can be used on non-existent collections or collections with no + indexes. Raises OperationFailure on an error (e.g. trying to + drop an index that does not exist). `index_or_name` + can be either an index name (as returned by `create_index`), + or an index specifier (as passed to `create_index`). An index + specifier should be a list of (key, direction) pairs. Raises + TypeError if index is not an instance of (str, unicode, list). + + .. warning:: + + if a custom name was used on index creation (by + passing the `name` parameter to :meth:`create_index`) the index + **must** be dropped by name. + + :param index_or_name: index (or name of index) to drop + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the createIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + + .. versionchanged:: 3.6 + Added ``session`` parameter. Added support for arbitrary keyword + arguments. + + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. + + """ + name = index_or_name + if isinstance(index_or_name, list): + name = helpers._gen_index_name(index_or_name) + + if not isinstance(name, str): + raise TypeError("index_or_name must be an instance of str or list") + + cmd = {"dropIndexes": self.__name, "index": name} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + with self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn: + self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + allowable_errors=["ns not found", 26], + write_concern=self._write_concern_for(session), + session=session, + ) + + def list_indexes( + self, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + ) -> CommandCursor[MutableMapping[str, Any]]: + """Get a cursor over the index documents for this collection. + + >>> for index in db.test.list_indexes(): + ... print(index) + ... + SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')]) + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + :return: An instance of :class:`~pymongo.command_cursor.CommandCursor`. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionadded:: 3.0 + """ + codec_options: CodecOptions = CodecOptions(SON) + coll = cast( + Collection[MutableMapping[str, Any]], + self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY), + ) + read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + explicit_session = session is not None + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> CommandCursor[MutableMapping[str, Any]]: + cmd = {"listIndexes": self.__name, "cursor": {}} + if comment is not None: + cmd["comment"] = comment + + try: + cursor = self._command(conn, cmd, read_preference, codec_options, session=session)[ + "cursor" + ] + except OperationFailure as exc: + # Ignore NamespaceNotFound errors to match the behavior + # of reading from *.system.indexes. + if exc.code != 26: + raise + cursor = {"id": 0, "firstBatch": []} + cmd_cursor = CommandCursor( + coll, + cursor, + conn.address, + session=session, + explicit_session=explicit_session, + comment=cmd.get("comment"), + ) + cmd_cursor._maybe_pin_connection(conn) + return cmd_cursor + + with self.__database.client._tmp_session(session, False) as s: + return self.__database.client._retryable_read( + _cmd, read_pref, s, operation=_Op.LIST_INDEXES + ) + + def index_information( + self, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + ) -> MutableMapping[str, Any]: + """Get information on this collection's indexes. + + Returns a dictionary where the keys are index names (as + returned by create_index()) and the values are dictionaries + containing information about each index. The dictionary is + guaranteed to contain at least a single key, ``"key"`` which + is a list of (key, direction) pairs specifying the index (as + passed to create_index()). It will also contain any other + metadata about the indexes, except for the ``"ns"`` and + ``"name"`` keys, which are cleaned. Example output might look + like this: + + >>> db.test.create_index("x", unique=True) + 'x_1' + >>> db.test.index_information() + {'_id_': {'key': [('_id', 1)]}, + 'x_1': {'unique': True, 'key': [('x', 1)]}} + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + cursor = self.list_indexes(session=session, comment=comment) + info = {} + for index in cursor: + index["key"] = list(index["key"].items()) + index = dict(index) # noqa: PLW2901 + info[index.pop("name")] = index + return info + + def list_search_indexes( + self, + name: Optional[str] = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[Mapping[str, Any]]: + """Return a cursor over search indexes for the current collection. + + :param name: If given, the name of the index to search + for. Only indexes with matching index names will be returned. + If not given, all search indexes for the current collection + will be returned. + :param session: a :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + :return: A :class:`~pymongo.command_cursor.CommandCursor` over the result + set. + + .. note:: requires a MongoDB server version 7.0+ Atlas cluster. + + .. versionadded:: 4.5 + """ + if name is None: + pipeline: _Pipeline = [{"$listSearchIndexes": {}}] + else: + pipeline = [{"$listSearchIndexes": {"name": name}}] + + coll = self.with_options( + codec_options=DEFAULT_CODEC_OPTIONS, + read_preference=ReadPreference.PRIMARY, + write_concern=DEFAULT_WRITE_CONCERN, + read_concern=DEFAULT_READ_CONCERN, + ) + cmd = _CollectionAggregationCommand( + coll, + CommandCursor, + pipeline, + kwargs, + explicit_session=session is not None, + comment=comment, + user_fields={"cursor": {"firstBatch": 1}}, + ) + + return self.__database.client._retryable_read( + cmd.get_cursor, + cmd.get_read_preference(session), # type: ignore[arg-type] + session, + retryable=not cmd._performs_write, + operation=_Op.LIST_SEARCH_INDEX, + ) + + def create_search_index( + self, + model: Union[Mapping[str, Any], SearchIndexModel], + session: Optional[ClientSession] = None, + comment: Any = None, + **kwargs: Any, + ) -> str: + """Create a single search index for the current collection. + + :param model: The model for the new search index. + It can be given as a :class:`~pymongo.operations.SearchIndexModel` + instance or a dictionary with a model "definition" and optional + "name". + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the createSearchIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + :return: The name of the new search index. + + .. note:: requires a MongoDB server version 7.0+ Atlas cluster. + + .. versionadded:: 4.5 + """ + if not isinstance(model, SearchIndexModel): + model = SearchIndexModel(**model) + return self.create_search_indexes([model], session, comment, **kwargs)[0] + + def create_search_indexes( + self, + models: list[SearchIndexModel], + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> list[str]: + """Create multiple search indexes for the current collection. + + :param models: A list of :class:`~pymongo.operations.SearchIndexModel` instances. + :param session: a :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the createSearchIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + :return: A list of the newly created search index names. + + .. note:: requires a MongoDB server version 7.0+ Atlas cluster. + + .. versionadded:: 4.5 + """ + if comment is not None: + kwargs["comment"] = comment + + def gen_indexes() -> Iterator[Mapping[str, Any]]: + for index in models: + if not isinstance(index, SearchIndexModel): + raise TypeError( + f"{index!r} is not an instance of pymongo.operations.SearchIndexModel" + ) + yield index.document + + cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())} + cmd.update(kwargs) + + with self._conn_for_writes(session, operation=_Op.CREATE_SEARCH_INDEXES) as conn: + resp = self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + ) + return [index["name"] for index in resp["indexesCreated"]] + + def drop_search_index( + self, + name: str, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Delete a search index by index name. + + :param name: The name of the search index to be deleted. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the dropSearchIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + .. note:: requires a MongoDB server version 7.0+ Atlas cluster. + + .. versionadded:: 4.5 + """ + cmd = {"dropSearchIndex": self.__name, "name": name} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + with self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn: + self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + allowable_errors=["ns not found", 26], + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + ) + + def update_search_index( + self, + name: str, + definition: Mapping[str, Any], + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> None: + """Update a search index by replacing the existing index definition with the provided definition. + + :param name: The name of the search index to be updated. + :param definition: The new search index definition. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: optional arguments to the updateSearchIndexes + command (like maxTimeMS) can be passed as keyword arguments. + + .. note:: requires a MongoDB server version 7.0+ Atlas cluster. + + .. versionadded:: 4.5 + """ + cmd = {"updateSearchIndex": self.__name, "name": name, "definition": definition} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + with self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn: + self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + allowable_errors=["ns not found", 26], + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + ) + + def options( + self, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + ) -> MutableMapping[str, Any]: + """Get the options set on this collection. + + Returns a dictionary of options and their values - see + :meth:`~pymongo.database.Database.create_collection` for more + information on the possible options. Returns an empty + dictionary if the collection has not been created yet. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + dbo = self.__database.client.get_database( + self.__database.name, + self.codec_options, + self.read_preference, + self.write_concern, + self.read_concern, + ) + cursor = dbo.list_collections( + session=session, filter={"name": self.__name}, comment=comment + ) + + result = None + for doc in cursor: + result = doc + break + + if not result: + return {} + + options = result.get("options", {}) + assert options is not None + if "create" in options: + del options["create"] + + return options + + @_csot.apply + def _aggregate( + self, + aggregation_command: Type[_AggregationCommand], + pipeline: _Pipeline, + cursor_class: Type[CommandCursor], + session: Optional[ClientSession], + explicit_session: bool, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[_DocumentType]: + if comment is not None: + kwargs["comment"] = comment + cmd = aggregation_command( + self, + cursor_class, + pipeline, + kwargs, + explicit_session, + let, + user_fields={"cursor": {"firstBatch": 1}}, + ) + + return self.__database.client._retryable_read( + cmd.get_cursor, + cmd.get_read_preference(session), # type: ignore[arg-type] + session, + retryable=not cmd._performs_write, + operation=_Op.AGGREGATE, + ) + + def aggregate( + self, + pipeline: _Pipeline, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[_DocumentType]: + """Perform an aggregation using the aggregation framework on this + collection. + + The :meth:`aggregate` method obeys the :attr:`read_preference` of this + :class:`Collection`, except when ``$out`` or ``$merge`` are used on + MongoDB <5.0, in which case + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY` is used. + + .. note:: This method does not support the 'explain' option. Please + use `PyMongoExplain `_ + instead. An example is included in the :ref:`aggregate-examples` + documentation. + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + :param pipeline: a list of aggregation pipeline stages + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: A dict of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. ``"$$var"``). This option is + only supported on MongoDB >= 5.0. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: extra `aggregate command`_ parameters. + + All optional `aggregate command`_ parameters should be passed as + keyword arguments to this method. Valid options include, but are not + limited to: + + - `allowDiskUse` (bool): Enables writing to temporary files. When set + to True, aggregation stages can write data to the _tmp subdirectory + of the --dbpath directory. The default is False. + - `maxTimeMS` (int): The maximum amount of time to allow the operation + to run in milliseconds. + - `batchSize` (int): The maximum number of documents to return per + batch. Ignored if the connected mongod or mongos does not support + returning aggregate results using a cursor. + - `collation` (optional): An instance of + :class:`~pymongo.collation.Collation`. + + + :return: A :class:`~pymongo.command_cursor.CommandCursor` over the result + set. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + Added ``let`` parameter. + Support $merge and $out executing on secondaries according to the + collection's :attr:`read_preference`. + .. versionchanged:: 4.0 + Removed the ``useCursor`` option. + .. versionchanged:: 3.9 + Apply this collection's read concern to pipelines containing the + `$out` stage when connected to MongoDB >= 4.2. + Added support for the ``$merge`` pipeline stage. + Aggregations that write always use read preference + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + .. versionchanged:: 3.6 + Added the `session` parameter. Added the `maxAwaitTimeMS` option. + Deprecated the `useCursor` option. + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. Support the `collation` option. + .. versionchanged:: 3.0 + The :meth:`aggregate` method always returns a CommandCursor. The + pipeline argument must be a list. + + .. seealso:: :doc:`/examples/aggregation` + + .. _aggregate command: + https://mongodb.com/docs/manual/reference/command/aggregate + """ + with self.__database.client._tmp_session(session, close=False) as s: + return self._aggregate( + _CollectionAggregationCommand, + pipeline, + CommandCursor, + session=s, + explicit_session=session is not None, + let=let, + comment=comment, + **kwargs, + ) + + def aggregate_raw_batches( + self, + pipeline: _Pipeline, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> RawBatchCursor[_DocumentType]: + """Perform an aggregation and retrieve batches of raw BSON. + + Similar to the :meth:`aggregate` method but returns a + :class:`~pymongo.cursor.RawBatchCursor`. + + This example demonstrates how to work with raw batches, but in practice + raw batches should be passed to an external library that can decode + BSON into another data type, rather than used with PyMongo's + :mod:`bson` module. + + >>> import bson + >>> cursor = db.test.aggregate_raw_batches([ + ... {'$project': {'x': {'$multiply': [2, '$x']}}}]) + >>> for batch in cursor: + ... print(bson.decode_all(batch)) + + .. note:: aggregate_raw_batches does not support auto encryption. + + .. versionchanged:: 3.12 + Added session support. + + .. versionadded:: 3.6 + """ + # OP_MSG is required to support encryption. + if self.__database.client._encrypter: + raise InvalidOperation("aggregate_raw_batches does not support auto encryption") + if comment is not None: + kwargs["comment"] = comment + with self.__database.client._tmp_session(session, close=False) as s: + return cast( + RawBatchCursor[_DocumentType], + self._aggregate( + _CollectionRawAggregationCommand, + pipeline, + RawBatchCommandCursor, + session=s, + explicit_session=session is not None, + **kwargs, + ), + ) + + def watch( + self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional[ClientSession] = None, + start_after: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + full_document_before_change: Optional[str] = None, + show_expanded_events: Optional[bool] = None, + ) -> CollectionChangeStream[_DocumentType]: + """Watch changes on this collection. + + Performs an aggregation with an implicit initial ``$changeStream`` + stage and returns a + :class:`~pymongo.change_stream.CollectionChangeStream` cursor which + iterates over changes on this collection. + + .. code-block:: python + + with db.collection.watch() as stream: + for change in stream: + print(change) + + The :class:`~pymongo.change_stream.CollectionChangeStream` iterable + blocks until the next change document is returned or an error is + raised. If the + :meth:`~pymongo.change_stream.CollectionChangeStream.next` method + encounters a network error when retrieving a batch from the server, + it will automatically attempt to recreate the cursor such that no + change events are missed. Any error encountered during the resume + attempt indicates there may be an outage and will be raised. + + .. code-block:: python + + try: + with db.collection.watch([{"$match": {"operationType": "insert"}}]) as stream: + for insert_change in stream: + print(insert_change) + except pymongo.errors.PyMongoError: + # The ChangeStream encountered an unrecoverable error or the + # resume attempt failed to recreate the cursor. + logging.error("...") + + For a precise description of the resume process see the + `change streams specification`_. + + .. note:: Using this helper method is preferred to directly calling + :meth:`~pymongo.collection.Collection.aggregate` with a + ``$changeStream`` stage, for the purpose of supporting + resumability. + + .. warning:: This Collection's :attr:`read_concern` must be + ``ReadConcern("majority")`` in order to use the ``$changeStream`` + stage. + + :param pipeline: A list of aggregation pipeline stages to + append to an initial ``$changeStream`` stage. Not all + pipeline stages are valid after a ``$changeStream`` stage, see the + MongoDB documentation on change streams for the supported stages. + :param full_document: The fullDocument to pass as an option + to the ``$changeStream`` stage. Allowed values: 'updateLookup', + 'whenAvailable', 'required'. When set to 'updateLookup', the + change notification for partial updates will include both a delta + describing the changes to the document, as well as a copy of the + entire document that was changed from some time after the change + occurred. + :param full_document_before_change: Allowed values: 'whenAvailable' + and 'required'. Change events may now result in a + 'fullDocumentBeforeChange' response field. + :param resume_after: A resume token. If provided, the + change stream will start returning changes that occur directly + after the operation specified in the resume token. A resume token + is the _id value of a change document. + :param max_await_time_ms: The maximum time in milliseconds + for the server to wait for changes before responding to a getMore + operation. + :param batch_size: The maximum number of documents to return + per batch. + :param collation: The :class:`~pymongo.collation.Collation` + to use for the aggregation. + :param start_at_operation_time: If provided, the resulting + change stream will only return changes that occurred at or after + the specified :class:`~bson.timestamp.Timestamp`. Requires + MongoDB >= 4.0. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param start_after: The same as `resume_after` except that + `start_after` can resume notifications after an invalidate event. + This option and `resume_after` are mutually exclusive. + :param comment: A user-provided comment to attach to this + command. + :param show_expanded_events: Include expanded events such as DDL events like `dropIndexes`. + + :return: A :class:`~pymongo.change_stream.CollectionChangeStream` cursor. + + .. versionchanged:: 4.3 + Added `show_expanded_events` parameter. + + .. versionchanged:: 4.2 + Added ``full_document_before_change`` parameter. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.9 + Added the ``start_after`` parameter. + + .. versionchanged:: 3.7 + Added the ``start_at_operation_time`` parameter. + + .. versionadded:: 3.6 + + .. seealso:: The MongoDB documentation on `changeStreams `_. + + .. _change streams specification: + https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.md + """ + return CollectionChangeStream( + self, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + comment, + full_document_before_change, + show_expanded_events, + ) + + @_csot.apply + def rename( + self, + new_name: str, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> MutableMapping[str, Any]: + """Rename this collection. + + If operating in auth mode, client must be authorized as an + admin to perform this operation. Raises :class:`TypeError` if + `new_name` is not an instance of :class:`str`. + Raises :class:`~pymongo.errors.InvalidName` + if `new_name` is not a valid collection name. + + :param new_name: new name for this collection + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: additional arguments to the rename command + may be passed as keyword arguments to this helper method + (i.e. ``dropTarget=True``) + + .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of + this collection is automatically applied to this operation. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.4 + Apply this collection's write concern automatically to this operation + when connected to MongoDB >= 3.4. + + """ + if not isinstance(new_name, str): + raise TypeError("new_name must be an instance of str") + + if not new_name or ".." in new_name: + raise InvalidName("collection names cannot be empty") + if new_name[0] == "." or new_name[-1] == ".": + raise InvalidName("collection names must not start or end with '.'") + if "$" in new_name and not new_name.startswith("oplog.$main"): + raise InvalidName("collection names must not contain '$'") + + new_name = f"{self.__database.name}.{new_name}" + cmd = {"renameCollection": self.__full_name, "to": new_name} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + write_concern = self._write_concern_for_cmd(cmd, session) + + with self._conn_for_writes(session, operation=_Op.RENAME) as conn: + with self.__database.client._tmp_session(session) as s: + return conn.command( + "admin", + cmd, + write_concern=write_concern, + parse_write_concern_error=True, + session=s, + client=self.__database.client, + ) + + def distinct( + self, + key: str, + filter: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> list: + """Get a list of distinct values for `key` among all documents + in this collection. + + Raises :class:`TypeError` if `key` is not an instance of + :class:`str`. + + All optional distinct parameters should be passed as keyword arguments + to this method. Valid options include: + + - `maxTimeMS` (int): The maximum amount of time to allow the count + command to run, in milliseconds. + - `collation` (optional): An instance of + :class:`~pymongo.collation.Collation`. + + The :meth:`distinct` method obeys the :attr:`read_preference` of + this :class:`Collection`. + + :param key: name of the field for which we want to get the distinct + values + :param filter: A query document that specifies the documents + from which to retrieve the distinct values. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: See list of options above. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.4 + Support the `collation` option. + + """ + if not isinstance(key, str): + raise TypeError("key must be an instance of str") + cmd = {"distinct": self.__name, "key": key} + if filter is not None: + if "query" in kwargs: + raise ConfigurationError("can't pass both filter and query") + kwargs["query"] = filter + collation = validate_collation_or_none(kwargs.pop("collation", None)) + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: Optional[_ServerMode], + ) -> list: + return self._command( + conn, + cmd, + read_preference=read_preference, + read_concern=self.read_concern, + collation=collation, + session=session, + user_fields={"values": 1}, + )["values"] + + return self._retryable_non_cursor_read(_cmd, session, operation=_Op.DISTINCT) + + def _write_concern_for_cmd( + self, cmd: Mapping[str, Any], session: Optional[ClientSession] + ) -> WriteConcern: + raw_wc = cmd.get("writeConcern") + if raw_wc is not None: + return WriteConcern(**raw_wc) + else: + return self._write_concern_for(session) + + def __find_and_modify( + self, + filter: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]], + sort: Optional[_IndexList], + upsert: Optional[bool] = None, + return_document: bool = ReturnDocument.BEFORE, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping] = None, + **kwargs: Any, + ) -> Any: + """Internal findAndModify helper.""" + common.validate_is_mapping("filter", filter) + if not isinstance(return_document, bool): + raise ValueError( + "return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER" + ) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + cmd = {"findAndModify": self.__name, "query": filter, "new": return_document} + if let is not None: + common.validate_is_mapping("let", let) + cmd["let"] = let + cmd.update(kwargs) + if projection is not None: + cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") + if sort is not None: + cmd["sort"] = helpers._index_document(sort) + if upsert is not None: + validate_boolean("upsert", upsert) + cmd["upsert"] = upsert + if hint is not None: + if not isinstance(hint, str): + hint = helpers._index_document(hint) + + write_concern = self._write_concern_for_cmd(cmd, session) + + def _find_and_modify( + session: Optional[ClientSession], conn: Connection, retryable_write: bool + ) -> Any: + acknowledged = write_concern.acknowledged + if array_filters is not None: + if not acknowledged: + raise ConfigurationError( + "arrayFilters is unsupported for unacknowledged writes." + ) + cmd["arrayFilters"] = list(array_filters) + if hint is not None: + if conn.max_wire_version < 8: + raise ConfigurationError( + "Must be connected to MongoDB 4.2+ to use hint on find and modify commands." + ) + elif not acknowledged and conn.max_wire_version < 9: + raise ConfigurationError( + "Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands." + ) + cmd["hint"] = hint + out = self._command( + conn, + cmd, + read_preference=ReadPreference.PRIMARY, + write_concern=write_concern, + collation=collation, + session=session, + retryable_write=retryable_write, + user_fields=_FIND_AND_MODIFY_DOC_FIELDS, + ) + _check_write_command_response(out) + + return out.get("value") + + return self.__database.client._retryable_write( + write_concern.acknowledged, + _find_and_modify, + session, + operation=_Op.FIND_AND_MODIFY, + ) + + def find_one_and_delete( + self, + filter: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _DocumentType: + """Finds a single document and deletes it, returning the document. + + >>> db.test.count_documents({'x': 1}) + 2 + >>> db.test.find_one_and_delete({'x': 1}) + {'x': 1, '_id': ObjectId('54f4e12bfba5220aa4d6dee8')} + >>> db.test.count_documents({'x': 1}) + 1 + + If multiple documents match *filter*, a *sort* can be applied. + + >>> for doc in db.test.find({'x': 1}): + ... print(doc) + ... + {'x': 1, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + >>> db.test.find_one_and_delete( + ... {'x': 1}, sort=[('_id', pymongo.DESCENDING)]) + {'x': 1, '_id': 2} + + The *projection* option can be used to limit the fields returned. + + >>> db.test.find_one_and_delete({'x': 1}, projection={'_id': False}) + {'x': 1} + + :param filter: A query that matches the document to delete. + :param projection: a list of field names that should be + returned in the result document or a mapping specifying the fields + to include or exclude. If `projection` is a list "_id" will + always be returned. Use a mapping to exclude fields from + the result (e.g. projection={'_id': False}). + :param sort: a list of (key, direction) pairs + specifying the sort order for the query. If multiple documents + match the query, they are sorted and the first is deleted. + :param hint: An index to use to support the query predicate + specified either by its string name, or in the same format as + passed to :meth:`~pymongo.collection.Collection.create_index` + (e.g. ``[('field', ASCENDING)]``). This option is only supported + on MongoDB 4.4 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + :param kwargs: additional command arguments can be passed + as keyword arguments (for example maxTimeMS can be used with + recent server versions). + + .. versionchanged:: 4.1 + Added ``let`` parameter. + .. versionchanged:: 3.11 + Added ``hint`` parameter. + .. versionchanged:: 3.6 + Added ``session`` parameter. + .. versionchanged:: 3.2 + Respects write concern. + + .. warning:: Starting in PyMongo 3.2, this command uses the + :class:`~pymongo.write_concern.WriteConcern` of this + :class:`~pymongo.collection.Collection` when connected to MongoDB >= + 3.2. Note that using an elevated write concern with this command may + be slower compared to using the default write concern. + + .. versionchanged:: 3.4 + Added the `collation` option. + .. versionadded:: 3.0 + """ + kwargs["remove"] = True + if comment is not None: + kwargs["comment"] = comment + return self.__find_and_modify( + filter, projection, sort, let=let, hint=hint, session=session, **kwargs + ) + + def find_one_and_replace( + self, + filter: Mapping[str, Any], + replacement: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + upsert: bool = False, + return_document: bool = ReturnDocument.BEFORE, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _DocumentType: + """Finds a single document and replaces it, returning either the + original or the replaced document. + + The :meth:`find_one_and_replace` method differs from + :meth:`find_one_and_update` by replacing the document matched by + *filter*, rather than modifying the existing document. + + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + >>> db.test.find_one_and_replace({'x': 1}, {'y': 1}) + {'x': 1, '_id': 0} + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'y': 1, '_id': 0} + {'x': 1, '_id': 1} + {'x': 1, '_id': 2} + + :param filter: A query that matches the document to replace. + :param replacement: The replacement document. + :param projection: A list of field names that should be + returned in the result document or a mapping specifying the fields + to include or exclude. If `projection` is a list "_id" will + always be returned. Use a mapping to exclude fields from + the result (e.g. projection={'_id': False}). + :param sort: a list of (key, direction) pairs + specifying the sort order for the query. If multiple documents + match the query, they are sorted and the first is replaced. + :param upsert: When ``True``, inserts a new document if no + document matches the query. Defaults to ``False``. + :param return_document: If + :attr:`ReturnDocument.BEFORE` (the default), + returns the original document before it was replaced, or ``None`` + if no document matches. If + :attr:`ReturnDocument.AFTER`, returns the replaced + or inserted document. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + :param kwargs: additional command arguments can be passed + as keyword arguments (for example maxTimeMS can be used with + recent server versions). + + .. versionchanged:: 4.1 + Added ``let`` parameter. + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.6 + Added ``session`` parameter. + .. versionchanged:: 3.4 + Added the ``collation`` option. + .. versionchanged:: 3.2 + Respects write concern. + + .. warning:: Starting in PyMongo 3.2, this command uses the + :class:`~pymongo.write_concern.WriteConcern` of this + :class:`~pymongo.collection.Collection` when connected to MongoDB >= + 3.2. Note that using an elevated write concern with this command may + be slower compared to using the default write concern. + + .. versionadded:: 3.0 + """ + common.validate_ok_for_replace(replacement) + kwargs["update"] = replacement + if comment is not None: + kwargs["comment"] = comment + return self.__find_and_modify( + filter, + projection, + sort, + upsert, + return_document, + let=let, + hint=hint, + session=session, + **kwargs, + ) + + def find_one_and_update( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + sort: Optional[_IndexList] = None, + upsert: bool = False, + return_document: bool = ReturnDocument.BEFORE, + array_filters: Optional[Sequence[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + session: Optional[ClientSession] = None, + let: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _DocumentType: + """Finds a single document and updates it, returning either the + original or the updated document. + + >>> db.test.find_one_and_update( + ... {'_id': 665}, {'$inc': {'count': 1}, '$set': {'done': True}}) + {'_id': 665, 'done': False, 'count': 25}} + + Returns ``None`` if no document matches the filter. + + >>> db.test.find_one_and_update( + ... {'_exists': False}, {'$inc': {'count': 1}}) + + When the filter matches, by default :meth:`find_one_and_update` + returns the original version of the document before the update was + applied. To return the updated (or inserted in the case of + *upsert*) version of the document instead, use the *return_document* + option. + + >>> from pymongo import ReturnDocument + >>> db.example.find_one_and_update( + ... {'_id': 'userid'}, + ... {'$inc': {'seq': 1}}, + ... return_document=ReturnDocument.AFTER) + {'_id': 'userid', 'seq': 1} + + You can limit the fields returned with the *projection* option. + + >>> db.example.find_one_and_update( + ... {'_id': 'userid'}, + ... {'$inc': {'seq': 1}}, + ... projection={'seq': True, '_id': False}, + ... return_document=ReturnDocument.AFTER) + {'seq': 2} + + The *upsert* option can be used to create the document if it doesn't + already exist. + + >>> db.example.delete_many({}).deleted_count + 1 + >>> db.example.find_one_and_update( + ... {'_id': 'userid'}, + ... {'$inc': {'seq': 1}}, + ... projection={'seq': True, '_id': False}, + ... upsert=True, + ... return_document=ReturnDocument.AFTER) + {'seq': 1} + + If multiple documents match *filter*, a *sort* can be applied. + + >>> for doc in db.test.find({'done': True}): + ... print(doc) + ... + {'_id': 665, 'done': True, 'result': {'count': 26}} + {'_id': 701, 'done': True, 'result': {'count': 17}} + >>> db.test.find_one_and_update( + ... {'done': True}, + ... {'$set': {'final': True}}, + ... sort=[('_id', pymongo.DESCENDING)]) + {'_id': 701, 'done': True, 'result': {'count': 17}} + + :param filter: A query that matches the document to update. + :param update: The update operations to apply. + :param projection: A list of field names that should be + returned in the result document or a mapping specifying the fields + to include or exclude. If `projection` is a list "_id" will + always be returned. Use a dict to exclude fields from + the result (e.g. projection={'_id': False}). + :param sort: a list of (key, direction) pairs + specifying the sort order for the query. If multiple documents + match the query, they are sorted and the first is updated. + :param upsert: When ``True``, inserts a new document if no + document matches the query. Defaults to ``False``. + :param return_document: If + :attr:`ReturnDocument.BEFORE` (the default), + returns the original document before it was updated. If + :attr:`ReturnDocument.AFTER`, returns the updated + or inserted document. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param let: Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param comment: A user-provided comment to attach to this + command. + :param kwargs: additional command arguments can be passed + as keyword arguments (for example maxTimeMS can be used with + recent server versions). + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the ``update``. + .. versionchanged:: 3.6 + Added the ``array_filters`` and ``session`` options. + .. versionchanged:: 3.4 + Added the ``collation`` option. + .. versionchanged:: 3.2 + Respects write concern. + + .. warning:: Starting in PyMongo 3.2, this command uses the + :class:`~pymongo.write_concern.WriteConcern` of this + :class:`~pymongo.collection.Collection` when connected to MongoDB >= + 3.2. Note that using an elevated write concern with this command may + be slower compared to using the default write concern. + + .. versionadded:: 3.0 + """ + common.validate_ok_for_update(update) + common.validate_list_or_none("array_filters", array_filters) + kwargs["update"] = update + if comment is not None: + kwargs["comment"] = comment + return self.__find_and_modify( + filter, + projection, + sort, + upsert, + return_document, + array_filters, + hint=hint, + let=let, + session=session, + **kwargs, + ) + + # See PYTHON-3084. + __iter__ = None + + def __next__(self) -> NoReturn: + raise TypeError("'Collection' object is not iterable") + + next = __next__ + + def __call__(self, *args: Any, **kwargs: Any) -> NoReturn: + """This is only here so that some API misusages are easier to debug.""" + if "." not in self.__name: + raise TypeError( + "'Collection' object is not callable. If you " + "meant to call the '%s' method on a 'Database' " + "object it is failing because no such method " + "exists." % self.__name + ) + raise TypeError( + "'Collection' object is not callable. If you meant to " + "call the '%s' method on a 'Collection' object it is " + "failing because no such method exists." % self.__name.split(".")[-1] + ) diff --git a/venv/Lib/site-packages/pymongo/command_cursor.py b/venv/Lib/site-packages/pymongo/command_cursor.py new file mode 100644 index 00000000..0411a45a --- /dev/null +++ b/venv/Lib/site-packages/pymongo/command_cursor.py @@ -0,0 +1,401 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CommandCursor class to iterate over command results.""" +from __future__ import annotations + +from collections import deque +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterator, + Mapping, + NoReturn, + Optional, + Sequence, + Union, +) + +from bson import CodecOptions, _convert_raw_document_lists_to_streams +from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _DocumentOut, _DocumentType + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + from pymongo.pool import Connection + + +class CommandCursor(Generic[_DocumentType]): + """A cursor / iterator over command cursors.""" + + _getmore_class = _GetMore + + def __init__( + self, + collection: Collection[_DocumentType], + cursor_info: Mapping[str, Any], + address: Optional[_Address], + batch_size: int = 0, + max_await_time_ms: Optional[int] = None, + session: Optional[ClientSession] = None, + explicit_session: bool = False, + comment: Any = None, + ) -> None: + """Create a new command cursor.""" + self.__sock_mgr: Any = None + self.__collection: Collection[_DocumentType] = collection + self.__id = cursor_info["id"] + self.__data = deque(cursor_info["firstBatch"]) + self.__postbatchresumetoken: Optional[Mapping[str, Any]] = cursor_info.get( + "postBatchResumeToken" + ) + self.__address = address + self.__batch_size = batch_size + self.__max_await_time_ms = max_await_time_ms + self.__session = session + self.__explicit_session = explicit_session + self.__killed = self.__id == 0 + self.__comment = comment + if self.__killed: + self.__end_session(True) + + if "ns" in cursor_info: # noqa: SIM401 + self.__ns = cursor_info["ns"] + else: + self.__ns = collection.full_name + + self.batch_size(batch_size) + + if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: + raise TypeError("max_await_time_ms must be an integer or None") + + def __del__(self) -> None: + self.__die() + + def __die(self, synchronous: bool = False) -> None: + """Closes this cursor.""" + already_killed = self.__killed + self.__killed = True + if self.__id and not already_killed: + cursor_id = self.__id + assert self.__address is not None + address = _CursorAddress(self.__address, self.__ns) + else: + # Skip killCursors. + cursor_id = 0 + address = None + self.__collection.database.client._cleanup_cursor( + synchronous, + cursor_id, + address, + self.__sock_mgr, + self.__session, + self.__explicit_session, + ) + if not self.__explicit_session: + self.__session = None + self.__sock_mgr = None + + def __end_session(self, synchronous: bool) -> None: + if self.__session and not self.__explicit_session: + self.__session._end_session(lock=synchronous) + self.__session = None + + def close(self) -> None: + """Explicitly close / kill this cursor.""" + self.__die(True) + + def batch_size(self, batch_size: int) -> CommandCursor[_DocumentType]: + """Limits the number of documents returned in one batch. Each batch + requires a round trip to the server. It can be adjusted to optimize + performance and limit data transfer. + + .. note:: batch_size can not override MongoDB's internal limits on the + amount of data it will return to the client in a single batch (i.e + if you set batch size to 1,000,000,000, MongoDB will currently only + return 4-16MB of results per batch). + + Raises :exc:`TypeError` if `batch_size` is not an integer. + Raises :exc:`ValueError` if `batch_size` is less than ``0``. + + :param batch_size: The size of each batch of results requested. + """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size < 0: + raise ValueError("batch_size must be >= 0") + + self.__batch_size = batch_size == 1 and 2 or batch_size + return self + + def _has_next(self) -> bool: + """Returns `True` if the cursor has documents remaining from the + previous batch. + """ + return len(self.__data) > 0 + + @property + def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]: + """Retrieve the postBatchResumeToken from the response to a + changeStream aggregate or getMore. + """ + return self.__postbatchresumetoken + + def _maybe_pin_connection(self, conn: Connection) -> None: + client = self.__collection.database.client + if not client._should_pin_cursor(self.__session): + return + if not self.__sock_mgr: + conn.pin_cursor() + conn_mgr = _ConnectionManager(conn, False) + # Ensure the connection gets returned when the entire result is + # returned in the first batch. + if self.__id == 0: + conn_mgr.close() + else: + self.__sock_mgr = conn_mgr + + def __send_message(self, operation: _GetMore) -> None: + """Send a getmore message and handle the response.""" + client = self.__collection.database.client + try: + response = client._run_operation( + operation, self._unpack_response, address=self.__address + ) + except OperationFailure as exc: + if exc.code in _CURSOR_CLOSED_ERRORS: + # Don't send killCursors because the cursor is already closed. + self.__killed = True + if exc.timeout: + self.__die(False) + else: + # Return the session and pinned connection, if necessary. + self.close() + raise + except ConnectionFailure: + # Don't send killCursors because the cursor is already closed. + self.__killed = True + # Return the session and pinned connection, if necessary. + self.close() + raise + except Exception: + self.close() + raise + + if isinstance(response, PinnedResponse): + if not self.__sock_mgr: + self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + if response.from_command: + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self.__postbatchresumetoken = cursor.get("postBatchResumeToken") + self.__id = cursor["id"] + else: + documents = response.docs + assert isinstance(response.data, _OpReply) + self.__id = response.data.cursor_id + + if self.__id == 0: + self.close() + self.__data = deque(documents) + + def _unpack_response( + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions[Mapping[str, Any]], + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> Sequence[_DocumentOut]: + return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) + + def _refresh(self) -> int: + """Refreshes the cursor with more data from the server. + + Returns the length of self.__data after refresh. Will exit early if + self.__data is already non-empty. Raises OperationFailure when the + cursor cannot be refreshed due to an error on the query. + """ + if len(self.__data) or self.__killed: + return len(self.__data) + + if self.__id: # Get More + dbname, collname = self.__ns.split(".", 1) + read_pref = self.__collection._read_preference_for(self.session) + self.__send_message( + self._getmore_class( + dbname, + collname, + self.__batch_size, + self.__id, + self.__collection.codec_options, + read_pref, + self.__session, + self.__collection.database.client, + self.__max_await_time_ms, + self.__sock_mgr, + False, + self.__comment, + ) + ) + else: # Cursor id is zero nothing else to return + self.__die(True) + + return len(self.__data) + + @property + def alive(self) -> bool: + """Does this cursor have the potential to return more data? + + Even if :attr:`alive` is ``True``, :meth:`next` can raise + :exc:`StopIteration`. Best to use a for loop:: + + for doc in collection.aggregate(pipeline): + print(doc) + + .. note:: :attr:`alive` can be True while iterating a cursor from + a failed server. In this case :attr:`alive` will return False after + :meth:`next` fails to retrieve the next batch of results from the + server. + """ + return bool(len(self.__data) or (not self.__killed)) + + @property + def cursor_id(self) -> int: + """Returns the id of the cursor.""" + return self.__id + + @property + def address(self) -> Optional[_Address]: + """The (host, port) of the server used, or None. + + .. versionadded:: 3.0 + """ + return self.__address + + @property + def session(self) -> Optional[ClientSession]: + """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + + .. versionadded:: 3.6 + """ + if self.__explicit_session: + return self.__session + return None + + def __iter__(self) -> Iterator[_DocumentType]: + return self + + def next(self) -> _DocumentType: + """Advance the cursor.""" + # Block until a document is returnable. + while self.alive: + doc = self._try_next(True) + if doc is not None: + return doc + + raise StopIteration + + __next__ = next + + def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]: + """Advance the cursor blocking for at most one getMore command.""" + if not len(self.__data) and not self.__killed and get_more_allowed: + self._refresh() + if len(self.__data): + return self.__data.popleft() + else: + return None + + def try_next(self) -> Optional[_DocumentType]: + """Advance the cursor without blocking indefinitely. + + This method returns the next document without waiting + indefinitely for data. + + If no document is cached locally then this method runs a single + getMore command. If the getMore yields any documents, the next + document is returned, otherwise, if the getMore returns no documents + (because there is no additional data) then ``None`` is returned. + + :return: The next document or ``None`` when no document is available + after running a single getMore or when the cursor is closed. + + .. versionadded:: 4.5 + """ + return self._try_next(get_more_allowed=True) + + def __enter__(self) -> CommandCursor[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + +class RawBatchCommandCursor(CommandCursor, Generic[_DocumentType]): + _getmore_class = _RawBatchGetMore + + def __init__( + self, + collection: Collection[_DocumentType], + cursor_info: Mapping[str, Any], + address: Optional[_Address], + batch_size: int = 0, + max_await_time_ms: Optional[int] = None, + session: Optional[ClientSession] = None, + explicit_session: bool = False, + comment: Any = None, + ) -> None: + """Create a new cursor / iterator over raw batches of BSON data. + + Should not be called directly by application developers - + see :meth:`~pymongo.collection.Collection.aggregate_raw_batches` + instead. + + .. seealso:: The MongoDB documentation on `cursors `_. + """ + assert not cursor_info.get("firstBatch") + super().__init__( + collection, + cursor_info, + address, + batch_size, + max_await_time_ms, + session, + explicit_session, + comment, + ) + + def _unpack_response( # type: ignore[override] + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[Mapping[str, Any]]: + raw_response = response.raw_response(cursor_id, user_fields=user_fields) + if not legacy_response: + # OP_MSG returns firstBatch/nextBatch documents as a BSON array + # Re-assemble the array of documents into a document stream + _convert_raw_document_lists_to_streams(raw_response[0]) + return raw_response # type: ignore[return-value] + + def __getitem__(self, index: int) -> NoReturn: + raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") diff --git a/venv/Lib/site-packages/pymongo/common.py b/venv/Lib/site-packages/pymongo/common.py new file mode 100644 index 00000000..7f1245b7 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/common.py @@ -0,0 +1,1055 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Functions and classes common to multiple pymongo modules.""" +from __future__ import annotations + +import datetime +import warnings +from collections import OrderedDict, abc +from difflib import get_close_matches +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Type, + Union, + overload, +) +from urllib.parse import unquote_plus + +from bson import SON +from bson.binary import UuidRepresentation +from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry +from bson.raw_bson import RawBSONDocument +from pymongo.auth import MECHANISMS +from pymongo.auth_oidc import OIDCCallback +from pymongo.compression_support import ( + validate_compressors, + validate_zlib_compression_level, +) +from pymongo.driver_info import DriverInfo +from pymongo.errors import ConfigurationError +from pymongo.monitoring import _validate_event_listeners +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import _MONGOS_MODES, _ServerMode +from pymongo.server_api import ServerApi +from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean + +if TYPE_CHECKING: + from pymongo.client_session import ClientSession + +ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) + +# Defaults until we connect to a server and get updated limits. +MAX_BSON_SIZE = 16 * (1024**2) +MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE +MIN_WIRE_VERSION = 0 +MAX_WIRE_VERSION = 0 +MAX_WRITE_BATCH_SIZE = 1000 + +# What this version of PyMongo supports. +MIN_SUPPORTED_SERVER_VERSION = "3.6" +MIN_SUPPORTED_WIRE_VERSION = 6 +MAX_SUPPORTED_WIRE_VERSION = 21 + +# Frequency to call hello on servers, in seconds. +HEARTBEAT_FREQUENCY = 10 + +# Frequency to clean up unclosed cursors, in seconds. +# See MongoClient._process_kill_cursors. +KILL_CURSOR_FREQUENCY = 1 + +# Frequency to process events queue, in seconds. +EVENTS_QUEUE_FREQUENCY = 1 + +# How long to wait, in seconds, for a suitable server to be found before +# aborting an operation. For example, if the client attempts an insert +# during a replica set election, SERVER_SELECTION_TIMEOUT governs the +# longest it is willing to wait for a new primary to be found. +SERVER_SELECTION_TIMEOUT = 30 + +# Spec requires at least 500ms between hello calls. +MIN_HEARTBEAT_INTERVAL = 0.5 + +# Spec requires at least 60s between SRV rescans. +MIN_SRV_RESCAN_INTERVAL = 60 + +# Default connectTimeout in seconds. +CONNECT_TIMEOUT = 20.0 + +# Default value for maxPoolSize. +MAX_POOL_SIZE = 100 + +# Default value for minPoolSize. +MIN_POOL_SIZE = 0 + +# The maximum number of concurrent connection creation attempts per pool. +MAX_CONNECTING = 2 + +# Default value for maxIdleTimeMS. +MAX_IDLE_TIME_MS: Optional[int] = None + +# Default value for maxIdleTimeMS in seconds. +MAX_IDLE_TIME_SEC: Optional[int] = None + +# Default value for waitQueueTimeoutMS in seconds. +WAIT_QUEUE_TIMEOUT: Optional[int] = None + +# Default value for localThresholdMS. +LOCAL_THRESHOLD_MS = 15 + +# Default value for retryWrites. +RETRY_WRITES = True + +# Default value for retryReads. +RETRY_READS = True + +# The error code returned when a command doesn't exist. +COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,) + +# Error codes to ignore if GridFS calls createIndex on a secondary +UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548) + +# Maximum number of sessions to send in a single endSessions command. +# From the driver sessions spec. +_MAX_END_SESSIONS = 10000 + +# Default value for srvServiceName +SRV_SERVICE_NAME = "mongodb" + +# Default value for serverMonitoringMode +SERVER_MONITORING_MODE = "auto" # poll/stream/auto + + +def partition_node(node: str) -> tuple[str, int]: + """Split a host:port string into (host, int(port)) pair.""" + host = node + port = 27017 + idx = node.rfind(":") + if idx != -1: + host, port = node[:idx], int(node[idx + 1 :]) + if host.startswith("["): + host = host[1:-1] + return host, port + + +def clean_node(node: str) -> tuple[str, int]: + """Split and normalize a node name from a hello response.""" + host, port = partition_node(node) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # http://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn: + """Raise ConfigurationError with the given key name.""" + msg = f"Unknown option: {key}." + if suggestions: + msg += f" Did you mean one of ({', '.join(suggestions)}) or maybe a camelCase version of one? Refer to docstring." + raise ConfigurationError(msg) + + +# Mapping of URI uuid representation options to valid subtypes. +_UUID_REPRESENTATIONS = { + "unspecified": UuidRepresentation.UNSPECIFIED, + "standard": UuidRepresentation.STANDARD, + "pythonLegacy": UuidRepresentation.PYTHON_LEGACY, + "javaLegacy": UuidRepresentation.JAVA_LEGACY, + "csharpLegacy": UuidRepresentation.CSHARP_LEGACY, +} + + +def validate_boolean_or_string(option: str, value: Any) -> bool: + """Validates that value is True, False, 'true', or 'false'.""" + if isinstance(value, str): + if value not in ("true", "false"): + raise ValueError(f"The value of {option} must be 'true' or 'false'") + return value == "true" + return validate_boolean(option, value) + + +def validate_integer(option: str, value: Any) -> int: + """Validates that 'value' is an integer (or basestring representation).""" + if isinstance(value, int): + return value + elif isinstance(value, str): + try: + return int(value) + except ValueError: + raise ValueError(f"The value of {option} must be an integer") from None + raise TypeError(f"Wrong type for {option}, value must be an integer") + + +def validate_positive_integer(option: str, value: Any) -> int: + """Validate that 'value' is a positive integer, which does not include 0.""" + val = validate_integer(option, value) + if val <= 0: + raise ValueError(f"The value of {option} must be a positive integer") + return val + + +def validate_non_negative_integer(option: str, value: Any) -> int: + """Validate that 'value' is a positive integer or 0.""" + val = validate_integer(option, value) + if val < 0: + raise ValueError(f"The value of {option} must be a non negative integer") + return val + + +def validate_readable(option: str, value: Any) -> Optional[str]: + """Validates that 'value' is file-like and readable.""" + if value is None: + return value + # First make sure its a string py3.3 open(True, 'r') succeeds + # Used in ssl cert checking due to poor ssl module error reporting + value = validate_string(option, value) + open(value).close() + return value + + +def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]: + """Validate that 'value' is a positive integer or None.""" + if value is None: + return value + return validate_positive_integer(option, value) + + +def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]: + """Validate that 'value' is a positive integer or 0 or None.""" + if value is None: + return value + return validate_non_negative_integer(option, value) + + +def validate_string(option: str, value: Any) -> str: + """Validates that 'value' is an instance of `str`.""" + if isinstance(value, str): + return value + raise TypeError(f"Wrong type for {option}, value must be an instance of str") + + +def validate_string_or_none(option: str, value: Any) -> Optional[str]: + """Validates that 'value' is an instance of `basestring` or `None`.""" + if value is None: + return value + return validate_string(option, value) + + +def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: + """Validates that 'value' is an integer or string.""" + if isinstance(value, int): + return value + elif isinstance(value, str): + try: + return int(value) + except ValueError: + return value + raise TypeError(f"Wrong type for {option}, value must be an integer or a string") + + +def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: + """Validates that 'value' is an integer or string.""" + if isinstance(value, int): + return value + elif isinstance(value, str): + try: + val = int(value) + except ValueError: + return value + return validate_non_negative_integer(option, val) + raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") + + +def validate_positive_float(option: str, value: Any) -> float: + """Validates that 'value' is a float, or can be converted to one, and is + positive. + """ + errmsg = f"{option} must be an integer or float" + try: + value = float(value) + except ValueError: + raise ValueError(errmsg) from None + except TypeError: + raise TypeError(errmsg) from None + + # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at + # one billion - this is a reasonable approximation for infinity + if not 0 < value < 1e9: + raise ValueError(f"{option} must be greater than 0 and less than one billion") + return value + + +def validate_positive_float_or_zero(option: str, value: Any) -> float: + """Validates that 'value' is 0 or a positive float, or can be converted to + 0 or a positive float. + """ + if value == 0 or value == "0": + return 0 + return validate_positive_float(option, value) + + +def validate_timeout_or_none(option: str, value: Any) -> Optional[float]: + """Validates a timeout specified in milliseconds returning + a value in floating point seconds. + """ + if value is None: + return value + return validate_positive_float(option, value) / 1000.0 + + +def validate_timeout_or_zero(option: str, value: Any) -> float: + """Validates a timeout specified in milliseconds returning + a value in floating point seconds for the case where None is an error + and 0 is valid. Setting the timeout to nothing in the URI string is a + config error. + """ + if value is None: + raise ConfigurationError(f"{option} cannot be None") + if value == 0 or value == "0": + return 0 + return validate_positive_float(option, value) / 1000.0 + + +def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]: + """Validates a timeout specified in milliseconds returning + a value in floating point seconds. value=0 and value="0" are treated the + same as value=None which means unlimited timeout. + """ + if value is None or value == 0 or value == "0": + return None + return validate_positive_float(option, value) / 1000.0 + + +def validate_timeoutms(option: Any, value: Any) -> Optional[float]: + """Validates a timeout specified in milliseconds returning + a value in floating point seconds. + """ + if value is None: + return None + return validate_positive_float_or_zero(option, value) / 1000.0 + + +def validate_max_staleness(option: str, value: Any) -> int: + """Validates maxStalenessSeconds according to the Max Staleness Spec.""" + if value == -1 or value == "-1": + # Default: No maximum staleness. + return -1 + return validate_positive_integer(option, value) + + +def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: + """Validate a read preference.""" + if not isinstance(value, _ServerMode): + raise TypeError(f"{value!r} is not a read preference.") + return value + + +def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: + """Validate read preference mode for a MongoClient. + + .. versionchanged:: 3.5 + Returns the original ``value`` instead of the validated read preference + mode. + """ + if value not in _MONGOS_MODES: + raise ValueError(f"{value} is not a valid read preference") + return value + + +def validate_auth_mechanism(option: str, value: Any) -> str: + """Validate the authMechanism URI option.""" + if value not in MECHANISMS: + raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") + return value + + +def validate_uuid_representation(dummy: Any, value: Any) -> int: + """Validate the uuid representation option selected in the URI.""" + try: + return _UUID_REPRESENTATIONS[value] + except KeyError: + raise ValueError( + f"{value} is an invalid UUID representation. " + "Must be one of " + f"{tuple(_UUID_REPRESENTATIONS)}" + ) from None + + +def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]: + """Parse readPreferenceTags if passed as a client kwarg.""" + if not isinstance(value, list): + value = [value] + + tag_sets: list = [] + for tag_set in value: + if tag_set == "": + tag_sets.append({}) + continue + try: + tags = {} + for tag in tag_set.split(","): + key, val = tag.split(":") + tags[unquote_plus(key)] = unquote_plus(val) + tag_sets.append(tags) + except Exception: + raise ValueError(f"{tag_set!r} not a valid value for {name}") from None + return tag_sets + + +_MECHANISM_PROPS = frozenset( + [ + "SERVICE_NAME", + "CANONICALIZE_HOST_NAME", + "SERVICE_REALM", + "AWS_SESSION_TOKEN", + "ENVIRONMENT", + "TOKEN_RESOURCE", + ] +) + + +def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Union[bool, str]]: + """Validate authMechanismProperties.""" + props: dict[str, Any] = {} + if not isinstance(value, str): + if not isinstance(value, dict): + raise ValueError("Auth mechanism properties must be given as a string or a dictionary") + for key, value in value.items(): # noqa: B020 + if isinstance(value, str): + props[key] = value + elif isinstance(value, bool): + props[key] = str(value).lower() + elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): + props[key] = value + elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: + if not isinstance(value, OIDCCallback): + raise ValueError("callback must be an OIDCCallback object") + props[key] = value + else: + raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") + return props + + value = validate_string(option, value) + for opt in value.split(","): + key, _, val = opt.partition(":") + if key not in _MECHANISM_PROPS: + # Try not to leak the token. + if "AWS_SESSION_TOKEN" in key: + raise ValueError( + "auth mechanism properties must be " + "key:value pairs like AWS_SESSION_TOKEN:" + ) + + raise ValueError( + f"{key} is not a supported auth " + "mechanism property. Must be one of " + f"{tuple(_MECHANISM_PROPS)}." + ) + + if key == "CANONICALIZE_HOST_NAME": + props[key] = validate_boolean_or_string(key, val) + else: + props[key] = unquote_plus(val) + + return props + + +def validate_document_class( + option: str, value: Any +) -> Union[Type[MutableMapping], Type[RawBSONDocument]]: + """Validate the document_class option.""" + # issubclass can raise TypeError for generic aliases like SON[str, Any]. + # In that case we can use the base class for the comparison. + is_mapping = False + try: + is_mapping = issubclass(value, abc.MutableMapping) + except TypeError: + if hasattr(value, "__origin__"): + is_mapping = issubclass(value.__origin__, abc.MutableMapping) + if not is_mapping and not issubclass(value, RawBSONDocument): + raise TypeError( + f"{option} must be dict, bson.son.SON, " + "bson.raw_bson.RawBSONDocument, or a " + "subclass of collections.MutableMapping" + ) + return value + + +def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: + """Validate the type_registry option.""" + if value is not None and not isinstance(value, TypeRegistry): + raise TypeError(f"{option} must be an instance of {TypeRegistry}") + return value + + +def validate_list(option: str, value: Any) -> list: + """Validates that 'value' is a list.""" + if not isinstance(value, list): + raise TypeError(f"{option} must be a list") + return value + + +def validate_list_or_none(option: Any, value: Any) -> Optional[list]: + """Validates that 'value' is a list or None.""" + if value is None: + return value + return validate_list(option, value) + + +def validate_list_or_mapping(option: Any, value: Any) -> None: + """Validates that 'value' is a list or a document.""" + if not isinstance(value, (abc.Mapping, list)): + raise TypeError( + f"{option} must either be a list or an instance of dict, " + "bson.son.SON, or any other type that inherits from " + "collections.Mapping" + ) + + +def validate_is_mapping(option: str, value: Any) -> None: + """Validate the type of method arguments that expect a document.""" + if not isinstance(value, abc.Mapping): + raise TypeError( + f"{option} must be an instance of dict, bson.son.SON, or " + "any other type that inherits from " + "collections.Mapping" + ) + + +def validate_is_document_type(option: str, value: Any) -> None: + """Validate the type of method arguments that expect a MongoDB document.""" + if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): + raise TypeError( + f"{option} must be an instance of dict, bson.son.SON, " + "bson.raw_bson.RawBSONDocument, or " + "a type that inherits from " + "collections.MutableMapping" + ) + + +def validate_appname_or_none(option: str, value: Any) -> Optional[str]: + """Validate the appname option.""" + if value is None: + return value + validate_string(option, value) + # We need length in bytes, so encode utf8 first. + if len(value.encode("utf-8")) > 128: + raise ValueError(f"{option} must be <= 128 bytes") + return value + + +def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]: + """Validate the driver keyword arg.""" + if value is None: + return value + if not isinstance(value, DriverInfo): + raise TypeError(f"{option} must be an instance of DriverInfo") + return value + + +def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: + """Validate the server_api keyword arg.""" + if value is None: + return value + if not isinstance(value, ServerApi): + raise TypeError(f"{option} must be an instance of ServerApi") + return value + + +def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: + """Validates that 'value' is a callable.""" + if value is None: + return value + if not callable(value): + raise ValueError(f"{option} must be a callable") + return value + + +def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None: + """Validate a replacement document.""" + validate_is_mapping("replacement", replacement) + # Replacement can be {} + if replacement and not isinstance(replacement, RawBSONDocument): + first = next(iter(replacement)) + if first.startswith("$"): + raise ValueError("replacement can not include $ operators") + + +def validate_ok_for_update(update: Any) -> None: + """Validate an update document.""" + validate_list_or_mapping("update", update) + # Update cannot be {}. + if not update: + raise ValueError("update cannot be empty") + + is_document = not isinstance(update, list) + first = next(iter(update)) + if is_document and not first.startswith("$"): + raise ValueError("update only works with $ operators") + + +_UNICODE_DECODE_ERROR_HANDLERS = frozenset(["strict", "replace", "ignore"]) + + +def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: + """Validate the Unicode decode error handler option of CodecOptions.""" + if value not in _UNICODE_DECODE_ERROR_HANDLERS: + raise ValueError( + f"{value} is an invalid Unicode decode error handler. " + "Must be one of " + f"{tuple(_UNICODE_DECODE_ERROR_HANDLERS)}" + ) + return value + + +def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]: + """Validate the tzinfo option""" + if value is not None and not isinstance(value, datetime.tzinfo): + raise TypeError("%s must be an instance of datetime.tzinfo" % value) + return value + + +def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]: + """Validate the driver keyword arg.""" + if value is None: + return value + from pymongo.encryption_options import AutoEncryptionOpts + + if not isinstance(value, AutoEncryptionOpts): + raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") + + return value + + +def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeConversion]: + """Validate a DatetimeConversion string.""" + if value is None: + return DatetimeConversion.DATETIME + + if isinstance(value, str): + if value.isdigit(): + return DatetimeConversion(int(value)) + return DatetimeConversion[value] + elif isinstance(value, int): + return DatetimeConversion(value) + + raise TypeError(f"{option} must be a str or int representing DatetimeConversion") + + +def validate_server_monitoring_mode(option: str, value: str) -> str: + """Validate the serverMonitoringMode option.""" + if value not in {"auto", "stream", "poll"}: + raise ValueError( + f'{option}={value!r} is invalid. Must be one of "auto", "stream", or "poll"' + ) + return value + + +# Dictionary where keys are the names of public URI options, and values +# are lists of aliases for that option. +URI_OPTIONS_ALIAS_MAP: dict[str, list[str]] = { + "tls": ["ssl"], +} + +# Dictionary where keys are the names of URI options, and values +# are functions that validate user-input values for that option. If an option +# alias uses a different validator than its public counterpart, it should be +# included here as a key, value pair. +URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { + "appname": validate_appname_or_none, + "authmechanism": validate_auth_mechanism, + "authmechanismproperties": validate_auth_mechanism_properties, + "authsource": validate_string, + "compressors": validate_compressors, + "connecttimeoutms": validate_timeout_or_none_or_zero, + "directconnection": validate_boolean_or_string, + "heartbeatfrequencyms": validate_timeout_or_none, + "journal": validate_boolean_or_string, + "localthresholdms": validate_positive_float_or_zero, + "maxidletimems": validate_timeout_or_none, + "maxconnecting": validate_positive_integer, + "maxpoolsize": validate_non_negative_integer_or_none, + "maxstalenessseconds": validate_max_staleness, + "readconcernlevel": validate_string_or_none, + "readpreference": validate_read_preference_mode, + "readpreferencetags": validate_read_preference_tags, + "replicaset": validate_string_or_none, + "retryreads": validate_boolean_or_string, + "retrywrites": validate_boolean_or_string, + "loadbalanced": validate_boolean_or_string, + "serverselectiontimeoutms": validate_timeout_or_zero, + "sockettimeoutms": validate_timeout_or_none_or_zero, + "tls": validate_boolean_or_string, + "tlsallowinvalidcertificates": validate_boolean_or_string, + "tlsallowinvalidhostnames": validate_boolean_or_string, + "tlscafile": validate_readable, + "tlscertificatekeyfile": validate_readable, + "tlscertificatekeyfilepassword": validate_string_or_none, + "tlsdisableocspendpointcheck": validate_boolean_or_string, + "tlsinsecure": validate_boolean_or_string, + "w": validate_non_negative_int_or_basestring, + "wtimeoutms": validate_non_negative_integer, + "zlibcompressionlevel": validate_zlib_compression_level, + "srvservicename": validate_string, + "srvmaxhosts": validate_non_negative_integer, + "timeoutms": validate_timeoutms, + "servermonitoringmode": validate_server_monitoring_mode, +} + +# Dictionary where keys are the names of URI options specific to pymongo, +# and values are functions that validate user-input values for those options. +NONSPEC_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { + "connect": validate_boolean_or_string, + "driver": validate_driver_or_none, + "server_api": validate_server_api_or_none, + "fsync": validate_boolean_or_string, + "minpoolsize": validate_non_negative_integer, + "tlscrlfile": validate_readable, + "tz_aware": validate_boolean_or_string, + "unicode_decode_error_handler": validate_unicode_decode_error_handler, + "uuidrepresentation": validate_uuid_representation, + "waitqueuemultiple": validate_non_negative_integer_or_none, + "waitqueuetimeoutms": validate_timeout_or_none, + "datetime_conversion": validate_datetime_conversion, +} + +# Dictionary where keys are the names of keyword-only options for the +# MongoClient constructor, and values are functions that validate user-input +# values for those options. +KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = { + "document_class": validate_document_class, + "type_registry": validate_type_registry, + "read_preference": validate_read_preference, + "event_listeners": _validate_event_listeners, + "tzinfo": validate_tzinfo, + "username": validate_string_or_none, + "password": validate_string_or_none, + "server_selector": validate_is_callable_or_none, + "auto_encryption_opts": validate_auto_encryption_opts_or_none, + "authoidcallowedhosts": validate_list, +} + +# Dictionary where keys are any URI option name, and values are the +# internally-used names of that URI option. Options with only one name +# variant need not be included here. Options whose public and internal +# names are the same need not be included here. +INTERNAL_URI_OPTION_NAME_MAP: dict[str, str] = { + "ssl": "tls", +} + +# Map from deprecated URI option names to a tuple indicating the method of +# their deprecation and any additional information that may be needed to +# construct the warning message. +URI_OPTIONS_DEPRECATION_MAP: dict[str, tuple[str, str]] = { + # format: : (, ), + # Supported values: + # - 'renamed': should be the new option name. Note that case is + # preserved for renamed options as they are part of user warnings. + # - 'removed': may suggest the rationale for deprecating the + # option and/or recommend remedial action. + # For example: + # 'wtimeout': ('renamed', 'wTimeoutMS'), +} + +# Augment the option validator map with pymongo-specific option information. +URI_OPTIONS_VALIDATOR_MAP.update(NONSPEC_OPTIONS_VALIDATOR_MAP) +for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): + for alias in aliases: + if alias not in URI_OPTIONS_VALIDATOR_MAP: + URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname] + +# Map containing all URI option and keyword argument validators. +VALIDATORS: dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() +VALIDATORS.update(KW_VALIDATORS) + +# List of timeout-related options. +TIMEOUT_OPTIONS: list[str] = [ + "connecttimeoutms", + "heartbeatfrequencyms", + "maxidletimems", + "maxstalenessseconds", + "serverselectiontimeoutms", + "sockettimeoutms", + "waitqueuetimeoutms", +] + +_AUTH_OPTIONS = frozenset(["authmechanismproperties"]) + + +def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: + """Validate optional authentication parameters.""" + lower, value = validate(option, value) + if lower not in _AUTH_OPTIONS: + raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}") + return option, value + + +def _get_validator( + key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None +) -> Callable: + normed_key = normed_key or key + try: + return validators[normed_key] + except KeyError: + suggestions = get_close_matches(normed_key, validators, cutoff=0.2) + raise_config_error(key, suggestions) + + +def validate(option: str, value: Any) -> tuple[str, Any]: + """Generic validation function.""" + validator = _get_validator(option, VALIDATORS, normed_key=option.lower()) + value = validator(option, value) + return option, value + + +def get_validated_options( + options: Mapping[str, Any], warn: bool = True +) -> MutableMapping[str, Any]: + """Validate each entry in options and raise a warning if it is not valid. + Returns a copy of options with invalid entries removed. + + :param opts: A dict containing MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise, invalid options will + cause errors. + """ + validated_options: MutableMapping[str, Any] + if isinstance(options, _CaseInsensitiveDictionary): + validated_options = _CaseInsensitiveDictionary() + + def get_normed_key(x: str) -> str: + return x + + def get_setter_key(x: str) -> str: + return options.cased_key(x) # type: ignore[attr-defined] + + else: + validated_options = {} + + def get_normed_key(x: str) -> str: + return x.lower() + + def get_setter_key(x: str) -> str: + return x + + for opt, value in options.items(): + normed_key = get_normed_key(opt) + try: + validator = _get_validator(opt, URI_OPTIONS_VALIDATOR_MAP, normed_key=normed_key) + validated = validator(opt, value) + except (ValueError, TypeError, ConfigurationError) as exc: + if warn: + warnings.warn(str(exc), stacklevel=2) + else: + raise + else: + validated_options[get_setter_key(normed_key)] = validated + return validated_options + + +def _esc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: + return encrypted_fields.get("escCollection", f"enxcol_.{name}.esc") + + +def _ecoc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: + return encrypted_fields.get("ecocCollection", f"enxcol_.{name}.ecoc") + + +# List of write-concern-related options. +WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"]) + + +class BaseObject: + """A base class that provides attributes and methods common + to multiple pymongo classes. + + SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. + """ + + def __init__( + self, + codec_options: CodecOptions, + read_preference: _ServerMode, + write_concern: WriteConcern, + read_concern: ReadConcern, + ) -> None: + if not isinstance(codec_options, CodecOptions): + raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + self.__codec_options = codec_options + + if not isinstance(read_preference, _ServerMode): + raise TypeError( + f"{read_preference!r} is not valid for read_preference. See " + "pymongo.read_preferences for valid " + "options." + ) + self.__read_preference = read_preference + + if not isinstance(write_concern, WriteConcern): + raise TypeError( + "write_concern must be an instance of pymongo.write_concern.WriteConcern" + ) + self.__write_concern = write_concern + + if not isinstance(read_concern, ReadConcern): + raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern") + self.__read_concern = read_concern + + @property + def codec_options(self) -> CodecOptions: + """Read only access to the :class:`~bson.codec_options.CodecOptions` + of this instance. + """ + return self.__codec_options + + @property + def write_concern(self) -> WriteConcern: + """Read only access to the :class:`~pymongo.write_concern.WriteConcern` + of this instance. + + .. versionchanged:: 3.0 + The :attr:`write_concern` attribute is now read only. + """ + return self.__write_concern + + def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern: + """Read only access to the write concern of this instance or session.""" + # Override this operation's write concern with the transaction's. + if session and session.in_transaction: + return DEFAULT_WRITE_CONCERN + return self.write_concern + + @property + def read_preference(self) -> _ServerMode: + """Read only access to the read preference of this instance. + + .. versionchanged:: 3.0 + The :attr:`read_preference` attribute is now read only. + """ + return self.__read_preference + + def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode: + """Read only access to the read preference of this instance or session.""" + # Override this operation's read preference with the transaction's. + if session: + return session._txn_read_preference() or self.__read_preference + return self.__read_preference + + @property + def read_concern(self) -> ReadConcern: + """Read only access to the :class:`~pymongo.read_concern.ReadConcern` + of this instance. + + .. versionadded:: 3.2 + """ + return self.__read_concern + + +class _CaseInsensitiveDictionary(MutableMapping[str, Any]): + def __init__(self, *args: Any, **kwargs: Any): + self.__casedkeys: dict[str, Any] = {} + self.__data: dict[str, Any] = {} + self.update(dict(*args, **kwargs)) + + def __contains__(self, key: str) -> bool: # type: ignore[override] + return key.lower() in self.__data + + def __len__(self) -> int: + return len(self.__data) + + def __iter__(self) -> Iterator[str]: + return (key for key in self.__casedkeys) + + def __repr__(self) -> str: + return str({self.__casedkeys[k]: self.__data[k] for k in self}) + + def __setitem__(self, key: str, value: Any) -> None: + lc_key = key.lower() + self.__casedkeys[lc_key] = key + self.__data[lc_key] = value + + def __getitem__(self, key: str) -> Any: + return self.__data[key.lower()] + + def __delitem__(self, key: str) -> None: + lc_key = key.lower() + del self.__casedkeys[lc_key] + del self.__data[lc_key] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, abc.Mapping): + return NotImplemented + if len(self) != len(other): + return False + for key in other: # noqa: SIM110 + if self[key] != other[key]: + return False + + return True + + def get(self, key: str, default: Optional[Any] = None) -> Any: + return self.__data.get(key.lower(), default) + + def pop(self, key: str, *args: Any, **kwargs: Any) -> Any: + lc_key = key.lower() + self.__casedkeys.pop(lc_key, None) + return self.__data.pop(lc_key, *args, **kwargs) + + def popitem(self) -> tuple[str, Any]: + lc_key, cased_key = self.__casedkeys.popitem() + value = self.__data.pop(lc_key) + return cased_key, value + + def clear(self) -> None: + self.__casedkeys.clear() + self.__data.clear() + + @overload + def setdefault(self, key: str, default: None = None) -> Optional[Any]: + ... + + @overload + def setdefault(self, key: str, default: Any) -> Any: + ... + + def setdefault(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + lc_key = key.lower() + if key in self: + return self.__data[lc_key] + else: + self.__casedkeys[lc_key] = key + self.__data[lc_key] = default + return default + + def update(self, other: Mapping[str, Any]) -> None: # type: ignore[override] + if isinstance(other, _CaseInsensitiveDictionary): + for key in other: + self[other.cased_key(key)] = other[key] + else: + for key in other: + self[key] = other[key] + + def cased_key(self, key: str) -> Any: + return self.__casedkeys[key.lower()] diff --git a/venv/Lib/site-packages/pymongo/compression_support.py b/venv/Lib/site-packages/pymongo/compression_support.py new file mode 100644 index 00000000..7daad210 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/compression_support.py @@ -0,0 +1,157 @@ +# Copyright 2018 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Iterable, Optional, Union + +from pymongo._lazy_import import lazy_import +from pymongo.hello import HelloCompat +from pymongo.monitoring import _SENSITIVE_COMMANDS + +try: + snappy = lazy_import("snappy") + _HAVE_SNAPPY = True +except ImportError: + # python-snappy isn't available. + _HAVE_SNAPPY = False + +try: + zlib = lazy_import("zlib") + + _HAVE_ZLIB = True +except ImportError: + # Python built without zlib support. + _HAVE_ZLIB = False + +try: + zstandard = lazy_import("zstandard") + _HAVE_ZSTD = True +except ImportError: + _HAVE_ZSTD = False + +_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} +_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} +_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) + + +def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: + try: + # `value` is string. + compressors = value.split(",") # type: ignore[union-attr] + except AttributeError: + # `value` is an iterable. + compressors = list(value) + + for compressor in compressors[:]: + if compressor not in _SUPPORTED_COMPRESSORS: + compressors.remove(compressor) + warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) + elif compressor == "snappy" and not _HAVE_SNAPPY: + compressors.remove(compressor) + warnings.warn( + "Wire protocol compression with snappy is not available. " + "You must install the python-snappy module for snappy support.", + stacklevel=2, + ) + elif compressor == "zlib" and not _HAVE_ZLIB: + compressors.remove(compressor) + warnings.warn( + "Wire protocol compression with zlib is not available. " + "The zlib module is not available.", + stacklevel=2, + ) + elif compressor == "zstd" and not _HAVE_ZSTD: + compressors.remove(compressor) + warnings.warn( + "Wire protocol compression with zstandard is not available. " + "You must install the zstandard module for zstandard support.", + stacklevel=2, + ) + return compressors + + +def validate_zlib_compression_level(option: str, value: Any) -> int: + try: + level = int(value) + except Exception: + raise TypeError(f"{option} must be an integer, not {value!r}.") from None + if level < -1 or level > 9: + raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) + return level + + +class CompressionSettings: + def __init__(self, compressors: list[str], zlib_compression_level: int): + self.compressors = compressors + self.zlib_compression_level = zlib_compression_level + + def get_compression_context( + self, compressors: Optional[list[str]] + ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: + if compressors: + chosen = compressors[0] + if chosen == "snappy": + return SnappyContext() + elif chosen == "zlib": + return ZlibContext(self.zlib_compression_level) + elif chosen == "zstd": + return ZstdContext() + return None + return None + + +class SnappyContext: + compressor_id = 1 + + @staticmethod + def compress(data: bytes) -> bytes: + return snappy.compress(data) + + +class ZlibContext: + compressor_id = 2 + + def __init__(self, level: int): + self.level = level + + def compress(self, data: bytes) -> bytes: + return zlib.compress(data, self.level) + + +class ZstdContext: + compressor_id = 3 + + @staticmethod + def compress(data: bytes) -> bytes: + # ZstdCompressor is not thread safe. + # TODO: Use a pool? + return zstandard.ZstdCompressor().compress(data) + + +def decompress(data: bytes, compressor_id: int) -> bytes: + if compressor_id == SnappyContext.compressor_id: + # python-snappy doesn't support the buffer interface. + # https://github.com/andrix/python-snappy/issues/65 + # This only matters when data is a memoryview since + # id(bytes(data)) == id(data) when data is a bytes. + return snappy.uncompress(bytes(data)) + elif compressor_id == ZlibContext.compressor_id: + return zlib.decompress(data) + elif compressor_id == ZstdContext.compressor_id: + # ZstdDecompressor is not thread safe. + # TODO: Use a pool? + return zstandard.ZstdDecompressor().decompress(data) + else: + raise ValueError("Unknown compressorId %d" % (compressor_id,)) diff --git a/venv/Lib/site-packages/pymongo/cursor.py b/venv/Lib/site-packages/pymongo/cursor.py new file mode 100644 index 00000000..3151fcaf --- /dev/null +++ b/venv/Lib/site-packages/pymongo/cursor.py @@ -0,0 +1,1357 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cursor class to iterate over Mongo query results.""" +from __future__ import annotations + +import copy +import warnings +from collections import deque +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Sequence, + Tuple, + Union, + cast, + overload, +) + +from bson import RE_TYPE, _convert_raw_document_lists_to_streams +from bson.code import Code +from bson.son import SON +from pymongo import helpers +from pymongo.collation import validate_collation_or_none +from pymongo.common import ( + validate_is_document_type, + validate_is_mapping, +) +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.lock import _create_lock +from pymongo.message import ( + _CursorAddress, + _GetMore, + _OpMsg, + _OpReply, + _Query, + _RawBatchGetMore, + _RawBatchQuery, +) +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType +from pymongo.write_concern import validate_boolean + +if TYPE_CHECKING: + from _typeshed import SupportsItems + + from bson.codec_options import CodecOptions + from pymongo.client_session import ClientSession + from pymongo.collection import Collection + from pymongo.pool import Connection + from pymongo.read_preferences import _ServerMode + + +# These errors mean that the server has already killed the cursor so there is +# no need to send killCursors. +_CURSOR_CLOSED_ERRORS = frozenset( + [ + 43, # CursorNotFound + 175, # QueryPlanKilled + 237, # CursorKilled + # On a tailable cursor, the following errors mean the capped collection + # rolled over. + # MongoDB 2.6: + # {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0} + 28617, + # MongoDB 3.0: + # {'$err': 'getMore executor error: UnknownError no details available', + # 'code': 17406, 'ok': 0} + 17406, + # MongoDB 3.2 + 3.4: + # {'ok': 0.0, 'errmsg': 'GetMore command executor error: + # CappedPositionLost: CollectionScan died due to failure to restore + # tailable cursor position. Last seen record id: RecordId(3)', + # 'code': 96} + 96, + # MongoDB 3.6+: + # {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to + # restore tailable cursor position. Last seen record id: RecordId(3)"', + # 'code': 136, 'codeName': 'CappedPositionLost'} + 136, + ] +) + +_QUERY_OPTIONS = { + "tailable_cursor": 2, + "secondary_okay": 4, + "oplog_replay": 8, + "no_timeout": 16, + "await_data": 32, + "exhaust": 64, + "partial": 128, +} + + +class CursorType: + NON_TAILABLE = 0 + """The standard cursor type.""" + + TAILABLE = _QUERY_OPTIONS["tailable_cursor"] + """The tailable cursor type. + + Tailable cursors are only for use with capped collections. They are not + closed when the last data is retrieved but are kept open and the cursor + location marks the final document position. If more data is received + iteration of the cursor will continue from the last document received. + """ + + TAILABLE_AWAIT = TAILABLE | _QUERY_OPTIONS["await_data"] + """A tailable cursor with the await option set. + + Creates a tailable cursor that will wait for a few seconds after returning + the full result set so that it can capture and return additional data added + during the query. + """ + + EXHAUST = _QUERY_OPTIONS["exhaust"] + """An exhaust cursor. + + MongoDB will stream batched results to the client without waiting for the + client to request each batch, reducing latency. + """ + + +class _ConnectionManager: + """Used with exhaust cursors to ensure the connection is returned.""" + + def __init__(self, conn: Connection, more_to_come: bool): + self.conn: Optional[Connection] = conn + self.more_to_come = more_to_come + self.lock = _create_lock() + + def update_exhaust(self, more_to_come: bool) -> None: + self.more_to_come = more_to_come + + def close(self) -> None: + """Return this instance's connection to the connection pool.""" + if self.conn: + self.conn.unpin() + self.conn = None + + +_Sort = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] +_Hint = Union[str, _Sort] + + +class Cursor(Generic[_DocumentType]): + """A cursor / iterator over Mongo query results.""" + + _query_class = _Query + _getmore_class = _GetMore + + def __init__( + self, + collection: Collection[_DocumentType], + filter: Optional[Mapping[str, Any]] = None, + projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, + skip: int = 0, + limit: int = 0, + no_cursor_timeout: bool = False, + cursor_type: int = CursorType.NON_TAILABLE, + sort: Optional[_Sort] = None, + allow_partial_results: bool = False, + oplog_replay: bool = False, + batch_size: int = 0, + collation: Optional[_CollationIn] = None, + hint: Optional[_Hint] = None, + max_scan: Optional[int] = None, + max_time_ms: Optional[int] = None, + max: Optional[_Sort] = None, + min: Optional[_Sort] = None, + return_key: Optional[bool] = None, + show_record_id: Optional[bool] = None, + snapshot: Optional[bool] = None, + comment: Optional[Any] = None, + session: Optional[ClientSession] = None, + allow_disk_use: Optional[bool] = None, + let: Optional[bool] = None, + ) -> None: + """Create a new cursor. + + Should not be called directly by application developers - see + :meth:`~pymongo.collection.Collection.find` instead. + + .. seealso:: The MongoDB documentation on `cursors `_. + """ + # Initialize all attributes used in __del__ before possibly raising + # an error to avoid attribute errors during garbage collection. + self.__collection: Collection[_DocumentType] = collection + self.__id: Any = None + self.__exhaust = False + self.__sock_mgr: Any = None + self.__killed = False + self.__session: Optional[ClientSession] + + if session: + self.__session = session + self.__explicit_session = True + else: + self.__session = None + self.__explicit_session = False + + spec: Mapping[str, Any] = filter or {} + validate_is_mapping("filter", spec) + if not isinstance(skip, int): + raise TypeError("skip must be an instance of int") + if not isinstance(limit, int): + raise TypeError("limit must be an instance of int") + validate_boolean("no_cursor_timeout", no_cursor_timeout) + if no_cursor_timeout and not self.__explicit_session: + warnings.warn( + "use an explicit session with no_cursor_timeout=True " + "otherwise the cursor may still timeout after " + "30 minutes, for more info see " + "https://mongodb.com/docs/v4.4/reference/method/" + "cursor.noCursorTimeout/" + "#session-idle-timeout-overrides-nocursortimeout", + UserWarning, + stacklevel=2, + ) + if cursor_type not in ( + CursorType.NON_TAILABLE, + CursorType.TAILABLE, + CursorType.TAILABLE_AWAIT, + CursorType.EXHAUST, + ): + raise ValueError("not a valid value for cursor_type") + validate_boolean("allow_partial_results", allow_partial_results) + validate_boolean("oplog_replay", oplog_replay) + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size < 0: + raise ValueError("batch_size must be >= 0") + # Only set if allow_disk_use is provided by the user, else None. + if allow_disk_use is not None: + allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) + + if projection is not None: + projection = helpers._fields_list_to_dict(projection, "projection") + + if let is not None: + validate_is_document_type("let", let) + + self.__let = let + self.__spec = spec + self.__has_filter = filter is not None + self.__projection = projection + self.__skip = skip + self.__limit = limit + self.__batch_size = batch_size + self.__ordering = sort and helpers._index_document(sort) or None + self.__max_scan = max_scan + self.__explain = False + self.__comment = comment + self.__max_time_ms = max_time_ms + self.__max_await_time_ms: Optional[int] = None + self.__max: Optional[Union[dict[Any, Any], _Sort]] = max + self.__min: Optional[Union[dict[Any, Any], _Sort]] = min + self.__collation = validate_collation_or_none(collation) + self.__return_key = return_key + self.__show_record_id = show_record_id + self.__allow_disk_use = allow_disk_use + self.__snapshot = snapshot + self.__hint: Union[str, dict[str, Any], None] + self.__set_hint(hint) + + # Exhaust cursor support + if cursor_type == CursorType.EXHAUST: + if self.__collection.database.client.is_mongos: + raise InvalidOperation("Exhaust cursors are not supported by mongos") + if limit: + raise InvalidOperation("Can't use limit and exhaust together.") + self.__exhaust = True + + # This is ugly. People want to be able to do cursor[5:5] and + # get an empty result set (old behavior was an + # exception). It's hard to do that right, though, because the + # server uses limit(0) to mean 'no limit'. So we set __empty + # in that case and check for it when iterating. We also unset + # it anytime we change __limit. + self.__empty = False + + self.__data: deque = deque() + self.__address: Optional[_Address] = None + self.__retrieved = 0 + + self.__codec_options = collection.codec_options + # Read preference is set when the initial find is sent. + self.__read_preference: Optional[_ServerMode] = None + self.__read_concern = collection.read_concern + + self.__query_flags = cursor_type + if no_cursor_timeout: + self.__query_flags |= _QUERY_OPTIONS["no_timeout"] + if allow_partial_results: + self.__query_flags |= _QUERY_OPTIONS["partial"] + if oplog_replay: + self.__query_flags |= _QUERY_OPTIONS["oplog_replay"] + + # The namespace to use for find/getMore commands. + self.__dbname = collection.database.name + self.__collname = collection.name + + @property + def collection(self) -> Collection[_DocumentType]: + """The :class:`~pymongo.collection.Collection` that this + :class:`Cursor` is iterating. + """ + return self.__collection + + @property + def retrieved(self) -> int: + """The number of documents retrieved so far.""" + return self.__retrieved + + def __del__(self) -> None: + self.__die() + + def rewind(self) -> Cursor[_DocumentType]: + """Rewind this cursor to its unevaluated state. + + Reset this cursor if it has been partially or completely evaluated. + Any options that are present on the cursor will remain in effect. + Future iterating performed on this cursor will cause new queries to + be sent to the server, even if the resultant data has already been + retrieved by this cursor. + """ + self.close() + self.__data = deque() + self.__id = None + self.__address = None + self.__retrieved = 0 + self.__killed = False + + return self + + def clone(self) -> Cursor[_DocumentType]: + """Get a clone of this cursor. + + Returns a new Cursor instance with options matching those that have + been set on the current instance. The clone will be completely + unevaluated, even if the current instance has been partially or + completely evaluated. + """ + return self._clone(True) + + def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: + """Internal clone helper.""" + if not base: + if self.__explicit_session: + base = self._clone_base(self.__session) + else: + base = self._clone_base(None) + + values_to_clone = ( + "spec", + "projection", + "skip", + "limit", + "max_time_ms", + "max_await_time_ms", + "comment", + "max", + "min", + "ordering", + "explain", + "hint", + "batch_size", + "max_scan", + "query_flags", + "collation", + "empty", + "show_record_id", + "return_key", + "allow_disk_use", + "snapshot", + "exhaust", + "has_filter", + ) + data = { + k: v + for k, v in self.__dict__.items() + if k.startswith("_Cursor__") and k[9:] in values_to_clone + } + if deepcopy: + data = self._deepcopy(data) + base.__dict__.update(data) + return base + + def _clone_base(self, session: Optional[ClientSession]) -> Cursor: + """Creates an empty Cursor object for information to be copied into.""" + return self.__class__(self.__collection, session=session) + + def __die(self, synchronous: bool = False) -> None: + """Closes this cursor.""" + try: + already_killed = self.__killed + except AttributeError: + # __init__ did not run to completion (or at all). + return + + self.__killed = True + if self.__id and not already_killed: + cursor_id = self.__id + assert self.__address is not None + address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}") + else: + # Skip killCursors. + cursor_id = 0 + address = None + self.__collection.database.client._cleanup_cursor( + synchronous, + cursor_id, + address, + self.__sock_mgr, + self.__session, + self.__explicit_session, + ) + if not self.__explicit_session: + self.__session = None + self.__sock_mgr = None + + def close(self) -> None: + """Explicitly close / kill this cursor.""" + self.__die(True) + + def __query_spec(self) -> Mapping[str, Any]: + """Get the spec to use for a query.""" + operators: dict[str, Any] = {} + if self.__ordering: + operators["$orderby"] = self.__ordering + if self.__explain: + operators["$explain"] = True + if self.__hint: + operators["$hint"] = self.__hint + if self.__let: + operators["let"] = self.__let + if self.__comment: + operators["$comment"] = self.__comment + if self.__max_scan: + operators["$maxScan"] = self.__max_scan + if self.__max_time_ms is not None: + operators["$maxTimeMS"] = self.__max_time_ms + if self.__max: + operators["$max"] = self.__max + if self.__min: + operators["$min"] = self.__min + if self.__return_key is not None: + operators["$returnKey"] = self.__return_key + if self.__show_record_id is not None: + # This is upgraded to showRecordId for MongoDB 3.2+ "find" command. + operators["$showDiskLoc"] = self.__show_record_id + if self.__snapshot is not None: + operators["$snapshot"] = self.__snapshot + + if operators: + # Make a shallow copy so we can cleanly rewind or clone. + spec = dict(self.__spec) + + # Allow-listed commands must be wrapped in $query. + if "$query" not in spec: + # $query has to come first + spec = {"$query": spec} + + spec.update(operators) + return spec + # Have to wrap with $query if "query" is the first key. + # We can't just use $query anytime "query" is a key as + # that breaks commands like count and find_and_modify. + # Checking spec.keys()[0] covers the case that the spec + # was passed as an instance of SON or OrderedDict. + elif "query" in self.__spec and ( + len(self.__spec) == 1 or next(iter(self.__spec)) == "query" + ): + return {"$query": self.__spec} + + return self.__spec + + def __check_okay_to_chain(self) -> None: + """Check if it is okay to chain more options onto this cursor.""" + if self.__retrieved or self.__id is not None: + raise InvalidOperation("cannot set options after executing query") + + def add_option(self, mask: int) -> Cursor[_DocumentType]: + """Set arbitrary query flags using a bitmask. + + To set the tailable flag: + cursor.add_option(2) + """ + if not isinstance(mask, int): + raise TypeError("mask must be an int") + self.__check_okay_to_chain() + + if mask & _QUERY_OPTIONS["exhaust"]: + if self.__limit: + raise InvalidOperation("Can't use limit and exhaust together.") + if self.__collection.database.client.is_mongos: + raise InvalidOperation("Exhaust cursors are not supported by mongos") + self.__exhaust = True + + self.__query_flags |= mask + return self + + def remove_option(self, mask: int) -> Cursor[_DocumentType]: + """Unset arbitrary query flags using a bitmask. + + To unset the tailable flag: + cursor.remove_option(2) + """ + if not isinstance(mask, int): + raise TypeError("mask must be an int") + self.__check_okay_to_chain() + + if mask & _QUERY_OPTIONS["exhaust"]: + self.__exhaust = False + + self.__query_flags &= ~mask + return self + + def allow_disk_use(self, allow_disk_use: bool) -> Cursor[_DocumentType]: + """Specifies whether MongoDB can use temporary disk files while + processing a blocking sort operation. + + Raises :exc:`TypeError` if `allow_disk_use` is not a boolean. + + .. note:: `allow_disk_use` requires server version **>= 4.4** + + :param allow_disk_use: if True, MongoDB may use temporary + disk files to store data exceeding the system memory limit while + processing a blocking sort operation. + + .. versionadded:: 3.11 + """ + if not isinstance(allow_disk_use, bool): + raise TypeError("allow_disk_use must be a bool") + self.__check_okay_to_chain() + + self.__allow_disk_use = allow_disk_use + return self + + def limit(self, limit: int) -> Cursor[_DocumentType]: + """Limits the number of results to be returned by this cursor. + + Raises :exc:`TypeError` if `limit` is not an integer. Raises + :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` + has already been used. The last `limit` applied to this cursor + takes precedence. A limit of ``0`` is equivalent to no limit. + + :param limit: the number of results to return + + .. seealso:: The MongoDB documentation on `limit `_. + """ + if not isinstance(limit, int): + raise TypeError("limit must be an integer") + if self.__exhaust: + raise InvalidOperation("Can't use limit and exhaust together.") + self.__check_okay_to_chain() + + self.__empty = False + self.__limit = limit + return self + + def batch_size(self, batch_size: int) -> Cursor[_DocumentType]: + """Limits the number of documents returned in one batch. Each batch + requires a round trip to the server. It can be adjusted to optimize + performance and limit data transfer. + + .. note:: batch_size can not override MongoDB's internal limits on the + amount of data it will return to the client in a single batch (i.e + if you set batch size to 1,000,000,000, MongoDB will currently only + return 4-16MB of results per batch). + + Raises :exc:`TypeError` if `batch_size` is not an integer. + Raises :exc:`ValueError` if `batch_size` is less than ``0``. + Raises :exc:`~pymongo.errors.InvalidOperation` if this + :class:`Cursor` has already been used. The last `batch_size` + applied to this cursor takes precedence. + + :param batch_size: The size of each batch of results requested. + """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size < 0: + raise ValueError("batch_size must be >= 0") + self.__check_okay_to_chain() + + self.__batch_size = batch_size + return self + + def skip(self, skip: int) -> Cursor[_DocumentType]: + """Skips the first `skip` results of this cursor. + + Raises :exc:`TypeError` if `skip` is not an integer. Raises + :exc:`ValueError` if `skip` is less than ``0``. Raises + :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has + already been used. The last `skip` applied to this cursor takes + precedence. + + :param skip: the number of results to skip + """ + if not isinstance(skip, int): + raise TypeError("skip must be an integer") + if skip < 0: + raise ValueError("skip must be >= 0") + self.__check_okay_to_chain() + + self.__skip = skip + return self + + def max_time_ms(self, max_time_ms: Optional[int]) -> Cursor[_DocumentType]: + """Specifies a time limit for a query operation. If the specified + time is exceeded, the operation will be aborted and + :exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms` + is ``None`` no limit is applied. + + Raises :exc:`TypeError` if `max_time_ms` is not an integer or ``None``. + Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` + has already been used. + + :param max_time_ms: the time limit after which the operation is aborted + """ + if not isinstance(max_time_ms, int) and max_time_ms is not None: + raise TypeError("max_time_ms must be an integer or None") + self.__check_okay_to_chain() + + self.__max_time_ms = max_time_ms + return self + + def max_await_time_ms(self, max_await_time_ms: Optional[int]) -> Cursor[_DocumentType]: + """Specifies a time limit for a getMore operation on a + :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other + types of cursor max_await_time_ms is ignored. + + Raises :exc:`TypeError` if `max_await_time_ms` is not an integer or + ``None``. Raises :exc:`~pymongo.errors.InvalidOperation` if this + :class:`Cursor` has already been used. + + .. note:: `max_await_time_ms` requires server version **>= 3.2** + + :param max_await_time_ms: the time limit after which the operation is + aborted + + .. versionadded:: 3.2 + """ + if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: + raise TypeError("max_await_time_ms must be an integer or None") + self.__check_okay_to_chain() + + # Ignore max_await_time_ms if not tailable or await_data is False. + if self.__query_flags & CursorType.TAILABLE_AWAIT: + self.__max_await_time_ms = max_await_time_ms + + return self + + @overload + def __getitem__(self, index: int) -> _DocumentType: + ... + + @overload + def __getitem__(self, index: slice) -> Cursor[_DocumentType]: + ... + + def __getitem__(self, index: Union[int, slice]) -> Union[_DocumentType, Cursor[_DocumentType]]: + """Get a single document or a slice of documents from this cursor. + + .. warning:: A :class:`~Cursor` is not a Python :class:`list`. Each + index access or slice requires that a new query be run using skip + and limit. Do not iterate the cursor using index accesses. + The following example is **extremely inefficient** and may return + surprising results:: + + cursor = db.collection.find() + # Warning: This runs a new query for each document. + # Don't do this! + for idx in range(10): + print(cursor[idx]) + + Raises :class:`~pymongo.errors.InvalidOperation` if this + cursor has already been used. + + To get a single document use an integral index, e.g.:: + + >>> db.test.find()[50] + + An :class:`IndexError` will be raised if the index is negative + or greater than the amount of documents in this cursor. Any + limit previously applied to this cursor will be ignored. + + To get a slice of documents use a slice index, e.g.:: + + >>> db.test.find()[20:25] + + This will return this cursor with a limit of ``5`` and skip of + ``20`` applied. Using a slice index will override any prior + limits or skips applied to this cursor (including those + applied through previous calls to this method). Raises + :class:`IndexError` when the slice has a step, a negative + start value, or a stop value less than or equal to the start + value. + + :param index: An integer or slice index to be applied to this cursor + """ + self.__check_okay_to_chain() + self.__empty = False + if isinstance(index, slice): + if index.step is not None: + raise IndexError("Cursor instances do not support slice steps") + + skip = 0 + if index.start is not None: + if index.start < 0: + raise IndexError("Cursor instances do not support negative indices") + skip = index.start + + if index.stop is not None: + limit = index.stop - skip + if limit < 0: + raise IndexError( + "stop index must be greater than start index for slice %r" % index + ) + if limit == 0: + self.__empty = True + else: + limit = 0 + + self.__skip = skip + self.__limit = limit + return self + + if isinstance(index, int): + if index < 0: + raise IndexError("Cursor instances do not support negative indices") + clone = self.clone() + clone.skip(index + self.__skip) + clone.limit(-1) # use a hard limit + clone.__query_flags &= ~CursorType.TAILABLE_AWAIT # PYTHON-1371 + for doc in clone: + return doc + raise IndexError("no such item for Cursor instance") + raise TypeError("index %r cannot be applied to Cursor instances" % index) + + def max_scan(self, max_scan: Optional[int]) -> Cursor[_DocumentType]: + """**DEPRECATED** - Limit the number of documents to scan when + performing the query. + + Raises :class:`~pymongo.errors.InvalidOperation` if this + cursor has already been used. Only the last :meth:`max_scan` + applied to this cursor has any effect. + + :param max_scan: the maximum number of documents to scan + + .. versionchanged:: 3.7 + Deprecated :meth:`max_scan`. Support for this option is deprecated in + MongoDB 4.0. Use :meth:`max_time_ms` instead to limit server side + execution time. + """ + self.__check_okay_to_chain() + self.__max_scan = max_scan + return self + + def max(self, spec: _Sort) -> Cursor[_DocumentType]: + """Adds ``max`` operator that specifies upper bound for specific index. + + When using ``max``, :meth:`~hint` should also be configured to ensure + the query uses the expected index and starting in MongoDB 4.2 + :meth:`~hint` will be required. + + :param spec: a list of field, limit pairs specifying the exclusive + upper bound for all keys of a specific index in order. + + .. versionchanged:: 3.8 + Deprecated cursors that use ``max`` without a :meth:`~hint`. + + .. versionadded:: 2.7 + """ + if not isinstance(spec, (list, tuple)): + raise TypeError("spec must be an instance of list or tuple") + + self.__check_okay_to_chain() + self.__max = dict(spec) + return self + + def min(self, spec: _Sort) -> Cursor[_DocumentType]: + """Adds ``min`` operator that specifies lower bound for specific index. + + When using ``min``, :meth:`~hint` should also be configured to ensure + the query uses the expected index and starting in MongoDB 4.2 + :meth:`~hint` will be required. + + :param spec: a list of field, limit pairs specifying the inclusive + lower bound for all keys of a specific index in order. + + .. versionchanged:: 3.8 + Deprecated cursors that use ``min`` without a :meth:`~hint`. + + .. versionadded:: 2.7 + """ + if not isinstance(spec, (list, tuple)): + raise TypeError("spec must be an instance of list or tuple") + + self.__check_okay_to_chain() + self.__min = dict(spec) + return self + + def sort( + self, key_or_list: _Hint, direction: Optional[Union[int, str]] = None + ) -> Cursor[_DocumentType]: + """Sorts this cursor's results. + + Pass a field name and a direction, either + :data:`~pymongo.ASCENDING` or :data:`~pymongo.DESCENDING`.:: + + for doc in collection.find().sort('field', pymongo.ASCENDING): + print(doc) + + To sort by multiple fields, pass a list of (key, direction) pairs. + If just a name is given, :data:`~pymongo.ASCENDING` will be inferred:: + + for doc in collection.find().sort([ + 'field1', + ('field2', pymongo.DESCENDING)]): + print(doc) + + Text search results can be sorted by relevance:: + + cursor = db.test.find( + {'$text': {'$search': 'some words'}}, + {'score': {'$meta': 'textScore'}}) + + # Sort by 'score' field. + cursor.sort([('score', {'$meta': 'textScore'})]) + + for doc in cursor: + print(doc) + + For more advanced text search functionality, see MongoDB's + `Atlas Search `_. + + Raises :class:`~pymongo.errors.InvalidOperation` if this cursor has + already been used. Only the last :meth:`sort` applied to this + cursor has any effect. + + :param key_or_list: a single key or a list of (key, direction) + pairs specifying the keys to sort on + :param direction: only used if `key_or_list` is a single + key, if not given :data:`~pymongo.ASCENDING` is assumed + """ + self.__check_okay_to_chain() + keys = helpers._index_list(key_or_list, direction) + self.__ordering = helpers._index_document(keys) + return self + + def distinct(self, key: str) -> list: + """Get a list of distinct values for `key` among all documents + in the result set of this query. + + Raises :class:`TypeError` if `key` is not an instance of + :class:`str`. + + The :meth:`distinct` method obeys the + :attr:`~pymongo.collection.Collection.read_preference` of the + :class:`~pymongo.collection.Collection` instance on which + :meth:`~pymongo.collection.Collection.find` was called. + + :param key: name of key for which we want to get the distinct values + + .. seealso:: :meth:`pymongo.collection.Collection.distinct` + """ + options: dict[str, Any] = {} + if self.__spec: + options["query"] = self.__spec + if self.__max_time_ms is not None: + options["maxTimeMS"] = self.__max_time_ms + if self.__comment: + options["comment"] = self.__comment + if self.__collation is not None: + options["collation"] = self.__collation + + return self.__collection.distinct(key, session=self.__session, **options) + + def explain(self) -> _DocumentType: + """Returns an explain plan record for this cursor. + + .. note:: This method uses the default verbosity mode of the + `explain command + `_, + ``allPlansExecution``. To use a different verbosity use + :meth:`~pymongo.database.Database.command` to run the explain + command directly. + + .. seealso:: The MongoDB documentation on `explain `_. + """ + c = self.clone() + c.__explain = True + + # always use a hard limit for explains + if c.__limit: + c.__limit = -abs(c.__limit) + return next(c) + + def __set_hint(self, index: Optional[_Hint]) -> None: + if index is None: + self.__hint = None + return + + if isinstance(index, str): + self.__hint = index + else: + self.__hint = helpers._index_document(index) + + def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]: + """Adds a 'hint', telling Mongo the proper index to use for the query. + + Judicious use of hints can greatly improve query + performance. When doing a query on multiple fields (at least + one of which is indexed) pass the indexed field as a hint to + the query. Raises :class:`~pymongo.errors.OperationFailure` if the + provided hint requires an index that does not exist on this collection, + and raises :class:`~pymongo.errors.InvalidOperation` if this cursor has + already been used. + + `index` should be an index as passed to + :meth:`~pymongo.collection.Collection.create_index` + (e.g. ``[('field', ASCENDING)]``) or the name of the index. + If `index` is ``None`` any existing hint for this query is + cleared. The last hint applied to this cursor takes precedence + over all others. + + :param index: index to hint on (as an index specifier) + """ + self.__check_okay_to_chain() + self.__set_hint(index) + return self + + def comment(self, comment: Any) -> Cursor[_DocumentType]: + """Adds a 'comment' to the cursor. + + http://mongodb.com/docs/manual/reference/operator/comment/ + + :param comment: A string to attach to the query to help interpret and + trace the operation in the server logs and in profile data. + + .. versionadded:: 2.7 + """ + self.__check_okay_to_chain() + self.__comment = comment + return self + + def where(self, code: Union[str, Code]) -> Cursor[_DocumentType]: + """Adds a `$where`_ clause to this query. + + The `code` argument must be an instance of :class:`str` or + :class:`~bson.code.Code` containing a JavaScript expression. + This expression will be evaluated for each document scanned. + Only those documents for which the expression evaluates to + *true* will be returned as results. The keyword *this* refers + to the object currently being scanned. For example:: + + # Find all documents where field "a" is less than "b" plus "c". + for doc in db.test.find().where('this.a < (this.b + this.c)'): + print(doc) + + Raises :class:`TypeError` if `code` is not an instance of + :class:`str`. Raises :class:`~pymongo.errors.InvalidOperation` if this + :class:`Cursor` has already been used. Only the last call to + :meth:`where` applied to a :class:`Cursor` has any effect. + + .. note:: MongoDB 4.4 drops support for :class:`~bson.code.Code` + with scope variables. Consider using `$expr`_ instead. + + :param code: JavaScript expression to use as a filter + + .. _$expr: https://mongodb.com/docs/manual/reference/operator/query/expr/ + .. _$where: https://mongodb.com/docs/manual/reference/operator/query/where/ + """ + self.__check_okay_to_chain() + if not isinstance(code, Code): + code = Code(code) + + # Avoid overwriting a filter argument that was given by the user + # when updating the spec. + spec: dict[str, Any] + if self.__has_filter: + spec = dict(self.__spec) + else: + spec = cast(dict, self.__spec) + spec["$where"] = code + self.__spec = spec + return self + + def collation(self, collation: Optional[_CollationIn]) -> Cursor[_DocumentType]: + """Adds a :class:`~pymongo.collation.Collation` to this query. + + Raises :exc:`TypeError` if `collation` is not an instance of + :class:`~pymongo.collation.Collation` or a ``dict``. Raises + :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has + already been used. Only the last collation applied to this cursor has + any effect. + + :param collation: An instance of :class:`~pymongo.collation.Collation`. + """ + self.__check_okay_to_chain() + self.__collation = validate_collation_or_none(collation) + return self + + def __send_message(self, operation: Union[_Query, _GetMore]) -> None: + """Send a query or getmore operation and handles the response. + + If operation is ``None`` this is an exhaust cursor, which reads + the next result batch off the exhaust socket instead of + sending getMore messages to the server. + + Can raise ConnectionFailure. + """ + client = self.__collection.database.client + # OP_MSG is required to support exhaust cursors with encryption. + if client._encrypter and self.__exhaust: + raise InvalidOperation("exhaust cursors do not support auto encryption") + + try: + response = client._run_operation( + operation, self._unpack_response, address=self.__address + ) + except OperationFailure as exc: + if exc.code in _CURSOR_CLOSED_ERRORS or self.__exhaust: + # Don't send killCursors because the cursor is already closed. + self.__killed = True + if exc.timeout: + self.__die(False) + else: + self.close() + # If this is a tailable cursor the error is likely + # due to capped collection roll over. Setting + # self.__killed to True ensures Cursor.alive will be + # False. No need to re-raise. + if ( + exc.code in _CURSOR_CLOSED_ERRORS + and self.__query_flags & _QUERY_OPTIONS["tailable_cursor"] + ): + return + raise + except ConnectionFailure: + self.__killed = True + self.close() + raise + except Exception: + self.close() + raise + + self.__address = response.address + if isinstance(response, PinnedResponse): + if not self.__sock_mgr: + self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + + cmd_name = operation.name + docs = response.docs + if response.from_command: + if cmd_name != "explain": + cursor = docs[0]["cursor"] + self.__id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] + # Update the namespace used for future getMore commands. + ns = cursor.get("ns") + if ns: + self.__dbname, self.__collname = ns.split(".", 1) + else: + documents = cursor["nextBatch"] + self.__data = deque(documents) + self.__retrieved += len(documents) + else: + self.__id = 0 + self.__data = deque(docs) + self.__retrieved += len(docs) + else: + assert isinstance(response.data, _OpReply) + self.__id = response.data.cursor_id + self.__data = deque(docs) + self.__retrieved += response.data.number_returned + + if self.__id == 0: + # Don't wait for garbage collection to call __del__, return the + # socket and the session to the pool now. + self.close() + + if self.__limit and self.__id and self.__limit <= self.__retrieved: + self.close() + + def _unpack_response( + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> Sequence[_DocumentOut]: + return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) + + def _read_preference(self) -> _ServerMode: + if self.__read_preference is None: + # Save the read preference for getMore commands. + self.__read_preference = self.__collection._read_preference_for(self.session) + return self.__read_preference + + def _refresh(self) -> int: + """Refreshes the cursor with more data from Mongo. + + Returns the length of self.__data after refresh. Will exit early if + self.__data is already non-empty. Raises OperationFailure when the + cursor cannot be refreshed due to an error on the query. + """ + if len(self.__data) or self.__killed: + return len(self.__data) + + if not self.__session: + self.__session = self.__collection.database.client._ensure_session() + + if self.__id is None: # Query + if (self.__min or self.__max) and not self.__hint: + raise InvalidOperation( + "Passing a 'hint' is required when using the min/max query" + " option to ensure the query utilizes the correct index" + ) + q = self._query_class( + self.__query_flags, + self.__collection.database.name, + self.__collection.name, + self.__skip, + self.__query_spec(), + self.__projection, + self.__codec_options, + self._read_preference(), + self.__limit, + self.__batch_size, + self.__read_concern, + self.__collation, + self.__session, + self.__collection.database.client, + self.__allow_disk_use, + self.__exhaust, + ) + self.__send_message(q) + elif self.__id: # Get More + if self.__limit: + limit = self.__limit - self.__retrieved + if self.__batch_size: + limit = min(limit, self.__batch_size) + else: + limit = self.__batch_size + # Exhaust cursors don't send getMore messages. + g = self._getmore_class( + self.__dbname, + self.__collname, + limit, + self.__id, + self.__codec_options, + self._read_preference(), + self.__session, + self.__collection.database.client, + self.__max_await_time_ms, + self.__sock_mgr, + self.__exhaust, + self.__comment, + ) + self.__send_message(g) + + return len(self.__data) + + @property + def alive(self) -> bool: + """Does this cursor have the potential to return more data? + + This is mostly useful with `tailable cursors + `_ + since they will stop iterating even though they *may* return more + results in the future. + + With regular cursors, simply use a for loop instead of :attr:`alive`:: + + for doc in collection.find(): + print(doc) + + .. note:: Even if :attr:`alive` is True, :meth:`next` can raise + :exc:`StopIteration`. :attr:`alive` can also be True while iterating + a cursor from a failed server. In this case :attr:`alive` will + return False after :meth:`next` fails to retrieve the next batch + of results from the server. + """ + return bool(len(self.__data) or (not self.__killed)) + + @property + def cursor_id(self) -> Optional[int]: + """Returns the id of the cursor + + .. versionadded:: 2.2 + """ + return self.__id + + @property + def address(self) -> Optional[tuple[str, Any]]: + """The (host, port) of the server used, or None. + + .. versionchanged:: 3.0 + Renamed from "conn_id". + """ + return self.__address + + @property + def session(self) -> Optional[ClientSession]: + """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + + .. versionadded:: 3.6 + """ + if self.__explicit_session: + return self.__session + return None + + def __iter__(self) -> Cursor[_DocumentType]: + return self + + def next(self) -> _DocumentType: + """Advance the cursor.""" + if self.__empty: + raise StopIteration + if len(self.__data) or self._refresh(): + return self.__data.popleft() + else: + raise StopIteration + + __next__ = next + + def __enter__(self) -> Cursor[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def __copy__(self) -> Cursor[_DocumentType]: + """Support function for `copy.copy()`. + + .. versionadded:: 2.4 + """ + return self._clone(deepcopy=False) + + def __deepcopy__(self, memo: Any) -> Any: + """Support function for `copy.deepcopy()`. + + .. versionadded:: 2.4 + """ + return self._clone(deepcopy=True) + + @overload + def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: + ... + + @overload + def _deepcopy( + self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None + ) -> dict: + ... + + def _deepcopy( + self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None + ) -> Union[list, dict]: + """Deepcopy helper for the data dictionary or list. + + Regular expressions cannot be deep copied but as they are immutable we + don't have to copy them when cloning. + """ + y: Union[list, dict] + iterator: Iterable[tuple[Any, Any]] + if not hasattr(x, "items"): + y, is_list, iterator = [], True, enumerate(x) + else: + y, is_list, iterator = {}, False, cast("SupportsItems", x).items() + if memo is None: + memo = {} + val_id = id(x) + if val_id in memo: + return memo[val_id] + memo[val_id] = y + + for key, value in iterator: + if isinstance(value, (dict, list)) and not isinstance(value, SON): + value = self._deepcopy(value, memo) # noqa: PLW2901 + elif not isinstance(value, RE_TYPE): + value = copy.deepcopy(value, memo) # noqa: PLW2901 + + if is_list: + y.append(value) # type: ignore[union-attr] + else: + if not isinstance(key, RE_TYPE): + key = copy.deepcopy(key, memo) # noqa: PLW2901 + y[key] = value + return y + + +class RawBatchCursor(Cursor, Generic[_DocumentType]): + """A cursor / iterator over raw batches of BSON data from a query result.""" + + _query_class = _RawBatchQuery + _getmore_class = _RawBatchGetMore + + def __init__(self, collection: Collection[_DocumentType], *args: Any, **kwargs: Any) -> None: + """Create a new cursor / iterator over raw batches of BSON data. + + Should not be called directly by application developers - + see :meth:`~pymongo.collection.Collection.find_raw_batches` + instead. + + .. seealso:: The MongoDB documentation on `cursors `_. + """ + super().__init__(collection, *args, **kwargs) + + def _unpack_response( + self, + response: Union[_OpReply, _OpMsg], + cursor_id: Optional[int], + codec_options: CodecOptions[Mapping[str, Any]], + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[_DocumentOut]: + raw_response = response.raw_response(cursor_id, user_fields=user_fields) + if not legacy_response: + # OP_MSG returns firstBatch/nextBatch documents as a BSON array + # Re-assemble the array of documents into a document stream + _convert_raw_document_lists_to_streams(raw_response[0]) + return cast(List["_DocumentOut"], raw_response) + + def explain(self) -> _DocumentType: + """Returns an explain plan record for this cursor. + + .. seealso:: The MongoDB documentation on `explain `_. + """ + clone = self._clone(deepcopy=True, base=Cursor(self.collection)) + return clone.explain() + + def __getitem__(self, index: Any) -> NoReturn: + raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") diff --git a/venv/Lib/site-packages/pymongo/daemon.py b/venv/Lib/site-packages/pymongo/daemon.py new file mode 100644 index 00000000..b40384df --- /dev/null +++ b/venv/Lib/site-packages/pymongo/daemon.py @@ -0,0 +1,148 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support for spawning a daemon process. + +PyMongo only attempts to spawn the mongocryptd daemon process when automatic +client-side field level encryption is enabled. See +:ref:`automatic-client-side-encryption` for more info. +""" +from __future__ import annotations + +import os +import subprocess +import sys +import warnings +from typing import Any, Optional, Sequence + +# The maximum amount of time to wait for the intermediate subprocess. +_WAIT_TIMEOUT = 10 +_THIS_FILE = os.path.realpath(__file__) + + +def _popen_wait(popen: subprocess.Popen[Any], timeout: Optional[float]) -> Optional[int]: + """Implement wait timeout support for Python 3.""" + try: + return popen.wait(timeout=timeout) + except subprocess.TimeoutExpired: + # Silence TimeoutExpired errors. + return None + + +def _silence_resource_warning(popen: Optional[subprocess.Popen[Any]]) -> None: + """Silence Popen's ResourceWarning. + + Note this should only be used if the process was created as a daemon. + """ + # Set the returncode to avoid this warning when popen is garbage collected: + # "ResourceWarning: subprocess XXX is still running". + # See https://bugs.python.org/issue38890 and + # https://bugs.python.org/issue26741. + # popen is None when mongocryptd spawning fails + if popen is not None: + popen.returncode = 0 + + +if sys.platform == "win32": + # On Windows we spawn the daemon process simply by using DETACHED_PROCESS. + _DETACHED_PROCESS = getattr(subprocess, "DETACHED_PROCESS", 0x00000008) + + def _spawn_daemon(args: Sequence[str]) -> None: + """Spawn a daemon process (Windows).""" + try: + with open(os.devnull, "r+b") as devnull: + popen = subprocess.Popen( + args, # noqa: S603 + creationflags=_DETACHED_PROCESS, + stdin=devnull, + stderr=devnull, + stdout=devnull, + ) + _silence_resource_warning(popen) + except FileNotFoundError as exc: + warnings.warn( + f"Failed to start {args[0]}: is it on your $PATH?\nOriginal exception: {exc}", + RuntimeWarning, + stacklevel=2, + ) + +else: + # On Unix we spawn the daemon process with a double Popen. + # 1) The first Popen runs this file as a Python script using the current + # interpreter. + # 2) The script then decouples itself and performs the second Popen to + # spawn the daemon process. + # 3) The original process waits up to 10 seconds for the script to exit. + # + # Note that we do not call fork() directly because we want this procedure + # to be safe to call from any thread. Using Popen instead of fork also + # avoids triggering the application's os.register_at_fork() callbacks when + # we spawn the mongocryptd daemon process. + def _spawn(args: Sequence[str]) -> Optional[subprocess.Popen[Any]]: + """Spawn the process and silence stdout/stderr.""" + try: + with open(os.devnull, "r+b") as devnull: + return subprocess.Popen( + args, # noqa: S603 + close_fds=True, + stdin=devnull, + stderr=devnull, + stdout=devnull, + ) + except FileNotFoundError as exc: + warnings.warn( + f"Failed to start {args[0]}: is it on your $PATH?\nOriginal exception: {exc}", + RuntimeWarning, + stacklevel=2, + ) + return None + + def _spawn_daemon_double_popen(args: Sequence[str]) -> None: + """Spawn a daemon process using a double subprocess.Popen.""" + spawner_args = [sys.executable, _THIS_FILE] + spawner_args.extend(args) + temp_proc = subprocess.Popen(spawner_args, close_fds=True) # noqa: S603 + # Reap the intermediate child process to avoid creating zombie + # processes. + _popen_wait(temp_proc, _WAIT_TIMEOUT) + + def _spawn_daemon(args: Sequence[str]) -> None: + """Spawn a daemon process (Unix).""" + # "If Python is unable to retrieve the real path to its executable, + # sys.executable will be an empty string or None". + if sys.executable: + _spawn_daemon_double_popen(args) + else: + # Fallback to spawn a non-daemon process without silencing the + # resource warning. We do not use fork here because it is not + # safe to call from a thread on all systems. + # Unfortunately, this means that: + # 1) If the parent application is killed via Ctrl-C, the + # non-daemon process will also be killed. + # 2) Each non-daemon process will hang around as a zombie process + # until the main application exits. + _spawn(args) + + if __name__ == "__main__": + # Attempt to start a new session to decouple from the parent. + if hasattr(os, "setsid"): + try: + os.setsid() + except OSError: + pass + + # We are performing a double fork (Popen) to spawn the process as a + # daemon so it is safe to ignore the resource warning. + _silence_resource_warning(_spawn(sys.argv[1:])) + os._exit(0) diff --git a/venv/Lib/site-packages/pymongo/database.py b/venv/Lib/site-packages/pymongo/database.py new file mode 100644 index 00000000..70580694 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/database.py @@ -0,0 +1,1388 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database level operations.""" +from __future__ import annotations + +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + TypeVar, + Union, + cast, + overload, +) + +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions +from bson.dbref import DBRef +from bson.timestamp import Timestamp +from pymongo import _csot, common +from pymongo.aggregation import _DatabaseAggregationCommand +from pymongo.change_stream import DatabaseChangeStream +from pymongo.collection import Collection +from pymongo.command_cursor import CommandCursor +from pymongo.common import _ecoc_coll_name, _esc_coll_name +from pymongo.errors import CollectionInvalid, InvalidName, InvalidOperation +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline + +if TYPE_CHECKING: + import bson + import bson.codec_options + from pymongo.client_session import ClientSession + from pymongo.mongo_client import MongoClient + from pymongo.pool import Connection + from pymongo.read_concern import ReadConcern + from pymongo.server import Server + from pymongo.write_concern import WriteConcern + + +def _check_name(name: str) -> None: + """Check if a database name is valid.""" + if not name: + raise InvalidName("database name cannot be the empty string") + + for invalid_char in [" ", ".", "$", "/", "\\", "\x00", '"']: + if invalid_char in name: + raise InvalidName("database names cannot contain the character %r" % invalid_char) + + +_CodecDocumentType = TypeVar("_CodecDocumentType", bound=Mapping[str, Any]) + + +class Database(common.BaseObject, Generic[_DocumentType]): + """A Mongo database.""" + + def __init__( + self, + client: MongoClient[_DocumentType], + name: str, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> None: + """Get a database by client and name. + + Raises :class:`TypeError` if `name` is not an instance of + :class:`str`. Raises :class:`~pymongo.errors.InvalidName` if + `name` is not a valid database name. + + :param client: A :class:`~pymongo.mongo_client.MongoClient` instance. + :param name: The database name. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) client.codec_options is used. + :param read_preference: The read preference to use. If + ``None`` (the default) client.read_preference is used. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) client.write_concern is used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) client.read_concern is used. + + .. seealso:: The MongoDB documentation on `databases `_. + + .. versionchanged:: 4.0 + Removed the eval, system_js, error, last_status, previous_error, + reset_error_history, authenticate, logout, collection_names, + current_op, add_user, remove_user, profiling_level, + set_profiling_level, and profiling_info methods. + See the :ref:`pymongo4-migration-guide`. + + .. versionchanged:: 3.2 + Added the read_concern option. + + .. versionchanged:: 3.0 + Added the codec_options, read_preference, and write_concern options. + :class:`~pymongo.database.Database` no longer returns an instance + of :class:`~pymongo.collection.Collection` for attribute names + with leading underscores. You must use dict-style lookups instead:: + + db['__my_collection__'] + + Not: + + db.__my_collection__ + """ + super().__init__( + codec_options or client.codec_options, + read_preference or client.read_preference, + write_concern or client.write_concern, + read_concern or client.read_concern, + ) + + if not isinstance(name, str): + raise TypeError("name must be an instance of str") + + if name != "$external": + _check_name(name) + + self.__name = name + self.__client: MongoClient[_DocumentType] = client + self._timeout = client.options.timeout + + @property + def client(self) -> MongoClient[_DocumentType]: + """The client instance for this :class:`Database`.""" + return self.__client + + @property + def name(self) -> str: + """The name of this :class:`Database`.""" + return self.__name + + def with_options( + self, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> Database[_DocumentType]: + """Get a clone of this database changing the specified settings. + + >>> db1.read_preference + Primary() + >>> from pymongo.read_preferences import Secondary + >>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}])) + >>> db1.read_preference + Primary() + >>> db2.read_preference + Secondary(tag_sets=[{'node': 'analytics'}], max_staleness=-1, hedge=None) + + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`Collection` + is used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`Collection` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`Collection` + is used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`Collection` + is used. + + .. versionadded:: 3.8 + """ + return Database( + self.client, + self.__name, + codec_options or self.codec_options, + read_preference or self.read_preference, + write_concern or self.write_concern, + read_concern or self.read_concern, + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Database): + return self.__client == other.client and self.__name == other.name + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + return hash((self.__client, self.__name)) + + def __repr__(self) -> str: + return f"Database({self.__client!r}, {self.__name!r})" + + def __getattr__(self, name: str) -> Collection[_DocumentType]: + """Get a collection of this database by name. + + Raises InvalidName if an invalid collection name is used. + + :param name: the name of the collection to get + """ + if name.startswith("_"): + raise AttributeError( + f"Database has no attribute {name!r}. To access the {name}" + f" collection, use database[{name!r}]." + ) + return self.__getitem__(name) + + def __getitem__(self, name: str) -> Collection[_DocumentType]: + """Get a collection of this database by name. + + Raises InvalidName if an invalid collection name is used. + + :param name: the name of the collection to get + """ + return Collection(self, name) + + def get_collection( + self, + name: str, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> Collection[_DocumentType]: + """Get a :class:`~pymongo.collection.Collection` with the given name + and options. + + Useful for creating a :class:`~pymongo.collection.Collection` with + different codec options, read preference, and/or write concern from + this :class:`Database`. + + >>> db.read_preference + Primary() + >>> coll1 = db.test + >>> coll1.read_preference + Primary() + >>> from pymongo import ReadPreference + >>> coll2 = db.get_collection( + ... 'test', read_preference=ReadPreference.SECONDARY) + >>> coll2.read_preference + Secondary(tag_sets=None) + + :param name: The name of the collection - a string. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`Database` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`Database` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`Database` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`Database` is + used. + """ + return Collection( + self, + name, + False, + codec_options, + read_preference, + write_concern, + read_concern, + ) + + def _get_encrypted_fields( + self, kwargs: Mapping[str, Any], coll_name: str, ask_db: bool + ) -> Optional[Mapping[str, Any]]: + encrypted_fields = kwargs.get("encryptedFields") + if encrypted_fields: + return cast(Mapping[str, Any], deepcopy(encrypted_fields)) + if ( + self.client.options.auto_encryption_opts + and self.client.options.auto_encryption_opts._encrypted_fields_map + and self.client.options.auto_encryption_opts._encrypted_fields_map.get( + f"{self.name}.{coll_name}" + ) + ): + return cast( + Mapping[str, Any], + deepcopy( + self.client.options.auto_encryption_opts._encrypted_fields_map[ + f"{self.name}.{coll_name}" + ] + ), + ) + if ask_db and self.client.options.auto_encryption_opts: + options = self[coll_name].options() + if options.get("encryptedFields"): + return cast(Mapping[str, Any], deepcopy(options["encryptedFields"])) + return None + + @_csot.apply + def create_collection( + self, + name: str, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + session: Optional[ClientSession] = None, + check_exists: Optional[bool] = True, + **kwargs: Any, + ) -> Collection[_DocumentType]: + """Create a new :class:`~pymongo.collection.Collection` in this + database. + + Normally collection creation is automatic. This method should + only be used to specify options on + creation. :class:`~pymongo.errors.CollectionInvalid` will be + raised if the collection already exists. + + :param name: the name of the collection to create + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`Database` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`Database` is used. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`Database` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`Database` is + used. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param `check_exists`: if True (the default), send a listCollections command to + check if the collection already exists before creation. + :param kwargs: additional keyword arguments will + be passed as options for the `create collection command`_ + + All optional `create collection command`_ parameters should be passed + as keyword arguments to this method. Valid options include, but are not + limited to: + + - ``size`` (int): desired initial size for the collection (in + bytes). For capped collections this size is the max + size of the collection. + - ``capped`` (bool): if True, this is a capped collection + - ``max`` (int): maximum number of objects if capped (optional) + - ``timeseries`` (dict): a document specifying configuration options for + timeseries collections + - ``expireAfterSeconds`` (int): the number of seconds after which a + document in a timeseries collection expires + - ``validator`` (dict): a document specifying validation rules or expressions + for the collection + - ``validationLevel`` (str): how strictly to apply the + validation rules to existing documents during an update. The default level + is "strict" + - ``validationAction`` (str): whether to "error" on invalid documents + (the default) or just "warn" about the violations but allow invalid + documents to be inserted + - ``indexOptionDefaults`` (dict): a document specifying a default configuration + for indexes when creating a collection + - ``viewOn`` (str): the name of the source collection or view from which + to create the view + - ``pipeline`` (list): a list of aggregation pipeline stages + - ``comment`` (str): a user-provided comment to attach to this command. + This option is only supported on MongoDB >= 4.4. + - ``encryptedFields`` (dict): **(BETA)** Document that describes the encrypted fields for + Queryable Encryption. For example:: + + { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + - ``clusteredIndex`` (dict): Document that specifies the clustered index + configuration. It must have the following form:: + + { + // key pattern must be {_id: 1} + key: , // required + unique: , // required, must be `true` + name: , // optional, otherwise automatically generated + v: , // optional, must be `2` if provided + } + - ``changeStreamPreAndPostImages`` (dict): a document with a boolean field ``enabled`` for + enabling pre- and post-images. + + .. versionchanged:: 4.2 + Added the ``check_exists``, ``clusteredIndex``, and ``encryptedFields`` parameters. + + .. versionchanged:: 3.11 + This method is now supported inside multi-document transactions + with MongoDB 4.4+. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.4 + Added the collation option. + + .. versionchanged:: 3.0 + Added the codec_options, read_preference, and write_concern options. + + .. _create collection command: + https://mongodb.com/docs/manual/reference/command/create + """ + encrypted_fields = self._get_encrypted_fields(kwargs, name, False) + if encrypted_fields: + common.validate_is_mapping("encryptedFields", encrypted_fields) + kwargs["encryptedFields"] = encrypted_fields + + clustered_index = kwargs.get("clusteredIndex") + if clustered_index: + common.validate_is_mapping("clusteredIndex", clustered_index) + + with self.__client._tmp_session(session) as s: + # Skip this check in a transaction where listCollections is not + # supported. + if ( + check_exists + and (not s or not s.in_transaction) + and name in self.list_collection_names(filter={"name": name}, session=s) + ): + raise CollectionInvalid("collection %s already exists" % name) + return Collection( + self, + name, + True, + codec_options, + read_preference, + write_concern, + read_concern, + session=s, + **kwargs, + ) + + def aggregate( + self, pipeline: _Pipeline, session: Optional[ClientSession] = None, **kwargs: Any + ) -> CommandCursor[_DocumentType]: + """Perform a database-level aggregation. + + See the `aggregation pipeline`_ documentation for a list of stages + that are supported. + + .. code-block:: python + + # Lists all operations currently running on the server. + with client.admin.aggregate([{"$currentOp": {}}]) as cursor: + for operation in cursor: + print(operation) + + The :meth:`aggregate` method obeys the :attr:`read_preference` of this + :class:`Database`, except when ``$out`` or ``$merge`` are used, in + which case :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY` + is used. + + .. note:: This method does not support the 'explain' option. Please + use :meth:`~pymongo.database.Database.command` instead. + + .. note:: The :attr:`~pymongo.database.Database.write_concern` of + this collection is automatically applied to this operation. + + :param pipeline: a list of aggregation pipeline stages + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param kwargs: extra `aggregate command`_ parameters. + + All optional `aggregate command`_ parameters should be passed as + keyword arguments to this method. Valid options include, but are not + limited to: + + - `allowDiskUse` (bool): Enables writing to temporary files. When set + to True, aggregation stages can write data to the _tmp subdirectory + of the --dbpath directory. The default is False. + - `maxTimeMS` (int): The maximum amount of time to allow the operation + to run in milliseconds. + - `batchSize` (int): The maximum number of documents to return per + batch. Ignored if the connected mongod or mongos does not support + returning aggregate results using a cursor. + - `collation` (optional): An instance of + :class:`~pymongo.collation.Collation`. + - `let` (dict): A dict of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. ``"$$var"``). This option is + only supported on MongoDB >= 5.0. + + :return: A :class:`~pymongo.command_cursor.CommandCursor` over the result + set. + + .. versionadded:: 3.9 + + .. _aggregation pipeline: + https://mongodb.com/docs/manual/reference/operator/aggregation-pipeline + + .. _aggregate command: + https://mongodb.com/docs/manual/reference/command/aggregate + """ + with self.client._tmp_session(session, close=False) as s: + cmd = _DatabaseAggregationCommand( + self, + CommandCursor, + pipeline, + kwargs, + session is not None, + user_fields={"cursor": {"firstBatch": 1}}, + ) + return self.client._retryable_read( + cmd.get_cursor, + cmd.get_read_preference(s), # type: ignore[arg-type] + s, + retryable=not cmd._performs_write, + operation=_Op.AGGREGATE, + ) + + def watch( + self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional[ClientSession] = None, + start_after: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + full_document_before_change: Optional[str] = None, + show_expanded_events: Optional[bool] = None, + ) -> DatabaseChangeStream[_DocumentType]: + """Watch changes on this database. + + Performs an aggregation with an implicit initial ``$changeStream`` + stage and returns a + :class:`~pymongo.change_stream.DatabaseChangeStream` cursor which + iterates over changes on all collections in this database. + + Introduced in MongoDB 4.0. + + .. code-block:: python + + with db.watch() as stream: + for change in stream: + print(change) + + The :class:`~pymongo.change_stream.DatabaseChangeStream` iterable + blocks until the next change document is returned or an error is + raised. If the + :meth:`~pymongo.change_stream.DatabaseChangeStream.next` method + encounters a network error when retrieving a batch from the server, + it will automatically attempt to recreate the cursor such that no + change events are missed. Any error encountered during the resume + attempt indicates there may be an outage and will be raised. + + .. code-block:: python + + try: + with db.watch([{"$match": {"operationType": "insert"}}]) as stream: + for insert_change in stream: + print(insert_change) + except pymongo.errors.PyMongoError: + # The ChangeStream encountered an unrecoverable error or the + # resume attempt failed to recreate the cursor. + logging.error("...") + + For a precise description of the resume process see the + `change streams specification`_. + + :param pipeline: A list of aggregation pipeline stages to + append to an initial ``$changeStream`` stage. Not all + pipeline stages are valid after a ``$changeStream`` stage, see the + MongoDB documentation on change streams for the supported stages. + :param full_document: The fullDocument to pass as an option + to the ``$changeStream`` stage. Allowed values: 'updateLookup', + 'whenAvailable', 'required'. When set to 'updateLookup', the + change notification for partial updates will include both a delta + describing the changes to the document, as well as a copy of the + entire document that was changed from some time after the change + occurred. + :param full_document_before_change: Allowed values: 'whenAvailable' + and 'required'. Change events may now result in a + 'fullDocumentBeforeChange' response field. + :param resume_after: A resume token. If provided, the + change stream will start returning changes that occur directly + after the operation specified in the resume token. A resume token + is the _id value of a change document. + :param max_await_time_ms: The maximum time in milliseconds + for the server to wait for changes before responding to a getMore + operation. + :param batch_size: The maximum number of documents to return + per batch. + :param collation: The :class:`~pymongo.collation.Collation` + to use for the aggregation. + :param start_at_operation_time: If provided, the resulting + change stream will only return changes that occurred at or after + the specified :class:`~bson.timestamp.Timestamp`. Requires + MongoDB >= 4.0. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param start_after: The same as `resume_after` except that + `start_after` can resume notifications after an invalidate event. + This option and `resume_after` are mutually exclusive. + :param comment: A user-provided comment to attach to this + command. + :param show_expanded_events: Include expanded events such as DDL events like `dropIndexes`. + + :return: A :class:`~pymongo.change_stream.DatabaseChangeStream` cursor. + + .. versionchanged:: 4.3 + Added `show_expanded_events` parameter. + + .. versionchanged:: 4.2 + Added ``full_document_before_change`` parameter. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.9 + Added the ``start_after`` parameter. + + .. versionadded:: 3.7 + + .. seealso:: The MongoDB documentation on `changeStreams `_. + + .. _change streams specification: + https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.md + """ + return DatabaseChangeStream( + self, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + comment, + full_document_before_change, + show_expanded_events=show_expanded_events, + ) + + @overload + def _command( + self, + conn: Connection, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: _ServerMode = ReadPreference.PRIMARY, + codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> dict[str, Any]: + ... + + @overload + def _command( + self, + conn: Connection, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: _ServerMode = ReadPreference.PRIMARY, + codec_options: CodecOptions[_CodecDocumentType] = ..., + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> _CodecDocumentType: + ... + + def _command( + self, + conn: Connection, + command: Union[str, MutableMapping[str, Any]], + value: int = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: _ServerMode = ReadPreference.PRIMARY, + codec_options: Union[ + CodecOptions[dict[str, Any]], CodecOptions[_CodecDocumentType] + ] = DEFAULT_CODEC_OPTIONS, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + session: Optional[ClientSession] = None, + **kwargs: Any, + ) -> Union[dict[str, Any], _CodecDocumentType]: + """Internal command helper.""" + if isinstance(command, str): + command = {command: value} + + command.update(kwargs) + with self.__client._tmp_session(session) as s: + return conn.command( + self.__name, + command, + read_preference, + codec_options, + check, + allowable_errors, + write_concern=write_concern, + parse_write_concern_error=parse_write_concern_error, + session=s, + client=self.__client, + ) + + @overload + def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: None = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> dict[str, Any]: + ... + + @overload + def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: CodecOptions[_CodecDocumentType] = ..., + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> _CodecDocumentType: + ... + + @_csot.apply + def command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_preference: Optional[_ServerMode] = None, + codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> Union[dict[str, Any], _CodecDocumentType]: + """Issue a MongoDB command. + + Send command `command` to the database and return the + response. If `command` is an instance of :class:`str` + then the command {`command`: `value`} will be sent. + Otherwise, `command` must be an instance of + :class:`dict` and will be sent as is. + + Any additional keyword arguments will be added to the final + command document before it is sent. + + For example, a command like ``{buildinfo: 1}`` can be sent + using: + + >>> db.command("buildinfo") + OR + >>> db.command({"buildinfo": 1}) + + For a command where the value matters, like ``{count: + collection_name}`` we can do: + + >>> db.command("count", collection_name) + OR + >>> db.command({"count": collection_name}) + + For commands that take additional arguments we can use + kwargs. So ``{count: collection_name, query: query}`` becomes: + + >>> db.command("count", collection_name, query=query) + OR + >>> db.command({"count": collection_name, "query": query}) + + :param command: document representing the command to be issued, + or the name of the command (for simple commands only). + + .. note:: the order of keys in the `command` document is + significant (the "verb" must come first), so commands + which require multiple keys (e.g. `findandmodify`) + should be done with this in mind. + + :param value: value to use for the command verb when + `command` is passed as a string + :param check: check the response for errors, raising + :class:`~pymongo.errors.OperationFailure` if there are any + :param allowable_errors: if `check` is ``True``, error messages + in this list will be ignored by error-checking + :param read_preference: The read preference for this + operation. See :mod:`~pymongo.read_preferences` for options. + If the provided `session` is in a transaction, defaults to the + read preference configured for the transaction. + Otherwise, defaults to + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + :param codec_options: A :class:`~bson.codec_options.CodecOptions` + instance. + :param session: A + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: additional keyword arguments will + be added to the command document before it is sent + + + .. note:: :meth:`command` does **not** obey this Database's + :attr:`read_preference` or :attr:`codec_options`. You must use the + ``read_preference`` and ``codec_options`` parameters instead. + + .. note:: :meth:`command` does **not** apply any custom TypeDecoders + when decoding the command response. + + .. note:: If this client has been configured to use MongoDB Stable + API (see :ref:`versioned-api-ref`), then :meth:`command` will + automatically add API versioning options to the given command. + Explicitly adding API versioning options in the command and + declaring an API version on the client is not supported. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.0 + Removed the `as_class`, `fields`, `uuid_subtype`, `tag_sets`, + and `secondary_acceptable_latency_ms` option. + Removed `compile_re` option: PyMongo now always represents BSON + regular expressions as :class:`~bson.regex.Regex` objects. Use + :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a + BSON regular expression to a Python regular expression object. + Added the ``codec_options`` parameter. + + .. seealso:: The MongoDB documentation on `commands `_. + """ + opts = codec_options or DEFAULT_CODEC_OPTIONS + if comment is not None: + kwargs["comment"] = comment + + if isinstance(command, str): + command_name = command + else: + command_name = next(iter(command)) + + if read_preference is None: + read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + with self.__client._conn_for_reads(read_preference, session, operation=command_name) as ( + connection, + read_preference, + ): + return self._command( + connection, + command, + value, + check, + allowable_errors, + read_preference, + opts, + session=session, + **kwargs, + ) + + @_csot.apply + def cursor_command( + self, + command: Union[str, MutableMapping[str, Any]], + value: Any = 1, + read_preference: Optional[_ServerMode] = None, + codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + max_await_time_ms: Optional[int] = None, + **kwargs: Any, + ) -> CommandCursor[_DocumentType]: + """Issue a MongoDB command and parse the response as a cursor. + + If the response from the server does not include a cursor field, an error will be thrown. + + Otherwise, behaves identically to issuing a normal MongoDB command. + + :param command: document representing the command to be issued, + or the name of the command (for simple commands only). + + .. note:: the order of keys in the `command` document is + significant (the "verb" must come first), so commands + which require multiple keys (e.g. `findandmodify`) + should use an instance of :class:`~bson.son.SON` or + a string and kwargs instead of a Python `dict`. + + :param value: value to use for the command verb when + `command` is passed as a string + :param read_preference: The read preference for this + operation. See :mod:`~pymongo.read_preferences` for options. + If the provided `session` is in a transaction, defaults to the + read preference configured for the transaction. + Otherwise, defaults to + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + :param codec_options`: A :class:`~bson.codec_options.CodecOptions` + instance. + :param session: A + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to future getMores for this + command. + :param max_await_time_ms: The number of ms to wait for more data on future getMores for this command. + :param kwargs: additional keyword arguments will + be added to the command document before it is sent + + .. note:: :meth:`command` does **not** obey this Database's + :attr:`read_preference` or :attr:`codec_options`. You must use the + ``read_preference`` and ``codec_options`` parameters instead. + + .. note:: :meth:`command` does **not** apply any custom TypeDecoders + when decoding the command response. + + .. note:: If this client has been configured to use MongoDB Stable + API (see :ref:`versioned-api-ref`), then :meth:`command` will + automatically add API versioning options to the given command. + Explicitly adding API versioning options in the command and + declaring an API version on the client is not supported. + + .. seealso:: The MongoDB documentation on `commands `_. + """ + if isinstance(command, str): + command_name = command + else: + command_name = next(iter(command)) + + with self.__client._tmp_session(session, close=False) as tmp_session: + opts = codec_options or DEFAULT_CODEC_OPTIONS + + if read_preference is None: + read_preference = ( + tmp_session and tmp_session._txn_read_preference() + ) or ReadPreference.PRIMARY + with self.__client._conn_for_reads(read_preference, tmp_session, command_name) as ( + conn, + read_preference, + ): + response = self._command( + conn, + command, + value, + True, + None, + read_preference, + opts, + session=tmp_session, + **kwargs, + ) + coll = self.get_collection("$cmd", read_preference=read_preference) + if response.get("cursor"): + cmd_cursor = CommandCursor( + coll, + response["cursor"], + conn.address, + max_await_time_ms=max_await_time_ms, + session=tmp_session, + explicit_session=session is not None, + comment=comment, + ) + cmd_cursor._maybe_pin_connection(conn) + return cmd_cursor + else: + raise InvalidOperation("Command does not return a cursor.") + + def _retryable_read_command( + self, + command: Union[str, MutableMapping[str, Any]], + operation: str, + session: Optional[ClientSession] = None, + ) -> dict[str, Any]: + """Same as command but used for retryable read commands.""" + read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> dict[str, Any]: + return self._command( + conn, + command, + read_preference=read_preference, + session=session, + ) + + return self.__client._retryable_read(_cmd, read_preference, session, operation) + + def _list_collections( + self, + conn: Connection, + session: Optional[ClientSession], + read_preference: _ServerMode, + **kwargs: Any, + ) -> CommandCursor[MutableMapping[str, Any]]: + """Internal listCollections helper.""" + coll = cast( + Collection[MutableMapping[str, Any]], + self.get_collection("$cmd", read_preference=read_preference), + ) + cmd = {"listCollections": 1, "cursor": {}} + cmd.update(kwargs) + with self.__client._tmp_session(session, close=False) as tmp_session: + cursor = self._command(conn, cmd, read_preference=read_preference, session=tmp_session)[ + "cursor" + ] + cmd_cursor = CommandCursor( + coll, + cursor, + conn.address, + session=tmp_session, + explicit_session=session is not None, + comment=cmd.get("comment"), + ) + cmd_cursor._maybe_pin_connection(conn) + return cmd_cursor + + def list_collections( + self, + session: Optional[ClientSession] = None, + filter: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[MutableMapping[str, Any]]: + """Get a cursor over the collections of this database. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param filter: A query document to filter the list of + collections returned from the listCollections command. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: Optional parameters of the + `listCollections command + `_ + can be passed as keyword arguments to this method. The supported + options differ by server version. + + + :return: An instance of :class:`~pymongo.command_cursor.CommandCursor`. + + .. versionadded:: 3.6 + """ + if filter is not None: + kwargs["filter"] = filter + read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + if comment is not None: + kwargs["comment"] = comment + + def _cmd( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> CommandCursor[MutableMapping[str, Any]]: + return self._list_collections(conn, session, read_preference=read_preference, **kwargs) + + return self.__client._retryable_read( + _cmd, read_pref, session, operation=_Op.LIST_COLLECTIONS + ) + + def list_collection_names( + self, + session: Optional[ClientSession] = None, + filter: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> list[str]: + """Get a list of all the collection names in this database. + + For example, to list all non-system collections:: + + filter = {"name": {"$regex": r"^(?!system\\.)"}} + db.list_collection_names(filter=filter) + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param filter: A query document to filter the list of + collections returned from the listCollections command. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: Optional parameters of the + `listCollections command + `_ + can be passed as keyword arguments to this method. The supported + options differ by server version. + + + .. versionchanged:: 3.8 + Added the ``filter`` and ``**kwargs`` parameters. + + .. versionadded:: 3.6 + """ + if comment is not None: + kwargs["comment"] = comment + if filter is None: + kwargs["nameOnly"] = True + + else: + # The enumerate collections spec states that "drivers MUST NOT set + # nameOnly if a filter specifies any keys other than name." + common.validate_is_mapping("filter", filter) + kwargs["filter"] = filter + if not filter or (len(filter) == 1 and "name" in filter): + kwargs["nameOnly"] = True + + return [result["name"] for result in self.list_collections(session=session, **kwargs)] + + def _drop_helper( + self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None + ) -> dict[str, Any]: + command = {"drop": name} + if comment is not None: + command["comment"] = comment + + with self.__client._conn_for_writes(session, operation=_Op.DROP) as connection: + return self._command( + connection, + command, + allowable_errors=["ns not found", 26], + write_concern=self._write_concern_for(session), + parse_write_concern_error=True, + session=session, + ) + + @_csot.apply + def drop_collection( + self, + name_or_collection: Union[str, Collection[_DocumentTypeArg]], + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + encrypted_fields: Optional[Mapping[str, Any]] = None, + ) -> dict[str, Any]: + """Drop a collection. + + :param name_or_collection: the name of a collection to drop or the + collection object itself + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for + Queryable Encryption. For example:: + + { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + + } + + + .. note:: The :attr:`~pymongo.database.Database.write_concern` of + this database is automatically applied to this operation. + + .. versionchanged:: 4.2 + Added ``encrypted_fields`` parameter. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. versionchanged:: 3.4 + Apply this database's write concern automatically to this operation + when connected to MongoDB >= 3.4. + + """ + name = name_or_collection + if isinstance(name, Collection): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_collection must be an instance of str") + encrypted_fields = self._get_encrypted_fields( + {"encryptedFields": encrypted_fields}, + name, + True, + ) + if encrypted_fields: + common.validate_is_mapping("encrypted_fields", encrypted_fields) + self._drop_helper( + _esc_coll_name(encrypted_fields, name), session=session, comment=comment + ) + self._drop_helper( + _ecoc_coll_name(encrypted_fields, name), session=session, comment=comment + ) + + return self._drop_helper(name, session, comment) + + def validate_collection( + self, + name_or_collection: Union[str, Collection[_DocumentTypeArg]], + scandata: bool = False, + full: bool = False, + session: Optional[ClientSession] = None, + background: Optional[bool] = None, + comment: Optional[Any] = None, + ) -> dict[str, Any]: + """Validate a collection. + + Returns a dict of validation info. Raises CollectionInvalid if + validation fails. + + See also the MongoDB documentation on the `validate command`_. + + :param name_or_collection: A Collection object or the name of a + collection to validate. + :param scandata: Do extra checks beyond checking the overall + structure of the collection. + :param full: Have the server do a more thorough scan of the + collection. Use with `scandata` for a thorough scan + of the structure of the collection and the individual + documents. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param background: A boolean flag that determines whether + the command runs in the background. Requires MongoDB 4.4+. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.11 + Added ``background`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. _validate command: https://mongodb.com/docs/manual/reference/command/validate/ + """ + name = name_or_collection + if isinstance(name, Collection): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_collection must be an instance of str or Collection") + cmd = {"validate": name, "scandata": scandata, "full": full} + if comment is not None: + cmd["comment"] = comment + + if background is not None: + cmd["background"] = background + + result = self.command(cmd, session=session) + + valid = True + # Pre 1.9 results + if "result" in result: + info = result["result"] + if info.find("exception") != -1 or info.find("corrupt") != -1: + raise CollectionInvalid(f"{name} invalid: {info}") + # Sharded results + elif "raw" in result: + for _, res in result["raw"].items(): + if "result" in res: + info = res["result"] + if info.find("exception") != -1 or info.find("corrupt") != -1: + raise CollectionInvalid(f"{name} invalid: {info}") + elif not res.get("valid", False): + valid = False + break + # Post 1.9 non-sharded results. + elif not result.get("valid", False): + valid = False + + if not valid: + raise CollectionInvalid(f"{name} invalid: {result!r}") + + return result + + # See PYTHON-3084. + __iter__ = None + + def __next__(self) -> NoReturn: + raise TypeError("'Database' object is not iterable") + + next = __next__ + + def __bool__(self) -> NoReturn: + raise NotImplementedError( + "Database objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: database is not None" + ) + + def dereference( + self, + dbref: DBRef, + session: Optional[ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> Optional[_DocumentType]: + """Dereference a :class:`~bson.dbref.DBRef`, getting the + document it points to. + + Raises :class:`TypeError` if `dbref` is not an instance of + :class:`~bson.dbref.DBRef`. Returns a document, or ``None`` if + the reference does not point to a valid document. Raises + :class:`ValueError` if `dbref` has a database specified that + is different from the current database. + + :param dbref: the reference + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: any additional keyword arguments + are the same as the arguments to + :meth:`~pymongo.collection.Collection.find`. + + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + if not isinstance(dbref, DBRef): + raise TypeError("cannot dereference a %s" % type(dbref)) + if dbref.database is not None and dbref.database != self.__name: + raise ValueError( + "trying to dereference a DBRef that points to " + f"another database ({dbref.database!r} not {self.__name!r})" + ) + return self[dbref.collection].find_one( + {"_id": dbref.id}, session=session, comment=comment, **kwargs + ) diff --git a/venv/Lib/site-packages/pymongo/driver_info.py b/venv/Lib/site-packages/pymongo/driver_info.py new file mode 100644 index 00000000..9e7cfbda --- /dev/null +++ b/venv/Lib/site-packages/pymongo/driver_info.py @@ -0,0 +1,42 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Advanced options for MongoDB drivers implemented on top of PyMongo.""" +from __future__ import annotations + +from collections import namedtuple +from typing import Optional + + +class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])): + """Info about a driver wrapping PyMongo. + + The MongoDB server logs PyMongo's name, version, and platform whenever + PyMongo establishes a connection. A driver implemented on top of PyMongo + can add its own info to this log message. Initialize with three strings + like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be + None to accept PyMongo's default. + """ + + def __new__( + cls, name: str, version: Optional[str] = None, platform: Optional[str] = None + ) -> DriverInfo: + self = super().__new__(cls, name, version, platform) + for key, value in self._asdict().items(): + if value is not None and not isinstance(value, str): + raise TypeError( + f"Wrong type for DriverInfo {key} option, value must be an instance of str" + ) + + return self diff --git a/venv/Lib/site-packages/pymongo/encryption.py b/venv/Lib/site-packages/pymongo/encryption.py new file mode 100644 index 00000000..c7f02766 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/encryption.py @@ -0,0 +1,1112 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support for explicit client-side field level encryption.""" +from __future__ import annotations + +import contextlib +import enum +import socket +import uuid +import weakref +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterator, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +try: + from pymongocrypt.auto_encrypter import AutoEncrypter # type:ignore[import] + from pymongocrypt.errors import MongoCryptError # type:ignore[import] + from pymongocrypt.explicit_encrypter import ExplicitEncrypter # type:ignore[import] + from pymongocrypt.mongocrypt import MongoCryptOptions # type:ignore[import] + from pymongocrypt.state_machine import MongoCryptCallback # type:ignore[import] + + _HAVE_PYMONGOCRYPT = True +except ImportError: + _HAVE_PYMONGOCRYPT = False + MongoCryptCallback = object + +from bson import _dict_to_bson, decode, encode +from bson.binary import STANDARD, UUID_SUBTYPE, Binary +from bson.codec_options import CodecOptions +from bson.errors import BSONError +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson +from pymongo import _csot +from pymongo.collection import Collection +from pymongo.common import CONNECT_TIMEOUT +from pymongo.cursor import Cursor +from pymongo.daemon import _spawn_daemon +from pymongo.database import Database +from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts +from pymongo.errors import ( + ConfigurationError, + EncryptedCollectionError, + EncryptionError, + InvalidOperation, + PyMongoError, + ServerSelectionTimeoutError, +) +from pymongo.mongo_client import MongoClient +from pymongo.network import BLOCKING_IO_ERRORS +from pymongo.operations import UpdateOne +from pymongo.pool import PoolOptions, _configured_socket, _raise_connection_failure +from pymongo.read_concern import ReadConcern +from pymongo.results import BulkWriteResult, DeleteResult +from pymongo.ssl_support import get_ssl_context +from pymongo.typings import _DocumentType, _DocumentTypeArg +from pymongo.uri_parser import parse_host +from pymongo.write_concern import WriteConcern + +if TYPE_CHECKING: + from pymongocrypt.mongocrypt import MongoCryptKmsContext + +_HTTPS_PORT = 443 +_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT +_MONGOCRYPTD_TIMEOUT_MS = 10000 + +_DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions( + document_class=Dict[str, Any], uuid_representation=STANDARD +) +# Use RawBSONDocument codec options to avoid needlessly decoding +# documents from the key vault. +_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) + + +@contextlib.contextmanager +def _wrap_encryption_errors() -> Iterator[None]: + """Context manager to wrap encryption related errors.""" + try: + yield + except BSONError: + # BSON encoding/decoding errors are unrelated to encryption so + # we should propagate them unchanged. + raise + except Exception as exc: + raise EncryptionError(exc) from exc + + +class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] + def __init__( + self, + client: Optional[MongoClient[_DocumentTypeArg]], + key_vault_coll: Collection[_DocumentTypeArg], + mongocryptd_client: Optional[MongoClient[_DocumentTypeArg]], + opts: AutoEncryptionOpts, + ): + """Internal class to perform I/O on behalf of pymongocrypt.""" + self.client_ref: Any + # Use a weak ref to break reference cycle. + if client is not None: + self.client_ref = weakref.ref(client) + else: + self.client_ref = None + self.key_vault_coll: Optional[Collection[RawBSONDocument]] = cast( + Collection[RawBSONDocument], + key_vault_coll.with_options( + codec_options=_KEY_VAULT_OPTS, + read_concern=ReadConcern(level="majority"), + write_concern=WriteConcern(w="majority"), + ), + ) + self.mongocryptd_client = mongocryptd_client + self.opts = opts + self._spawned = False + + def kms_request(self, kms_context: MongoCryptKmsContext) -> None: + """Complete a KMS request. + + :param kms_context: A :class:`MongoCryptKmsContext`. + + :return: None + """ + endpoint = kms_context.endpoint + message = kms_context.message + provider = kms_context.kms_provider + ctx = self.opts._kms_ssl_contexts.get(provider) + if ctx is None: + # Enable strict certificate verification, OCSP, match hostname, and + # SNI using the system default CA certificates. + ctx = get_ssl_context( + None, # certfile + None, # passphrase + None, # ca_certs + None, # crlfile + False, # allow_invalid_certificates + False, # allow_invalid_hostnames + False, + ) # disable_ocsp_endpoint_check + # CSOT: set timeout for socket creation. + connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001) + opts = PoolOptions( + connect_timeout=connect_timeout, + socket_timeout=connect_timeout, + ssl_context=ctx, + ) + host, port = parse_host(endpoint, _HTTPS_PORT) + try: + conn = _configured_socket((host, port), opts) + try: + conn.sendall(message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = conn.recv(kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() + except (PyMongoError, MongoCryptError): + raise # Propagate pymongo errors directly. + except Exception as error: + # Wrap I/O errors in PyMongo exceptions. + _raise_connection_failure((host, port), error) + + def collection_info( + self, database: Database[Mapping[str, Any]], filter: bytes + ) -> Optional[bytes]: + """Get the collection info for a namespace. + + The returned collection info is passed to libmongocrypt which reads + the JSON schema. + + :param database: The database on which to run listCollections. + :param filter: The filter to pass to listCollections. + + :return: The first document from the listCollections command response as BSON. + """ + with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor: + for doc in cursor: + return _dict_to_bson(doc, False, _DATA_KEY_OPTS) + return None + + def spawn(self) -> None: + """Spawn mongocryptd. + + Note this method is thread safe; at most one mongocryptd will start + successfully. + """ + self._spawned = True + args = [self.opts._mongocryptd_spawn_path or "mongocryptd"] + args.extend(self.opts._mongocryptd_spawn_args) + _spawn_daemon(args) + + def mark_command(self, database: str, cmd: bytes) -> bytes: + """Mark a command for encryption. + + :param database: The database on which to run this command. + :param cmd: The BSON command to run. + + :return: The marked command response from mongocryptd. + """ + if not self._spawned and not self.opts._mongocryptd_bypass_spawn: + self.spawn() + # Database.command only supports mutable mappings so we need to decode + # the raw BSON command first. + inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS) + assert self.mongocryptd_client is not None + try: + res = self.mongocryptd_client[database].command( + inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS + ) + except ServerSelectionTimeoutError: + if self.opts._mongocryptd_bypass_spawn: + raise + self.spawn() + res = self.mongocryptd_client[database].command( + inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS + ) + return res.raw + + def fetch_keys(self, filter: bytes) -> Iterator[bytes]: + """Yields one or more keys from the key vault. + + :param filter: The filter to pass to find. + + :return: A generator which yields the requested keys from the key vault. + """ + assert self.key_vault_coll is not None + with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor: + for key in cursor: + yield key.raw + + def insert_data_key(self, data_key: bytes) -> Binary: + """Insert a data key into the key vault. + + :param data_key: The data key document to insert. + + :return: The _id of the inserted data key document. + """ + raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) + data_key_id = raw_doc.get("_id") + if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE: + raise TypeError("data_key _id must be Binary with a UUID subtype") + + assert self.key_vault_coll is not None + self.key_vault_coll.insert_one(raw_doc) + return data_key_id + + def bson_encode(self, doc: MutableMapping[str, Any]) -> bytes: + """Encode a document to BSON. + + A document can be any mapping type (like :class:`dict`). + + :param doc: mapping type representing a document + + :return: The encoded BSON bytes. + """ + return encode(doc) + + def close(self) -> None: + """Release resources. + + Note it is not safe to call this method from __del__ or any GC hooks. + """ + self.client_ref = None + self.key_vault_coll = None + if self.mongocryptd_client: + self.mongocryptd_client.close() + self.mongocryptd_client = None + + +class RewrapManyDataKeyResult: + """Result object returned by a :meth:`~ClientEncryption.rewrap_many_data_key` operation. + + .. versionadded:: 4.2 + """ + + def __init__(self, bulk_write_result: Optional[BulkWriteResult] = None) -> None: + self._bulk_write_result = bulk_write_result + + @property + def bulk_write_result(self) -> Optional[BulkWriteResult]: + """The result of the bulk write operation used to update the key vault + collection with one or more rewrapped data keys. If + :meth:`~ClientEncryption.rewrap_many_data_key` does not find any matching keys to rewrap, + no bulk write operation will be executed and this field will be + ``None``. + """ + return self._bulk_write_result + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._bulk_write_result!r})" + + +class _Encrypter: + """Encrypts and decrypts MongoDB commands. + + This class is used to support automatic encryption and decryption of + MongoDB commands. + """ + + def __init__(self, client: MongoClient[_DocumentTypeArg], opts: AutoEncryptionOpts): + """Create a _Encrypter for a client. + + :param client: The encrypted MongoClient. + :param opts: The encrypted client's :class:`AutoEncryptionOpts`. + """ + if opts._schema_map is None: + schema_map = None + else: + schema_map = _dict_to_bson(opts._schema_map, False, _DATA_KEY_OPTS) + + if opts._encrypted_fields_map is None: + encrypted_fields_map = None + else: + encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS) + self._bypass_auto_encryption = opts._bypass_auto_encryption + self._internal_client = None + + def _get_internal_client( + encrypter: _Encrypter, mongo_client: MongoClient[_DocumentTypeArg] + ) -> MongoClient[_DocumentTypeArg]: + if mongo_client.options.pool_options.max_pool_size is None: + # Unlimited pool size, use the same client. + return mongo_client + # Else - limited pool size, use an internal client. + if encrypter._internal_client is not None: + return encrypter._internal_client + internal_client = mongo_client._duplicate(minPoolSize=0, auto_encryption_opts=None) + encrypter._internal_client = internal_client + return internal_client + + if opts._key_vault_client is not None: + key_vault_client = opts._key_vault_client + else: + key_vault_client = _get_internal_client(self, client) + + if opts._bypass_auto_encryption: + metadata_client = None + else: + metadata_client = _get_internal_client(self, client) + + db, coll = opts._key_vault_namespace.split(".", 1) + key_vault_coll = key_vault_client[db][coll] + + mongocryptd_client: MongoClient[Mapping[str, Any]] = MongoClient( + opts._mongocryptd_uri, connect=False, serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS + ) + + io_callbacks = _EncryptionIO( # type:ignore[misc] + metadata_client, key_vault_coll, mongocryptd_client, opts + ) + self._auto_encrypter = AutoEncrypter( + io_callbacks, + MongoCryptOptions( + opts._kms_providers, + schema_map, + crypt_shared_lib_path=opts._crypt_shared_lib_path, + crypt_shared_lib_required=opts._crypt_shared_lib_required, + bypass_encryption=opts._bypass_auto_encryption, + encrypted_fields_map=encrypted_fields_map, + bypass_query_analysis=opts._bypass_query_analysis, + ), + ) + self._closed = False + + def encrypt( + self, database: str, cmd: Mapping[str, Any], codec_options: CodecOptions[_DocumentTypeArg] + ) -> dict[str, Any]: + """Encrypt a MongoDB command. + + :param database: The database for this command. + :param cmd: A command document. + :param codec_options: The CodecOptions to use while encoding `cmd`. + + :return: The encrypted command to execute. + """ + self._check_closed() + encoded_cmd = _dict_to_bson(cmd, False, codec_options) + with _wrap_encryption_errors(): + encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) + # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. + return _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) + + def decrypt(self, response: bytes) -> Optional[bytes]: + """Decrypt a MongoDB command response. + + :param response: A MongoDB command response as BSON. + + :return: The decrypted command response. + """ + self._check_closed() + with _wrap_encryption_errors(): + return cast(bytes, self._auto_encrypter.decrypt(response)) + + def _check_closed(self) -> None: + if self._closed: + raise InvalidOperation("Cannot use MongoClient after close") + + def close(self) -> None: + """Cleanup resources.""" + self._closed = True + self._auto_encrypter.close() + if self._internal_client: + self._internal_client.close() + self._internal_client = None + + +class Algorithm(str, enum.Enum): + """An enum that defines the supported encryption algorithms.""" + + AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + """AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic.""" + AEAD_AES_256_CBC_HMAC_SHA_512_Random = "AEAD_AES_256_CBC_HMAC_SHA_512-Random" + """AEAD_AES_256_CBC_HMAC_SHA_512_Random.""" + INDEXED = "Indexed" + """Indexed. + + .. versionadded:: 4.2 + """ + UNINDEXED = "Unindexed" + """Unindexed. + + .. versionadded:: 4.2 + """ + RANGEPREVIEW = "RangePreview" + """RangePreview. + + .. note:: Support for Range queries is in beta. + Backwards-breaking changes may be made before the final release. + + .. versionadded:: 4.4 + """ + + +class QueryType(str, enum.Enum): + """An enum that defines the supported values for explicit encryption query_type. + + .. versionadded:: 4.2 + """ + + EQUALITY = "equality" + """Used to encrypt a value for an equality query.""" + + RANGEPREVIEW = "rangePreview" + """Used to encrypt a value for a range query. + + .. note:: Support for Range queries is in beta. + Backwards-breaking changes may be made before the final release. +""" + + +class ClientEncryption(Generic[_DocumentType]): + """Explicit client-side field level encryption.""" + + def __init__( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient[_DocumentTypeArg], + codec_options: CodecOptions[_DocumentTypeArg], + kms_tls_options: Optional[Mapping[str, Any]] = None, + ) -> None: + """Explicit client-side field level encryption. + + The ClientEncryption class encapsulates explicit operations on a key + vault collection that cannot be done directly on a MongoClient. Similar + to configuring auto encryption on a MongoClient, it is constructed with + a MongoClient (to a MongoDB cluster containing the key vault + collection), KMS provider configuration, and keyVaultNamespace. It + provides an API for explicitly encrypting and decrypting values, and + creating data keys. It does not provide an API to query keys from the + key vault collection, as this can be done directly on the MongoClient. + + See :ref:`explicit-client-side-encryption` for an example. + + :param kms_providers: Map of KMS provider options. The `kms_providers` + map values differ by provider: + + - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. + These are the AWS access key ID and AWS secret access key used + to generate KMS messages. An optional "sessionToken" may be + included to support temporary AWS credentials. + - `azure`: Map with "tenantId", "clientId", and "clientSecret" as + strings. Additionally, "identityPlatformEndpoint" may also be + specified as a string (defaults to 'login.microsoftonline.com'). + These are the Azure Active Directory credentials used to + generate Azure Key Vault messages. + - `gcp`: Map with "email" as a string and "privateKey" + as `bytes` or a base64 encoded string. + Additionally, "endpoint" may also be specified as a string + (defaults to 'oauth2.googleapis.com'). These are the + credentials used to generate Google Cloud KMS messages. + - `kmip`: Map with "endpoint" as a host with required port. + For example: ``{"endpoint": "example.com:443"}``. + - `local`: Map with "key" as `bytes` (96 bytes in length) or + a base64 encoded string which decodes + to 96 bytes. "key" is the master key used to encrypt/decrypt + data keys. This key should be generated and stored as securely + as possible. + + KMS providers may be specified with an optional name suffix + separated by a colon, for example "kmip:name" or "aws:name". + Named KMS providers do not support :ref:`CSFLE on-demand credentials`. + :param key_vault_namespace: The namespace for the key vault collection. + The key vault collection contains all data keys used for encryption + and decryption. Data keys are stored as documents in this MongoDB + collection. Data keys are protected with encryption by a KMS + provider. + :param key_vault_client: A MongoClient connected to a MongoDB cluster + containing the `key_vault_namespace` collection. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions` to use when encoding a + value for encryption and decoding the decrypted BSON value. This + should be the same CodecOptions instance configured on the + MongoClient, Database, or Collection used to access application + data. + :param kms_tls_options: A map of KMS provider names to TLS + options to use when creating secure connections to KMS providers. + Accepts the same TLS options as + :class:`pymongo.mongo_client.MongoClient`. For example, to + override the system default CA file:: + + kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} + + Or to supply a client certificate:: + + kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + + .. versionchanged:: 4.0 + Added the `kms_tls_options` parameter and the "kmip" KMS provider. + + .. versionadded:: 3.9 + """ + if not _HAVE_PYMONGOCRYPT: + raise ConfigurationError( + "client-side field level encryption requires the pymongocrypt " + "library: install a compatible version with: " + "python -m pip install 'pymongo[encryption]'" + ) + + if not isinstance(codec_options, CodecOptions): + raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + + self._kms_providers = kms_providers + self._key_vault_namespace = key_vault_namespace + self._key_vault_client = key_vault_client + self._codec_options = codec_options + + db, coll = key_vault_namespace.split(".", 1) + key_vault_coll = key_vault_client[db][coll] + + opts = AutoEncryptionOpts( + kms_providers, key_vault_namespace, kms_tls_options=kms_tls_options + ) + self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO( + None, key_vault_coll, None, opts + ) + self._encryption = ExplicitEncrypter( + self._io_callbacks, MongoCryptOptions(kms_providers, None) + ) + # Use the same key vault collection as the callback. + assert self._io_callbacks.key_vault_coll is not None + self._key_vault_coll = self._io_callbacks.key_vault_coll + + def create_encrypted_collection( + self, + database: Database[_DocumentTypeArg], + name: str, + encrypted_fields: Mapping[str, Any], + kms_provider: Optional[str] = None, + master_key: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> tuple[Collection[_DocumentTypeArg], Mapping[str, Any]]: + """Create a collection with encryptedFields. + + .. warning:: + This function does not update the encryptedFieldsMap in the client's + AutoEncryptionOpts, thus the user must create a new client after calling this function with + the encryptedFields returned. + + Normally collection creation is automatic. This method should + only be used to specify options on + creation. :class:`~pymongo.errors.EncryptionError` will be + raised if the collection already exists. + + :param name: the name of the collection to create + :param encrypted_fields: Document that describes the encrypted fields for + Queryable Encryption. The "keyId" may be set to ``None`` to auto-generate the data keys. For example: + + .. code-block: python + + { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + + :param kms_provider: the KMS provider to be used + :param master_key: Identifies a KMS-specific key used to encrypt the + new data key. If the kmsProvider is "local" the `master_key` is + not applicable and may be omitted. + :param kwargs: additional keyword arguments are the same as "create_collection". + + All optional `create collection command`_ parameters should be passed + as keyword arguments to this method. + See the documentation for :meth:`~pymongo.database.Database.create_collection` for all valid options. + + :raises: - :class:`~pymongo.errors.EncryptedCollectionError`: When either data-key creation or creating the collection fails. + + .. versionadded:: 4.4 + + .. _create collection command: + https://mongodb.com/docs/manual/reference/command/create + + """ + encrypted_fields = deepcopy(encrypted_fields) + for i, field in enumerate(encrypted_fields["fields"]): + if isinstance(field, dict) and field.get("keyId") is None: + try: + encrypted_fields["fields"][i]["keyId"] = self.create_data_key( + kms_provider=kms_provider, # type:ignore[arg-type] + master_key=master_key, + ) + except EncryptionError as exc: + raise EncryptedCollectionError(exc, encrypted_fields) from exc + kwargs["encryptedFields"] = encrypted_fields + kwargs["check_exists"] = False + try: + return ( + database.create_collection(name=name, **kwargs), + encrypted_fields, + ) + except Exception as exc: + raise EncryptedCollectionError(exc, encrypted_fields) from exc + + def create_data_key( + self, + kms_provider: str, + master_key: Optional[Mapping[str, Any]] = None, + key_alt_names: Optional[Sequence[str]] = None, + key_material: Optional[bytes] = None, + ) -> Binary: + """Create and insert a new data key into the key vault collection. + + :param kms_provider: The KMS provider to use. Supported values are + "aws", "azure", "gcp", "kmip", "local", or a named provider like + "kmip:name". + :param master_key: Identifies a KMS-specific key used to encrypt the + new data key. If the kmsProvider is "local" the `master_key` is + not applicable and may be omitted. + + If the `kms_provider` type is "aws" it is required and has the + following fields:: + + - `region` (string): Required. The AWS region, e.g. "us-east-1". + - `key` (string): Required. The Amazon Resource Name (ARN) to + the AWS customer. + - `endpoint` (string): Optional. An alternate host to send KMS + requests to. May include port number, e.g. + "kms.us-east-1.amazonaws.com:443". + + If the `kms_provider` type is "azure" it is required and has the + following fields:: + + - `keyVaultEndpoint` (string): Required. Host with optional + port, e.g. "example.vault.azure.net". + - `keyName` (string): Required. Key name in the key vault. + - `keyVersion` (string): Optional. Version of the key to use. + + If the `kms_provider` type is "gcp" it is required and has the + following fields:: + + - `projectId` (string): Required. The Google cloud project ID. + - `location` (string): Required. The GCP location, e.g. "us-east1". + - `keyRing` (string): Required. Name of the key ring that contains + the key to use. + - `keyName` (string): Required. Name of the key to use. + - `keyVersion` (string): Optional. Version of the key to use. + - `endpoint` (string): Optional. Host with optional port. + Defaults to "cloudkms.googleapis.com". + + If the `kms_provider` type is "kmip" it is optional and has the + following fields:: + + - `keyId` (string): Optional. `keyId` is the KMIP Unique + Identifier to a 96 byte KMIP Secret Data managed object. If + keyId is omitted, the driver creates a random 96 byte KMIP + Secret Data managed object. + - `endpoint` (string): Optional. Host with optional + port, e.g. "example.vault.azure.net:". + + :param key_alt_names: An optional list of string alternate + names used to reference a key. If a key is created with alternate + names, then encryption may refer to the key by the unique alternate + name instead of by ``key_id``. The following example shows creating + and referring to a data key by alternate name:: + + client_encryption.create_data_key("local", key_alt_names=["name1"]) + # reference the key with the alternate name + client_encryption.encrypt("457-55-5462", key_alt_name="name1", + algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random) + :param key_material: Sets the custom key material to be used + by the data key for encryption and decryption. + + :return: The ``_id`` of the created data key document as a + :class:`~bson.binary.Binary` with subtype + :data:`~bson.binary.UUID_SUBTYPE`. + + .. versionchanged:: 4.2 + Added the `key_material` parameter. + """ + self._check_closed() + with _wrap_encryption_errors(): + return cast( + Binary, + self._encryption.create_data_key( + kms_provider, + master_key=master_key, + key_alt_names=key_alt_names, + key_material=key_material, + ), + ) + + def _encrypt_helper( + self, + value: Any, + algorithm: str, + key_id: Optional[Union[Binary, uuid.UUID]] = None, + key_alt_name: Optional[str] = None, + query_type: Optional[str] = None, + contention_factor: Optional[int] = None, + range_opts: Optional[RangeOpts] = None, + is_expression: bool = False, + ) -> Any: + self._check_closed() + if isinstance(key_id, uuid.UUID): + key_id = Binary.from_uuid(key_id) + if key_id is not None and not ( + isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE + ): + raise TypeError("key_id must be a bson.binary.Binary with subtype 4") + + doc = encode( + {"v": value}, + codec_options=self._codec_options, + ) + range_opts_bytes = None + if range_opts: + range_opts_bytes = encode( + range_opts.document, + codec_options=self._codec_options, + ) + with _wrap_encryption_errors(): + encrypted_doc = self._encryption.encrypt( + value=doc, + algorithm=algorithm, + key_id=key_id, + key_alt_name=key_alt_name, + query_type=query_type, + contention_factor=contention_factor, + range_opts=range_opts_bytes, + is_expression=is_expression, + ) + return decode(encrypted_doc)["v"] + + def encrypt( + self, + value: Any, + algorithm: str, + key_id: Optional[Union[Binary, uuid.UUID]] = None, + key_alt_name: Optional[str] = None, + query_type: Optional[str] = None, + contention_factor: Optional[int] = None, + range_opts: Optional[RangeOpts] = None, + ) -> Binary: + """Encrypt a BSON value with a given key and algorithm. + + Note that exactly one of ``key_id`` or ``key_alt_name`` must be + provided. + + :param value: The BSON value to encrypt. + :param algorithm` (string): The encryption algorithm to use. See + :class:`Algorithm` for some valid options. + :param key_id: Identifies a data key by ``_id`` which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + :param key_alt_name: Identifies a key vault document by 'keyAltName'. + :param query_type` (str): The query type to execute. See :class:`QueryType` for valid options. + :param contention_factor` (int): The contention factor to use + when the algorithm is :attr:`Algorithm.INDEXED`. An integer value + *must* be given when the :attr:`Algorithm.INDEXED` algorithm is + used. + :param range_opts: Experimental only, not intended for public use. + + :return: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. + + .. versionchanged:: 4.7 + ``key_id`` can now be passed in as a :class:`uuid.UUID`. + + .. versionchanged:: 4.2 + Added the `query_type` and `contention_factor` parameters. + """ + return cast( + Binary, + self._encrypt_helper( + value=value, + algorithm=algorithm, + key_id=key_id, + key_alt_name=key_alt_name, + query_type=query_type, + contention_factor=contention_factor, + range_opts=range_opts, + is_expression=False, + ), + ) + + def encrypt_expression( + self, + expression: Mapping[str, Any], + algorithm: str, + key_id: Optional[Union[Binary, uuid.UUID]] = None, + key_alt_name: Optional[str] = None, + query_type: Optional[str] = None, + contention_factor: Optional[int] = None, + range_opts: Optional[RangeOpts] = None, + ) -> RawBSONDocument: + """Encrypt a BSON expression with a given key and algorithm. + + Note that exactly one of ``key_id`` or ``key_alt_name`` must be + provided. + + :param expression: The BSON aggregate or match expression to encrypt. + :param algorithm` (string): The encryption algorithm to use. See + :class:`Algorithm` for some valid options. + :param key_id: Identifies a data key by ``_id`` which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + :param key_alt_name: Identifies a key vault document by 'keyAltName'. + :param query_type` (str): The query type to execute. See + :class:`QueryType` for valid options. + :param contention_factor` (int): The contention factor to use + when the algorithm is :attr:`Algorithm.INDEXED`. An integer value + *must* be given when the :attr:`Algorithm.INDEXED` algorithm is + used. + :param range_opts: Experimental only, not intended for public use. + + :return: The encrypted expression, a :class:`~bson.RawBSONDocument`. + + .. versionchanged:: 4.7 + ``key_id`` can now be passed in as a :class:`uuid.UUID`. + + .. versionadded:: 4.4 + """ + return cast( + RawBSONDocument, + self._encrypt_helper( + value=expression, + algorithm=algorithm, + key_id=key_id, + key_alt_name=key_alt_name, + query_type=query_type, + contention_factor=contention_factor, + range_opts=range_opts, + is_expression=True, + ), + ) + + def decrypt(self, value: Binary) -> Any: + """Decrypt an encrypted value. + + :param value` (Binary): The encrypted value, a + :class:`~bson.binary.Binary` with subtype 6. + + :return: The decrypted BSON value. + """ + self._check_closed() + if not (isinstance(value, Binary) and value.subtype == 6): + raise TypeError("value to decrypt must be a bson.binary.Binary with subtype 6") + + with _wrap_encryption_errors(): + doc = encode({"v": value}) + decrypted_doc = self._encryption.decrypt(doc) + return decode(decrypted_doc, codec_options=self._codec_options)["v"] + + def get_key(self, id: Binary) -> Optional[RawBSONDocument]: + """Get a data key by id. + + :param id` (Binary): The UUID of a key a which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + + :return: The key document. + + .. versionadded:: 4.2 + """ + self._check_closed() + assert self._key_vault_coll is not None + return self._key_vault_coll.find_one({"_id": id}) + + def get_keys(self) -> Cursor[RawBSONDocument]: + """Get all of the data keys. + + :return: An instance of :class:`~pymongo.cursor.Cursor` over the data key + documents. + + .. versionadded:: 4.2 + """ + self._check_closed() + assert self._key_vault_coll is not None + return self._key_vault_coll.find({}) + + def delete_key(self, id: Binary) -> DeleteResult: + """Delete a key document in the key vault collection that has the given ``key_id``. + + :param id` (Binary): The UUID of a key a which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + + :return: The delete result. + + .. versionadded:: 4.2 + """ + self._check_closed() + assert self._key_vault_coll is not None + return self._key_vault_coll.delete_one({"_id": id}) + + def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any: + """Add ``key_alt_name`` to the set of alternate names in the key document with UUID ``key_id``. + + :param `id`: The UUID of a key a which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + :param `key_alt_name`: The key alternate name to add. + + :return: The previous version of the key document. + + .. versionadded:: 4.2 + """ + self._check_closed() + update = {"$addToSet": {"keyAltNames": key_alt_name}} + assert self._key_vault_coll is not None + return self._key_vault_coll.find_one_and_update({"_id": id}, update) + + def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]: + """Get a key document in the key vault collection that has the given ``key_alt_name``. + + :param key_alt_name: (str): The key alternate name of the key to get. + + :return: The key document. + + .. versionadded:: 4.2 + """ + self._check_closed() + assert self._key_vault_coll is not None + return self._key_vault_coll.find_one({"keyAltNames": key_alt_name}) + + def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSONDocument]: + """Remove ``key_alt_name`` from the set of keyAltNames in the key document with UUID ``id``. + + Also removes the ``keyAltNames`` field from the key document if it would otherwise be empty. + + :param `id`: The UUID of a key a which must be a + :class:`~bson.binary.Binary` with subtype 4 ( + :attr:`~bson.binary.UUID_SUBTYPE`). + :param `key_alt_name`: The key alternate name to remove. + + :return: Returns the previous version of the key document. + + .. versionadded:: 4.2 + """ + self._check_closed() + pipeline = [ + { + "$set": { + "keyAltNames": { + "$cond": [ + {"$eq": ["$keyAltNames", [key_alt_name]]}, + "$$REMOVE", + { + "$filter": { + "input": "$keyAltNames", + "cond": {"$ne": ["$$this", key_alt_name]}, + } + }, + ] + } + } + } + ] + assert self._key_vault_coll is not None + return self._key_vault_coll.find_one_and_update({"_id": id}, pipeline) + + def rewrap_many_data_key( + self, + filter: Mapping[str, Any], + provider: Optional[str] = None, + master_key: Optional[Mapping[str, Any]] = None, + ) -> RewrapManyDataKeyResult: + """Decrypts and encrypts all matching data keys in the key vault with a possibly new `master_key` value. + + :param filter: A document used to filter the data keys. + :param provider: The new KMS provider to use to encrypt the data keys, + or ``None`` to use the current KMS provider(s). + :param `master_key`: The master key fields corresponding to the new KMS + provider when ``provider`` is not ``None``. + + :return: A :class:`RewrapManyDataKeyResult`. + + This method allows you to re-encrypt all of your data-keys with a new CMK, or master key. + Note that this does *not* require re-encrypting any of the data in your encrypted collections, + but rather refreshes the key that protects the keys that encrypt the data: + + .. code-block:: python + + client_encryption.rewrap_many_data_key( + filter={"keyAltNames": "optional filter for which keys you want to update"}, + master_key={ + "provider": "azure", # replace with your cloud provider + "master_key": { + # put the rest of your master_key options here + "key": "" + }, + }, + ) + + .. versionadded:: 4.2 + """ + if master_key is not None and provider is None: + raise ConfigurationError("A provider must be given if a master_key is given") + self._check_closed() + with _wrap_encryption_errors(): + raw_result = self._encryption.rewrap_many_data_key(filter, provider, master_key) + if raw_result is None: + return RewrapManyDataKeyResult() + + raw_doc = RawBSONDocument(raw_result, DEFAULT_RAW_BSON_OPTIONS) + replacements = [] + for key in raw_doc["v"]: + update_model = { + "$set": {"keyMaterial": key["keyMaterial"], "masterKey": key["masterKey"]}, + "$currentDate": {"updateDate": True}, + } + op = UpdateOne({"_id": key["_id"]}, update_model) + replacements.append(op) + if not replacements: + return RewrapManyDataKeyResult() + assert self._key_vault_coll is not None + result = self._key_vault_coll.bulk_write(replacements) + return RewrapManyDataKeyResult(result) + + def __enter__(self) -> ClientEncryption[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def _check_closed(self) -> None: + if self._encryption is None: + raise InvalidOperation("Cannot use closed ClientEncryption") + + def close(self) -> None: + """Release resources. + + Note that using this class in a with-statement will automatically call + :meth:`close`:: + + with ClientEncryption(...) as client_encryption: + encrypted = client_encryption.encrypt(value, ...) + decrypted = client_encryption.decrypt(encrypted) + + """ + if self._io_callbacks: + self._io_callbacks.close() + self._encryption.close() + self._io_callbacks = None + self._encryption = None diff --git a/venv/Lib/site-packages/pymongo/encryption_options.py b/venv/Lib/site-packages/pymongo/encryption_options.py new file mode 100644 index 00000000..1d536997 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/encryption_options.py @@ -0,0 +1,268 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support for automatic client-side field level encryption.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Optional + +try: + import pymongocrypt # type:ignore[import] # noqa: F401 + + _HAVE_PYMONGOCRYPT = True +except ImportError: + _HAVE_PYMONGOCRYPT = False +from bson import int64 +from pymongo.common import validate_is_mapping +from pymongo.errors import ConfigurationError +from pymongo.uri_parser import _parse_kms_tls_options + +if TYPE_CHECKING: + from pymongo.mongo_client import MongoClient + from pymongo.typings import _DocumentTypeArg + + +class AutoEncryptionOpts: + """Options to configure automatic client-side field level encryption.""" + + def __init__( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None, + schema_map: Optional[Mapping[str, Any]] = None, + bypass_auto_encryption: bool = False, + mongocryptd_uri: str = "mongodb://localhost:27020", + mongocryptd_bypass_spawn: bool = False, + mongocryptd_spawn_path: str = "mongocryptd", + mongocryptd_spawn_args: Optional[list[str]] = None, + kms_tls_options: Optional[Mapping[str, Any]] = None, + crypt_shared_lib_path: Optional[str] = None, + crypt_shared_lib_required: bool = False, + bypass_query_analysis: bool = False, + encrypted_fields_map: Optional[Mapping[str, Any]] = None, + ) -> None: + """Options to configure automatic client-side field level encryption. + + Automatic client-side field level encryption requires MongoDB >=4.2 + enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not + supported for operations on a database or view and will result in + error. + + Although automatic encryption requires MongoDB >=4.2 enterprise or a + MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all + users. To configure automatic *decryption* without automatic + *encryption* set ``bypass_auto_encryption=True``. Explicit + encryption and explicit decryption is also supported for all users + with the :class:`~pymongo.encryption.ClientEncryption` class. + + See :ref:`automatic-client-side-encryption` for an example. + + :param kms_providers: Map of KMS provider options. The `kms_providers` + map values differ by provider: + + - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. + These are the AWS access key ID and AWS secret access key used + to generate KMS messages. An optional "sessionToken" may be + included to support temporary AWS credentials. + - `azure`: Map with "tenantId", "clientId", and "clientSecret" as + strings. Additionally, "identityPlatformEndpoint" may also be + specified as a string (defaults to 'login.microsoftonline.com'). + These are the Azure Active Directory credentials used to + generate Azure Key Vault messages. + - `gcp`: Map with "email" as a string and "privateKey" + as `bytes` or a base64 encoded string. + Additionally, "endpoint" may also be specified as a string + (defaults to 'oauth2.googleapis.com'). These are the + credentials used to generate Google Cloud KMS messages. + - `kmip`: Map with "endpoint" as a host with required port. + For example: ``{"endpoint": "example.com:443"}``. + - `local`: Map with "key" as `bytes` (96 bytes in length) or + a base64 encoded string which decodes + to 96 bytes. "key" is the master key used to encrypt/decrypt + data keys. This key should be generated and stored as securely + as possible. + + KMS providers may be specified with an optional name suffix + separated by a colon, for example "kmip:name" or "aws:name". + Named KMS providers do not support :ref:`CSFLE on-demand credentials`. + Named KMS providers enables more than one of each KMS provider type to be configured. + For example, to configure multiple local KMS providers:: + + kms_providers = { + "local": {"key": local_kek1}, # Unnamed KMS provider. + "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". + } + + :param key_vault_namespace: The namespace for the key vault collection. + The key vault collection contains all data keys used for encryption + and decryption. Data keys are stored as documents in this MongoDB + collection. Data keys are protected with encryption by a KMS + provider. + :param key_vault_client: By default, the key vault collection + is assumed to reside in the same MongoDB cluster as the encrypted + MongoClient. Use this option to route data key queries to a + separate MongoDB cluster. + :param schema_map: Map of collection namespace ("db.coll") to + JSON Schema. By default, a collection's JSONSchema is periodically + polled with the listCollections command. But a JSONSchema may be + specified locally with the schemaMap option. + + **Supplying a `schema_map` provides more security than relying on + JSON Schemas obtained from the server. It protects against a + malicious server advertising a false JSON Schema, which could trick + the client into sending unencrypted data that should be + encrypted.** + + Schemas supplied in the schemaMap only apply to configuring + automatic encryption for client side encryption. Other validation + rules in the JSON schema will not be enforced by the driver and + will result in an error. + :param bypass_auto_encryption: If ``True``, automatic + encryption will be disabled but automatic decryption will still be + enabled. Defaults to ``False``. + :param mongocryptd_uri: The MongoDB URI used to connect + to the *local* mongocryptd process. Defaults to + ``'mongodb://localhost:27020'``. + :param mongocryptd_bypass_spawn: If ``True``, the encrypted + MongoClient will not attempt to spawn the mongocryptd process. + Defaults to ``False``. + :param mongocryptd_spawn_path: Used for spawning the + mongocryptd process. Defaults to ``'mongocryptd'`` and spawns + mongocryptd from the system path. + :param mongocryptd_spawn_args: A list of string arguments to + use when spawning the mongocryptd process. Defaults to + ``['--idleShutdownTimeoutSecs=60']``. If the list does not include + the ``idleShutdownTimeoutSecs`` option then + ``'--idleShutdownTimeoutSecs=60'`` will be added. + :param kms_tls_options: A map of KMS provider names to TLS + options to use when creating secure connections to KMS providers. + Accepts the same TLS options as + :class:`pymongo.mongo_client.MongoClient`. For example, to + override the system default CA file:: + + kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} + + Or to supply a client certificate:: + + kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + :param crypt_shared_lib_path: Override the path to load the crypt_shared library. + :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is + unable to load the crypt_shared library. + :param bypass_query_analysis: If ``True``, disable automatic analysis + of outgoing commands. Set `bypass_query_analysis` to use explicit + encryption on indexed fields without the MongoDB Enterprise Advanced + licensed crypt_shared library. + :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents + that described the encrypted fields for Queryable Encryption. For example:: + + { + "db.encryptedCollection": { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + } + + .. versionchanged:: 4.2 + Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, + and `bypass_query_analysis` parameters. + + .. versionchanged:: 4.0 + Added the `kms_tls_options` parameter and the "kmip" KMS provider. + + .. versionadded:: 3.9 + """ + if not _HAVE_PYMONGOCRYPT: + raise ConfigurationError( + "client side encryption requires the pymongocrypt library: " + "install a compatible version with: " + "python -m pip install 'pymongo[encryption]'" + ) + if encrypted_fields_map: + validate_is_mapping("encrypted_fields_map", encrypted_fields_map) + self._encrypted_fields_map = encrypted_fields_map + self._bypass_query_analysis = bypass_query_analysis + self._crypt_shared_lib_path = crypt_shared_lib_path + self._crypt_shared_lib_required = crypt_shared_lib_required + self._kms_providers = kms_providers + self._key_vault_namespace = key_vault_namespace + self._key_vault_client = key_vault_client + self._schema_map = schema_map + self._bypass_auto_encryption = bypass_auto_encryption + self._mongocryptd_uri = mongocryptd_uri + self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn + self._mongocryptd_spawn_path = mongocryptd_spawn_path + if mongocryptd_spawn_args is None: + mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] + self._mongocryptd_spawn_args = mongocryptd_spawn_args + if not isinstance(self._mongocryptd_spawn_args, list): + raise TypeError("mongocryptd_spawn_args must be a list") + if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): + self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") + # Maps KMS provider name to a SSLContext. + self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) + self._bypass_query_analysis = bypass_query_analysis + + +class RangeOpts: + """Options to configure encrypted queries using the rangePreview algorithm.""" + + def __init__( + self, + sparsity: int, + min: Optional[Any] = None, + max: Optional[Any] = None, + precision: Optional[int] = None, + ) -> None: + """Options to configure encrypted queries using the rangePreview algorithm. + + .. note:: This feature is experimental only, and not intended for public use. + + :param sparsity: An integer. + :param min: A BSON scalar value corresponding to the type being queried. + :param max: A BSON scalar value corresponding to the type being queried. + :param precision: An integer, may only be set for double or decimal128 types. + + .. versionadded:: 4.4 + """ + self.min = min + self.max = max + self.sparsity = sparsity + self.precision = precision + + @property + def document(self) -> dict[str, Any]: + doc = {} + for k, v in [ + ("sparsity", int64.Int64(self.sparsity)), + ("precision", self.precision), + ("min", self.min), + ("max", self.max), + ]: + if v is not None: + doc[k] = v + return doc diff --git a/venv/Lib/site-packages/pymongo/errors.py b/venv/Lib/site-packages/pymongo/errors.py new file mode 100644 index 00000000..a781e4a0 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/errors.py @@ -0,0 +1,376 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by PyMongo.""" +from __future__ import annotations + +from ssl import SSLCertVerificationError as _CertificateError # noqa: F401 +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union + +from bson.errors import InvalidDocument + +if TYPE_CHECKING: + from pymongo.typings import _DocumentOut + + +class PyMongoError(Exception): + """Base class for all PyMongo exceptions.""" + + def __init__(self, message: str = "", error_labels: Optional[Iterable[str]] = None) -> None: + super().__init__(message) + self._message = message + self._error_labels = set(error_labels or []) + + def has_error_label(self, label: str) -> bool: + """Return True if this error contains the given label. + + .. versionadded:: 3.7 + """ + return label in self._error_labels + + def _add_error_label(self, label: str) -> None: + """Add the given label to this error.""" + self._error_labels.add(label) + + def _remove_error_label(self, label: str) -> None: + """Remove the given label from this error.""" + self._error_labels.discard(label) + + @property + def timeout(self) -> bool: + """True if this error was caused by a timeout. + + .. versionadded:: 4.2 + """ + return False + + +class ProtocolError(PyMongoError): + """Raised for failures related to the wire protocol.""" + + +class ConnectionFailure(PyMongoError): + """Raised when a connection to the database cannot be made or is lost.""" + + +class WaitQueueTimeoutError(ConnectionFailure): + """Raised when an operation times out waiting to checkout a connection from the pool. + + Subclass of :exc:`~pymongo.errors.ConnectionFailure`. + + .. versionadded:: 4.2 + """ + + @property + def timeout(self) -> bool: + return True + + +class AutoReconnect(ConnectionFailure): + """Raised when a connection to the database is lost and an attempt to + auto-reconnect will be made. + + In order to auto-reconnect you must handle this exception, recognizing that + the operation which caused it has not necessarily succeeded. Future + operations will attempt to open a new connection to the database (and + will continue to raise this exception until the first successful + connection is made). + + Subclass of :exc:`~pymongo.errors.ConnectionFailure`. + """ + + errors: Union[Mapping[str, Any], Sequence[Any]] + details: Union[Mapping[str, Any], Sequence[Any]] + + def __init__( + self, message: str = "", errors: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None + ) -> None: + error_labels = None + if errors is not None: + if isinstance(errors, dict): + error_labels = errors.get("errorLabels") + super().__init__(message, error_labels) + self.errors = self.details = errors or [] + + +class NetworkTimeout(AutoReconnect): + """An operation on an open connection exceeded socketTimeoutMS. + + The remaining connections in the pool stay open. In the case of a write + operation, you cannot know whether it succeeded or failed. + + Subclass of :exc:`~pymongo.errors.AutoReconnect`. + """ + + @property + def timeout(self) -> bool: + return True + + +def _format_detailed_error( + message: str, details: Optional[Union[Mapping[str, Any], list[Any]]] +) -> str: + if details is not None: + message = f"{message}, full error: {details}" + return message + + +class NotPrimaryError(AutoReconnect): + """The server responded "not primary" or "node is recovering". + + These errors result from a query, write, or command. The operation failed + because the client thought it was using the primary but the primary has + stepped down, or the client thought it was using a healthy secondary but + the secondary is stale and trying to recover. + + The client launches a refresh operation on a background thread, to update + its view of the server as soon as possible after throwing this exception. + + Subclass of :exc:`~pymongo.errors.AutoReconnect`. + + .. versionadded:: 3.12 + """ + + def __init__( + self, message: str = "", errors: Optional[Union[Mapping[str, Any], list[Any]]] = None + ) -> None: + super().__init__(_format_detailed_error(message, errors), errors=errors) + + +class ServerSelectionTimeoutError(AutoReconnect): + """Thrown when no MongoDB server is available for an operation + + If there is no suitable server for an operation PyMongo tries for + ``serverSelectionTimeoutMS`` (default 30 seconds) to find one, then + throws this exception. For example, it is thrown after attempting an + operation when PyMongo cannot connect to any server, or if you attempt + an insert into a replica set that has no primary and does not elect one + within the timeout window, or if you attempt to query with a Read + Preference that the replica set cannot satisfy. + """ + + @property + def timeout(self) -> bool: + return True + + +class ConfigurationError(PyMongoError): + """Raised when something is incorrectly configured.""" + + +class OperationFailure(PyMongoError): + """Raised when a database operation fails. + + .. versionadded:: 2.7 + The :attr:`details` attribute. + """ + + def __init__( + self, + error: str, + code: Optional[int] = None, + details: Optional[Mapping[str, Any]] = None, + max_wire_version: Optional[int] = None, + ) -> None: + error_labels = None + if details is not None: + error_labels = details.get("errorLabels") + super().__init__(_format_detailed_error(error, details), error_labels=error_labels) + self.__code = code + self.__details = details + self.__max_wire_version = max_wire_version + + @property + def _max_wire_version(self) -> Optional[int]: + return self.__max_wire_version + + @property + def code(self) -> Optional[int]: + """The error code returned by the server, if any.""" + return self.__code + + @property + def details(self) -> Optional[Mapping[str, Any]]: + """The complete error document returned by the server. + + Depending on the error that occurred, the error document + may include useful information beyond just the error + message. When connected to a mongos the error document + may contain one or more subdocuments if errors occurred + on multiple shards. + """ + return self.__details + + @property + def timeout(self) -> bool: + return self.__code in (50,) + + +class CursorNotFound(OperationFailure): + """Raised while iterating query results if the cursor is + invalidated on the server. + + .. versionadded:: 2.7 + """ + + +class ExecutionTimeout(OperationFailure): + """Raised when a database operation times out, exceeding the $maxTimeMS + set in the query or command option. + + .. note:: Requires server version **>= 2.6.0** + + .. versionadded:: 2.7 + """ + + @property + def timeout(self) -> bool: + return True + + +class WriteConcernError(OperationFailure): + """Base exception type for errors raised due to write concern. + + .. versionadded:: 3.0 + """ + + +class WriteError(OperationFailure): + """Base exception type for errors raised during write operations. + + .. versionadded:: 3.0 + """ + + +class WTimeoutError(WriteConcernError): + """Raised when a database operation times out (i.e. wtimeout expires) + before replication completes. + + With newer versions of MongoDB the `details` attribute may include + write concern fields like 'n', 'updatedExisting', or 'writtenTo'. + + .. versionadded:: 2.7 + """ + + @property + def timeout(self) -> bool: + return True + + +class DuplicateKeyError(WriteError): + """Raised when an insert or update fails due to a duplicate key error.""" + + +def _wtimeout_error(error: Any) -> bool: + """Return True if this writeConcernError doc is a caused by a timeout.""" + return error.get("code") == 50 or ("errInfo" in error and error["errInfo"].get("wtimeout")) + + +class BulkWriteError(OperationFailure): + """Exception class for bulk write errors. + + .. versionadded:: 2.7 + """ + + details: _DocumentOut + + def __init__(self, results: _DocumentOut) -> None: + super().__init__("batch op errors occurred", 65, results) + + def __reduce__(self) -> tuple[Any, Any]: + return self.__class__, (self.details,) + + @property + def timeout(self) -> bool: + # Check the last writeConcernError and last writeError to determine if this + # BulkWriteError was caused by a timeout. + wces = self.details.get("writeConcernErrors", []) + if wces and _wtimeout_error(wces[-1]): + return True + + werrs = self.details.get("writeErrors", []) + if werrs and werrs[-1].get("code") == 50: + return True + return False + + +class InvalidOperation(PyMongoError): + """Raised when a client attempts to perform an invalid operation.""" + + +class InvalidName(PyMongoError): + """Raised when an invalid name is used.""" + + +class CollectionInvalid(PyMongoError): + """Raised when collection validation fails.""" + + +class InvalidURI(ConfigurationError): + """Raised when trying to parse an invalid mongodb URI.""" + + +class DocumentTooLarge(InvalidDocument): + """Raised when an encoded document is too large for the connected server.""" + + +class EncryptionError(PyMongoError): + """Raised when encryption or decryption fails. + + This error always wraps another exception which can be retrieved via the + :attr:`cause` property. + + .. versionadded:: 3.9 + """ + + def __init__(self, cause: Exception) -> None: + super().__init__(str(cause)) + self.__cause = cause + + @property + def cause(self) -> Exception: + """The exception that caused this encryption or decryption error.""" + return self.__cause + + @property + def timeout(self) -> bool: + if isinstance(self.__cause, PyMongoError): + return self.__cause.timeout + return False + + +class EncryptedCollectionError(EncryptionError): + """Raised when creating a collection with encrypted_fields fails. + + .. versionadded:: 4.4 + """ + + def __init__(self, cause: Exception, encrypted_fields: Mapping[str, Any]) -> None: + super().__init__(cause) + self.__encrypted_fields = encrypted_fields + + @property + def encrypted_fields(self) -> Mapping[str, Any]: + """The encrypted_fields document that allows inferring which data keys are *known* to be created. + + Note that the returned document is not guaranteed to contain information about *all* of the data keys that + were created, for example in the case of an indefinite error like a timeout. Use the `cause` property to + determine whether a definite or indefinite error caused this error, and only rely on the accuracy of the + encrypted_fields if the error is definite. + """ + return self.__encrypted_fields + + +class _OperationCancelled(AutoReconnect): + """Internal error raised when a socket operation is cancelled.""" diff --git a/venv/Lib/site-packages/pymongo/event_loggers.py b/venv/Lib/site-packages/pymongo/event_loggers.py new file mode 100644 index 00000000..287db3fc --- /dev/null +++ b/venv/Lib/site-packages/pymongo/event_loggers.py @@ -0,0 +1,223 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Example event logger classes. + +.. versionadded:: 3.11 + +These loggers can be registered using :func:`register` or +:class:`~pymongo.mongo_client.MongoClient`. + +``monitoring.register(CommandLogger())`` + +or + +``MongoClient(event_listeners=[CommandLogger()])`` +""" +from __future__ import annotations + +import logging + +from pymongo import monitoring + + +class CommandLogger(monitoring.CommandListener): + """A simple listener that logs command events. + + Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, + :class:`~pymongo.monitoring.CommandSucceededEvent` and + :class:`~pymongo.monitoring.CommandFailedEvent` events and + logs them at the `INFO` severity level using :mod:`logging`. + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.CommandStartedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} started on server " + f"{event.connection_id}" + ) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"succeeded in {event.duration_micros} " + "microseconds" + ) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"failed in {event.duration_micros} " + "microseconds" + ) + + +class ServerLogger(monitoring.ServerListener): + """A simple listener that logs server discovery events. + + Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, + :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.ServerClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.ServerOpeningEvent) -> None: + logging.info(f"Server {event.server_address} added to topology {event.topology_id}") + + def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + f"Server {event.server_address} changed type from " + f"{event.previous_description.server_type_name} to " + f"{event.new_description.server_type_name}" + ) + + def closed(self, event: monitoring.ServerClosedEvent) -> None: + logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") + + +class HeartbeatLogger(monitoring.ServerHeartbeatListener): + """A simple listener that logs server heartbeat events. + + Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, + :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, + and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: + logging.info(f"Heartbeat sent to server {event.connection_id}") + + def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: + # The reply.document attribute was added in PyMongo 3.4. + logging.info( + f"Heartbeat to server {event.connection_id} " + "succeeded with reply " + f"{event.reply.document}" + ) + + def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: + logging.warning( + f"Heartbeat to server {event.connection_id} failed with error {event.reply}" + ) + + +class TopologyLogger(monitoring.TopologyListener): + """A simple listener that logs server topology events. + + Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, + :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.TopologyClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.TopologyOpenedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} opened") + + def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: + logging.info(f"Topology description updated for topology id {event.topology_id}") + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + f"Topology {event.topology_id} changed type from " + f"{event.previous_description.topology_type_name} to " + f"{event.new_description.topology_type_name}" + ) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event: monitoring.TopologyClosedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} closed") + + +class ConnectionPoolLogger(monitoring.ConnectionPoolListener): + """A simple listener that logs server connection pool events. + + Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, + :class:`~pymongo.monitoring.PoolClearedEvent`, + :class:`~pymongo.monitoring.PoolClosedEvent`, + :~pymongo.monitoring.class:`ConnectionCreatedEvent`, + :class:`~pymongo.monitoring.ConnectionReadyEvent`, + :class:`~pymongo.monitoring.ConnectionClosedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, + and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: + logging.info(f"[pool {event.address}] pool created") + + def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: + logging.info(f"[pool {event.address}] pool ready") + + def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: + logging.info(f"[pool {event.address}] pool cleared") + + def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: + logging.info(f"[pool {event.address}] pool closed") + + def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: + logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") + + def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" + ) + + def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] " + f'connection closed, reason: "{event.reason}"' + ) + + def connection_check_out_started( + self, event: monitoring.ConnectionCheckOutStartedEvent + ) -> None: + logging.info(f"[pool {event.address}] connection check out started") + + def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: + logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") + + def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" + ) + + def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" + ) diff --git a/venv/Lib/site-packages/pymongo/hello.py b/venv/Lib/site-packages/pymongo/hello.py new file mode 100644 index 00000000..0f6d7a39 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/hello.py @@ -0,0 +1,224 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for the 'hello' and legacy hello commands.""" +from __future__ import annotations + +import copy +import datetime +import itertools +from typing import Any, Generic, Mapping, Optional + +from bson.objectid import ObjectId +from pymongo import common +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import ClusterTime, _DocumentType + + +class HelloCompat: + CMD = "hello" + LEGACY_CMD = "ismaster" + PRIMARY = "isWritablePrimary" + LEGACY_PRIMARY = "ismaster" + LEGACY_ERROR = "not master" + + +def _get_server_type(doc: Mapping[str, Any]) -> int: + """Determine the server type from a hello response.""" + if not doc.get("ok"): + return SERVER_TYPE.Unknown + + if doc.get("serviceId"): + return SERVER_TYPE.LoadBalancer + elif doc.get("isreplicaset"): + return SERVER_TYPE.RSGhost + elif doc.get("setName"): + if doc.get("hidden"): + return SERVER_TYPE.RSOther + elif doc.get(HelloCompat.PRIMARY): + return SERVER_TYPE.RSPrimary + elif doc.get(HelloCompat.LEGACY_PRIMARY): + return SERVER_TYPE.RSPrimary + elif doc.get("secondary"): + return SERVER_TYPE.RSSecondary + elif doc.get("arbiterOnly"): + return SERVER_TYPE.RSArbiter + else: + return SERVER_TYPE.RSOther + elif doc.get("msg") == "isdbgrid": + return SERVER_TYPE.Mongos + else: + return SERVER_TYPE.Standalone + + +class Hello(Generic[_DocumentType]): + """Parse a hello response from the server. + + .. versionadded:: 3.12 + """ + + __slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable") + + def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None: + self._server_type = _get_server_type(doc) + self._doc: _DocumentType = doc + self._is_writable = self._server_type in ( + SERVER_TYPE.RSPrimary, + SERVER_TYPE.Standalone, + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + + self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable + self._awaitable = awaitable + + @property + def document(self) -> _DocumentType: + """The complete hello command response document. + + .. versionadded:: 3.4 + """ + return copy.copy(self._doc) + + @property + def server_type(self) -> int: + return self._server_type + + @property + def all_hosts(self) -> set[tuple[str, int]]: + """List of hosts, passives, and arbiters known to this server.""" + return set( + map( + common.clean_node, + itertools.chain( + self._doc.get("hosts", []), + self._doc.get("passives", []), + self._doc.get("arbiters", []), + ), + ) + ) + + @property + def tags(self) -> Mapping[str, Any]: + """Replica set member tags or empty dict.""" + return self._doc.get("tags", {}) + + @property + def primary(self) -> Optional[tuple[str, int]]: + """This server's opinion about who the primary is, or None.""" + if self._doc.get("primary"): + return common.partition_node(self._doc["primary"]) + else: + return None + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self._doc.get("setName") + + @property + def max_bson_size(self) -> int: + return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE) + + @property + def max_message_size(self) -> int: + return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size) + + @property + def max_write_batch_size(self) -> int: + return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE) + + @property + def min_wire_version(self) -> int: + return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION) + + @property + def max_wire_version(self) -> int: + return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION) + + @property + def set_version(self) -> Optional[int]: + return self._doc.get("setVersion") + + @property + def election_id(self) -> Optional[ObjectId]: + return self._doc.get("electionId") + + @property + def cluster_time(self) -> Optional[ClusterTime]: + return self._doc.get("$clusterTime") + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + return self._doc.get("logicalSessionTimeoutMinutes") + + @property + def is_writable(self) -> bool: + return self._is_writable + + @property + def is_readable(self) -> bool: + return self._is_readable + + @property + def me(self) -> Optional[tuple[str, int]]: + me = self._doc.get("me") + if me: + return common.clean_node(me) + return None + + @property + def last_write_date(self) -> Optional[datetime.datetime]: + return self._doc.get("lastWrite", {}).get("lastWriteDate") + + @property + def compressors(self) -> Optional[list[str]]: + return self._doc.get("compression") + + @property + def sasl_supported_mechs(self) -> list[str]: + """Supported authentication mechanisms for the current user. + + For example:: + + >>> hello.sasl_supported_mechs + ["SCRAM-SHA-1", "SCRAM-SHA-256"] + + """ + return self._doc.get("saslSupportedMechs", []) + + @property + def speculative_authenticate(self) -> Optional[Mapping[str, Any]]: + """The speculativeAuthenticate field.""" + return self._doc.get("speculativeAuthenticate") + + @property + def topology_version(self) -> Optional[Mapping[str, Any]]: + return self._doc.get("topologyVersion") + + @property + def awaitable(self) -> bool: + return self._awaitable + + @property + def service_id(self) -> Optional[ObjectId]: + return self._doc.get("serviceId") + + @property + def hello_ok(self) -> bool: + return self._doc.get("helloOk", False) + + @property + def connection_id(self) -> Optional[int]: + return self._doc.get("connectionId") diff --git a/venv/Lib/site-packages/pymongo/helpers.py b/venv/Lib/site-packages/pymongo/helpers.py new file mode 100644 index 00000000..916d78a3 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/helpers.py @@ -0,0 +1,350 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bits and pieces used by the driver that don't really fit elsewhere.""" +from __future__ import annotations + +import sys +import traceback +from collections import abc +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Container, + Iterable, + Mapping, + NoReturn, + Optional, + Sequence, + TypeVar, + Union, + cast, +) + +from pymongo import ASCENDING +from pymongo.errors import ( + CursorNotFound, + DuplicateKeyError, + ExecutionTimeout, + NotPrimaryError, + OperationFailure, + WriteConcernError, + WriteError, + WTimeoutError, + _wtimeout_error, +) +from pymongo.hello import HelloCompat + +if TYPE_CHECKING: + from pymongo.cursor import _Hint + from pymongo.operations import _IndexList + from pymongo.typings import _DocumentOut + +# From the SDAM spec, the "node is shutting down" codes. +_SHUTDOWN_CODES: frozenset = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) +# From the SDAM spec, the "not primary" error codes are combined with the +# "node is recovering" error codes (of which the "node is shutting down" +# errors are a subset). +_NOT_PRIMARY_CODES: frozenset = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotWritablePrimary + 13435, # NotPrimaryNoSecondaryOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotPrimaryOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) +# From the retryable writes spec. +_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + 262, # ExceededTimeLimit + 134, # ReadConcernMajorityNotAvailableYet + ] +) + +# Server code raised when re-authentication is required +_REAUTHENTICATION_REQUIRED_CODE: int = 391 + +# Server code raised when authentication fails. +_AUTHENTICATION_FAILURE_CODE: int = 18 + + +def _gen_index_name(keys: _IndexList) -> str: + """Generate an index name from the set of fields it is over.""" + return "_".join(["{}_{}".format(*item) for item in keys]) + + +def _index_list( + key_or_list: _Hint, direction: Optional[Union[int, str]] = None +) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: + """Helper to generate a list of (key, direction) pairs. + + Takes such a list, or a single key, or a single key and direction. + """ + if direction is not None: + if not isinstance(key_or_list, str): + raise TypeError("Expected a string and a direction") + return [(key_or_list, direction)] + else: + if isinstance(key_or_list, str): + return [(key_or_list, ASCENDING)] + elif isinstance(key_or_list, abc.ItemsView): + return list(key_or_list) # type: ignore[arg-type] + elif isinstance(key_or_list, abc.Mapping): + return list(key_or_list.items()) + elif not isinstance(key_or_list, (list, tuple)): + raise TypeError("if no direction is specified, key_or_list must be an instance of list") + values: list[tuple[str, int]] = [] + for item in key_or_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + values.append(item) + return values + + +def _index_document(index_list: _IndexList) -> dict[str, Any]: + """Helper to generate an index specifying document. + + Takes a list of (key, direction) pairs. + """ + if not isinstance(index_list, (list, tuple, abc.Mapping)): + raise TypeError( + "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) + ) + if not len(index_list): + raise ValueError("key_or_list must not be empty") + + index: dict[str, Any] = {} + + if isinstance(index_list, abc.Mapping): + for key in index_list: + value = index_list[key] + _validate_index_key_pair(key, value) + index[key] = value + else: + for item in index_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + key, value = item + _validate_index_key_pair(key, value) + index[key] = value + return index + + +def _validate_index_key_pair(key: Any, value: Any) -> None: + if not isinstance(key, str): + raise TypeError("first item in each key pair must be an instance of str") + if not isinstance(value, (str, int, abc.Mapping)): + raise TypeError( + "second item in each key pair must be 1, -1, " + "'2d', or another valid MongoDB index specifier." + ) + + +def _check_command_response( + response: _DocumentOut, + max_wire_version: Optional[int], + allowable_errors: Optional[Container[Union[int, str]]] = None, + parse_write_concern_error: bool = False, +) -> None: + """Check the response to a command for errors.""" + if "ok" not in response: + # Server didn't recognize our message as a command. + raise OperationFailure( + response.get("$err"), # type: ignore[arg-type] + response.get("code"), + response, + max_wire_version, + ) + + if parse_write_concern_error and "writeConcernError" in response: + _error = response["writeConcernError"] + _labels = response.get("errorLabels") + if _labels: + _error.update({"errorLabels": _labels}) + _raise_write_concern_error(_error) + + if response["ok"]: + return + + details = response + # Mongos returns the error details in a 'raw' object + # for some errors. + if "raw" in response: + for shard in response["raw"].values(): + # Grab the first non-empty raw error from a shard. + if shard.get("errmsg") and not shard.get("ok"): + details = shard + break + + errmsg = details["errmsg"] + code = details.get("code") + + # For allowable errors, only check for error messages when the code is not + # included. + if allowable_errors: + if code is not None: + if code in allowable_errors: + return + elif errmsg in allowable_errors: + return + + # Server is "not primary" or "recovering" + if code is not None: + if code in _NOT_PRIMARY_CODES: + raise NotPrimaryError(errmsg, response) + elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: + raise NotPrimaryError(errmsg, response) + + # Other errors + # findAndModify with upsert can raise duplicate key error + if code in (11000, 11001, 12582): + raise DuplicateKeyError(errmsg, code, response, max_wire_version) + elif code == 50: + raise ExecutionTimeout(errmsg, code, response, max_wire_version) + elif code == 43: + raise CursorNotFound(errmsg, code, response, max_wire_version) + + raise OperationFailure(errmsg, code, response, max_wire_version) + + +def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: + # If the last batch had multiple errors only report + # the last error to emulate continue_on_error. + error = write_errors[-1] + if error.get("code") == 11000: + raise DuplicateKeyError(error.get("errmsg"), 11000, error) + raise WriteError(error.get("errmsg"), error.get("code"), error) + + +def _raise_write_concern_error(error: Any) -> NoReturn: + if _wtimeout_error(error): + # Make sure we raise WTimeoutError + raise WTimeoutError(error.get("errmsg"), error.get("code"), error) + raise WriteConcernError(error.get("errmsg"), error.get("code"), error) + + +def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: + """Return the writeConcernError or None.""" + wce = result.get("writeConcernError") + if wce: + # The server reports errorLabels at the top level but it's more + # convenient to attach it to the writeConcernError doc itself. + error_labels = result.get("errorLabels") + if error_labels: + # Copy to avoid changing the original document. + wce = wce.copy() + wce["errorLabels"] = error_labels + return wce + + +def _check_write_command_response(result: Mapping[str, Any]) -> None: + """Backward compatibility helper for write command error handling.""" + # Prefer write errors over write concern errors + write_errors = result.get("writeErrors") + if write_errors: + _raise_last_write_error(write_errors) + + wce = _get_wce_doc(result) + if wce: + _raise_write_concern_error(wce) + + +def _fields_list_to_dict( + fields: Union[Mapping[str, Any], Iterable[str]], option_name: str +) -> Mapping[str, Any]: + """Takes a sequence of field names and returns a matching dictionary. + + ["a", "b"] becomes {"a": 1, "b": 1} + + and + + ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} + """ + if isinstance(fields, abc.Mapping): + return fields + + if isinstance(fields, (abc.Sequence, abc.Set)): + if not all(isinstance(field, str) for field in fields): + raise TypeError(f"{option_name} must be a list of key names, each an instance of str") + return dict.fromkeys(fields, 1) + + raise TypeError(f"{option_name} must be a mapping or list of key names") + + +def _handle_exception() -> None: + """Print exceptions raised by subscribers to stderr.""" + # Heavily influenced by logging.Handler.handleError. + + # See note here: + # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ + if sys.stderr: + einfo = sys.exc_info() + try: + traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) + except OSError: + pass + finally: + del einfo + + +# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories +F = TypeVar("F", bound=Callable[..., Any]) + + +def _handle_reauth(func: F) -> F: + def inner(*args: Any, **kwargs: Any) -> Any: + no_reauth = kwargs.pop("no_reauth", False) + from pymongo.message import _BulkWriteContext + from pymongo.pool import Connection + + try: + return func(*args, **kwargs) + except OperationFailure as exc: + if no_reauth: + raise + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + # Look for an argument that either is a Connection + # or has a connection attribute, so we can trigger + # a reauth. + conn = None + for arg in args: + if isinstance(arg, Connection): + conn = arg + break + if isinstance(arg, _BulkWriteContext): + conn = arg.conn + break + if conn: + conn.authenticate(reauthenticate=True) + else: + raise + return func(*args, **kwargs) + raise + + return cast(F, inner) diff --git a/venv/Lib/site-packages/pymongo/lock.py b/venv/Lib/site-packages/pymongo/lock.py new file mode 100644 index 00000000..e3747850 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/lock.py @@ -0,0 +1,40 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import threading +import weakref + +_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") + +# References to instances of _create_lock +_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() + + +def _create_lock() -> threading.Lock: + """Represents a lock that is tracked upon instantiation using a WeakSet and + reset by pymongo upon forking. + """ + lock = threading.Lock() + if _HAS_REGISTER_AT_FORK: + _forkable_locks.add(lock) + return lock + + +def _release_locks() -> None: + # Completed the fork, reset all the locks in the child. + for lock in _forkable_locks: + if lock.locked(): + lock.release() diff --git a/venv/Lib/site-packages/pymongo/logger.py b/venv/Lib/site-packages/pymongo/logger.py new file mode 100644 index 00000000..2caafa77 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/logger.py @@ -0,0 +1,169 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import enum +import logging +import os +import warnings +from typing import Any + +from bson import UuidRepresentation, json_util +from bson.json_util import JSONOptions, _truncate_documents +from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason + + +class _CommandStatusMessage(str, enum.Enum): + STARTED = "Command started" + SUCCEEDED = "Command succeeded" + FAILED = "Command failed" + + +class _ServerSelectionStatusMessage(str, enum.Enum): + STARTED = "Server selection started" + SUCCEEDED = "Server selection succeeded" + FAILED = "Server selection failed" + WAITING = "Waiting for suitable server to become available" + + +class _ConnectionStatusMessage(str, enum.Enum): + POOL_CREATED = "Connection pool created" + POOL_READY = "Connection pool ready" + POOL_CLOSED = "Connection pool closed" + POOL_CLEARED = "Connection pool cleared" + + CONN_CREATED = "Connection created" + CONN_READY = "Connection ready" + CONN_CLOSED = "Connection closed" + + CHECKOUT_STARTED = "Connection checkout started" + CHECKOUT_SUCCEEDED = "Connection checked out" + CHECKOUT_FAILED = "Connection checkout failed" + CHECKEDIN = "Connection checked in" + + +_DEFAULT_DOCUMENT_LENGTH = 1000 +_SENSITIVE_COMMANDS = [ + "authenticate", + "saslStart", + "saslContinue", + "getnonce", + "createUser", + "updateUser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +] +_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"] +_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"] +_DOCUMENT_NAMES = ["command", "reply", "failure"] +_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD) +_COMMAND_LOGGER = logging.getLogger("pymongo.command") +_CONNECTION_LOGGER = logging.getLogger("pymongo.connection") +_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection") +_CLIENT_LOGGER = logging.getLogger("pymongo.client") +_VERBOSE_CONNECTION_ERROR_REASONS = { + ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed", + ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed", + ConnectionClosedReason.STALE: "Connection pool was stale", + ConnectionClosedReason.ERROR: "An error occurred while using the connection", + ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection", + ConnectionClosedReason.IDLE: "Connection was idle too long", + ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout", +} + + +def _debug_log(logger: logging.Logger, **fields: Any) -> None: + logger.debug(LogMessage(**fields)) + + +def _verbose_connection_error_reason(reason: str) -> str: + return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason) + + +def _info_log(logger: logging.Logger, **fields: Any) -> None: + logger.info(LogMessage(**fields)) + + +def _log_or_warn(logger: logging.Logger, message: str) -> None: + if logger.isEnabledFor(logging.INFO): + logger.info(message) + else: + # stacklevel=4 ensures that the warning is for the user's code. + warnings.warn(message, UserWarning, stacklevel=4) + + +class LogMessage: + __slots__ = ("_kwargs", "_redacted") + + def __init__(self, **kwargs: Any): + self._kwargs = kwargs + self._redacted = False + + def __str__(self) -> str: + self._redact() + return "%s" % ( + json_util.dumps( + self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__() + ) + ) + + def _is_sensitive(self, doc_name: str) -> bool: + is_speculative_authenticate = ( + self._kwargs.pop("speculative_authenticate", False) + or "speculativeAuthenticate" in self._kwargs[doc_name] + ) + is_sensitive_command = ( + "commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS + ) + + is_sensitive_hello = ( + self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate + ) + + return is_sensitive_command or is_sensitive_hello + + def _redact(self) -> None: + if self._redacted: + return + self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None} + if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"): + self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000 + if "serviceId" in self._kwargs: + self._kwargs["serviceId"] = str(self._kwargs["serviceId"]) + document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH)) + if document_length < 0: + document_length = _DEFAULT_DOCUMENT_LENGTH + is_server_side_error = self._kwargs.pop("isServerSideError", False) + + for doc_name in _DOCUMENT_NAMES: + doc = self._kwargs.get(doc_name) + if doc: + if doc_name == "failure" and is_server_side_error: + doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS} + if doc_name != "failure" and self._is_sensitive(doc_name): + doc = json_util.dumps({}) + else: + truncated_doc = _truncate_documents(doc, document_length)[0] + doc = json_util.dumps( + truncated_doc, + json_options=_JSON_OPTIONS, + default=lambda o: o.__repr__(), + ) + if len(doc) > document_length: + doc = ( + doc.encode()[:document_length].decode("unicode-escape", "ignore") + ) + "..." + self._kwargs[doc_name] = doc + self._redacted = True diff --git a/venv/Lib/site-packages/pymongo/max_staleness_selectors.py b/venv/Lib/site-packages/pymongo/max_staleness_selectors.py new file mode 100644 index 00000000..72edf555 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/max_staleness_selectors.py @@ -0,0 +1,122 @@ +# Copyright 2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Criteria to select ServerDescriptions based on maxStalenessSeconds. + +The Max Staleness Spec says: When there is a known primary P, +a secondary S's staleness is estimated with this formula: + + (S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate) + + heartbeatFrequencyMS + +When there is no known primary, a secondary S's staleness is estimated with: + + SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS + +where "SMax" is the secondary with the greatest lastWriteDate. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pymongo.errors import ConfigurationError +from pymongo.server_type import SERVER_TYPE + +if TYPE_CHECKING: + from pymongo.server_selectors import Selection +# Constant defined in Max Staleness Spec: An idle primary writes a no-op every +# 10 seconds to refresh secondaries' lastWriteDate values. +IDLE_WRITE_PERIOD = 10 +SMALLEST_MAX_STALENESS = 90 + + +def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None: + # We checked for max staleness -1 before this, it must be positive here. + if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD: + raise ConfigurationError( + "maxStalenessSeconds must be at least heartbeatFrequencyMS +" + " %d seconds. maxStalenessSeconds is set to %d," + " heartbeatFrequencyMS is set to %d." + % (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000) + ) + + if max_staleness < SMALLEST_MAX_STALENESS: + raise ConfigurationError( + "maxStalenessSeconds must be at least %d. " + "maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness) + ) + + +def _with_primary(max_staleness: int, selection: Selection) -> Selection: + """Apply max_staleness, in seconds, to a Selection with a known primary.""" + primary = selection.primary + assert primary + sds = [] + + for s in selection.server_descriptions: + if s.server_type == SERVER_TYPE.RSSecondary: + # See max-staleness.rst for explanation of this formula. + assert s.last_write_date and primary.last_write_date # noqa: PT018 + staleness = ( + (s.last_update_time - s.last_write_date) + - (primary.last_update_time - primary.last_write_date) + + selection.heartbeat_frequency + ) + + if staleness <= max_staleness: + sds.append(s) + else: + sds.append(s) + + return selection.with_server_descriptions(sds) + + +def _no_primary(max_staleness: int, selection: Selection) -> Selection: + """Apply max_staleness, in seconds, to a Selection with no known primary.""" + # Secondary that's replicated the most recent writes. + smax = selection.secondary_with_max_last_write_date() + if not smax: + # No secondaries and no primary, short-circuit out of here. + return selection.with_server_descriptions([]) + + sds = [] + + for s in selection.server_descriptions: + if s.server_type == SERVER_TYPE.RSSecondary: + # See max-staleness.rst for explanation of this formula. + assert smax.last_write_date and s.last_write_date # noqa: PT018 + staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency + + if staleness <= max_staleness: + sds.append(s) + else: + sds.append(s) + + return selection.with_server_descriptions(sds) + + +def select(max_staleness: int, selection: Selection) -> Selection: + """Apply max_staleness, in seconds, to a Selection.""" + if max_staleness == -1: + return selection + + # Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or + # ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness < + # heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90. + _validate_max_staleness(max_staleness, selection.heartbeat_frequency) + + if selection.primary: + return _with_primary(max_staleness, selection) + else: + return _no_primary(max_staleness, selection) diff --git a/venv/Lib/site-packages/pymongo/message.py b/venv/Lib/site-packages/pymongo/message.py new file mode 100644 index 00000000..9412dc91 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/message.py @@ -0,0 +1,1753 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for creating `messages +`_ to be sent to +MongoDB. + +.. note:: This module is for internal use and is generally not needed by + application developers. +""" +from __future__ import annotations + +import datetime +import logging +import random +import struct +from io import BytesIO as _BytesIO +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Mapping, + MutableMapping, + NoReturn, + Optional, + Union, +) + +import bson +from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode +from bson.int64 import Int64 +from bson.raw_bson import ( + _RAW_ARRAY_BSON_OPTIONS, + DEFAULT_RAW_BSON_OPTIONS, + RawBSONDocument, + _inflate_bson, +) + +try: + from pymongo import _cmessage # type: ignore[attr-defined] + + _use_c = True +except ImportError: + _use_c = False +from pymongo.errors import ( + ConfigurationError, + CursorNotFound, + DocumentTooLarge, + ExecutionTimeout, + InvalidOperation, + NotPrimaryError, + OperationFailure, + ProtocolError, +) +from pymongo.hello import HelloCompat +from pymongo.helpers import _handle_reauth +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.read_preferences import ReadPreference +from pymongo.write_concern import WriteConcern + +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.client_session import ClientSession + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.mongo_client import MongoClient + from pymongo.monitoring import _EventListeners + from pymongo.pool import Connection + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _DocumentOut + +MAX_INT32 = 2147483647 +MIN_INT32 = -2147483648 + +# Overhead allowed for encoded command documents. +_COMMAND_OVERHEAD = 16382 + +_INSERT = 0 +_UPDATE = 1 +_DELETE = 2 + +_EMPTY = b"" +_BSONOBJ = b"\x03" +_ZERO_8 = b"\x00" +_ZERO_16 = b"\x00\x00" +_ZERO_32 = b"\x00\x00\x00\x00" +_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" +_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" +_OP_MAP = { + _INSERT: b"\x04documents\x00\x00\x00\x00\x00", + _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", + _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", +} +_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} + +_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( + unicode_decode_error_handler="replace" +) + + +def _randint() -> int: + """Generate a pseudo random 32 bit integer.""" + return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 + + +def _maybe_add_read_preference( + spec: MutableMapping[str, Any], read_preference: _ServerMode +) -> MutableMapping[str, Any]: + """Add $readPreference to spec when appropriate.""" + mode = read_preference.mode + document = read_preference.document + # Only add $readPreference if it's something other than primary to avoid + # problems with mongos versions that don't support read preferences. Also, + # for maximum backwards compatibility, don't add $readPreference for + # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting + # the secondaryOkay bit has the same effect). + if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): + if "$query" not in spec: + spec = {"$query": spec} + spec["$readPreference"] = document + return spec + + +def _convert_exception(exception: Exception) -> dict[str, Any]: + """Convert an Exception into a failure document for publishing.""" + return {"errmsg": str(exception), "errtype": exception.__class__.__name__} + + +def _convert_write_result( + operation: str, command: Mapping[str, Any], result: Mapping[str, Any] +) -> dict[str, Any]: + """Convert a legacy write result to write command format.""" + # Based on _merge_legacy from bulk.py + affected = result.get("n", 0) + res = {"ok": 1, "n": affected} + errmsg = result.get("errmsg", result.get("err", "")) + if errmsg: + # The write was successful on at least the primary so don't return. + if result.get("wtimeout"): + res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} + else: + # The write failed. + error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} + if "errInfo" in result: + error["errInfo"] = result["errInfo"] + res["writeErrors"] = [error] + return res + if operation == "insert": + # GLE result for insert is always 0 in most MongoDB versions. + res["n"] = len(command["documents"]) + elif operation == "update": + if "upserted" in result: + res["upserted"] = [{"index": 0, "_id": result["upserted"]}] + # Versions of MongoDB before 2.6 don't return the _id for an + # upsert if _id is not an ObjectId. + elif result.get("updatedExisting") is False and affected == 1: + # If _id is in both the update document *and* the query spec + # the update document _id takes precedence. + update = command["updates"][0] + _id = update["u"].get("_id", update["q"].get("_id")) + res["upserted"] = [{"index": 0, "_id": _id}] + return res + + +_OPTIONS = { + "tailable": 2, + "oplogReplay": 8, + "noCursorTimeout": 16, + "awaitData": 32, + "allowPartialResults": 128, +} + + +_MODIFIERS = { + "$query": "filter", + "$orderby": "sort", + "$hint": "hint", + "$comment": "comment", + "$maxScan": "maxScan", + "$maxTimeMS": "maxTimeMS", + "$max": "max", + "$min": "min", + "$returnKey": "returnKey", + "$showRecordId": "showRecordId", + "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 + "$snapshot": "snapshot", +} + + +def _gen_find_command( + coll: str, + spec: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]], + skip: int, + limit: int, + batch_size: Optional[int], + options: Optional[int], + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]] = None, + session: Optional[ClientSession] = None, + allow_disk_use: Optional[bool] = None, +) -> dict[str, Any]: + """Generate a find command document.""" + cmd: dict[str, Any] = {"find": coll} + if "$query" in spec: + cmd.update( + [ + (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) + for key, val in spec.items() + ] + ) + if "$explain" in cmd: + cmd.pop("$explain") + if "$readPreference" in cmd: + cmd.pop("$readPreference") + else: + cmd["filter"] = spec + + if projection: + cmd["projection"] = projection + if skip: + cmd["skip"] = skip + if limit: + cmd["limit"] = abs(limit) + if limit < 0: + cmd["singleBatch"] = True + if batch_size: + cmd["batchSize"] = batch_size + if read_concern.level and not (session and session.in_transaction): + cmd["readConcern"] = read_concern.document + if collation: + cmd["collation"] = collation + if allow_disk_use is not None: + cmd["allowDiskUse"] = allow_disk_use + if options: + cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) + + return cmd + + +def _gen_get_more_command( + cursor_id: Optional[int], + coll: str, + batch_size: Optional[int], + max_await_time_ms: Optional[int], + comment: Optional[Any], + conn: Connection, +) -> dict[str, Any]: + """Generate a getMore command document.""" + cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} + if batch_size: + cmd["batchSize"] = batch_size + if max_await_time_ms is not None: + cmd["maxTimeMS"] = max_await_time_ms + if comment is not None and conn.max_wire_version >= 9: + cmd["comment"] = comment + return cmd + + +class _Query: + """A query operation.""" + + __slots__ = ( + "flags", + "db", + "coll", + "ntoskip", + "spec", + "fields", + "codec_options", + "read_preference", + "limit", + "batch_size", + "name", + "read_concern", + "collation", + "session", + "client", + "allow_disk_use", + "_as_command", + "exhaust", + ) + + # For compatibility with the _GetMore class. + conn_mgr = None + cursor_id = None + + def __init__( + self, + flags: int, + db: str, + coll: str, + ntoskip: int, + spec: Mapping[str, Any], + fields: Optional[Mapping[str, Any]], + codec_options: CodecOptions, + read_preference: _ServerMode, + limit: int, + batch_size: int, + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]], + session: Optional[ClientSession], + client: MongoClient, + allow_disk_use: Optional[bool], + exhaust: bool, + ): + self.flags = flags + self.db = db + self.coll = coll + self.ntoskip = ntoskip + self.spec = spec + self.fields = fields + self.codec_options = codec_options + self.read_preference = read_preference + self.read_concern = read_concern + self.limit = limit + self.batch_size = batch_size + self.collation = collation + self.session = session + self.client = client + self.allow_disk_use = allow_disk_use + self.name = "find" + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: Connection) -> bool: + use_find_cmd = False + if not self.exhaust: + use_find_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_find_cmd = True + elif not self.read_concern.ok_for_legacy: + raise ConfigurationError( + "read concern level of %s is not valid " + "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) + ) + + conn.validate_session(self.client, self.session) + return use_find_cmd + + def as_command( + self, conn: Connection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a find command document for this query.""" + # We use the command twice: on the wire and for command monitoring. + # Generate it once, for speed and to avoid repeating side-effects. + if self._as_command is not None: + return self._as_command + + explain = "$explain" in self.spec + cmd: dict[str, Any] = _gen_find_command( + self.coll, + self.spec, + self.fields, + self.ntoskip, + self.limit, + self.batch_size, + self.flags, + self.read_concern, + self.collation, + self.session, + self.allow_disk_use, + ) + if explain: + self.name = "explain" + cmd = {"explain": cmd} + session = self.session + conn.add_server_api(cmd) + if session: + session._apply_to(cmd, False, self.read_preference, conn) + # Explain does not support readConcern. + if not explain and not session.in_transaction: + session._update_read_concern(cmd, conn) + conn.send_cluster_time(cmd, session, self.client) + # Support auto encryption + client = self.client + if client._encrypter and not client._encrypter._bypass_auto_encryption: + cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) + # Support CSOT + if apply_timeout: + conn.apply_timeout(client, cmd) + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False + ) -> tuple[int, bytes, int]: + """Get a query message, possibly setting the secondaryOk bit.""" + # Use the read_preference decided by _socket_from_server. + self.read_preference = read_preference + if read_preference.mode: + # Set the secondaryOk bit. + flags = self.flags | 4 + else: + flags = self.flags + + ns = self.namespace() + spec = self.spec + + if use_cmd: + spec = self.as_command(conn, apply_timeout=True)[0] + request_id, msg, size, _ = _op_msg( + 0, + spec, + self.db, + read_preference, + self.codec_options, + ctx=conn.compression_context, + ) + return request_id, msg, size + + # OP_QUERY treats ntoreturn of -1 and 1 the same, return + # one document and close the cursor. We have to use 2 for + # batch size if 1 is specified. + ntoreturn = self.batch_size == 1 and 2 or self.batch_size + if self.limit: + if ntoreturn: + ntoreturn = min(self.limit, ntoreturn) + else: + ntoreturn = self.limit + + if conn.is_mongos: + assert isinstance(spec, MutableMapping) + spec = _maybe_add_read_preference(spec, read_preference) + + return _query( + flags, + ns, + self.ntoskip, + ntoreturn, + spec, + None if use_cmd else self.fields, + self.codec_options, + ctx=conn.compression_context, + ) + + +class _GetMore: + """A getmore operation.""" + + __slots__ = ( + "db", + "coll", + "ntoreturn", + "cursor_id", + "max_await_time_ms", + "codec_options", + "read_preference", + "session", + "client", + "conn_mgr", + "_as_command", + "exhaust", + "comment", + ) + + name = "getMore" + + def __init__( + self, + db: str, + coll: str, + ntoreturn: int, + cursor_id: int, + codec_options: CodecOptions, + read_preference: _ServerMode, + session: Optional[ClientSession], + client: MongoClient, + max_await_time_ms: Optional[int], + conn_mgr: Any, + exhaust: bool, + comment: Any, + ): + self.db = db + self.coll = coll + self.ntoreturn = ntoreturn + self.cursor_id = cursor_id + self.codec_options = codec_options + self.read_preference = read_preference + self.session = session + self.client = client + self.max_await_time_ms = max_await_time_ms + self.conn_mgr = conn_mgr + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + self.comment = comment + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: Connection) -> bool: + use_cmd = False + if not self.exhaust: + use_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_cmd = True + + conn.validate_session(self.client, self.session) + return use_cmd + + def as_command( + self, conn: Connection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a getMore command document for this query.""" + # See _Query.as_command for an explanation of this caching. + if self._as_command is not None: + return self._as_command + + cmd: dict[str, Any] = _gen_get_more_command( + self.cursor_id, + self.coll, + self.ntoreturn, + self.max_await_time_ms, + self.comment, + conn, + ) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, self.session, self.client) + # Support auto encryption + client = self.client + if client._encrypter and not client._encrypter._bypass_auto_encryption: + cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) + # Support CSOT + if apply_timeout: + conn.apply_timeout(client, cmd=None) + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, dummy0: Any, conn: Connection, use_cmd: bool = False + ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: + """Get a getmore message.""" + ns = self.namespace() + ctx = conn.compression_context + + if use_cmd: + spec = self.as_command(conn, apply_timeout=True)[0] + if self.conn_mgr and self.exhaust: + flags = _OpMsg.EXHAUST_ALLOWED + else: + flags = 0 + request_id, msg, size, _ = _op_msg( + flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context + ) + return request_id, msg, size + + return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) + + +class _RawBatchQuery(_Query): + def use_command(self, conn: Connection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _RawBatchGetMore(_GetMore): + def use_command(self, conn: Connection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _CursorAddress(tuple): + """The server address (host, port) of a cursor, with namespace property.""" + + __namespace: Any + + def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: + self = tuple.__new__(cls, address) + self.__namespace = namespace + return self + + @property + def namespace(self) -> str: + """The namespace this cursor.""" + return self.__namespace + + def __hash__(self) -> int: + # Two _CursorAddress instances with different namespaces + # must not hash the same. + return ((*self, self.__namespace)).__hash__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, _CursorAddress): + return tuple(self) == tuple(other) and self.namespace == other.namespace + return NotImplemented + + def __ne__(self, other: object) -> bool: + return not self == other + + +_pack_compression_header = struct.Struct(" tuple[int, bytes]: + """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" + compressed = ctx.compress(data) + request_id = _randint() + + header = _pack_compression_header( + _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length + request_id, # Request id + 0, # responseTo + 2012, # operation id + operation, # original operation id + len(data), # uncompressed message length + ctx.compressor_id, + ) # compressor id + return request_id, header + compressed + + +_pack_header = struct.Struct(" tuple[int, bytes]: + """Takes message data and adds a message header based on the operation. + + Returns the resultant message string. + """ + rid = _randint() + message = _pack_header(16 + len(data), rid, 0, operation) + return rid, message + data + + +_pack_int = struct.Struct(" tuple[bytes, int, int]: + """Get a OP_MSG message. + + Note: this method handles multiple documents in a type one payload but + it does not perform batch splitting and the total message size is + only checked *after* generating the entire message. + """ + # Encode the command document in payload 0 without checking keys. + encoded = _dict_to_bson(command, False, opts) + flags_type = _pack_op_msg_flags_type(flags, 0) + total_size = len(encoded) + max_doc_size = 0 + if identifier and docs is not None: + type_one = _pack_byte(1) + cstring = _make_c_string(identifier) + encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] + size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 + encoded_size = _pack_int(size) + total_size += size + max_doc_size = max(len(doc) for doc in encoded_docs) + data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] + else: + data = [flags_type, encoded] + return b"".join(data), total_size, max_doc_size + + +def _op_msg_compressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int, int]: + """Internal OP_MSG message helper.""" + msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + rid, msg = _compress(2013, msg, ctx) + return rid, msg, total_size, max_bson_size + + +def _op_msg_uncompressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, +) -> tuple[int, bytes, int, int]: + """Internal compressed OP_MSG message helper.""" + data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + request_id, op_message = __pack_message(2013, data) + return request_id, op_message, total_size, max_bson_size + + +if _use_c: + _op_msg_uncompressed = _cmessage._op_msg + + +def _op_msg( + flags: int, + command: MutableMapping[str, Any], + dbname: str, + read_preference: Optional[_ServerMode], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes, int, int]: + """Get a OP_MSG message.""" + command["$db"] = dbname + # getMore commands do not send $readPreference. + if read_preference is not None and "$readPreference" not in command: + # Only send $readPreference if it's not primary (the default). + if read_preference.mode: + command["$readPreference"] = read_preference.document + name = next(iter(command)) + try: + identifier = _FIELD_MAP[name] + docs = command.pop(identifier) + except KeyError: + identifier = "" + docs = None + try: + if ctx: + return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) + return _op_msg_uncompressed(flags, command, identifier, docs, opts) + finally: + # Add the field back to the command. + if identifier: + command[identifier] = docs + + +def _query_impl( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[bytes, int]: + """Get an OP_QUERY message.""" + encoded = _dict_to_bson(query, False, opts) + if field_selector: + efs = _dict_to_bson(field_selector, False, opts) + else: + efs = b"" + max_bson_size = max(len(encoded), len(efs)) + return ( + b"".join( + [ + _pack_int(options), + _make_c_string(collection_name), + _pack_int(num_to_skip), + _pack_int(num_to_return), + encoded, + efs, + ] + ), + max_bson_size, + ) + + +def _query_compressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int]: + """Internal compressed query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = _compress(2004, op_query, ctx) + return rid, msg, max_bson_size + + +def _query_uncompressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[int, bytes, int]: + """Internal query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = __pack_message(2004, op_query) + return rid, msg, max_bson_size + + +if _use_c: + _query_uncompressed = _cmessage._query_message + + +def _query( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes, int]: + """Get a **query** message.""" + if ctx: + return _query_compressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx + ) + return _query_uncompressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + + +_pack_long_long = struct.Struct(" bytes: + """Get an OP_GET_MORE message.""" + return b"".join( + [ + _ZERO_32, + _make_c_string(collection_name), + _pack_int(num_to_return), + _pack_long_long(cursor_id), + ] + ) + + +def _get_more_compressed( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes]: + """Internal compressed getMore message helper.""" + return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) + + +def _get_more_uncompressed( + collection_name: str, num_to_return: int, cursor_id: int +) -> tuple[int, bytes]: + """Internal getMore message helper.""" + return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) + + +if _use_c: + _get_more_uncompressed = _cmessage._get_more_message + + +def _get_more( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes]: + """Get a **getMore** message.""" + if ctx: + return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) + return _get_more_uncompressed(collection_name, num_to_return, cursor_id) + + +class _BulkWriteContext: + """A wrapper around Connection for use with write splitting functions.""" + + __slots__ = ( + "db_name", + "conn", + "op_id", + "name", + "field", + "publish", + "start_time", + "listeners", + "session", + "compress", + "op_type", + "codec", + ) + + def __init__( + self, + database_name: str, + cmd_name: str, + conn: Connection, + operation_id: int, + listeners: _EventListeners, + session: ClientSession, + op_type: int, + codec: CodecOptions, + ): + self.db_name = database_name + self.conn = conn + self.op_id = operation_id + self.listeners = listeners + self.publish = listeners.enabled_for_commands + self.name = cmd_name + self.field = _FIELD_MAP[self.name] + self.start_time = datetime.datetime.now() + self.session = session + self.compress = bool(conn.compression_context) + self.op_type = op_type + self.codec = codec + + def __batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, bytes, list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + request_id, msg, to_send = _do_batched_op_msg( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + return request_id, msg, to_send + + def execute( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: + request_id, msg, to_send = self.__batch_command(cmd, docs) + result = self.write_command(cmd, request_id, msg, to_send, client) + client._process_response(result, self.session) + return result, to_send + + def execute_unack( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> list[Mapping[str, Any]]: + request_id, msg, to_send = self.__batch_command(cmd, docs) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + self.unack_write(cmd, request_id, msg, 0, to_send, client) + return to_send + + @property + def max_bson_size(self) -> int: + """A proxy for SockInfo.max_bson_size.""" + return self.conn.max_bson_size + + @property + def max_message_size(self) -> int: + """A proxy for SockInfo.max_message_size.""" + if self.compress: + # Subtract 16 bytes for the message header. + return self.conn.max_message_size - 16 + return self.conn.max_message_size + + @property + def max_write_batch_size(self) -> int: + """A proxy for SockInfo.max_write_batch_size.""" + return self.conn.max_write_batch_size + + @property + def max_split_size(self) -> int: + """The maximum size of a BSON command before batch splitting.""" + return self.max_bson_size + + def unack_write( + self, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> Optional[Mapping[str, Any]]: + """A proxy for Connection.unack_write that handles event publishing.""" + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + ) + if self.publish: + cmd = self._start(cmd, request_id, docs) + try: + result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] + duration = datetime.datetime.now() - self.start_time + if result is not None: + reply = _convert_write_result(self.name, cmd, result) + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + ) + if self.publish: + self._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - self.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if self.publish: + assert self.start_time is not None + self._fail(request_id, failure, duration) + raise + finally: + self.start_time = datetime.datetime.now() + return result + + @_handle_reauth + def write_command( + self, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> dict[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" + cmd[self.field] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + ) + if self.publish: + self._start(cmd, request_id, docs) + try: + reply = self.conn.write_command(request_id, msg, self.codec) + duration = datetime.datetime.now() - self.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + ) + if self.publish: + self._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - self.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=self.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=self.conn.id, + serverConnectionId=self.conn.server_connection_id, + serverHost=self.conn.address[0], + serverPort=self.conn.address[1], + serviceId=self.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if self.publish: + self._fail(request_id, failure, duration) + raise + finally: + self.start_time = datetime.datetime.now() + return reply + + def _start( + self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] + ) -> MutableMapping[str, Any]: + """Publish a CommandStartedEvent.""" + cmd[self.field] = docs + self.listeners.publish_command_start( + cmd, + self.db_name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + ) + return cmd + + def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: + """Publish a CommandSucceededEvent.""" + self.listeners.publish_command_success( + duration, + reply, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, + ) + + def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: + """Publish a CommandFailedEvent.""" + self.listeners.publish_command_failure( + duration, + failure, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, + ) + + +# From the Client Side Encryption spec: +# Because automatic encryption increases the size of commands, the driver +# MUST split bulk writes at a reduced size limit before undergoing automatic +# encryption. The write payload MUST be split at 2MiB (2097152). +_MAX_SPLIT_SIZE_ENC = 2097152 + + +class _EncryptedBulkWriteContext(_BulkWriteContext): + __slots__ = () + + def __batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + msg, to_send = _encode_batched_write_command( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + + # Chop off the OP_QUERY header to get a properly batched write command. + cmd_start = msg.index(b"\x00", 4) + 9 + outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) + return outgoing, to_send + + def execute( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: + batched_cmd, to_send = self.__batch_command(cmd, docs) + result: Mapping[str, Any] = self.conn.command( + self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client + ) + return result, to_send + + def execute_unack( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient + ) -> list[Mapping[str, Any]]: + batched_cmd, to_send = self.__batch_command(cmd, docs) + self.conn.command( + self.db_name, + batched_cmd, + write_concern=WriteConcern(w=0), + session=self.session, + client=client, + ) + return to_send + + @property + def max_split_size(self) -> int: + """Reduce the batch splitting size.""" + return _MAX_SPLIT_SIZE_ENC + + +def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: + """Internal helper for raising DocumentTooLarge.""" + if operation == "insert": + raise DocumentTooLarge( + "BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % (doc_size, max_size) + ) + else: + # There's nothing intelligent we can say + # about size for update and delete + raise DocumentTooLarge(f"{operation!r} command document too large") + + +# OP_MSG ------------------------------------------------------------- + + +_OP_MSG_MAP = { + _INSERT: b"documents\x00", + _UPDATE: b"updates\x00", + _DELETE: b"deletes\x00", +} + + +def _batched_op_msg_impl( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_MSG write.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + max_message_size = ctx.max_message_size + + flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" + # Flags + buf.write(flags) + + # Type 0 Section + buf.write(b"\x00") + buf.write(_dict_to_bson(command, False, opts)) + + # Type 1 Section + buf.write(b"\x01") + size_location = buf.tell() + # Save space for size + buf.write(b"\x00\x00\x00\x00") + try: + buf.write(_OP_MSG_MAP[operation]) + except KeyError: + raise InvalidOperation("Unknown command") from None + + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + value = _dict_to_bson(doc, False, opts) + doc_length = len(value) + new_message_size = buf.tell() + doc_length + # Does first document exceed max_message_size? + doc_too_large = idx == 0 and (new_message_size > max_message_size) + # When OP_MSG is used unacknowledged we have to check + # document size client side or applications won't be notified. + # Otherwise we let the server deal with documents that are too large + # since ordered=False causes those documents to be skipped instead of + # halting the bulk write operation. + unacked_doc_too_large = not ack and (doc_length > max_bson_size) + if doc_too_large or unacked_doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + # We have enough data, return this batch. + if new_message_size > max_message_size: + break + buf.write(value) + to_send.append(doc) + idx += 1 + # We have enough documents, return this batch. + if idx == max_write_batch_size: + break + + # Write type 1 section size + length = buf.tell() + buf.seek(size_location) + buf.write(_pack_int(length - size_location)) + + return to_send, length + + +def _encode_batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete operation + as OP_MSG. + """ + buf = _BytesIO() + + to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + return buf.getvalue(), to_send + + +if _use_c: + _encode_batched_op_msg = _cmessage._encode_batched_op_msg + + +def _batched_op_msg_compressed( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + with OP_MSG, compressed. + """ + data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) + + assert ctx.conn.compression_context is not None + request_id, msg = _compress(2013, data, ctx.conn.compression_context) + return request_id, msg, to_send + + +def _batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """OP_MSG implementation entry point.""" + buf = _BytesIO() + + # Save space for message length and request id + buf.write(_ZERO_64) + # responseTo, opCode + buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") + + to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + + # Header - request id and message length + buf.seek(4) + request_id = _randint() + buf.write(_pack_int(request_id)) + buf.seek(0) + buf.write(_pack_int(length)) + + return request_id, buf.getvalue(), to_send + + +if _use_c: + _batched_op_msg = _cmessage._batched_op_msg + + +def _do_batched_op_msg( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + using OP_MSG. + """ + command["$db"] = namespace.split(".", 1)[0] + if "writeConcern" in command: + ack = bool(command["writeConcern"].get("w", 1)) + else: + ack = True + if ctx.conn.compression_context: + return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) + return _batched_op_msg(operation, command, docs, ack, opts, ctx) + + +# End OP_MSG ----------------------------------------------------- + + +def _encode_batched_write_command( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete command.""" + buf = _BytesIO() + + to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) + return buf.getvalue(), to_send + + +if _use_c: + _encode_batched_write_command = _cmessage._encode_batched_write_command + + +def _batched_write_command_impl( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_QUERY write command.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + # Max BSON object size + 16k - 2 bytes for ending NUL bytes. + # Server guarantees there is enough room: SERVER-10643. + max_cmd_size = max_bson_size + _COMMAND_OVERHEAD + max_split_size = ctx.max_split_size + + # No options + buf.write(_ZERO_32) + # Namespace as C string + buf.write(namespace.encode("utf8")) + buf.write(_ZERO_8) + # Skip: 0, Limit: -1 + buf.write(_SKIPLIM) + + # Where to write command document length + command_start = buf.tell() + buf.write(encode(command)) + + # Start of payload + buf.seek(-1, 2) + # Work around some Jython weirdness. + buf.truncate() + try: + buf.write(_OP_MAP[operation]) + except KeyError: + raise InvalidOperation("Unknown command") from None + + # Where to write list document length + list_start = buf.tell() - 4 + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + key = str(idx).encode("utf8") + value = _dict_to_bson(doc, False, opts) + # Is there enough room to add this document? max_cmd_size accounts for + # the two trailing null bytes. + doc_too_large = len(value) > max_cmd_size + if doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size + enough_documents = idx >= max_write_batch_size + if enough_data or enough_documents: + break + buf.write(_BSONOBJ) + buf.write(key) + buf.write(_ZERO_8) + buf.write(value) + to_send.append(doc) + idx += 1 + + # Finalize the current OP_QUERY message. + # Close list and command documents + buf.write(_ZERO_16) + + # Write document lengths and request id + length = buf.tell() + buf.seek(list_start) + buf.write(_pack_int(length - list_start - 1)) + buf.seek(command_start) + buf.write(_pack_int(length - command_start)) + + return to_send, length + + +class _OpReply: + """A MongoDB OP_REPLY response message.""" + + __slots__ = ("flags", "cursor_id", "number_returned", "documents") + + UNPACK_FROM = struct.Struct(" list[bytes]: + """Check the response header from the database, without decoding BSON. + + Check the response for errors and unpack. + + Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or + OperationFailure. + + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response. + """ + if self.flags & 1: + # Shouldn't get this response if we aren't doing a getMore + if cursor_id is None: + raise ProtocolError("No cursor id for getMore operation") + + # Fake a getMore command response. OP_GET_MORE provides no + # document. + msg = "Cursor not found, cursor id: %d" % (cursor_id,) + errobj = {"ok": 0, "errmsg": msg, "code": 43} + raise CursorNotFound(msg, 43, errobj) + elif self.flags & 2: + error_object: dict = bson.BSON(self.documents).decode() + # Fake the ok field if it doesn't exist. + error_object.setdefault("ok", 0) + if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): + raise NotPrimaryError(error_object["$err"], error_object) + elif error_object.get("code") == 50: + default_msg = "operation exceeded time limit" + raise ExecutionTimeout( + error_object.get("$err", default_msg), error_object.get("code"), error_object + ) + raise OperationFailure( + "database error: %s" % error_object.get("$err"), + error_object.get("code"), + error_object, + ) + if self.documents: + return [self.documents] + return [] + + def unpack_response( + self, + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a response from the database and decode the BSON document(s). + + Check the response for errors and unpack, returning a dictionary + containing the response data. + + Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or + OperationFailure. + + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + self.raw_response(cursor_id) + if legacy_response: + return bson.decode_all(self.documents, codec_options) + return bson._decode_all_selective(self.documents, codec_options, user_fields) + + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + docs = self.unpack_response(codec_options=codec_options) + assert self.number_returned == 1 + return docs[0] + + def raw_command_response(self) -> NoReturn: + """Return the bytes of the command response.""" + # This should never be called on _OpReply. + raise NotImplementedError + + @property + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return False + + @classmethod + def unpack(cls, msg: bytes) -> _OpReply: + """Construct an _OpReply from raw bytes.""" + # PYTHON-945: ignore starting_from field. + flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) + + documents = msg[20:] + return cls(flags, cursor_id, number_returned, documents) + + +class _OpMsg: + """A MongoDB OP_MSG response message.""" + + __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") + + UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: + """ + cursor_id is ignored + user_fields is used to determine which fields must not be decoded + """ + inflated_response = _decode_selective( + RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS + ) + return [inflated_response] + + def unpack_response( + self, + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a OP_MSG command response. + + :param cursor_id: Ignored, for compatibility with _OpReply. + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + # If _OpMsg is in-use, this cannot be a legacy response. + assert not legacy_response + return bson._decode_all_selective(self.payload_document, codec_options, user_fields) + + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + return self.unpack_response(codec_options=codec_options)[0] + + def raw_command_response(self) -> bytes: + """Return the bytes of the command response.""" + return self.payload_document + + @property + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return bool(self.flags & self.MORE_TO_COME) + + @classmethod + def unpack(cls, msg: bytes) -> _OpMsg: + """Construct an _OpMsg from raw bytes.""" + flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) + if flags != 0: + if flags & cls.CHECKSUM_PRESENT: + raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") + + if flags ^ cls.MORE_TO_COME: + raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") + if first_payload_type != 0: + raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") + + if len(msg) != first_payload_size + 5: + raise ProtocolError("Unsupported OP_MSG reply: >1 section") + + payload_document = msg[5:] + return cls(flags, payload_document) + + +_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { + _OpReply.OP_CODE: _OpReply.unpack, + _OpMsg.OP_CODE: _OpMsg.unpack, +} diff --git a/venv/Lib/site-packages/pymongo/mongo_client.py b/venv/Lib/site-packages/pymongo/mongo_client.py new file mode 100644 index 00000000..f2076b08 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/mongo_client.py @@ -0,0 +1,2529 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for connecting to MongoDB. + +.. seealso:: :doc:`/examples/high_availability` for examples of connecting + to replica sets or sets of mongos servers. + +To get a :class:`~pymongo.database.Database` instance from a +:class:`MongoClient` use either dictionary-style or attribute-style +access: + +.. doctest:: + + >>> from pymongo import MongoClient + >>> c = MongoClient() + >>> c.test_database + Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True), 'test_database') + >>> c["test-database"] + Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True), 'test-database') +""" +from __future__ import annotations + +import contextlib +import os +import weakref +from collections import defaultdict +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + FrozenSet, + Generic, + Iterator, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, +) + +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry +from bson.timestamp import Timestamp +from pymongo import ( + _csot, + client_session, + common, + database, + helpers, + message, + periodic_executor, + uri_parser, +) +from pymongo.change_stream import ChangeStream, ClusterChangeStream +from pymongo.client_options import ClientOptions +from pymongo.client_session import _EmptyServerSession +from pymongo.command_cursor import CommandCursor +from pymongo.errors import ( + AutoReconnect, + BulkWriteError, + ConfigurationError, + ConnectionFailure, + InvalidOperation, + NotPrimaryError, + OperationFailure, + PyMongoError, + ServerSelectionTimeoutError, + WaitQueueTimeoutError, + WriteConcernError, +) +from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.monitoring import ConnectionClosedReason +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_selectors import writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.settings import TopologySettings +from pymongo.topology import Topology, _ErrorContext +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import ( + ClusterTime, + _Address, + _CollationIn, + _DocumentType, + _DocumentTypeArg, + _Pipeline, +) +from pymongo.uri_parser import ( + _check_options, + _handle_option_deprecations, + _handle_security_options, + _normalize_options, +) +from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern + +if TYPE_CHECKING: + import sys + from types import TracebackType + + from bson.objectid import ObjectId + from pymongo.bulk import _Bulk + from pymongo.client_session import ClientSession, _ServerSession + from pymongo.cursor import _ConnectionManager + from pymongo.database import Database + from pymongo.message import _CursorAddress, _GetMore, _Query + from pymongo.pool import Connection + from pymongo.read_concern import ReadConcern + from pymongo.response import Response + from pymongo.server import Server + from pymongo.server_selectors import Selection + + if sys.version_info[:2] >= (3, 9): + from collections.abc import Generator + else: + # Deprecated since version 3.9: collections.abc.Generator now supports []. + from typing import Generator + +T = TypeVar("T") + +_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] +_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T] + + +class MongoClient(common.BaseObject, Generic[_DocumentType]): + """ + A client-side representation of a MongoDB cluster. + + Instances can represent either a standalone MongoDB server, a replica + set, or a sharded cluster. Instances of this class are responsible for + maintaining up-to-date state of the cluster, and possibly cache + resources related to this, including background threads for monitoring, + and connection pools. + """ + + HOST = "localhost" + PORT = 27017 + # Define order to retrieve options from ClientOptions for __repr__. + # No host/port; these are retrieved from TopologySettings. + _constructor_args = ("document_class", "tz_aware", "connect") + _clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + def __init__( + self, + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, + document_class: Optional[Type[_DocumentType]] = None, + tz_aware: Optional[bool] = None, + connect: Optional[bool] = None, + type_registry: Optional[TypeRegistry] = None, + **kwargs: Any, + ) -> None: + """Client for a MongoDB instance, a replica set, or a set of mongoses. + + .. warning:: Starting in PyMongo 4.0, ``directConnection`` now has a default value of + False instead of None. + For more details, see the relevant section of the PyMongo 4.x migration guide: + :ref:`pymongo4-migration-direct-connection`. + + The client object is thread-safe and has connection-pooling built in. + If an operation fails because of a network error, + :class:`~pymongo.errors.ConnectionFailure` is raised and the client + reconnects in the background. Application code should handle this + exception (recognizing that the operation failed) and then continue to + execute. + + The `host` parameter can be a full `mongodb URI + `_, in addition to + a simple hostname. It can also be a list of hostnames but no more + than one URI. Any port specified in the host string(s) will override + the `port` parameter. For username and + passwords reserved characters like ':', '/', '+' and '@' must be + percent encoded following RFC 2396:: + + from urllib.parse import quote_plus + + uri = "mongodb://%s:%s@%s" % ( + quote_plus(user), quote_plus(password), host) + client = MongoClient(uri) + + Unix domain sockets are also supported. The socket path must be percent + encoded in the URI:: + + uri = "mongodb://%s:%s@%s" % ( + quote_plus(user), quote_plus(password), quote_plus(socket_path)) + client = MongoClient(uri) + + But not when passed as a simple hostname:: + + client = MongoClient('/tmp/mongodb-27017.sock') + + Starting with version 3.6, PyMongo supports mongodb+srv:// URIs. The + URI must include one, and only one, hostname. The hostname will be + resolved to one or more DNS `SRV records + `_ which will be used + as the seed list for connecting to the MongoDB deployment. When using + SRV URIs, the `authSource` and `replicaSet` configuration options can + be specified using `TXT records + `_. See the + `Initial DNS Seedlist Discovery spec + `_ + for more details. Note that the use of SRV URIs implicitly enables + TLS support. Pass tls=false in the URI to override. + + .. note:: MongoClient creation will block waiting for answers from + DNS when mongodb+srv:// URIs are used. + + .. note:: Starting with version 3.0 the :class:`MongoClient` + constructor no longer blocks while connecting to the server or + servers, and it no longer raises + :class:`~pymongo.errors.ConnectionFailure` if they are + unavailable, nor :class:`~pymongo.errors.ConfigurationError` + if the user's credentials are wrong. Instead, the constructor + returns immediately and launches the connection process on + background threads. You can check if the server is available + like this:: + + from pymongo.errors import ConnectionFailure + client = MongoClient() + try: + # The ping command is cheap and does not require auth. + client.admin.command('ping') + except ConnectionFailure: + print("Server not available") + + .. warning:: When using PyMongo in a multiprocessing context, please + read :ref:`multiprocessing` first. + + .. note:: Many of the following options can be passed using a MongoDB + URI or keyword parameters. If the same option is passed in a URI and + as a keyword parameter the keyword parameter takes precedence. + + :param host: hostname or IP address or Unix domain socket + path of a single mongod or mongos instance to connect to, or a + mongodb URI, or a list of hostnames (but no more than one mongodb + URI). If `host` is an IPv6 literal it must be enclosed in '[' + and ']' characters + following the RFC2732 URL syntax (e.g. '[::1]' for localhost). + Multihomed and round robin DNS addresses are **not** supported. + :param port: port number on which to connect + :param document_class: default class to use for + documents returned from queries on this client + :param tz_aware: if ``True``, + :class:`~datetime.datetime` instances returned as values + in a document by this :class:`MongoClient` will be timezone + aware (otherwise they will be naive) + :param connect: if ``True`` (the default), immediately + begin connecting to MongoDB in the background. Otherwise connect + on the first operation. + :param type_registry: instance of + :class:`~bson.codec_options.TypeRegistry` to enable encoding + and decoding of custom types. + :param datetime_conversion: Specifies how UTC datetimes should be decoded + within BSON. Valid options include 'datetime_ms' to return as a + DatetimeMS, 'datetime' to return as a datetime.datetime and + raising a ValueError for out-of-range values, 'datetime_auto' to + return DatetimeMS objects when the underlying datetime is + out-of-range and 'datetime_clamp' to clamp to the minimum and + maximum possible datetimes. Defaults to 'datetime'. See + :ref:`handling-out-of-range-datetimes` for details. + + | **Other optional parameters can be passed as keyword arguments:** + + - `directConnection` (optional): if ``True``, forces this client to + connect directly to the specified MongoDB host as a standalone. + If ``false``, the client connects to the entire replica set of + which the given MongoDB host(s) is a part. If this is ``True`` + and a mongodb+srv:// URI or a URI containing multiple seeds is + provided, an exception will be raised. + - `maxPoolSize` (optional): The maximum allowable number of + concurrent connections to each connected server. Requests to a + server will block if there are `maxPoolSize` outstanding + connections to the requested server. Defaults to 100. Can be + either 0 or None, in which case there is no limit on the number + of concurrent connections. + - `minPoolSize` (optional): The minimum required number of concurrent + connections that the pool will maintain to each connected server. + Default is 0. + - `maxIdleTimeMS` (optional): The maximum number of milliseconds that + a connection can remain idle in the pool before being removed and + replaced. Defaults to `None` (no limit). + - `maxConnecting` (optional): The maximum number of connections that + each pool can establish concurrently. Defaults to `2`. + - `timeoutMS`: (integer or None) Controls how long (in + milliseconds) the driver will wait when executing an operation + (including retry attempts) before raising a timeout error. + ``0`` or ``None`` means no timeout. + - `socketTimeoutMS`: (integer or None) Controls how long (in + milliseconds) the driver will wait for a response after sending an + ordinary (non-monitoring) database operation before concluding that + a network error has occurred. ``0`` or ``None`` means no timeout. + Defaults to ``None`` (no timeout). + - `connectTimeoutMS`: (integer or None) Controls how long (in + milliseconds) the driver will wait during server monitoring when + connecting a new socket to a server before concluding the server + is unavailable. ``0`` or ``None`` means no timeout. + Defaults to ``20000`` (20 seconds). + - `server_selector`: (callable or None) Optional, user-provided + function that augments server selection rules. The function should + accept as an argument a list of + :class:`~pymongo.server_description.ServerDescription` objects and + return a list of server descriptions that should be considered + suitable for the desired operation. + - `serverSelectionTimeoutMS`: (integer) Controls how long (in + milliseconds) the driver will wait to find an available, + appropriate server to carry out a database operation; while it is + waiting, multiple server monitoring operations may be carried out, + each controlled by `connectTimeoutMS`. Defaults to ``30000`` (30 + seconds). + - `waitQueueTimeoutMS`: (integer or None) How long (in milliseconds) + a thread will wait for a socket from the pool if the pool has no + free sockets. Defaults to ``None`` (no timeout). + - `heartbeatFrequencyMS`: (optional) The number of milliseconds + between periodic server checks, or None to accept the default + frequency of 10 seconds. + - `serverMonitoringMode`: (optional) The server monitoring mode to use. + Valid values are the strings: "auto", "stream", "poll". Defaults to "auto". + - `appname`: (string or None) The name of the application that + created this MongoClient instance. The server will log this value + upon establishing each connection. It is also recorded in the slow + query log and profile collections. + - `driver`: (pair or None) A driver implemented on top of PyMongo can + pass a :class:`~pymongo.driver_info.DriverInfo` to add its name, + version, and platform to the message printed in the server log when + establishing a connection. + - `event_listeners`: a list or tuple of event listeners. See + :mod:`~pymongo.monitoring` for details. + - `retryWrites`: (boolean) Whether supported write operations + executed within this MongoClient will be retried once after a + network error. Defaults to ``True``. + The supported write operations are: + + - :meth:`~pymongo.collection.Collection.bulk_write`, as long as + :class:`~pymongo.operations.UpdateMany` or + :class:`~pymongo.operations.DeleteMany` are not included. + - :meth:`~pymongo.collection.Collection.delete_one` + - :meth:`~pymongo.collection.Collection.insert_one` + - :meth:`~pymongo.collection.Collection.insert_many` + - :meth:`~pymongo.collection.Collection.replace_one` + - :meth:`~pymongo.collection.Collection.update_one` + - :meth:`~pymongo.collection.Collection.find_one_and_delete` + - :meth:`~pymongo.collection.Collection.find_one_and_replace` + - :meth:`~pymongo.collection.Collection.find_one_and_update` + + Unsupported write operations include, but are not limited to, + :meth:`~pymongo.collection.Collection.aggregate` using the ``$out`` + pipeline operator and any operation with an unacknowledged write + concern (e.g. {w: 0})). See + https://github.com/mongodb/specifications/blob/master/source/retryable-writes/retryable-writes.rst + - `retryReads`: (boolean) Whether supported read operations + executed within this MongoClient will be retried once after a + network error. Defaults to ``True``. + The supported read operations are: + :meth:`~pymongo.collection.Collection.find`, + :meth:`~pymongo.collection.Collection.find_one`, + :meth:`~pymongo.collection.Collection.aggregate` without ``$out``, + :meth:`~pymongo.collection.Collection.distinct`, + :meth:`~pymongo.collection.Collection.count`, + :meth:`~pymongo.collection.Collection.estimated_document_count`, + :meth:`~pymongo.collection.Collection.count_documents`, + :meth:`pymongo.collection.Collection.watch`, + :meth:`~pymongo.collection.Collection.list_indexes`, + :meth:`pymongo.database.Database.watch`, + :meth:`~pymongo.database.Database.list_collections`, + :meth:`pymongo.mongo_client.MongoClient.watch`, + and :meth:`~pymongo.mongo_client.MongoClient.list_databases`. + + Unsupported read operations include, but are not limited to + :meth:`~pymongo.database.Database.command` and any getMore + operation on a cursor. + + Enabling retryable reads makes applications more resilient to + transient errors such as network failures, database upgrades, and + replica set failovers. For an exact definition of which errors + trigger a retry, see the `retryable reads specification + `_. + + - `compressors`: Comma separated list of compressors for wire + protocol compression. The list is used to negotiate a compressor + with the server. Currently supported options are "snappy", "zlib" + and "zstd". Support for snappy requires the + `python-snappy `_ package. + zlib support requires the Python standard library zlib module. zstd + requires the `zstandard `_ + package. By default no compression is used. Compression support + must also be enabled on the server. MongoDB 3.6+ supports snappy + and zlib compression. MongoDB 4.2+ adds support for zstd. + See :ref:`network-compression-example` for details. + - `zlibCompressionLevel`: (int) The zlib compression level to use + when zlib is used as the wire protocol compressor. Supported values + are -1 through 9. -1 tells the zlib library to use its default + compression level (usually 6). 0 means no compression. 1 is best + speed. 9 is best compression. Defaults to -1. + - `uuidRepresentation`: The BSON representation to use when encoding + from and decoding to instances of :class:`~uuid.UUID`. Valid + values are the strings: "standard", "pythonLegacy", "javaLegacy", + "csharpLegacy", and "unspecified" (the default). New applications + should consider setting this to "standard" for cross language + compatibility. See :ref:`handling-uuid-data-example` for details. + - `unicode_decode_error_handler`: The error handler to apply when + a Unicode-related error occurs during BSON decoding that would + otherwise raise :exc:`UnicodeDecodeError`. Valid options include + 'strict', 'replace', 'backslashreplace', 'surrogateescape', and + 'ignore'. Defaults to 'strict'. + - `srvServiceName`: (string) The SRV service name to use for + "mongodb+srv://" URIs. Defaults to "mongodb". Use it like so:: + + MongoClient("mongodb+srv://example.com/?srvServiceName=customname") + - `srvMaxHosts`: (int) limits the number of mongos-like hosts a client will + connect to. More specifically, when a "mongodb+srv://" connection string + resolves to more than srvMaxHosts number of hosts, the client will randomly + choose an srvMaxHosts sized subset of hosts. + + + | **Write Concern options:** + | (Only set if passed. No default values.) + + - `w`: (integer or string) If this is a replica set, write operations + will block until they have been replicated to the specified number + or tagged set of servers. `w=` always includes the replica set + primary (e.g. w=3 means write to the primary and wait until + replicated to **two** secondaries). Passing w=0 **disables write + acknowledgement** and all other write concern options. + - `wTimeoutMS`: **DEPRECATED** (integer) Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. Passing wTimeoutMS=0 + will cause **write operations to wait indefinitely**. + - `journal`: If ``True`` block until write operations have been + committed to the journal. Cannot be used in combination with + `fsync`. Write operations will fail with an exception if this + option is used when the server is running without journaling. + - `fsync`: If ``True`` and the server is running without journaling, + blocks until the server has synced all data files to disk. If the + server is running with journaling, this acts the same as the `j` + option, blocking until write operations have been committed to the + journal. Cannot be used in combination with `j`. + + | **Replica set keyword arguments for connecting with a replica set + - either directly or via a mongos:** + + - `replicaSet`: (string or None) The name of the replica set to + connect to. The driver will verify that all servers it connects to + match this name. Implies that the hosts specified are a seed list + and the driver should attempt to find all members of the set. + Defaults to ``None``. + + | **Read Preference:** + + - `readPreference`: The replica set read preference for this client. + One of ``primary``, ``primaryPreferred``, ``secondary``, + ``secondaryPreferred``, or ``nearest``. Defaults to ``primary``. + - `readPreferenceTags`: Specifies a tag set as a comma-separated list + of colon-separated key-value pairs. For example ``dc:ny,rack:1``. + Defaults to ``None``. + - `maxStalenessSeconds`: (integer) The maximum estimated + length of time a replica set secondary can fall behind the primary + in replication before it will no longer be selected for operations. + Defaults to ``-1``, meaning no maximum. If maxStalenessSeconds + is set, it must be a positive integer greater than or equal to + 90 seconds. + + .. seealso:: :doc:`/examples/server_selection` + + | **Authentication:** + + - `username`: A string. + - `password`: A string. + + Although username and password must be percent-escaped in a MongoDB + URI, they must not be percent-escaped when passed as parameters. In + this example, both the space and slash special characters are passed + as-is:: + + MongoClient(username="user name", password="pass/word") + + - `authSource`: The database to authenticate on. Defaults to the + database specified in the URI, if provided, or to "admin". + - `authMechanism`: See :data:`~pymongo.auth.MECHANISMS` for options. + If no mechanism is specified, PyMongo automatically SCRAM-SHA-1 + when connected to MongoDB 3.6 and negotiates the mechanism to use + (SCRAM-SHA-1 or SCRAM-SHA-256) when connected to MongoDB 4.0+. + - `authMechanismProperties`: Used to specify authentication mechanism + specific options. To specify the service name for GSSAPI + authentication pass authMechanismProperties='SERVICE_NAME:'. + To specify the session token for MONGODB-AWS authentication pass + ``authMechanismProperties='AWS_SESSION_TOKEN:'``. + + .. seealso:: :doc:`/examples/authentication` + + | **TLS/SSL configuration:** + + - `tls`: (boolean) If ``True``, create the connection to the server + using transport layer security. Defaults to ``False``. + - `tlsInsecure`: (boolean) Specify whether TLS constraints should be + relaxed as much as possible. Setting ``tlsInsecure=True`` implies + ``tlsAllowInvalidCertificates=True`` and + ``tlsAllowInvalidHostnames=True``. Defaults to ``False``. Think + very carefully before setting this to ``True`` as it dramatically + reduces the security of TLS. + - `tlsAllowInvalidCertificates`: (boolean) If ``True``, continues + the TLS handshake regardless of the outcome of the certificate + verification process. If this is ``False``, and a value is not + provided for ``tlsCAFile``, PyMongo will attempt to load system + provided CA certificates. If the python version in use does not + support loading system CA certificates then the ``tlsCAFile`` + parameter must point to a file of CA certificates. + ``tlsAllowInvalidCertificates=False`` implies ``tls=True``. + Defaults to ``False``. Think very carefully before setting this + to ``True`` as that could make your application vulnerable to + on-path attackers. + - `tlsAllowInvalidHostnames`: (boolean) If ``True``, disables TLS + hostname verification. ``tlsAllowInvalidHostnames=False`` implies + ``tls=True``. Defaults to ``False``. Think very carefully before + setting this to ``True`` as that could make your application + vulnerable to on-path attackers. + - `tlsCAFile`: A file containing a single or a bundle of + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``tls=True``. Defaults to ``None``. + - `tlsCertificateKeyFile`: A file containing the client certificate + and private key. Implies ``tls=True``. Defaults to ``None``. + - `tlsCRLFile`: A file containing a PEM or DER formatted + certificate revocation list. Implies ``tls=True``. Defaults to + ``None``. + - `tlsCertificateKeyFilePassword`: The password or passphrase for + decrypting the private key in ``tlsCertificateKeyFile``. Only + necessary if the private key is encrypted. Defaults to ``None``. + - `tlsDisableOCSPEndpointCheck`: (boolean) If ``True``, disables + certificate revocation status checking via the OCSP responder + specified on the server certificate. + ``tlsDisableOCSPEndpointCheck=False`` implies ``tls=True``. + Defaults to ``False``. + - `ssl`: (boolean) Alias for ``tls``. + + | **Read Concern options:** + | (If not set explicitly, this will use the server default) + + - `readConcernLevel`: (string) The read concern level specifies the + level of isolation for read operations. For example, a read + operation using a read concern level of ``majority`` will only + return data that has been written to a majority of nodes. If the + level is left unspecified, the server default will be used. + + | **Client side encryption options:** + | (If not set explicitly, client side encryption will not be enabled.) + + - `auto_encryption_opts`: A + :class:`~pymongo.encryption_options.AutoEncryptionOpts` which + configures this client to automatically encrypt collection commands + and automatically decrypt results. See + :ref:`automatic-client-side-encryption` for an example. + If a :class:`MongoClient` is configured with + ``auto_encryption_opts`` and a non-None ``maxPoolSize``, a + separate internal ``MongoClient`` is created if any of the + following are true: + + - A ``key_vault_client`` is not passed to + :class:`~pymongo.encryption_options.AutoEncryptionOpts` + - ``bypass_auto_encrpytion=False`` is passed to + :class:`~pymongo.encryption_options.AutoEncryptionOpts` + + | **Stable API options:** + | (If not set explicitly, Stable API will not be enabled.) + + - `server_api`: A + :class:`~pymongo.server_api.ServerApi` which configures this + client to use Stable API. See :ref:`versioned-api-ref` for + details. + + .. seealso:: The MongoDB documentation on `connections `_. + + .. versionchanged:: 4.5 + Added the ``serverMonitoringMode`` keyword argument. + + .. versionchanged:: 4.2 + Added the ``timeoutMS`` keyword argument. + + .. versionchanged:: 4.0 + + - Removed the fsync, unlock, is_locked, database_names, and + close_cursor methods. + See the :ref:`pymongo4-migration-guide`. + - Removed the ``waitQueueMultiple`` and ``socketKeepAlive`` + keyword arguments. + - The default for `uuidRepresentation` was changed from + ``pythonLegacy`` to ``unspecified``. + - Added the ``srvServiceName``, ``maxConnecting``, and ``srvMaxHosts`` URI and + keyword arguments. + + .. versionchanged:: 3.12 + Added the ``server_api`` keyword argument. + The following keyword arguments were deprecated: + + - ``ssl_certfile`` and ``ssl_keyfile`` were deprecated in favor + of ``tlsCertificateKeyFile``. + + .. versionchanged:: 3.11 + Added the following keyword arguments and URI options: + + - ``tlsDisableOCSPEndpointCheck`` + - ``directConnection`` + + .. versionchanged:: 3.9 + Added the ``retryReads`` keyword argument and URI option. + Added the ``tlsInsecure`` keyword argument and URI option. + The following keyword arguments and URI options were deprecated: + + - ``wTimeout`` was deprecated in favor of ``wTimeoutMS``. + - ``j`` was deprecated in favor of ``journal``. + - ``ssl_cert_reqs`` was deprecated in favor of + ``tlsAllowInvalidCertificates``. + - ``ssl_match_hostname`` was deprecated in favor of + ``tlsAllowInvalidHostnames``. + - ``ssl_ca_certs`` was deprecated in favor of ``tlsCAFile``. + - ``ssl_certfile`` was deprecated in favor of + ``tlsCertificateKeyFile``. + - ``ssl_crlfile`` was deprecated in favor of ``tlsCRLFile``. + - ``ssl_pem_passphrase`` was deprecated in favor of + ``tlsCertificateKeyFilePassword``. + + .. versionchanged:: 3.9 + ``retryWrites`` now defaults to ``True``. + + .. versionchanged:: 3.8 + Added the ``server_selector`` keyword argument. + Added the ``type_registry`` keyword argument. + + .. versionchanged:: 3.7 + Added the ``driver`` keyword argument. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + Added the ``retryWrites`` keyword argument and URI option. + + .. versionchanged:: 3.5 + Add ``username`` and ``password`` options. Document the + ``authSource``, ``authMechanism``, and ``authMechanismProperties`` + options. + Deprecated the ``socketKeepAlive`` keyword argument and URI option. + ``socketKeepAlive`` now defaults to ``True``. + + .. versionchanged:: 3.0 + :class:`~pymongo.mongo_client.MongoClient` is now the one and only + client class for a standalone server, mongos, or replica set. + It includes the functionality that had been split into + :class:`~pymongo.mongo_client.MongoReplicaSetClient`: it can connect + to a replica set, discover all its members, and monitor the set for + stepdowns, elections, and reconfigs. + + The :class:`~pymongo.mongo_client.MongoClient` constructor no + longer blocks while connecting to the server or servers, and it no + longer raises :class:`~pymongo.errors.ConnectionFailure` if they + are unavailable, nor :class:`~pymongo.errors.ConfigurationError` + if the user's credentials are wrong. Instead, the constructor + returns immediately and launches the connection process on + background threads. + + Therefore the ``alive`` method is removed since it no longer + provides meaningful information; even if the client is disconnected, + it may discover a server in time to fulfill the next operation. + + In PyMongo 2.x, :class:`~pymongo.MongoClient` accepted a list of + standalone MongoDB servers and used the first it could connect to:: + + MongoClient(['host1.com:27017', 'host2.com:27017']) + + A list of multiple standalones is no longer supported; if multiple + servers are listed they must be members of the same replica set, or + mongoses in the same sharded cluster. + + The behavior for a list of mongoses is changed from "high + availability" to "load balancing". Before, the client connected to + the lowest-latency mongos in the list, and used it until a network + error prompted it to re-evaluate all mongoses' latencies and + reconnect to one of them. In PyMongo 3, the client monitors its + network latency to all the mongoses continuously, and distributes + operations evenly among those with the lowest latency. See + :ref:`mongos-load-balancing` for more information. + + The ``connect`` option is added. + + The ``start_request``, ``in_request``, and ``end_request`` methods + are removed, as well as the ``auto_start_request`` option. + + The ``copy_database`` method is removed, see the + :doc:`copy_database examples ` for alternatives. + + The :meth:`MongoClient.disconnect` method is removed; it was a + synonym for :meth:`~pymongo.MongoClient.close`. + + :class:`~pymongo.mongo_client.MongoClient` no longer returns an + instance of :class:`~pymongo.database.Database` for attribute names + with leading underscores. You must use dict-style lookups instead:: + + client['__my_database__'] + + Not:: + + client.__my_database__ + + .. versionchanged:: 4.7 + Deprecated parameter ``wTimeoutMS``, use :meth:`~pymongo.timeout`. + """ + doc_class = document_class or dict + self.__init_kwargs: dict[str, Any] = { + "host": host, + "port": port, + "document_class": doc_class, + "tz_aware": tz_aware, + "connect": connect, + "type_registry": type_registry, + **kwargs, + } + + if host is None: + host = self.HOST + if isinstance(host, str): + host = [host] + if port is None: + port = self.PORT + if not isinstance(port, int): + raise TypeError("port must be an instance of int") + + # _pool_class, _monitor_class, and _condition_class are for deep + # customization of PyMongo, e.g. Motor. + pool_class = kwargs.pop("_pool_class", None) + monitor_class = kwargs.pop("_monitor_class", None) + condition_class = kwargs.pop("_condition_class", None) + + # Parse options passed as kwargs. + keyword_opts = common._CaseInsensitiveDictionary(kwargs) + keyword_opts["document_class"] = doc_class + + seeds = set() + username = None + password = None + dbase = None + opts = common._CaseInsensitiveDictionary() + fqdn = None + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + if len([h for h in host if "/" in h]) > 1: + raise ConfigurationError("host must not contain multiple MongoDB URIs") + for entity in host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = uri_parser.parse_uri( + entity, + port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + username = res["username"] or username + password = res["password"] or password + dbase = res["database"] or dbase + opts = res["options"] + fqdn = res["fqdn"] + else: + seeds.update(uri_parser.split_hosts(entity, port)) + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + if type_registry is not None: + keyword_opts["type_registry"] = type_registry + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + connect = opts.get("connect", True) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + + # Override connection string options with kwarg options. + opts.update(keyword_opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", username) + password = opts.get("password", password) + self.__options = options = ClientOptions(username, password, dbase, opts) + + self.__default_database_name = dbase + self.__lock = _create_lock() + self.__kill_cursors_queue: list = [] + + self._event_listeners = options.pool_options._event_listeners + super().__init__( + options.codec_options, + options.read_preference, + options.write_concern, + options.read_concern, + ) + + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=options.replica_set_name, + pool_class=pool_class, + pool_options=options.pool_options, + monitor_class=monitor_class, + condition_class=condition_class, + local_threshold_ms=options.local_threshold_ms, + server_selection_timeout=options.server_selection_timeout, + server_selector=options.server_selector, + heartbeat_frequency=options.heartbeat_frequency, + fqdn=fqdn, + direct_connection=options.direct_connection, + load_balanced=options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=options.server_monitoring_mode, + ) + + self._init_background() + + if connect: + self._get_topology() + + self._encrypter = None + if self.__options.auto_encryption_opts: + from pymongo.encryption import _Encrypter + + self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts) + self._timeout = self.__options.timeout + + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + MongoClient._clients[self._topology._topology_id] = self + + def _init_background(self, old_pid: Optional[int] = None) -> None: + self._topology = Topology(self._topology_settings) + # Seed the topology with the old one's pid so we can detect clients + # that are opened before a fork and used after. + self._topology._pid = old_pid + + def target() -> bool: + client = self_ref() + if client is None: + return False # Stop the executor. + MongoClient._process_periodic_tasks(client) + return True + + executor = periodic_executor.PeriodicExecutor( + interval=common.KILL_CURSOR_FREQUENCY, + min_interval=common.MIN_HEARTBEAT_INTERVAL, + target=target, + name="pymongo_kill_cursors_thread", + ) + + # We strongly reference the executor and it weakly references us via + # this closure. When the client is freed, stop the executor soon. + self_ref: Any = weakref.ref(self, executor.close) + self._kill_cursors_executor = executor + + def _after_fork(self) -> None: + """Resets topology in a child after successfully forking.""" + self._init_background(self._topology._pid) + + def _duplicate(self, **kwargs: Any) -> MongoClient: + args = self.__init_kwargs.copy() + args.update(kwargs) + return MongoClient(**args) + + def _server_property(self, attr_name: str) -> Any: + """An attribute of the current server's description. + + If the client is not connected, this will block until a connection is + established or raise ServerSelectionTimeoutError if no server is + available. + + Not threadsafe if used multiple times in a single method, since + the server may change. In such cases, store a local reference to a + ServerDescription first, then use its properties. + """ + server = self._get_topology().select_server(writable_server_selector, _Op.TEST) + + return getattr(server.description, attr_name) + + def watch( + self, + pipeline: Optional[_Pipeline] = None, + full_document: Optional[str] = None, + resume_after: Optional[Mapping[str, Any]] = None, + max_await_time_ms: Optional[int] = None, + batch_size: Optional[int] = None, + collation: Optional[_CollationIn] = None, + start_at_operation_time: Optional[Timestamp] = None, + session: Optional[client_session.ClientSession] = None, + start_after: Optional[Mapping[str, Any]] = None, + comment: Optional[Any] = None, + full_document_before_change: Optional[str] = None, + show_expanded_events: Optional[bool] = None, + ) -> ChangeStream[_DocumentType]: + """Watch changes on this cluster. + + Performs an aggregation with an implicit initial ``$changeStream`` + stage and returns a + :class:`~pymongo.change_stream.ClusterChangeStream` cursor which + iterates over changes on all databases on this cluster. + + Introduced in MongoDB 4.0. + + .. code-block:: python + + with client.watch() as stream: + for change in stream: + print(change) + + The :class:`~pymongo.change_stream.ClusterChangeStream` iterable + blocks until the next change document is returned or an error is + raised. If the + :meth:`~pymongo.change_stream.ClusterChangeStream.next` method + encounters a network error when retrieving a batch from the server, + it will automatically attempt to recreate the cursor such that no + change events are missed. Any error encountered during the resume + attempt indicates there may be an outage and will be raised. + + .. code-block:: python + + try: + with client.watch([{"$match": {"operationType": "insert"}}]) as stream: + for insert_change in stream: + print(insert_change) + except pymongo.errors.PyMongoError: + # The ChangeStream encountered an unrecoverable error or the + # resume attempt failed to recreate the cursor. + logging.error("...") + + For a precise description of the resume process see the + `change streams specification`_. + + :param pipeline: A list of aggregation pipeline stages to + append to an initial ``$changeStream`` stage. Not all + pipeline stages are valid after a ``$changeStream`` stage, see the + MongoDB documentation on change streams for the supported stages. + :param full_document: The fullDocument to pass as an option + to the ``$changeStream`` stage. Allowed values: 'updateLookup', + 'whenAvailable', 'required'. When set to 'updateLookup', the + change notification for partial updates will include both a delta + describing the changes to the document, as well as a copy of the + entire document that was changed from some time after the change + occurred. + :param full_document_before_change: Allowed values: 'whenAvailable' + and 'required'. Change events may now result in a + 'fullDocumentBeforeChange' response field. + :param resume_after: A resume token. If provided, the + change stream will start returning changes that occur directly + after the operation specified in the resume token. A resume token + is the _id value of a change document. + :param max_await_time_ms: The maximum time in milliseconds + for the server to wait for changes before responding to a getMore + operation. + :param batch_size: The maximum number of documents to return + per batch. + :param collation: The :class:`~pymongo.collation.Collation` + to use for the aggregation. + :param start_at_operation_time: If provided, the resulting + change stream will only return changes that occurred at or after + the specified :class:`~bson.timestamp.Timestamp`. Requires + MongoDB >= 4.0. + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param start_after: The same as `resume_after` except that + `start_after` can resume notifications after an invalidate event. + This option and `resume_after` are mutually exclusive. + :param comment: A user-provided comment to attach to this + command. + :param show_expanded_events: Include expanded events such as DDL events like `dropIndexes`. + + :return: A :class:`~pymongo.change_stream.ClusterChangeStream` cursor. + + .. versionchanged:: 4.3 + Added `show_expanded_events` parameter. + + .. versionchanged:: 4.2 + Added ``full_document_before_change`` parameter. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.9 + Added the ``start_after`` parameter. + + .. versionadded:: 3.7 + + .. seealso:: The MongoDB documentation on `changeStreams `_. + + .. _change streams specification: + https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.md + """ + return ClusterChangeStream( + self.admin, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + comment, + full_document_before_change, + show_expanded_events=show_expanded_events, + ) + + @property + def topology_description(self) -> TopologyDescription: + """The description of the connected MongoDB deployment. + + >>> client.topology_description + , , ]> + >>> client.topology_description.topology_type_name + 'ReplicaSetWithPrimary' + + Note that the description is periodically updated in the background + but the returned object itself is immutable. Access this property again + to get a more recent + :class:`~pymongo.topology_description.TopologyDescription`. + + :return: An instance of + :class:`~pymongo.topology_description.TopologyDescription`. + + .. versionadded:: 4.0 + """ + return self._topology.description + + @property + def address(self) -> Optional[tuple[str, int]]: + """(host, port) of the current standalone, primary, or mongos, or None. + + Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if + the client is load-balancing among mongoses, since there is no single + address. Use :attr:`nodes` instead. + + If the client is not connected, this will block until a connection is + established or raise ServerSelectionTimeoutError if no server is + available. + + .. versionadded:: 3.0 + """ + topology_type = self._topology._description.topology_type + if ( + topology_type == TOPOLOGY_TYPE.Sharded + and len(self.topology_description.server_descriptions()) > 1 + ): + raise InvalidOperation( + 'Cannot use "address" property when load balancing among' + ' mongoses, use "nodes" instead.' + ) + if topology_type not in ( + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + TOPOLOGY_TYPE.Single, + TOPOLOGY_TYPE.LoadBalanced, + TOPOLOGY_TYPE.Sharded, + ): + return None + return self._server_property("address") + + @property + def primary(self) -> Optional[tuple[str, int]]: + """The (host, port) of the current primary of the replica set. + + Returns ``None`` if this client is not connected to a replica set, + there is no primary, or this client was created without the + `replicaSet` option. + + .. versionadded:: 3.0 + MongoClient gained this property in version 3.0. + """ + return self._topology.get_primary() # type: ignore[return-value] + + @property + def secondaries(self) -> set[_Address]: + """The secondary members known to this client. + + A sequence of (host, port) pairs. Empty if this client is not + connected to a replica set, there are no visible secondaries, or this + client was created without the `replicaSet` option. + + .. versionadded:: 3.0 + MongoClient gained this property in version 3.0. + """ + return self._topology.get_secondaries() + + @property + def arbiters(self) -> set[_Address]: + """Arbiters in the replica set. + + A sequence of (host, port) pairs. Empty if this client is not + connected to a replica set, there are no arbiters, or this client was + created without the `replicaSet` option. + """ + return self._topology.get_arbiters() + + @property + def is_primary(self) -> bool: + """If this client is connected to a server that can accept writes. + + True if the current server is a standalone, mongos, or the primary of + a replica set. If the client is not connected, this will block until a + connection is established or raise ServerSelectionTimeoutError if no + server is available. + """ + return self._server_property("is_writable") + + @property + def is_mongos(self) -> bool: + """If this client is connected to mongos. If the client is not + connected, this will block until a connection is established or raise + ServerSelectionTimeoutError if no server is available. + """ + return self._server_property("server_type") == SERVER_TYPE.Mongos + + @property + def nodes(self) -> FrozenSet[_Address]: + """Set of all currently connected servers. + + .. warning:: When connected to a replica set the value of :attr:`nodes` + can change over time as :class:`MongoClient`'s view of the replica + set changes. :attr:`nodes` can also be an empty set when + :class:`MongoClient` is first instantiated and hasn't yet connected + to any servers, or a network partition causes it to lose connection + to all servers. + """ + description = self._topology.description + return frozenset(s.address for s in description.known_servers) + + @property + def options(self) -> ClientOptions: + """The configuration options for this client. + + :return: An instance of :class:`~pymongo.client_options.ClientOptions`. + + .. versionadded:: 4.0 + """ + return self.__options + + def _end_sessions(self, session_ids: list[_ServerSession]) -> None: + """Send endSessions command(s) with the given session ids.""" + try: + # Use Connection.command directly to avoid implicitly creating + # another session. + with self._conn_for_reads( + ReadPreference.PRIMARY_PREFERRED, None, operation=_Op.END_SESSIONS + ) as ( + conn, + read_pref, + ): + if not conn.supports_sessions: + return + + for i in range(0, len(session_ids), common._MAX_END_SESSIONS): + spec = {"endSessions": session_ids[i : i + common._MAX_END_SESSIONS]} + conn.command("admin", spec, read_preference=read_pref, client=self) + except PyMongoError: + # Drivers MUST ignore any errors returned by the endSessions + # command. + pass + + def close(self) -> None: + """Cleanup client resources and disconnect from MongoDB. + + End all server sessions created by this client by sending one or more + endSessions commands. + + Close all sockets in the connection pools and stop the monitor threads. + + .. versionchanged:: 4.0 + Once closed, the client cannot be used again and any attempt will + raise :exc:`~pymongo.errors.InvalidOperation`. + + .. versionchanged:: 3.6 + End all server sessions created by this client. + """ + session_ids = self._topology.pop_all_sessions() + if session_ids: + self._end_sessions(session_ids) + # Stop the periodic task thread and then send pending killCursor + # requests before closing the topology. + self._kill_cursors_executor.close() + self._process_kill_cursors() + self._topology.close() + if self._encrypter: + # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. + self._encrypter.close() + + def _get_topology(self) -> Topology: + """Get the internal :class:`~pymongo.topology.Topology` object. + + If this client was created with "connect=False", calling _get_topology + launches the connection process in the background. + """ + self._topology.open() + with self.__lock: + self._kill_cursors_executor.open() + return self._topology + + @contextlib.contextmanager + def _checkout(self, server: Server, session: Optional[ClientSession]) -> Iterator[Connection]: + in_txn = session and session.in_transaction + with _MongoClientErrorHandler(self, server, session) as err_handler: + # Reuse the pinned connection, if it exists. + if in_txn and session and session._pinned_connection: + err_handler.contribute_socket(session._pinned_connection) + yield session._pinned_connection + return + with server.checkout(handler=err_handler) as conn: + # Pin this session to the selected server or connection. + if ( + in_txn + and session + and server.description.server_type + in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ) + ): + session._pin(server, conn) + err_handler.contribute_socket(conn) + if ( + self._encrypter + and not self._encrypter._bypass_auto_encryption + and conn.max_wire_version < 8 + ): + raise ConfigurationError( + "Auto-encryption requires a minimum MongoDB version of 4.2" + ) + yield conn + + def _select_server( + self, + server_selector: Callable[[Selection], Selection], + session: Optional[ClientSession], + operation: str, + address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, + operation_id: Optional[int] = None, + ) -> Server: + """Select a server to run an operation on this client. + + :param server_selector: The server selector to use if the session is + not pinned and no address is given. + :param session: The ClientSession for the next operation, or None. May + be pinned to a mongos server address. + :param operation: The name of the operation that the server is being selected for. + :param address: Address when sending a message + to a specific server, used for getMore. + """ + try: + topology = self._get_topology() + if session and not session.in_transaction: + session._transaction.reset() + if not address and session: + address = session._pinned_address + if address: + # We're running a getMore or this session is pinned to a mongos. + server = topology.select_server_by_address( + address, operation, operation_id=operation_id + ) + if not server: + raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031 + else: + server = topology.select_server( + server_selector, + operation, + deprioritized_servers=deprioritized_servers, + operation_id=operation_id, + ) + return server + except PyMongoError as exc: + # Server selection errors in a transaction are transient. + if session and session.in_transaction: + exc._add_error_label("TransientTransactionError") + session._unpin() + raise + + def _conn_for_writes( + self, session: Optional[ClientSession], operation: str + ) -> ContextManager[Connection]: + server = self._select_server(writable_server_selector, session, operation) + return self._checkout(server, session) + + @contextlib.contextmanager + def _conn_from_server( + self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] + ) -> Iterator[tuple[Connection, _ServerMode]]: + assert read_preference is not None, "read_preference must not be None" + # Get a connection for a server matching the read preference, and yield + # conn with the effective read preference. The Server Selection + # Spec says not to send any $readPreference to standalones and to + # always send primaryPreferred when directly connected to a repl set + # member. + # Thread safe: if the type is single it cannot change. + topology = self._get_topology() + single = topology.description.topology_type == TOPOLOGY_TYPE.Single + + with self._checkout(server, session) as conn: + if single: + if conn.is_repl and not (session and session.in_transaction): + # Use primary preferred to ensure any repl set member + # can handle the request. + read_preference = ReadPreference.PRIMARY_PREFERRED + elif conn.is_standalone: + # Don't send read preference to standalones. + read_preference = ReadPreference.PRIMARY + yield conn, read_preference + + def _conn_for_reads( + self, + read_preference: _ServerMode, + session: Optional[ClientSession], + operation: str, + ) -> ContextManager[tuple[Connection, _ServerMode]]: + assert read_preference is not None, "read_preference must not be None" + _ = self._get_topology() + server = self._select_server(read_preference, session, operation) + return self._conn_from_server(read_preference, server, session) + + def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: + return self.__options.load_balanced and not (session and session.in_transaction) + + @_csot.apply + def _run_operation( + self, + operation: Union[_Query, _GetMore], + unpack_res: Callable, + address: Optional[_Address] = None, + ) -> Response: + """Run a _Query/_GetMore operation and return a Response. + + :param operation: a _Query or _GetMore object. + :param unpack_res: A callable that decodes the wire protocol response. + :param address: Optional address when sending a message + to a specific server, used for getMore. + """ + if operation.conn_mgr: + server = self._select_server( + operation.read_preference, + operation.session, + operation.name, + address=address, + ) + + with operation.conn_mgr.lock: + with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + err_handler.contribute_socket(operation.conn_mgr.conn) + return server.run_operation( + operation.conn_mgr.conn, + operation, + operation.read_preference, + self._event_listeners, + unpack_res, + self, + ) + + def _cmd( + _session: Optional[ClientSession], + server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> Response: + operation.reset() # Reset op in case of retry. + return server.run_operation( + conn, + operation, + read_preference, + self._event_listeners, + unpack_res, + self, + ) + + return self._retryable_read( + _cmd, + operation.read_preference, + operation.session, + address=address, + retryable=isinstance(operation, message._Query), + operation=operation.name, + ) + + def _retry_with_session( + self, + retryable: bool, + func: _WriteCall[T], + session: Optional[ClientSession], + bulk: Optional[_Bulk], + operation: str, + operation_id: Optional[int] = None, + ) -> T: + """Execute an operation with at most one consecutive retries + + Returns func()'s return value on success. On error retries the same + command. + + Re-raises any exception thrown by func(). + """ + # Ensure that the options supports retry_writes and there is a valid session not in + # transaction, otherwise, we will not support retry behavior for this txn. + retryable = bool( + retryable and self.options.retry_writes and session and not session.in_transaction + ) + return self._retry_internal( + func=func, + session=session, + bulk=bulk, + operation=operation, + retryable=retryable, + operation_id=operation_id, + ) + + @_csot.apply + def _retry_internal( + self, + func: _WriteCall[T] | _ReadCall[T], + session: Optional[ClientSession], + bulk: Optional[_Bulk], + operation: str, + is_read: bool = False, + address: Optional[_Address] = None, + read_pref: Optional[_ServerMode] = None, + retryable: bool = False, + operation_id: Optional[int] = None, + ) -> T: + """Internal retryable helper for all client transactions. + + :param func: Callback function we want to retry + :param session: Client Session on which the transaction should occur + :param bulk: Abstraction to handle bulk write operations + :param operation: The name of the operation that the server is being selected for + :param is_read: If this is an exclusive read transaction, defaults to False + :param address: Server Address, defaults to None + :param read_pref: Topology of read operation, defaults to None + :param retryable: If the operation should be retried once, defaults to None + + :return: Output of the calling func() + """ + return _ClientConnectionRetryable( + mongo_client=self, + func=func, + bulk=bulk, + operation=operation, + is_read=is_read, + session=session, + read_pref=read_pref, + address=address, + retryable=retryable, + operation_id=operation_id, + ).run() + + def _retryable_read( + self, + func: _ReadCall[T], + read_pref: _ServerMode, + session: Optional[ClientSession], + operation: str, + address: Optional[_Address] = None, + retryable: bool = True, + operation_id: Optional[int] = None, + ) -> T: + """Execute an operation with consecutive retries if possible + + Returns func()'s return value on success. On error retries the same + command. + + Re-raises any exception thrown by func(). + + :param func: Read call we want to execute + :param read_pref: Desired topology of read operation + :param session: Client session we should use to execute operation + :param operation: The name of the operation that the server is being selected for + :param address: Optional address when sending a message, defaults to None + :param retryable: if we should attempt retries + (may not always be supported even if supplied), defaults to False + """ + + # Ensure that the client supports retrying on reads and there is no session in + # transaction, otherwise, we will not support retry behavior for this call. + retryable = bool( + retryable and self.options.retry_reads and not (session and session.in_transaction) + ) + return self._retry_internal( + func, + session, + None, + operation, + is_read=True, + address=address, + read_pref=read_pref, + retryable=retryable, + operation_id=operation_id, + ) + + def _retryable_write( + self, + retryable: bool, + func: _WriteCall[T], + session: Optional[ClientSession], + operation: str, + bulk: Optional[_Bulk] = None, + operation_id: Optional[int] = None, + ) -> T: + """Execute an operation with consecutive retries if possible + + Returns func()'s return value on success. On error retries the same + command. + + Re-raises any exception thrown by func(). + + :param retryable: if we should attempt retries (may not always be supported) + :param func: write call we want to execute during a session + :param session: Client session we will use to execute write operation + :param operation: The name of the operation that the server is being selected for + :param bulk: bulk abstraction to execute operations in bulk, defaults to None + """ + with self._tmp_session(session) as s: + return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, self.__class__): + return self._topology == other._topology + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + return hash(self._topology) + + def _repr_helper(self) -> str: + def option_repr(option: str, value: Any) -> str: + """Fix options whose __repr__ isn't usable in a constructor.""" + if option == "document_class": + if value is dict: + return "document_class=dict" + else: + return f"document_class={value.__module__}.{value.__name__}" + if option in common.TIMEOUT_OPTIONS and value is not None: + return f"{option}={int(value * 1000)}" + + return f"{option}={value!r}" + + # Host first... + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] + ] + # ... then everything in self._constructor_args... + options.extend( + option_repr(key, self.__options._options[key]) for key in self._constructor_args + ) + # ... then everything else. + options.extend( + option_repr(key, self.__options._options[key]) + for key in self.__options._options + if key not in set(self._constructor_args) and key != "username" and key != "password" + ) + return ", ".join(options) + + def __repr__(self) -> str: + return f"MongoClient({self._repr_helper()})" + + def __getattr__(self, name: str) -> database.Database[_DocumentType]: + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :param name: the name of the database to get + """ + if name.startswith("_"): + raise AttributeError( + f"MongoClient has no attribute {name!r}. To access the {name}" + f" database, use client[{name!r}]." + ) + return self.__getitem__(name) + + def __getitem__(self, name: str) -> database.Database[_DocumentType]: + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :param name: the name of the database to get + """ + return database.Database(self, name) + + def _cleanup_cursor( + self, + locks_allowed: bool, + cursor_id: int, + address: Optional[_CursorAddress], + conn_mgr: _ConnectionManager, + session: Optional[ClientSession], + explicit_session: bool, + ) -> None: + """Cleanup a cursor from cursor.close() or __del__. + + This method handles cleanup for Cursors/CommandCursors including any + pinned connection or implicit session attached at the time the cursor + was closed or garbage collected. + + :param locks_allowed: True if we are allowed to acquire locks. + :param cursor_id: The cursor id which may be 0. + :param address: The _CursorAddress. + :param conn_mgr: The _ConnectionManager for the pinned connection or None. + :param session: The cursor's session. + :param explicit_session: True if the session was passed explicitly. + """ + if locks_allowed: + if cursor_id: + if conn_mgr and conn_mgr.more_to_come: + # If this is an exhaust cursor and we haven't completely + # exhausted the result set we *must* close the socket + # to stop the server from sending more data. + assert conn_mgr.conn is not None + conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) + else: + self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) + if conn_mgr: + conn_mgr.close() + else: + # The cursor will be closed later in a different session. + if cursor_id or conn_mgr: + self._close_cursor_soon(cursor_id, address, conn_mgr) + if session and not explicit_session: + session._end_session(lock=locks_allowed) + + def _close_cursor_soon( + self, + cursor_id: int, + address: Optional[_CursorAddress], + conn_mgr: Optional[_ConnectionManager] = None, + ) -> None: + """Request that a cursor and/or connection be cleaned up soon.""" + self.__kill_cursors_queue.append((address, cursor_id, conn_mgr)) + + def _close_cursor_now( + self, + cursor_id: int, + address: Optional[_CursorAddress], + session: Optional[ClientSession] = None, + conn_mgr: Optional[_ConnectionManager] = None, + ) -> None: + """Send a kill cursors message with the given id. + + The cursor is closed synchronously on the current thread. + """ + if not isinstance(cursor_id, int): + raise TypeError("cursor_id must be an instance of int") + + try: + if conn_mgr: + with conn_mgr.lock: + # Cursor is pinned to LB outside of a transaction. + assert address is not None + assert conn_mgr.conn is not None + self._kill_cursor_impl([cursor_id], address, session, conn_mgr.conn) + else: + self._kill_cursors([cursor_id], address, self._get_topology(), session) + except PyMongoError: + # Make another attempt to kill the cursor later. + self._close_cursor_soon(cursor_id, address) + + def _kill_cursors( + self, + cursor_ids: Sequence[int], + address: Optional[_CursorAddress], + topology: Topology, + session: Optional[ClientSession], + ) -> None: + """Send a kill cursors message with the given ids.""" + if address: + # address could be a tuple or _CursorAddress, but + # select_server_by_address needs (host, port). + server = topology.select_server_by_address(tuple(address), _Op.KILL_CURSORS) # type: ignore[arg-type] + else: + # Application called close_cursor() with no address. + server = topology.select_server(writable_server_selector, _Op.KILL_CURSORS) + + with self._checkout(server, session) as conn: + assert address is not None + self._kill_cursor_impl(cursor_ids, address, session, conn) + + def _kill_cursor_impl( + self, + cursor_ids: Sequence[int], + address: _CursorAddress, + session: Optional[ClientSession], + conn: Connection, + ) -> None: + namespace = address.namespace + db, coll = namespace.split(".", 1) + spec = {"killCursors": coll, "cursors": cursor_ids} + conn.command(db, spec, session=session, client=self) + + def _process_kill_cursors(self) -> None: + """Process any pending kill cursors requests.""" + address_to_cursor_ids = defaultdict(list) + pinned_cursors = [] + + # Other threads or the GC may append to the queue concurrently. + while True: + try: + address, cursor_id, conn_mgr = self.__kill_cursors_queue.pop() + except IndexError: + break + + if conn_mgr: + pinned_cursors.append((address, cursor_id, conn_mgr)) + else: + address_to_cursor_ids[address].append(cursor_id) + + for address, cursor_id, conn_mgr in pinned_cursors: + try: + self._cleanup_cursor(True, cursor_id, address, conn_mgr, None, False) + except Exception as exc: + if isinstance(exc, InvalidOperation) and self._topology._closed: + # Raise the exception when client is closed so that it + # can be caught in _process_periodic_tasks + raise + else: + helpers._handle_exception() + + # Don't re-open topology if it's closed and there's no pending cursors. + if address_to_cursor_ids: + topology = self._get_topology() + for address, cursor_ids in address_to_cursor_ids.items(): + try: + self._kill_cursors(cursor_ids, address, topology, session=None) + except Exception as exc: + if isinstance(exc, InvalidOperation) and self._topology._closed: + raise + else: + helpers._handle_exception() + + # This method is run periodically by a background thread. + def _process_periodic_tasks(self) -> None: + """Process any pending kill cursors requests and + maintain connection pool parameters. + """ + try: + self._process_kill_cursors() + self._topology.update_pool() + except Exception as exc: + if isinstance(exc, InvalidOperation) and self._topology._closed: + return + else: + helpers._handle_exception() + + def __start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: + server_session = _EmptyServerSession() + opts = client_session.SessionOptions(**kwargs) + return client_session.ClientSession(self, server_session, opts, implicit) + + def start_session( + self, + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional[client_session.TransactionOptions] = None, + snapshot: Optional[bool] = False, + ) -> client_session.ClientSession: + """Start a logical session. + + This method takes the same parameters as + :class:`~pymongo.client_session.SessionOptions`. See the + :mod:`~pymongo.client_session` module for details and examples. + + A :class:`~pymongo.client_session.ClientSession` may only be used with + the MongoClient that started it. :class:`ClientSession` instances are + **not thread-safe or fork-safe**. They can only be used by one thread + or process at a time. A single :class:`ClientSession` cannot be used + to run multiple operations concurrently. + + :return: An instance of :class:`~pymongo.client_session.ClientSession`. + + .. versionadded:: 3.6 + """ + return self.__start_session( + False, + causal_consistency=causal_consistency, + default_transaction_options=default_transaction_options, + snapshot=snapshot, + ) + + def _return_server_session( + self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool + ) -> None: + """Internal: return a _ServerSession to the pool.""" + if isinstance(server_session, _EmptyServerSession): + return None + return self._topology.return_server_session(server_session, lock) + + def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: + """If provided session is None, lend a temporary session.""" + if session: + return session + + try: + # Don't make implicit sessions causally consistent. Applications + # should always opt-in. + return self.__start_session(True, causal_consistency=False) + except (ConfigurationError, InvalidOperation): + # Sessions not supported. + return None + + @contextlib.contextmanager + def _tmp_session( + self, session: Optional[client_session.ClientSession], close: bool = True + ) -> Generator[Optional[client_session.ClientSession], None, None]: + """If provided session is None, lend a temporary session.""" + if session is not None: + if not isinstance(session, client_session.ClientSession): + raise ValueError("'session' argument must be a ClientSession or None.") + # Don't call end_session. + yield session + return + + s = self._ensure_session(session) + if s: + try: + yield s + except Exception as exc: + if isinstance(exc, ConnectionFailure): + s._server_session.mark_dirty() + + # Always call end_session on error. + s.end_session() + raise + finally: + # Call end_session when we exit this scope. + if close: + s.end_session() + else: + yield None + + def _send_cluster_time( + self, command: MutableMapping[str, Any], session: Optional[ClientSession] + ) -> None: + topology_time = self._topology.max_cluster_time() + session_time = session.cluster_time if session else None + if topology_time and session_time: + if topology_time["clusterTime"] > session_time["clusterTime"]: + cluster_time: Optional[ClusterTime] = topology_time + else: + cluster_time = session_time + else: + cluster_time = topology_time or session_time + if cluster_time: + command["$clusterTime"] = cluster_time + + def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSession]) -> None: + self._topology.receive_cluster_time(reply.get("$clusterTime")) + if session is not None: + session._process_response(reply) + + def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]: + """Get information about the MongoDB server we're connected to. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + """ + return cast( + dict, + self.admin.command( + "buildinfo", read_preference=ReadPreference.PRIMARY, session=session + ), + ) + + def list_databases( + self, + session: Optional[client_session.ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[dict[str, Any]]: + """Get a cursor over the databases of the connected server. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + :param kwargs: Optional parameters of the + `listDatabases command + `_ + can be passed as keyword arguments to this method. The supported + options differ by server version. + + + :return: An instance of :class:`~pymongo.command_cursor.CommandCursor`. + + .. versionadded:: 3.6 + """ + cmd = {"listDatabases": 1} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + admin = self._database_default_options("admin") + res = admin._retryable_read_command(cmd, session=session, operation=_Op.LIST_DATABASES) + # listDatabases doesn't return a cursor (yet). Fake one. + cursor = { + "id": 0, + "firstBatch": res["databases"], + "ns": "admin.$cmd", + } + return CommandCursor(admin["$cmd"], cursor, None, comment=comment) + + def list_database_names( + self, + session: Optional[client_session.ClientSession] = None, + comment: Optional[Any] = None, + ) -> list[str]: + """Get a list of the names of all databases on the connected server. + + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionadded:: 3.6 + """ + return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)] + + @_csot.apply + def drop_database( + self, + name_or_database: Union[str, database.Database[_DocumentTypeArg]], + session: Optional[client_session.ClientSession] = None, + comment: Optional[Any] = None, + ) -> None: + """Drop a database. + + Raises :class:`TypeError` if `name_or_database` is not an instance of + :class:`str` or :class:`~pymongo.database.Database`. + + :param name_or_database: the name of a database to drop, or a + :class:`~pymongo.database.Database` instance representing the + database to drop + :param session: a + :class:`~pymongo.client_session.ClientSession`. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.6 + Added ``session`` parameter. + + .. note:: The :attr:`~pymongo.mongo_client.MongoClient.write_concern` of + this client is automatically applied to this operation. + + .. versionchanged:: 3.4 + Apply this client's write concern automatically to this operation + when connected to MongoDB >= 3.4. + + """ + name = name_or_database + if isinstance(name, database.Database): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_database must be an instance of str or a Database") + + with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: + self[name]._command( + conn, + {"dropDatabase": 1, "comment": comment}, + read_preference=ReadPreference.PRIMARY, + write_concern=self._write_concern_for(session), + parse_write_concern_error=True, + session=session, + ) + + def get_default_database( + self, + default: Optional[str] = None, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> database.Database[_DocumentType]: + """Get the database named in the MongoDB connection URI. + + >>> uri = 'mongodb://host/my_database' + >>> client = MongoClient(uri) + >>> db = client.get_default_database() + >>> assert db.name == 'my_database' + >>> db = client.get_database() + >>> assert db.name == 'my_database' + + Useful in scripts where you want to choose which database to use + based only on the URI in a configuration file. + + :param default: the database name to use if no database name + was provided in the URI. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`MongoClient` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`MongoClient` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`MongoClient` is + used. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.8 + Undeprecated. Added the ``default``, ``codec_options``, + ``read_preference``, ``write_concern`` and ``read_concern`` + parameters. + + .. versionchanged:: 3.5 + Deprecated, use :meth:`get_database` instead. + """ + if self.__default_database_name is None and default is None: + raise ConfigurationError("No default database name defined or provided.") + + name = cast(str, self.__default_database_name or default) + return database.Database( + self, name, codec_options, read_preference, write_concern, read_concern + ) + + def get_database( + self, + name: Optional[str] = None, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> database.Database[_DocumentType]: + """Get a :class:`~pymongo.database.Database` with the given name and + options. + + Useful for creating a :class:`~pymongo.database.Database` with + different codec options, read preference, and/or write concern from + this :class:`MongoClient`. + + >>> client.read_preference + Primary() + >>> db1 = client.test + >>> db1.read_preference + Primary() + >>> from pymongo import ReadPreference + >>> db2 = client.get_database( + ... 'test', read_preference=ReadPreference.SECONDARY) + >>> db2.read_preference + Secondary(tag_sets=None) + + :param name: The name of the database - a string. If ``None`` + (the default) the database named in the MongoDB connection URI is + returned. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`MongoClient` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`MongoClient` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`MongoClient` is + used. + + .. versionchanged:: 3.5 + The `name` parameter is now optional, defaulting to the database + named in the MongoDB connection URI. + """ + if name is None: + if self.__default_database_name is None: + raise ConfigurationError("No default database defined") + name = self.__default_database_name + + return database.Database( + self, name, codec_options, read_preference, write_concern, read_concern + ) + + def _database_default_options(self, name: str) -> Database: + """Get a Database instance with the default settings.""" + return self.get_database( + name, + codec_options=DEFAULT_CODEC_OPTIONS, + read_preference=ReadPreference.PRIMARY, + write_concern=DEFAULT_WRITE_CONCERN, + ) + + def __enter__(self) -> MongoClient[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + # See PYTHON-3084. + __iter__ = None + + def __next__(self) -> NoReturn: + raise TypeError("'MongoClient' object is not iterable") + + next = __next__ + + +def _retryable_error_doc(exc: PyMongoError) -> Optional[Mapping[str, Any]]: + """Return the server response from PyMongo exception or None.""" + if isinstance(exc, BulkWriteError): + # Check the last writeConcernError to determine if this + # BulkWriteError is retryable. + wces = exc.details["writeConcernErrors"] + return wces[-1] if wces else None + if isinstance(exc, (NotPrimaryError, OperationFailure)): + return cast(Mapping[str, Any], exc.details) + return None + + +def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mongos: bool) -> None: + doc = _retryable_error_doc(exc) + if doc: + code = doc.get("code", 0) + # retryWrites on MMAPv1 should raise an actionable error. + if code == 20 and str(exc).startswith("Transaction numbers"): + errmsg = ( + "This MongoDB deployment does not support " + "retryable writes. Please add retryWrites=false " + "to your connection string." + ) + raise OperationFailure(errmsg, code, exc.details) # type: ignore[attr-defined] + if max_wire_version >= 9: + # In MongoDB 4.4+, the server reports the error labels. + for label in doc.get("errorLabels", []): + exc._add_error_label(label) + else: + # Do not consult writeConcernError for pre-4.4 mongos. + if isinstance(exc, WriteConcernError) and is_mongos: + pass + elif code in helpers._RETRYABLE_ERROR_CODES: + exc._add_error_label("RetryableWriteError") + + # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is + # handled above. + if isinstance(exc, ConnectionFailure) and not isinstance( + exc, (NotPrimaryError, WaitQueueTimeoutError) + ): + exc._add_error_label("RetryableWriteError") + + +class _MongoClientErrorHandler: + """Handle errors raised when executing an operation.""" + + __slots__ = ( + "client", + "server_address", + "session", + "max_wire_version", + "sock_generation", + "completed_handshake", + "service_id", + "handled", + ) + + def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): + self.client = client + self.server_address = server.description.address + self.session = session + self.max_wire_version = common.MIN_WIRE_VERSION + # XXX: When get_socket fails, this generation could be out of date: + # "Note that when a network error occurs before the handshake + # completes then the error's generation number is the generation + # of the pool at the time the connection attempt was started." + self.sock_generation = server.pool.gen.get_overall() + self.completed_handshake = False + self.service_id: Optional[ObjectId] = None + self.handled = False + + def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: + """Provide socket information to the error handler.""" + self.max_wire_version = conn.max_wire_version + self.sock_generation = conn.generation + self.service_id = conn.service_id + self.completed_handshake = completed_handshake + + def handle( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException] + ) -> None: + if self.handled or exc_val is None: + return + self.handled = True + if self.session: + if isinstance(exc_val, ConnectionFailure): + if self.session.in_transaction: + exc_val._add_error_label("TransientTransactionError") + self.session._server_session.mark_dirty() + + if isinstance(exc_val, PyMongoError): + if exc_val.has_error_label("TransientTransactionError") or exc_val.has_error_label( + "RetryableWriteError" + ): + self.session._unpin() + err_ctx = _ErrorContext( + exc_val, + self.max_wire_version, + self.sock_generation, + self.completed_handshake, + self.service_id, + ) + self.client._topology.handle_error(self.server_address, err_ctx) + + def __enter__(self) -> _MongoClientErrorHandler: + return self + + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_val: Optional[Exception], + exc_tb: Optional[TracebackType], + ) -> None: + return self.handle(exc_type, exc_val) + + +class _ClientConnectionRetryable(Generic[T]): + """Responsible for executing retryable connections on read or write operations""" + + def __init__( + self, + mongo_client: MongoClient, + func: _WriteCall[T] | _ReadCall[T], + bulk: Optional[_Bulk], + operation: str, + is_read: bool = False, + session: Optional[ClientSession] = None, + read_pref: Optional[_ServerMode] = None, + address: Optional[_Address] = None, + retryable: bool = False, + operation_id: Optional[int] = None, + ): + self._last_error: Optional[Exception] = None + self._retrying = False + self._multiple_retries = _csot.get_timeout() is not None + self._client = mongo_client + + self._func = func + self._bulk = bulk + self._session = session + self._is_read = is_read + self._retryable = retryable + self._read_pref = read_pref + self._server_selector: Callable[[Selection], Selection] = ( + read_pref if is_read else writable_server_selector # type: ignore + ) + self._address = address + self._server: Server = None # type: ignore + self._deprioritized_servers: list[Server] = [] + self._operation = operation + self._operation_id = operation_id + + def run(self) -> T: + """Runs the supplied func() and attempts a retry + + :raises: self._last_error: Last exception raised + + :return: Result of the func() call + """ + # Increment the transaction id up front to ensure any retry attempt + # will use the proper txnNumber, even if server or socket selection + # fails before the command can be sent. + if self._is_session_state_retryable() and self._retryable and not self._is_read: + self._session._start_retryable_write() # type: ignore + if self._bulk: + self._bulk.started_retryable_write = True + + while True: + self._check_last_error(check_csot=True) + try: + return self._read() if self._is_read else self._write() + except ServerSelectionTimeoutError: + # The application may think the write was never attempted + # if we raise ServerSelectionTimeoutError on the retry + # attempt. Raise the original exception instead. + self._check_last_error() + # A ServerSelectionTimeoutError error indicates that there may + # be a persistent outage. Attempting to retry in this case will + # most likely be a waste of time. + raise + except PyMongoError as exc: + # Execute specialized catch on read + if self._is_read: + if isinstance(exc, (ConnectionFailure, OperationFailure)): + # ConnectionFailures do not supply a code property + exc_code = getattr(exc, "code", None) + if self._is_not_eligible_for_retry() or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers._RETRYABLE_ERROR_CODES + ): + raise + self._retrying = True + self._last_error = exc + else: + raise + + # Specialized catch on write operation + if not self._is_read: + if not self._retryable: + raise + retryable_write_error_exc = exc.has_error_label("RetryableWriteError") + if retryable_write_error_exc: + assert self._session + self._session._unpin() + if not retryable_write_error_exc or self._is_not_eligible_for_retry(): + if exc.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if self._bulk: + self._bulk.retrying = True + else: + self._retrying = True + if not exc.has_error_label("NoWritesPerformed"): + self._last_error = exc + if self._last_error is None: + self._last_error = exc + + if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: + self._deprioritized_servers.append(self._server) + + def _is_not_eligible_for_retry(self) -> bool: + """Checks if the exchange is not eligible for retry""" + return not self._retryable or (self._is_retrying() and not self._multiple_retries) + + def _is_retrying(self) -> bool: + """Checks if the exchange is currently undergoing a retry""" + return self._bulk.retrying if self._bulk else self._retrying + + def _is_session_state_retryable(self) -> bool: + """Checks if provided session is eligible for retry + + reads: Make sure there is no ongoing transaction (if provided a session) + writes: Make sure there is a session without an active transaction + """ + if self._is_read: + return not (self._session and self._session.in_transaction) + return bool(self._session and not self._session.in_transaction) + + def _check_last_error(self, check_csot: bool = False) -> None: + """Checks if the ongoing client exchange experienced a exception previously. + If so, raise last error + + :param check_csot: Checks CSOT to ensure we are retrying with time remaining defaults to False + """ + if self._is_retrying(): + remaining = _csot.remaining() + if not check_csot or (remaining is not None and remaining <= 0): + assert self._last_error is not None + raise self._last_error + + def _get_server(self) -> Server: + """Retrieves a server object based on provided object context + + :return: Abstraction to connect to server + """ + return self._client._select_server( + self._server_selector, + self._session, + self._operation, + address=self._address, + deprioritized_servers=self._deprioritized_servers, + operation_id=self._operation_id, + ) + + def _write(self) -> T: + """Wrapper method for write-type retryable client executions + + :return: Output for func()'s call + """ + try: + max_wire_version = 0 + is_mongos = False + self._server = self._get_server() + with self._client._checkout(self._server, self._session) as conn: + max_wire_version = conn.max_wire_version + sessions_supported = ( + self._session + and self._server.description.retryable_writes_supported + and conn.supports_sessions + ) + is_mongos = conn.is_mongos + if not sessions_supported: + # A retry is not possible because this server does + # not support sessions raise the last error. + self._check_last_error() + self._retryable = False + return self._func(self._session, conn, self._retryable) # type: ignore + except PyMongoError as exc: + if not self._retryable: + raise + # Add the RetryableWriteError label, if applicable. + _add_retryable_write_error(exc, max_wire_version, is_mongos) + raise + + def _read(self) -> T: + """Wrapper method for read-type retryable client executions + + :return: Output for func()'s call + """ + self._server = self._get_server() + assert self._read_pref is not None, "Read Preference required on read calls" + with self._client._conn_from_server(self._read_pref, self._server, self._session) as ( + conn, + read_pref, + ): + if self._retrying and not self._retryable: + self._check_last_error() + return self._func(self._session, self._server, conn, read_pref) # type: ignore + + +def _after_fork_child() -> None: + """Releases the locks in child process and resets the + topologies in all MongoClients. + """ + # Reinitialize locks + _release_locks() + + # Perform cleanup in clients (i.e. get rid of topology) + for _, client in MongoClient._clients.items(): + client._after_fork() + + +def _detect_external_db(entity: str) -> bool: + """Detects external database hosts and logs an informational message at the INFO level.""" + entity = entity.lower() + cosmos_db_hosts = [".cosmos.azure.com"] + document_db_hosts = [".docdb.amazonaws.com", ".docdb-elastic.amazonaws.com"] + + for host in cosmos_db_hosts: + if entity.endswith(host): + _log_or_warn( + _CLIENT_LOGGER, + "You appear to be connected to a CosmosDB cluster. For more information regarding feature " + "compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb", + ) + return True + for host in document_db_hosts: + if entity.endswith(host): + _log_or_warn( + _CLIENT_LOGGER, + "You appear to be connected to a DocumentDB cluster. For more information regarding feature " + "compatibility and support please visit https://www.mongodb.com/supportability/documentdb", + ) + return True + return False + + +if _HAS_REGISTER_AT_FORK: + # This will run in the same thread as the fork was called. + # If we fork in a critical region on the same thread, it should break. + # This is fine since we would never call fork directly from a critical region. + os.register_at_fork(after_in_child=_after_fork_child) diff --git a/venv/Lib/site-packages/pymongo/monitor.py b/venv/Lib/site-packages/pymongo/monitor.py new file mode 100644 index 00000000..64945dd1 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/monitor.py @@ -0,0 +1,485 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Class to monitor a MongoDB server on a background thread.""" + +from __future__ import annotations + +import atexit +import time +import weakref +from typing import TYPE_CHECKING, Any, Mapping, Optional, cast + +from pymongo import common, periodic_executor +from pymongo._csot import MovingMinimum +from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.hello import Hello +from pymongo.lock import _create_lock +from pymongo.periodic_executor import _shutdown_executors +from pymongo.pool import _is_faas +from pymongo.read_preferences import MovingAverage +from pymongo.server_description import ServerDescription +from pymongo.srv_resolver import _SrvResolver + +if TYPE_CHECKING: + from pymongo.pool import Connection, Pool, _CancellationContext + from pymongo.settings import TopologySettings + from pymongo.topology import Topology + + +def _sanitize(error: Exception) -> None: + """PYTHON-2433 Clear error traceback info.""" + error.__traceback__ = None + error.__context__ = None + error.__cause__ = None + + +class MonitorBase: + def __init__(self, topology: Topology, name: str, interval: int, min_interval: float): + """Base class to do periodic work on a background thread. + + The background thread is signaled to stop when the Topology or + this instance is freed. + """ + + # We strongly reference the executor and it weakly references us via + # this closure. When the monitor is freed, stop the executor soon. + def target() -> bool: + monitor = self_ref() + if monitor is None: + return False # Stop the executor. + monitor._run() # type:ignore[attr-defined] + return True + + executor = periodic_executor.PeriodicExecutor( + interval=interval, min_interval=min_interval, target=target, name=name + ) + + self._executor = executor + + def _on_topology_gc(dummy: Optional[Topology] = None) -> None: + # This prevents GC from waiting 10 seconds for hello to complete + # See test_cleanup_executors_on_client_del. + monitor = self_ref() + if monitor: + monitor.gc_safe_close() + + # Avoid cycles. When self or topology is freed, stop executor soon. + self_ref = weakref.ref(self, executor.close) + self._topology = weakref.proxy(topology, _on_topology_gc) + _register(self) + + def open(self) -> None: + """Start monitoring, or restart after a fork. + + Multiple calls have no effect. + """ + self._executor.open() + + def gc_safe_close(self) -> None: + """GC safe close.""" + self._executor.close() + + def close(self) -> None: + """Close and stop monitoring. + + open() restarts the monitor after closing. + """ + self.gc_safe_close() + + def join(self, timeout: Optional[int] = None) -> None: + """Wait for the monitor to stop.""" + self._executor.join(timeout) + + def request_check(self) -> None: + """If the monitor is sleeping, wake it soon.""" + self._executor.wake() + + +class Monitor(MonitorBase): + def __init__( + self, + server_description: ServerDescription, + topology: Topology, + pool: Pool, + topology_settings: TopologySettings, + ): + """Class to monitor a MongoDB server on a background thread. + + Pass an initial ServerDescription, a Topology, a Pool, and + TopologySettings. + + The Topology is weakly referenced. The Pool must be exclusive to this + Monitor. + """ + super().__init__( + topology, + "pymongo_server_monitor_thread", + topology_settings.heartbeat_frequency, + common.MIN_HEARTBEAT_INTERVAL, + ) + self._server_description = server_description + self._pool = pool + self._settings = topology_settings + self._listeners = self._settings._pool_options._event_listeners + self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat + self._cancel_context: Optional[_CancellationContext] = None + self._rtt_monitor = _RttMonitor( + topology, + topology_settings, + topology._create_pool_for_monitor(server_description.address), + ) + if topology_settings.server_monitoring_mode == "stream": + self._stream = True + elif topology_settings.server_monitoring_mode == "poll": + self._stream = False + else: + self._stream = not _is_faas() + + def cancel_check(self) -> None: + """Cancel any concurrent hello check. + + Note: this is called from a weakref.proxy callback and MUST NOT take + any locks. + """ + context = self._cancel_context + if context: + # Note: we cannot close the socket because doing so may cause + # concurrent reads/writes to hang until a timeout occurs + # (depending on the platform). + context.cancel() + + def _start_rtt_monitor(self) -> None: + """Start an _RttMonitor that periodically runs ping.""" + # If this monitor is closed directly before (or during) this open() + # call, the _RttMonitor will not be closed. Checking if this monitor + # was closed directly after resolves the race. + self._rtt_monitor.open() + if self._executor._stopped: + self._rtt_monitor.close() + + def gc_safe_close(self) -> None: + self._executor.close() + self._rtt_monitor.gc_safe_close() + self.cancel_check() + + def close(self) -> None: + self.gc_safe_close() + self._rtt_monitor.close() + # Increment the generation and maybe close the socket. If the executor + # thread has the socket checked out, it will be closed when checked in. + self._reset_connection() + + def _reset_connection(self) -> None: + # Clear our pooled connection. + self._pool.reset() + + def _run(self) -> None: + try: + prev_sd = self._server_description + try: + self._server_description = self._check_server() + except _OperationCancelled as exc: + _sanitize(exc) + # Already closed the connection, wait for the next check. + self._server_description = ServerDescription( + self._server_description.address, error=exc + ) + if prev_sd.is_server_type_known: + # Immediately retry since we've already waited 500ms to + # discover that we've been cancelled. + self._executor.skip_sleep() + return + + # Update the Topology and clear the server pool on error. + self._topology.on_change( + self._server_description, + reset_pool=self._server_description.error, + interrupt_connections=isinstance(self._server_description.error, NetworkTimeout), + ) + + if self._stream and ( + self._server_description.is_server_type_known + and self._server_description.topology_version + ): + self._start_rtt_monitor() + # Immediately check for the next streaming response. + self._executor.skip_sleep() + + if self._server_description.error and prev_sd.is_server_type_known: + # Immediately retry on network errors. + self._executor.skip_sleep() + except ReferenceError: + # Topology was garbage-collected. + self.close() + + def _check_server(self) -> ServerDescription: + """Call hello or read the next streaming response. + + Returns a ServerDescription. + """ + start = time.monotonic() + try: + try: + return self._check_once() + except (OperationFailure, NotPrimaryError) as exc: + # Update max cluster time even when hello fails. + details = cast(Mapping[str, Any], exc.details) + self._topology.receive_cluster_time(details.get("$clusterTime")) + raise + except ReferenceError: + raise + except Exception as error: + _sanitize(error) + sd = self._server_description + address = sd.address + duration = time.monotonic() - start + if self._publish: + awaited = bool(self._stream and sd.is_server_type_known and sd.topology_version) + assert self._listeners is not None + self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited) + self._reset_connection() + if isinstance(error, _OperationCancelled): + raise + self._rtt_monitor.reset() + # Server type defaults to Unknown. + return ServerDescription(address, error=error) + + def _check_once(self) -> ServerDescription: + """A single attempt to call hello. + + Returns a ServerDescription, or raises an exception. + """ + address = self._server_description.address + if self._publish: + assert self._listeners is not None + sd = self._server_description + # XXX: "awaited" could be incorrectly set to True in the rare case + # the pool checkout closes and recreates a connection. + awaited = bool( + self._pool.conns + and self._stream + and sd.is_server_type_known + and sd.topology_version + ) + self._listeners.publish_server_heartbeat_started(address, awaited) + + if self._cancel_context and self._cancel_context.cancelled: + self._reset_connection() + with self._pool.checkout() as conn: + self._cancel_context = conn.cancel_context + response, round_trip_time = self._check_with_socket(conn) + if not response.awaitable: + self._rtt_monitor.add_sample(round_trip_time) + + avg_rtt, min_rtt = self._rtt_monitor.get() + sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) + if self._publish: + assert self._listeners is not None + self._listeners.publish_server_heartbeat_succeeded( + address, round_trip_time, response, response.awaitable + ) + return sd + + def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: + """Return (Hello, round_trip_time). + + Can raise ConnectionFailure or OperationFailure. + """ + cluster_time = self._topology.max_cluster_time() + start = time.monotonic() + if conn.more_to_come: + # Read the next streaming hello (MongoDB 4.4+). + response = Hello(conn._next_reply(), awaitable=True) + elif ( + self._stream and conn.performed_handshake and self._server_description.topology_version + ): + # Initiate streaming hello (MongoDB 4.4+). + response = conn._hello( + cluster_time, + self._server_description.topology_version, + self._settings.heartbeat_frequency, + ) + else: + # New connection handshake or polling hello (MongoDB <4.4). + response = conn._hello(cluster_time, None, None) + return response, time.monotonic() - start + + +class SrvMonitor(MonitorBase): + def __init__(self, topology: Topology, topology_settings: TopologySettings): + """Class to poll SRV records on a background thread. + + Pass a Topology and a TopologySettings. + + The Topology is weakly referenced. + """ + super().__init__( + topology, + "pymongo_srv_polling_thread", + common.MIN_SRV_RESCAN_INTERVAL, + topology_settings.heartbeat_frequency, + ) + self._settings = topology_settings + self._seedlist = self._settings._seeds + assert isinstance(self._settings.fqdn, str) + self._fqdn: str = self._settings.fqdn + self._startup_time = time.monotonic() + + def _run(self) -> None: + # Don't poll right after creation, wait 60 seconds first + if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL: + return + seedlist = self._get_seedlist() + if seedlist: + self._seedlist = seedlist + try: + self._topology.on_srv_update(self._seedlist) + except ReferenceError: + # Topology was garbage-collected. + self.close() + + def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: + """Poll SRV records for a seedlist. + + Returns a list of ServerDescriptions. + """ + try: + resolver = _SrvResolver( + self._fqdn, + self._settings.pool_options.connect_timeout, + self._settings.srv_service_name, + ) + seedlist, ttl = resolver.get_hosts_and_min_ttl() + if len(seedlist) == 0: + # As per the spec: this should be treated as a failure. + raise Exception + except Exception: + # As per the spec, upon encountering an error: + # - An error must not be raised + # - SRV records must be rescanned every heartbeatFrequencyMS + # - Topology must be left unchanged + self.request_check() + return None + else: + self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) + return seedlist + + +class _RttMonitor(MonitorBase): + def __init__(self, topology: Topology, topology_settings: TopologySettings, pool: Pool): + """Maintain round trip times for a server. + + The Topology is weakly referenced. + """ + super().__init__( + topology, + "pymongo_server_rtt_thread", + topology_settings.heartbeat_frequency, + common.MIN_HEARTBEAT_INTERVAL, + ) + + self._pool = pool + self._moving_average = MovingAverage() + self._moving_min = MovingMinimum() + self._lock = _create_lock() + + def close(self) -> None: + self.gc_safe_close() + # Increment the generation and maybe close the socket. If the executor + # thread has the socket checked out, it will be closed when checked in. + self._pool.reset() + + def add_sample(self, sample: float) -> None: + """Add a RTT sample.""" + with self._lock: + self._moving_average.add_sample(sample) + self._moving_min.add_sample(sample) + + def get(self) -> tuple[Optional[float], float]: + """Get the calculated average, or None if no samples yet and the min.""" + with self._lock: + return self._moving_average.get(), self._moving_min.get() + + def reset(self) -> None: + """Reset the average RTT.""" + with self._lock: + self._moving_average.reset() + self._moving_min.reset() + + def _run(self) -> None: + try: + # NOTE: This thread is only run when using the streaming + # heartbeat protocol (MongoDB 4.4+). + # XXX: Skip check if the server is unknown? + rtt = self._ping() + self.add_sample(rtt) + except ReferenceError: + # Topology was garbage-collected. + self.close() + except Exception: + self._pool.reset() + + def _ping(self) -> float: + """Run a "hello" command and return the RTT.""" + with self._pool.checkout() as conn: + if self._executor._stopped: + raise Exception("_RttMonitor closed") + start = time.monotonic() + conn.hello() + return time.monotonic() - start + + +# Close monitors to cancel any in progress streaming checks before joining +# executor threads. For an explanation of how this works see the comment +# about _EXECUTORS in periodic_executor.py. +_MONITORS = set() + + +def _register(monitor: MonitorBase) -> None: + ref = weakref.ref(monitor, _unregister) + _MONITORS.add(ref) + + +def _unregister(monitor_ref: weakref.ReferenceType[MonitorBase]) -> None: + _MONITORS.remove(monitor_ref) + + +def _shutdown_monitors() -> None: + if _MONITORS is None: + return + + # Copy the set. Closing monitors removes them. + monitors = list(_MONITORS) + + # Close all monitors. + for ref in monitors: + monitor = ref() + if monitor: + monitor.gc_safe_close() + + monitor = None + + +def _shutdown_resources() -> None: + # _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown. + shutdown = _shutdown_monitors + if shutdown: # type:ignore[truthy-function] + shutdown() + shutdown = _shutdown_executors + if shutdown: # type:ignore[truthy-function] + shutdown() + + +atexit.register(_shutdown_resources) diff --git a/venv/Lib/site-packages/pymongo/monitoring.py b/venv/Lib/site-packages/pymongo/monitoring.py new file mode 100644 index 00000000..aff11a9f --- /dev/null +++ b/venv/Lib/site-packages/pymongo/monitoring.py @@ -0,0 +1,1916 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools to monitor driver events. + +.. versionadded:: 3.1 + +.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below + are included in the PyMongo distribution under the + :mod:`~pymongo.event_loggers` submodule. + +Use :func:`register` to register global listeners for specific events. +Listeners must inherit from one of the abstract classes below and implement +the correct functions for that class. + +For example, a simple command logger might be implemented like this:: + + import logging + + from pymongo import monitoring + + class CommandLogger(monitoring.CommandListener): + + def started(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} started on server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "succeeded in {0.duration_micros} " + "microseconds".format(event)) + + def failed(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "failed in {0.duration_micros} " + "microseconds".format(event)) + + monitoring.register(CommandLogger()) + +Server discovery and monitoring events are also available. For example:: + + class ServerLogger(monitoring.ServerListener): + + def opened(self, event): + logging.info("Server {0.server_address} added to topology " + "{0.topology_id}".format(event)) + + def description_changed(self, event): + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + "Server {0.server_address} changed type from " + "{0.previous_description.server_type_name} to " + "{0.new_description.server_type_name}".format(event)) + + def closed(self, event): + logging.warning("Server {0.server_address} removed from topology " + "{0.topology_id}".format(event)) + + + class HeartbeatLogger(monitoring.ServerHeartbeatListener): + + def started(self, event): + logging.info("Heartbeat sent to server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + # The reply.document attribute was added in PyMongo 3.4. + logging.info("Heartbeat to server {0.connection_id} " + "succeeded with reply " + "{0.reply.document}".format(event)) + + def failed(self, event): + logging.warning("Heartbeat to server {0.connection_id} " + "failed with error {0.reply}".format(event)) + + class TopologyLogger(monitoring.TopologyListener): + + def opened(self, event): + logging.info("Topology with id {0.topology_id} " + "opened".format(event)) + + def description_changed(self, event): + logging.info("Topology description updated for " + "topology id {0.topology_id}".format(event)) + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + "Topology {0.topology_id} changed type from " + "{0.previous_description.topology_type_name} to " + "{0.new_description.topology_type_name}".format(event)) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event): + logging.info("Topology with id {0.topology_id} " + "closed".format(event)) + +Connection monitoring and pooling events are also available. For example:: + + class ConnectionPoolLogger(ConnectionPoolListener): + + def pool_created(self, event): + logging.info("[pool {0.address}] pool created".format(event)) + + def pool_ready(self, event): + logging.info("[pool {0.address}] pool is ready".format(event)) + + def pool_cleared(self, event): + logging.info("[pool {0.address}] pool cleared".format(event)) + + def pool_closed(self, event): + logging.info("[pool {0.address}] pool closed".format(event)) + + def connection_created(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection created".format(event)) + + def connection_ready(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection setup succeeded".format(event)) + + def connection_closed(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection closed, reason: " + "{0.reason}".format(event)) + + def connection_check_out_started(self, event): + logging.info("[pool {0.address}] connection check out " + "started".format(event)) + + def connection_check_out_failed(self, event): + logging.info("[pool {0.address}] connection check out " + "failed, reason: {0.reason}".format(event)) + + def connection_checked_out(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked out of pool".format(event)) + + def connection_checked_in(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked into pool".format(event)) + + +Event listeners can also be registered per instance of +:class:`~pymongo.mongo_client.MongoClient`:: + + client = MongoClient(event_listeners=[CommandLogger()]) + +Note that previously registered global listeners are automatically included +when configuring per client event listeners. Registering a new global listener +will not add that listener to existing client instances. + +.. note:: Events are delivered **synchronously**. Application threads block + waiting for event handlers (e.g. :meth:`~CommandListener.started`) to + return. Care must be taken to ensure that your event handlers are efficient + enough to not adversely affect overall application performance. + +.. warning:: The command documents published through this API are *not* copies. + If you intend to modify them in any way you must copy them in your event + handler first. +""" + +from __future__ import annotations + +import datetime +from collections import abc, namedtuple +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +from bson.objectid import ObjectId +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers import _handle_exception +from pymongo.typings import _Address, _DocumentOut + +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription + + +_Listeners = namedtuple( + "_Listeners", + ( + "command_listeners", + "server_listeners", + "server_heartbeat_listeners", + "topology_listeners", + "cmap_listeners", + ), +) + +_LISTENERS = _Listeners([], [], [], [], []) + + +class _EventListener: + """Abstract base class for all event listeners.""" + + +class CommandListener(_EventListener): + """Abstract base class for command listeners. + + Handles `CommandStartedEvent`, `CommandSucceededEvent`, + and `CommandFailedEvent`. + """ + + def started(self, event: CommandStartedEvent) -> None: + """Abstract method to handle a `CommandStartedEvent`. + + :param event: An instance of :class:`CommandStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: CommandSucceededEvent) -> None: + """Abstract method to handle a `CommandSucceededEvent`. + + :param event: An instance of :class:`CommandSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: CommandFailedEvent) -> None: + """Abstract method to handle a `CommandFailedEvent`. + + :param event: An instance of :class:`CommandFailedEvent`. + """ + raise NotImplementedError + + +class ConnectionPoolListener(_EventListener): + """Abstract base class for connection pool listeners. + + Handles all of the connection pool events defined in the Connection + Monitoring and Pooling Specification: + :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, + :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, + :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, + :class:`ConnectionCheckOutStartedEvent`, + :class:`ConnectionCheckOutFailedEvent`, + :class:`ConnectionCheckedOutEvent`, + and :class:`ConnectionCheckedInEvent`. + + .. versionadded:: 3.9 + """ + + def pool_created(self, event: PoolCreatedEvent) -> None: + """Abstract method to handle a :class:`PoolCreatedEvent`. + + Emitted when a connection Pool is created. + + :param event: An instance of :class:`PoolCreatedEvent`. + """ + raise NotImplementedError + + def pool_ready(self, event: PoolReadyEvent) -> None: + """Abstract method to handle a :class:`PoolReadyEvent`. + + Emitted when a connection Pool is marked ready. + + :param event: An instance of :class:`PoolReadyEvent`. + + .. versionadded:: 4.0 + """ + raise NotImplementedError + + def pool_cleared(self, event: PoolClearedEvent) -> None: + """Abstract method to handle a `PoolClearedEvent`. + + Emitted when a connection Pool is cleared. + + :param event: An instance of :class:`PoolClearedEvent`. + """ + raise NotImplementedError + + def pool_closed(self, event: PoolClosedEvent) -> None: + """Abstract method to handle a `PoolClosedEvent`. + + Emitted when a connection Pool is closed. + + :param event: An instance of :class:`PoolClosedEvent`. + """ + raise NotImplementedError + + def connection_created(self, event: ConnectionCreatedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCreatedEvent`. + + Emitted when a connection Pool creates a Connection object. + + :param event: An instance of :class:`ConnectionCreatedEvent`. + """ + raise NotImplementedError + + def connection_ready(self, event: ConnectionReadyEvent) -> None: + """Abstract method to handle a :class:`ConnectionReadyEvent`. + + Emitted when a connection has finished its setup, and is now ready to + use. + + :param event: An instance of :class:`ConnectionReadyEvent`. + """ + raise NotImplementedError + + def connection_closed(self, event: ConnectionClosedEvent) -> None: + """Abstract method to handle a :class:`ConnectionClosedEvent`. + + Emitted when a connection Pool closes a connection. + + :param event: An instance of :class:`ConnectionClosedEvent`. + """ + raise NotImplementedError + + def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. + + Emitted when the driver starts attempting to check out a connection. + + :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. + """ + raise NotImplementedError + + def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. + + Emitted when the driver's attempt to check out a connection fails. + + :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. + """ + raise NotImplementedError + + def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. + + Emitted when the driver successfully checks out a connection. + + :param event: An instance of :class:`ConnectionCheckedOutEvent`. + """ + raise NotImplementedError + + def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedInEvent`. + + Emitted when the driver checks in a connection back to the connection + Pool. + + :param event: An instance of :class:`ConnectionCheckedInEvent`. + """ + raise NotImplementedError + + +class ServerHeartbeatListener(_EventListener): + """Abstract base class for server heartbeat listeners. + + Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, + and `ServerHeartbeatFailedEvent`. + + .. versionadded:: 3.3 + """ + + def started(self, event: ServerHeartbeatStartedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatStartedEvent`. + + :param event: An instance of :class:`ServerHeartbeatStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: + """Abstract method to handle a `ServerHeartbeatSucceededEvent`. + + :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: ServerHeartbeatFailedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatFailedEvent`. + + :param event: An instance of :class:`ServerHeartbeatFailedEvent`. + """ + raise NotImplementedError + + +class TopologyListener(_EventListener): + """Abstract base class for topology monitoring listeners. + Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and + `TopologyClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: TopologyOpenedEvent) -> None: + """Abstract method to handle a `TopologyOpenedEvent`. + + :param event: An instance of :class:`TopologyOpenedEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: + """Abstract method to handle a `TopologyDescriptionChangedEvent`. + + :param event: An instance of :class:`TopologyDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: TopologyClosedEvent) -> None: + """Abstract method to handle a `TopologyClosedEvent`. + + :param event: An instance of :class:`TopologyClosedEvent`. + """ + raise NotImplementedError + + +class ServerListener(_EventListener): + """Abstract base class for server listeners. + Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and + `ServerClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: ServerOpeningEvent) -> None: + """Abstract method to handle a `ServerOpeningEvent`. + + :param event: An instance of :class:`ServerOpeningEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: ServerDescriptionChangedEvent) -> None: + """Abstract method to handle a `ServerDescriptionChangedEvent`. + + :param event: An instance of :class:`ServerDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: ServerClosedEvent) -> None: + """Abstract method to handle a `ServerClosedEvent`. + + :param event: An instance of :class:`ServerClosedEvent`. + """ + raise NotImplementedError + + +def _to_micros(dur: timedelta) -> int: + """Convert duration 'dur' to microseconds.""" + return int(dur.total_seconds() * 10e5) + + +def _validate_event_listeners( + option: str, listeners: Sequence[_EventListeners] +) -> Sequence[_EventListeners]: + """Validate event listeners""" + if not isinstance(listeners, abc.Sequence): + raise TypeError(f"{option} must be a list or tuple") + for listener in listeners: + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {option} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + return listeners + + +def register(listener: _EventListener) -> None: + """Register a global event listener. + + :param listener: A subclasses of :class:`CommandListener`, + :class:`ServerHeartbeatListener`, :class:`ServerListener`, + :class:`TopologyListener`, or :class:`ConnectionPoolListener`. + """ + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {listener} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + if isinstance(listener, CommandListener): + _LISTENERS.command_listeners.append(listener) + if isinstance(listener, ServerHeartbeatListener): + _LISTENERS.server_heartbeat_listeners.append(listener) + if isinstance(listener, ServerListener): + _LISTENERS.server_listeners.append(listener) + if isinstance(listener, TopologyListener): + _LISTENERS.topology_listeners.append(listener) + if isinstance(listener, ConnectionPoolListener): + _LISTENERS.cmap_listeners.append(listener) + + +# Note - to avoid bugs from forgetting which if these is all lowercase and +# which are camelCase, and at the same time avoid having to add a test for +# every command, use all lowercase here and test against command_name.lower(). +_SENSITIVE_COMMANDS: set = { + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +} + + +# The "hello" command is also deemed sensitive when attempting speculative +# authentication. +def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: + if ( + command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) + and "speculativeAuthenticate" in doc + ): + return True + return False + + +class _CommandEvent: + """Base class for command events.""" + + __slots__ = ( + "__cmd_name", + "__rqst_id", + "__conn_id", + "__op_id", + "__service_id", + "__db", + "__server_conn_id", + ) + + def __init__( + self, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + self.__cmd_name = command_name + self.__rqst_id = request_id + self.__conn_id = connection_id + self.__op_id = operation_id + self.__service_id = service_id + self.__db = database_name + self.__server_conn_id = server_connection_id + + @property + def command_name(self) -> str: + """The command name.""" + return self.__cmd_name + + @property + def request_id(self) -> int: + """The request id for this operation.""" + return self.__rqst_id + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this command was sent to.""" + return self.__conn_id + + @property + def service_id(self) -> Optional[ObjectId]: + """The service_id this command was sent to, or ``None``. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def operation_id(self) -> Optional[int]: + """An id for this series of events or None.""" + return self.__op_id + + @property + def database_name(self) -> str: + """The database_name this command was sent to, or ``""``. + + .. versionadded:: 4.6 + """ + return self.__db + + @property + def server_connection_id(self) -> Optional[int]: + """The server-side connection id for the connection this command was sent on, or ``None``. + + .. versionadded:: 4.7 + """ + return self.__server_conn_id + + +class CommandStartedEvent(_CommandEvent): + """Event published when a command starts. + + :param command: The command document. + :param database_name: The name of the database this command was run against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + """ + + __slots__ = ("__cmd",) + + def __init__( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + server_connection_id: Optional[int] = None, + ) -> None: + if not command: + raise ValueError(f"{command!r} is not a valid command") + # Command name must be first key. + command_name = next(iter(command)) + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): + self.__cmd: _DocumentOut = {} + else: + self.__cmd = command + + @property + def command(self) -> _DocumentOut: + """The command document.""" + return self.__cmd + + @property + def database_name(self) -> str: + """The name of the database this command was run against.""" + return super().database_name + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.service_id, + self.server_connection_id, + ) + + +class CommandSucceededEvent(_CommandEvent): + """Event published when a command succeeds. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__reply") + + def __init__( + self, + duration: datetime.timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): + self.__reply: _DocumentOut = {} + else: + self.__reply = reply + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def reply(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__reply + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.service_id, + self.server_connection_id, + ) + + +class CommandFailedEvent(_CommandEvent): + """Event published when a command fails. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__failure") + + def __init__( + self, + duration: datetime.timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + self.__failure = failure + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def failure(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__failure + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " + "failure: {!r}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.failure, + self.service_id, + self.server_connection_id, + ) + + +class _PoolEvent: + """Base class for pool events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server the pool is attempting + to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class PoolCreatedEvent(_PoolEvent): + """Published when a Connection Pool is created. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__options",) + + def __init__(self, address: _Address, options: dict[str, Any]) -> None: + super().__init__(address) + self.__options = options + + @property + def options(self) -> dict[str, Any]: + """Any non-default pool options that were set on this Connection Pool.""" + return self.__options + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" + + +class PoolReadyEvent(_PoolEvent): + """Published when a Connection Pool is marked ready. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 4.0 + """ + + __slots__ = () + + +class PoolClearedEvent(_PoolEvent): + """Published when a Connection Pool is cleared. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + :param service_id: The service_id this command was sent to, or ``None``. + :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__service_id", "__interrupt_connections") + + def __init__( + self, + address: _Address, + service_id: Optional[ObjectId] = None, + interrupt_connections: bool = False, + ) -> None: + super().__init__(address) + self.__service_id = service_id + self.__interrupt_connections = interrupt_connections + + @property + def service_id(self) -> Optional[ObjectId]: + """Connections with this service_id are cleared. + + When service_id is ``None``, all connections in the pool are cleared. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def interrupt_connections(self) -> bool: + """If True, active connections are interrupted during clearing. + + .. versionadded:: 4.7 + """ + return self.__interrupt_connections + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" + + +class PoolClosedEvent(_PoolEvent): + """Published when a Connection Pool is closed. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionClosedEvent`. + + .. versionadded:: 3.9 + """ + + STALE = "stale" + """The pool was cleared, making the connection no longer valid.""" + + IDLE = "idle" + """The connection became stale by being idle for too long (maxIdleTimeMS). + """ + + ERROR = "error" + """The connection experienced an error, making it no longer valid.""" + + POOL_CLOSED = "poolClosed" + """The pool was closed, making the connection no longer valid.""" + + +class ConnectionCheckOutFailedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionCheckOutFailedEvent`. + + .. versionadded:: 3.9 + """ + + TIMEOUT = "timeout" + """The connection check out attempt exceeded the specified timeout.""" + + POOL_CLOSED = "poolClosed" + """The pool was previously closed, and cannot provide new connections.""" + + CONN_ERROR = "connectionError" + """The connection check out attempt experienced an error while setting up + a new connection. + """ + + +class _ConnectionEvent: + """Private base class for connection events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server this connection is + attempting to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class _ConnectionIdEvent(_ConnectionEvent): + """Private base class for connection events with an id.""" + + __slots__ = ("__connection_id",) + + def __init__(self, address: _Address, connection_id: int) -> None: + super().__init__(address) + self.__connection_id = connection_id + + @property + def connection_id(self) -> int: + """The ID of the connection.""" + return self.__connection_id + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" + + +class _ConnectionDurationEvent(_ConnectionIdEvent): + """Private base class for connection events with a duration.""" + + __slots__ = ("__duration",) + + def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: + super().__init__(address, connection_id) + self.__duration = duration + + @property + def duration(self) -> Optional[float]: + """The duration of the connection event. + + .. versionadded:: 4.7 + """ + return self.__duration + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" + + +class ConnectionCreatedEvent(_ConnectionIdEvent): + """Published when a Connection Pool creates a Connection object. + + NOTE: This connection is not ready for use until the + :class:`ConnectionReadyEvent` is published. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionReadyEvent(_ConnectionDurationEvent): + """Published when a Connection has finished its setup, and is ready to use. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedEvent(_ConnectionIdEvent): + """Published when a Connection is closed. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + :param reason: A reason explaining why this connection was closed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, connection_id: int, reason: str): + super().__init__(address, connection_id) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why this connection was closed. + + The reason must be one of the strings from the + :class:`ConnectionClosedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r})".format( + self.__class__.__name__, + self.address, + self.connection_id, + self.__reason, + ) + + +class ConnectionCheckOutStartedEvent(_ConnectionEvent): + """Published when the driver starts attempting to check out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): + """Published when the driver's attempt to check out a connection fails. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param reason: A reason explaining why connection check out failed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: + super().__init__(address=address, connection_id=0, duration=duration) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why connection check out failed. + + The reason must be one of the strings from the + :class:`ConnectionCheckOutFailedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" + + +class ConnectionCheckedOutEvent(_ConnectionDurationEvent): + """Published when the driver successfully checks out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckedInEvent(_ConnectionIdEvent): + """Published when the driver checks in a Connection into the Pool. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class _ServerEvent: + """Base class for server events.""" + + __slots__ = ("__server_address", "__topology_id") + + def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: + self.__server_address = server_address + self.__topology_id = topology_id + + @property + def server_address(self) -> _Address: + """The address (host, port) pair of the server""" + return self.__server_address + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" + + +class ServerDescriptionChangedEvent(_ServerEvent): + """Published when server description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> ServerDescription: + """The previous + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> ServerDescription: + """The new + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.server_address, + self.previous_description, + self.new_description, + ) + + +class ServerOpeningEvent(_ServerEvent): + """Published when server is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerClosedEvent(_ServerEvent): + """Published when server is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyEvent: + """Base class for topology description events.""" + + __slots__ = ("__topology_id",) + + def __init__(self, topology_id: ObjectId) -> None: + self.__topology_id = topology_id + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" + + +class TopologyDescriptionChangedEvent(TopologyEvent): + """Published when the topology description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> TopologyDescription: + """The previous + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> TopologyDescription: + """The new + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} topology_id: {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.topology_id, + self.previous_description, + self.new_description, + ) + + +class TopologyOpenedEvent(TopologyEvent): + """Published when the topology is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyClosedEvent(TopologyEvent): + """Published when the topology is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class _ServerHeartbeatEvent: + """Base class for server heartbeat events.""" + + __slots__ = ("__connection_id", "__awaited") + + def __init__(self, connection_id: _Address, awaited: bool = False) -> None: + self.__connection_id = connection_id + self.__awaited = awaited + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this heartbeat was sent + to. + """ + return self.__connection_id + + @property + def awaited(self) -> bool: + """Whether the heartbeat was issued as an awaitable hello command. + + .. versionadded:: 4.6 + """ + return self.__awaited + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" + + +class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): + """Published when a heartbeat is started. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat succeeds. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Hello: + """An instance of :class:`~pymongo.hello.Hello`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat fails, either with an "ok: 0" + or a socket exception. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Exception: + """A subclass of :exc:`Exception`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class _EventListeners: + """Configure event listeners for a client instance. + + Any event listeners registered globally are included by default. + + :param listeners: A list of event listeners. + """ + + def __init__(self, listeners: Optional[Sequence[_EventListener]]): + self.__command_listeners = _LISTENERS.command_listeners[:] + self.__server_listeners = _LISTENERS.server_listeners[:] + lst = _LISTENERS.server_heartbeat_listeners + self.__server_heartbeat_listeners = lst[:] + self.__topology_listeners = _LISTENERS.topology_listeners[:] + self.__cmap_listeners = _LISTENERS.cmap_listeners[:] + if listeners is not None: + for lst in listeners: + if isinstance(lst, CommandListener): + self.__command_listeners.append(lst) + if isinstance(lst, ServerListener): + self.__server_listeners.append(lst) + if isinstance(lst, ServerHeartbeatListener): + self.__server_heartbeat_listeners.append(lst) + if isinstance(lst, TopologyListener): + self.__topology_listeners.append(lst) + if isinstance(lst, ConnectionPoolListener): + self.__cmap_listeners.append(lst) + self.__enabled_for_commands = bool(self.__command_listeners) + self.__enabled_for_server = bool(self.__server_listeners) + self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) + self.__enabled_for_topology = bool(self.__topology_listeners) + self.__enabled_for_cmap = bool(self.__cmap_listeners) + + @property + def enabled_for_commands(self) -> bool: + """Are any CommandListener instances registered?""" + return self.__enabled_for_commands + + @property + def enabled_for_server(self) -> bool: + """Are any ServerListener instances registered?""" + return self.__enabled_for_server + + @property + def enabled_for_server_heartbeat(self) -> bool: + """Are any ServerHeartbeatListener instances registered?""" + return self.__enabled_for_server_heartbeat + + @property + def enabled_for_topology(self) -> bool: + """Are any TopologyListener instances registered?""" + return self.__enabled_for_topology + + @property + def enabled_for_cmap(self) -> bool: + """Are any ConnectionPoolListener instances registered?""" + return self.__enabled_for_cmap + + def event_listeners(self) -> list[_EventListeners]: + """List of registered event listeners.""" + return ( + self.__command_listeners + + self.__server_heartbeat_listeners + + self.__server_listeners + + self.__topology_listeners + + self.__cmap_listeners + ) + + def publish_command_start( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + ) -> None: + """Publish a CommandStartedEvent to all command listeners. + + :param command: The command document. + :param database_name: The name of the database this command was run + against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + """ + if op_id is None: + op_id = request_id + event = CommandStartedEvent( + command, + database_name, + request_id, + connection_id, + op_id, + service_id=service_id, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_command_success( + self, + duration: timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + speculative_hello: bool = False, + database_name: str = "", + ) -> None: + """Publish a CommandSucceededEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param speculative_hello: Was the command sent with speculative auth? + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + if speculative_hello: + # Redact entire response when the command started contained + # speculativeAuthenticate. + reply = {} + event = CommandSucceededEvent( + duration, + reply, + command_name, + request_id, + connection_id, + op_id, + service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_command_failure( + self, + duration: timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + database_name: str = "", + ) -> None: + """Publish a CommandFailedEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document or failure description + document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + event = CommandFailedEvent( + duration, + failure, + command_name, + request_id, + connection_id, + op_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: + """Publish a ServerHeartbeatStartedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param awaited: True if this heartbeat is part of an awaitable hello command. + """ + event = ServerHeartbeatStartedEvent(connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_succeeded( + self, connection_id: _Address, duration: float, reply: Hello, awaited: bool + ) -> None: + """Publish a ServerHeartbeatSucceededEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_failed( + self, connection_id: _Address, duration: float, reply: Exception, awaited: bool + ) -> None: + """Publish a ServerHeartbeatFailedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerOpeningEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerOpeningEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerClosedEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerClosedEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_server_description_changed( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + server_address: _Address, + topology_id: ObjectId, + ) -> None: + """Publish a ServerDescriptionChangedEvent to all server listeners. + + :param previous_description: The previous server description. + :param server_address: The address (host, port) pair of the server. + :param new_description: The new server description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerDescriptionChangedEvent( + previous_description, new_description, server_address, topology_id + ) + for subscriber in self.__server_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_topology_opened(self, topology_id: ObjectId) -> None: + """Publish a TopologyOpenedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyOpenedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_topology_closed(self, topology_id: ObjectId) -> None: + """Publish a TopologyClosedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyClosedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_topology_description_changed( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + topology_id: ObjectId, + ) -> None: + """Publish a TopologyDescriptionChangedEvent to all topology listeners. + + :param previous_description: The previous topology description. + :param new_description: The new topology description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: + """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" + event = PoolCreatedEvent(address, options) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_created(event) + except Exception: + _handle_exception() + + def publish_pool_ready(self, address: _Address) -> None: + """Publish a :class:`PoolReadyEvent` to all pool listeners.""" + event = PoolReadyEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_ready(event) + except Exception: + _handle_exception() + + def publish_pool_cleared( + self, + address: _Address, + service_id: Optional[ObjectId], + interrupt_connections: bool = False, + ) -> None: + """Publish a :class:`PoolClearedEvent` to all pool listeners.""" + event = PoolClearedEvent(address, service_id, interrupt_connections) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_cleared(event) + except Exception: + _handle_exception() + + def publish_pool_closed(self, address: _Address) -> None: + """Publish a :class:`PoolClosedEvent` to all pool listeners.""" + event = PoolClosedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_closed(event) + except Exception: + _handle_exception() + + def publish_connection_created(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCreatedEvent` to all connection + listeners. + """ + event = ConnectionCreatedEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_created(event) + except Exception: + _handle_exception() + + def publish_connection_ready( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" + event = ConnectionReadyEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_ready(event) + except Exception: + _handle_exception() + + def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: + """Publish a :class:`ConnectionClosedEvent` to all connection + listeners. + """ + event = ConnectionClosedEvent(address, connection_id, reason) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_closed(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_started(self, address: _Address) -> None: + """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutStartedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_started(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_failed( + self, address: _Address, reason: str, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutFailedEvent(address, reason, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_failed(event) + except Exception: + _handle_exception() + + def publish_connection_checked_out( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckedOutEvent` to all connection + listeners. + """ + event = ConnectionCheckedOutEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_out(event) + except Exception: + _handle_exception() + + def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCheckedInEvent` to all connection + listeners. + """ + event = ConnectionCheckedInEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_in(event) + except Exception: + _handle_exception() diff --git a/venv/Lib/site-packages/pymongo/network.py b/venv/Lib/site-packages/pymongo/network.py new file mode 100644 index 00000000..76afbe13 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/network.py @@ -0,0 +1,412 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal network layer helper methods.""" +from __future__ import annotations + +import datetime +import errno +import logging +import socket +import struct +import time +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, + cast, +) + +from bson import _decode_all_selective +from pymongo import _csot, helpers, message, ssl_support +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.errors import ( + NotPrimaryError, + OperationFailure, + ProtocolError, + _OperationCancelled, +) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo.monitoring import _is_speculative_authenticate +from pymongo.socket_checker import _errno_from_exception + +if TYPE_CHECKING: + from bson import CodecOptions + from pymongo.client_session import ClientSession + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.mongo_client import MongoClient + from pymongo.monitoring import _EventListeners + from pymongo.pool import Connection + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.write_concern import WriteConcern + +_UNPACK_HEADER = struct.Struct(" _DocumentType: + """Execute a command over the socket, or raise socket.error. + + :param conn: a Connection instance + :param dbname: name of the database on which to run the command + :param spec: a command document as an ordered dict type, eg SON. + :param is_mongos: are we connected to a mongos? + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param session: optional ClientSession instance. + :param client: optional MongoClient instance for updating $clusterTime. + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param address: the (host, port) of `conn` + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` + :param max_bson_size: The maximum encoded bson size for this server + :param read_concern: The read concern for this command. + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` + field in the command response. + :param collation: The collation for this command. + :param compression_ctx: optional compression Context. + :param use_op_msg: True if we should use OP_MSG. + :param unacknowledged: True if this is an unacknowledged command. + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. + """ + name = next(iter(spec)) + ns = dbname + ".$cmd" + speculative_hello = False + + # Publish the original command document, perhaps with lsid and $clusterTime. + orig = spec + if is_mongos and not use_op_msg: + assert read_preference is not None + spec = message._maybe_add_read_preference(spec, read_preference) + if read_concern and not (session and session.in_transaction): + if read_concern.level: + spec["readConcern"] = read_concern.document + if session: + session._update_read_concern(spec, conn) + if collation is not None: + spec["collation"] = collation + + publish = listeners is not None and listeners.enabled_for_commands + start = datetime.datetime.now() + if publish: + speculative_hello = _is_speculative_authenticate(name, spec) + + if compression_ctx and name.lower() in _NO_COMPRESSION: + compression_ctx = None + + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: + spec = orig = client._encrypter.encrypt(dbname, spec, codec_options) + + # Support CSOT + if client: + conn.apply_timeout(client, spec) + _csot.apply_write_concern(spec, write_concern) + + if use_op_msg: + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 + request_id, msg, size, max_doc_size = message._op_msg( + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx + ) + # If this is an unacknowledged write then make sure the encoded doc(s) + # are small enough, otherwise rely on the server to return an error. + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: + message._raise_document_too_large(name, size, max_bson_size) + else: + request_id, msg, size = message._query( + 0, ns, 0, -1, spec, None, codec_options, compression_ctx + ) + + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=spec, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + conn.conn.sendall(msg) + if use_op_msg and unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + response_doc: _DocumentOut = {"ok": 1} + else: + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields + ) + + response_doc = unpacked_docs[0] + if client: + client._process_response(response_doc, session) + if check: + helpers._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=response_doc, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + response_doc = cast( + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] + ) + + return response_doc # type: ignore[return-value] + + +_UNPACK_COMPRESSION_HEADER = struct.Struct(" Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" + if _csot.get_timeout(): + deadline = _csot.get_deadline() + else: + timeout = conn.conn.gettimeout() + if timeout: + deadline = time.monotonic() + timeout + else: + deadline = None + # Ignore the response's request id. + length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) + # No request_id for exhaust cursor "getMore". + if request_id is not None: + if request_id != response_to: + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") + if length <= 16: + raise ProtocolError( + f"Message length ({length!r}) not longer than standard message header size (16)" + ) + if length > max_message_size: + raise ProtocolError( + f"Message length ({length!r}) is larger than server max " + f"message size ({max_message_size!r})" + ) + if op_code == 2012: + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( + _receive_data_on_socket(conn, 9, deadline) + ) + data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) + else: + data = _receive_data_on_socket(conn, length - 16, deadline) + + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) + + +_POLL_TIMEOUT = 0.5 + + +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: + """Block until at least one byte is read, or a timeout, or a cancel.""" + sock = conn.conn + timed_out = False + # Check if the connection's socket has been manually closed + if sock.fileno() == -1: + return + while True: + # SSLSocket can have buffered data which won't be caught by select. + if hasattr(sock, "pending") and sock.pending() > 0: + readable = True + else: + # Wait up to 500ms for the socket to become readable and then + # check for cancellation. + if deadline: + remaining = deadline - time.monotonic() + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + if remaining <= 0: + timed_out = True + timeout = max(min(remaining, _POLL_TIMEOUT), 0) + else: + timeout = _POLL_TIMEOUT + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") + if readable: + return + if timed_out: + raise socket.timeout("timed out") + + +# Errors raised by sockets (and TLS sockets) when in non-blocking mode. +BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS) + + +def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: + buf = bytearray(length) + mv = memoryview(buf) + bytes_read = 0 + while bytes_read < length: + try: + wait_for_read(conn, deadline) + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + if _csot.get_timeout() and deadline is not None: + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + except OSError as exc: + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + + return mv diff --git a/venv/Lib/site-packages/pymongo/ocsp_cache.py b/venv/Lib/site-packages/pymongo/ocsp_cache.py new file mode 100644 index 00000000..74257931 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/ocsp_cache.py @@ -0,0 +1,108 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for caching OCSP responses.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import datetime as _datetime +from datetime import timezone +from typing import TYPE_CHECKING, Any + +from pymongo.lock import _create_lock + +if TYPE_CHECKING: + from cryptography.x509.ocsp import OCSPRequest, OCSPResponse + + +class _OCSPCache: + """A cache for OCSP responses.""" + + CACHE_KEY_TYPE = namedtuple( # type: ignore + "OcspResponseCacheKey", + ["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"], + ) + + def __init__(self) -> None: + self._data: dict[Any, OCSPResponse] = {} + # Hold this lock when accessing _data. + self._lock = _create_lock() + + def _get_cache_key(self, ocsp_request: OCSPRequest) -> CACHE_KEY_TYPE: + return self.CACHE_KEY_TYPE( + hash_algorithm=ocsp_request.hash_algorithm.name.lower(), + issuer_name_hash=ocsp_request.issuer_name_hash, + issuer_key_hash=ocsp_request.issuer_key_hash, + serial_number=ocsp_request.serial_number, + ) + + def __setitem__(self, key: OCSPRequest, value: OCSPResponse) -> None: + """Add/update a cache entry. + + 'key' is of type cryptography.x509.ocsp.OCSPRequest + 'value' is of type cryptography.x509.ocsp.OCSPResponse + + Validity of the OCSP response must be checked by caller. + """ + with self._lock: + cache_key = self._get_cache_key(key) + + # As per the OCSP protocol, if the response's nextUpdate field is + # not set, the responder is indicating that newer revocation + # information is available all the time. + if value.next_update is None: + self._data.pop(cache_key, None) + return + + # Do nothing if the response is invalid. + if not ( + value.this_update + <= _datetime.now(tz=timezone.utc).replace(tzinfo=None) + < value.next_update + ): + return + + # Cache new response OR update cached response if new response + # has longer validity. + cached_value = self._data.get(cache_key, None) + if cached_value is None or ( + cached_value.next_update is not None + and cached_value.next_update < value.next_update + ): + self._data[cache_key] = value + + def __getitem__(self, item: OCSPRequest) -> OCSPResponse: + """Get a cache entry if it exists. + + 'item' is of type cryptography.x509.ocsp.OCSPRequest + + Raises KeyError if the item is not in the cache. + """ + with self._lock: + cache_key = self._get_cache_key(item) + value = self._data[cache_key] + + # Return cached response if it is still valid. + assert value.this_update is not None + assert value.next_update is not None + if ( + value.this_update + <= _datetime.now(tz=timezone.utc).replace(tzinfo=None) + < value.next_update + ): + return value + + self._data.pop(cache_key, None) + raise KeyError(cache_key) diff --git a/venv/Lib/site-packages/pymongo/ocsp_support.py b/venv/Lib/site-packages/pymongo/ocsp_support.py new file mode 100644 index 00000000..1bda3b4d --- /dev/null +++ b/venv/Lib/site-packages/pymongo/ocsp_support.py @@ -0,0 +1,432 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for requesting and verifying OCSP responses.""" +from __future__ import annotations + +import logging as _logging +import re as _re +from datetime import datetime as _datetime +from datetime import timezone +from typing import TYPE_CHECKING, Iterable, Optional, Type, Union + +from cryptography.exceptions import InvalidSignature as _InvalidSignature +from cryptography.hazmat.backends import default_backend as _default_backend +from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey as _DSAPublicKey +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA as _ECDSA +from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePublicKey as _EllipticCurvePublicKey, +) +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 as _PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey as _RSAPublicKey +from cryptography.hazmat.primitives.asymmetric.x448 import ( + X448PublicKey as _X448PublicKey, +) +from cryptography.hazmat.primitives.asymmetric.x25519 import ( + X25519PublicKey as _X25519PublicKey, +) +from cryptography.hazmat.primitives.hashes import SHA1 as _SHA1 +from cryptography.hazmat.primitives.hashes import Hash as _Hash +from cryptography.hazmat.primitives.serialization import Encoding as _Encoding +from cryptography.hazmat.primitives.serialization import PublicFormat as _PublicFormat +from cryptography.x509 import AuthorityInformationAccess as _AuthorityInformationAccess +from cryptography.x509 import ExtendedKeyUsage as _ExtendedKeyUsage +from cryptography.x509 import ExtensionNotFound as _ExtensionNotFound +from cryptography.x509 import TLSFeature as _TLSFeature +from cryptography.x509 import TLSFeatureType as _TLSFeatureType +from cryptography.x509 import load_pem_x509_certificate as _load_pem_x509_certificate +from cryptography.x509.ocsp import OCSPCertStatus as _OCSPCertStatus +from cryptography.x509.ocsp import OCSPRequestBuilder as _OCSPRequestBuilder +from cryptography.x509.ocsp import OCSPResponseStatus as _OCSPResponseStatus +from cryptography.x509.ocsp import load_der_ocsp_response as _load_der_ocsp_response +from cryptography.x509.oid import ( + AuthorityInformationAccessOID as _AuthorityInformationAccessOID, +) +from cryptography.x509.oid import ExtendedKeyUsageOID as _ExtendedKeyUsageOID +from requests import post as _post +from requests.exceptions import RequestException as _RequestException + +from pymongo import _csot + +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + ed448, + ed25519, + rsa, + x448, + x25519, + ) + from cryptography.hazmat.primitives.asymmetric.utils import Prehashed + from cryptography.hazmat.primitives.hashes import HashAlgorithm + from cryptography.x509 import Certificate, Name + from cryptography.x509.extensions import Extension, ExtensionTypeVar + from cryptography.x509.ocsp import OCSPRequest, OCSPResponse + from OpenSSL.SSL import Connection + + from pymongo.ocsp_cache import _OCSPCache + from pymongo.pyopenssl_context import _CallbackData + + CertificateIssuerPublicKeyTypes = Union[ + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + x25519.X25519PublicKey, + x448.X448PublicKey, + ] + +# Note: the functions in this module generally return 1 or 0. The reason +# is simple. The entry point, ocsp_callback, is registered as a callback +# with OpenSSL through PyOpenSSL. The callback must return 1 (success) or +# 0 (failure). + +_LOGGER = _logging.getLogger(__name__) + +_CERT_REGEX = _re.compile( + b"-----BEGIN CERTIFICATE[^\r\n]+.+?-----END CERTIFICATE[^\r\n]+", _re.DOTALL +) + + +def _load_trusted_ca_certs(cafile: str) -> list[Certificate]: + """Parse the tlsCAFile into a list of certificates.""" + with open(cafile, "rb") as f: + data = f.read() + + # Load all the certs in the file. + trusted_ca_certs = [] + backend = _default_backend() + for cert_data in _re.findall(_CERT_REGEX, data): + trusted_ca_certs.append(_load_pem_x509_certificate(cert_data, backend)) + return trusted_ca_certs + + +def _get_issuer_cert( + cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[list[Certificate]] +) -> Optional[Certificate]: + issuer_name = cert.issuer + for candidate in chain: + if candidate.subject == issuer_name: + return candidate + + # Depending on the server's TLS library, the peer's cert chain may not + # include the self signed root CA. In this case we check the user + # provided tlsCAFile for the issuer. + # Remove once we use the verified peer cert chain in PYTHON-2147. + if trusted_ca_certs: + for candidate in trusted_ca_certs: + if candidate.subject == issuer_name: + return candidate + return None + + +def _verify_signature( + key: CertificateIssuerPublicKeyTypes, + signature: bytes, + algorithm: Union[Prehashed, HashAlgorithm, None], + data: bytes, +) -> int: + # See cryptography.x509.Certificate.public_key + # for the public key types. + try: + if isinstance(key, _RSAPublicKey): + key.verify(signature, data, _PKCS1v15(), algorithm) # type: ignore[arg-type] + elif isinstance(key, _DSAPublicKey): + key.verify(signature, data, algorithm) # type: ignore[arg-type] + elif isinstance(key, _EllipticCurvePublicKey): + key.verify(signature, data, _ECDSA(algorithm)) # type: ignore[arg-type] + elif isinstance( + key, (_X25519PublicKey, _X448PublicKey) + ): # Curve25519 and Curve448 keys do not require verification + return 1 + else: + key.verify(signature, data) + except _InvalidSignature: + return 0 + return 1 + + +def _get_extension( + cert: Certificate, klass: Type[ExtensionTypeVar] +) -> Optional[Extension[ExtensionTypeVar]]: + try: + return cert.extensions.get_extension_for_class(klass) + except _ExtensionNotFound: + return None + + +def _public_key_hash(cert: Certificate) -> bytes: + public_key = cert.public_key() + # https://tools.ietf.org/html/rfc2560#section-4.2.1 + # "KeyHash ::= OCTET STRING -- SHA-1 hash of responder's public key + # (excluding the tag and length fields)" + # https://stackoverflow.com/a/46309453/600498 + if isinstance(public_key, _RSAPublicKey): + pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.PKCS1) + elif isinstance(public_key, _EllipticCurvePublicKey): + pbytes = public_key.public_bytes(_Encoding.X962, _PublicFormat.UncompressedPoint) + else: + pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.SubjectPublicKeyInfo) + digest = _Hash(_SHA1(), backend=_default_backend()) # noqa: S303 + digest.update(pbytes) + return digest.finalize() + + +def _get_certs_by_key_hash( + certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes] +) -> list[Certificate]: + return [ + cert + for cert in certificates + if _public_key_hash(cert) == responder_key_hash and cert.issuer == issuer.subject + ] + + +def _get_certs_by_name( + certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name] +) -> list[Certificate]: + return [ + cert + for cert in certificates + if cert.subject == responder_name and cert.issuer == issuer.subject + ] + + +def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> int: + # Response object will have a responder_name or responder_key_hash + # not both. + name = response.responder_name + rkey_hash = response.responder_key_hash + ikey_hash = response.issuer_key_hash + if name is not None and name == issuer.subject or rkey_hash == ikey_hash: + _LOGGER.debug("Responder is issuer") + # Responder is the issuer + responder_cert = issuer + else: + _LOGGER.debug("Responder is a delegate") + # Responder is a delegate + # https://tools.ietf.org/html/rfc6960#section-2.6 + # RFC6960, Section 3.2, Number 3 + certs = response.certificates + if response.responder_name is not None: + responder_certs = _get_certs_by_name(certs, issuer, name) + _LOGGER.debug("Using responder name") + else: + responder_certs = _get_certs_by_key_hash(certs, issuer, rkey_hash) + _LOGGER.debug("Using key hash") + if not responder_certs: + _LOGGER.debug("No matching or valid responder certs.") + return 0 + # XXX: Can there be more than one? If so, should we try each one + # until we find one that passes signature verification? + responder_cert = responder_certs[0] + + # RFC6960, Section 3.2, Number 4 + ext = _get_extension(responder_cert, _ExtendedKeyUsage) + if not ext or _ExtendedKeyUsageOID.OCSP_SIGNING not in ext.value: + _LOGGER.debug("Delegate not authorized for OCSP signing") + return 0 + if not _verify_signature( + issuer.public_key(), + responder_cert.signature, + responder_cert.signature_hash_algorithm, + responder_cert.tbs_certificate_bytes, + ): + _LOGGER.debug("Delegate signature verification failed") + return 0 + # RFC6960, Section 3.2, Number 2 + ret = _verify_signature( + responder_cert.public_key(), + response.signature, + response.signature_hash_algorithm, + response.tbs_response_bytes, + ) + if not ret: + _LOGGER.debug("Response signature verification failed") + return ret + + +def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest: + # https://cryptography.io/en/latest/x509/ocsp/#creating-requests + builder = _OCSPRequestBuilder() + builder = builder.add_certificate(cert, issuer, _SHA1()) # noqa: S303 + return builder.build() + + +def _verify_response(issuer: Certificate, response: OCSPResponse) -> int: + _LOGGER.debug("Verifying response") + # RFC6960, Section 3.2, Number 2, 3 and 4 happen here. + res = _verify_response_signature(issuer, response) + if not res: + return 0 + + # Note that we are not using a "tolerance period" as discussed in + # https://tools.ietf.org/rfc/rfc5019.txt? + now = _datetime.now(tz=timezone.utc).replace(tzinfo=None) + # RFC6960, Section 3.2, Number 5 + if response.this_update > now: + _LOGGER.debug("thisUpdate is in the future") + return 0 + # RFC6960, Section 3.2, Number 6 + if response.next_update and response.next_update < now: + _LOGGER.debug("nextUpdate is in the past") + return 0 + return 1 + + +def _get_ocsp_response( + cert: Certificate, issuer: Certificate, uri: Union[str, bytes], ocsp_response_cache: _OCSPCache +) -> Optional[OCSPResponse]: + ocsp_request = _build_ocsp_request(cert, issuer) + try: + ocsp_response = ocsp_response_cache[ocsp_request] + _LOGGER.debug("Using cached OCSP response.") + except KeyError: + # CSOT: use the configured timeout or 5 seconds, whichever is smaller. + # Note that request's timeout works differently and does not imply an absolute + # deadline: https://requests.readthedocs.io/en/stable/user/quickstart/#timeouts + timeout = max(_csot.clamp_remaining(5), 0.001) + try: + response = _post( + uri, + data=ocsp_request.public_bytes(_Encoding.DER), + headers={"Content-Type": "application/ocsp-request"}, + timeout=timeout, + ) + except _RequestException as exc: + _LOGGER.debug("HTTP request failed: %s", exc) + return None + if response.status_code != 200: + _LOGGER.debug("HTTP request returned %d", response.status_code) + return None + ocsp_response = _load_der_ocsp_response(response.content) + _LOGGER.debug("OCSP response status: %r", ocsp_response.response_status) + if ocsp_response.response_status != _OCSPResponseStatus.SUCCESSFUL: + return None + # RFC6960, Section 3.2, Number 1. Only relevant if we need to + # talk to the responder directly. + # Accessing response.serial_number raises if response status is not + # SUCCESSFUL. + if ocsp_response.serial_number != ocsp_request.serial_number: + _LOGGER.debug("Response serial number does not match request") + return None + if not _verify_response(issuer, ocsp_response): + # The response failed verification. + return None + _LOGGER.debug("Caching OCSP response.") + ocsp_response_cache[ocsp_request] = ocsp_response + + return ocsp_response + + +def _ocsp_callback(conn: Connection, ocsp_bytes: bytes, user_data: Optional[_CallbackData]) -> bool: + """Callback for use with OpenSSL.SSL.Context.set_ocsp_client_callback.""" + # always pass in user_data but OpenSSL requires it be optional + assert user_data + pycert = conn.get_peer_certificate() + if pycert is None: + _LOGGER.debug("No peer cert?") + return False + cert = pycert.to_cryptography() + # Use the verified chain when available (pyopenssl>=20.0). + if hasattr(conn, "get_verified_chain"): + pychain = conn.get_verified_chain() + trusted_ca_certs = None + else: + pychain = conn.get_peer_cert_chain() + trusted_ca_certs = user_data.trusted_ca_certs + if not pychain: + _LOGGER.debug("No peer cert chain?") + return False + chain = [cer.to_cryptography() for cer in pychain] + issuer = _get_issuer_cert(cert, chain, trusted_ca_certs) + must_staple = False + # https://tools.ietf.org/html/rfc7633#section-4.2.3.1 + ext_tls = _get_extension(cert, _TLSFeature) + if ext_tls is not None: + for feature in ext_tls.value: + if feature == _TLSFeatureType.status_request: + _LOGGER.debug("Peer presented a must-staple cert") + must_staple = True + break + ocsp_response_cache = user_data.ocsp_response_cache + + # No stapled OCSP response + if ocsp_bytes == b"": + _LOGGER.debug("Peer did not staple an OCSP response") + if must_staple: + _LOGGER.debug("Must-staple cert with no stapled response, hard fail.") + return False + if not user_data.check_ocsp_endpoint: + _LOGGER.debug("OCSP endpoint checking is disabled, soft fail.") + # No stapled OCSP response, checking responder URI disabled, soft fail. + return True + # https://tools.ietf.org/html/rfc6960#section-3.1 + ext_aia = _get_extension(cert, _AuthorityInformationAccess) + if ext_aia is None: + _LOGGER.debug("No authority access information, soft fail") + # No stapled OCSP response, no responder URI, soft fail. + return True + uris = [ + desc.access_location.value + for desc in ext_aia.value + if desc.access_method == _AuthorityInformationAccessOID.OCSP + ] + if not uris: + _LOGGER.debug("No OCSP URI, soft fail") + # No responder URI, soft fail. + return True + if issuer is None: + _LOGGER.debug("No issuer cert?") + return False + _LOGGER.debug("Requesting OCSP data") + # When requesting data from an OCSP endpoint we only fail on + # successful, valid responses with a certificate status of REVOKED. + for uri in uris: + _LOGGER.debug("Trying %s", uri) + response = _get_ocsp_response(cert, issuer, uri, ocsp_response_cache) + if response is None: + # The endpoint didn't respond in time, or the response was + # unsuccessful or didn't match the request, or the response + # failed verification. + continue + _LOGGER.debug("OCSP cert status: %r", response.certificate_status) + if response.certificate_status == _OCSPCertStatus.GOOD: + return True + if response.certificate_status == _OCSPCertStatus.REVOKED: + return False + # Soft fail if we couldn't get a definitive status. + _LOGGER.debug("No definitive OCSP cert status, soft fail") + return True + + _LOGGER.debug("Peer stapled an OCSP response") + if issuer is None: + _LOGGER.debug("No issuer cert?") + return False + response = _load_der_ocsp_response(ocsp_bytes) + _LOGGER.debug("OCSP response status: %r", response.response_status) + # This happens in _request_ocsp when there is no stapled response so + # we know if we can compare serial numbers for the request and response. + if response.response_status != _OCSPResponseStatus.SUCCESSFUL: + return False + if not _verify_response(issuer, response): + return False + # Cache the verified, stapled response. + ocsp_response_cache[_build_ocsp_request(cert, issuer)] = response + _LOGGER.debug("OCSP cert status: %r", response.certificate_status) + if response.certificate_status == _OCSPCertStatus.REVOKED: + return False + return True diff --git a/venv/Lib/site-packages/pymongo/operations.py b/venv/Lib/site-packages/pymongo/operations.py new file mode 100644 index 00000000..4872afa9 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/operations.py @@ -0,0 +1,623 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Operation class definitions.""" +from __future__ import annotations + +import enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +from bson.raw_bson import RawBSONDocument +from pymongo import helpers +from pymongo.collation import validate_collation_or_none +from pymongo.common import validate_is_mapping, validate_list +from pymongo.helpers import _gen_index_name, _index_document, _index_list +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.write_concern import validate_boolean + +if TYPE_CHECKING: + from pymongo.bulk import _Bulk + +# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary +_IndexList = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] +_IndexKeyHint = Union[str, _IndexList] + + +class _Op(str, enum.Enum): + ABORT = "abortTransaction" + AGGREGATE = "aggregate" + COMMIT = "commitTransaction" + COUNT = "count" + CREATE = "create" + CREATE_INDEXES = "createIndexes" + CREATE_SEARCH_INDEXES = "createSearchIndexes" + DELETE = "delete" + DISTINCT = "distinct" + DROP = "drop" + DROP_DATABASE = "dropDatabase" + DROP_INDEXES = "dropIndexes" + DROP_SEARCH_INDEXES = "dropSearchIndexes" + END_SESSIONS = "endSessions" + FIND_AND_MODIFY = "findAndModify" + FIND = "find" + INSERT = "insert" + LIST_COLLECTIONS = "listCollections" + LIST_INDEXES = "listIndexes" + LIST_SEARCH_INDEX = "listSearchIndexes" + LIST_DATABASES = "listDatabases" + UPDATE = "update" + UPDATE_INDEX = "updateIndex" + UPDATE_SEARCH_INDEX = "updateSearchIndex" + RENAME = "rename" + GETMORE = "getMore" + KILL_CURSORS = "killCursors" + TEST = "testOperation" + + +class InsertOne(Generic[_DocumentType]): + """Represents an insert_one operation.""" + + __slots__ = ("_doc",) + + def __init__(self, document: _DocumentType) -> None: + """Create an InsertOne instance. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param document: The document to insert. If the document is missing an + _id field one will be added. + """ + self._doc = document + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_insert(self._doc) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f"InsertOne({self._doc!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return other._doc == self._doc + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteOne: + """Represents a delete_one operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteOne instance. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 1, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteMany: + """Represents a delete_many operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteMany instance. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 0, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class ReplaceOne(Generic[_DocumentType]): + """Represents a replace_one operation.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + replacement: Union[_DocumentType, RawBSONDocument], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a ReplaceOne instance. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to replace. + :param replacement: The new document. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the ``collation`` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._doc = replacement + self._upsert = upsert + self._collation = collation + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_replace( + self._filter, + self._doc, + self._upsert, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + other._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._hint, + ) + + +class _UpdateOp: + """Private base class for update operations.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + doc: Union[Mapping[str, Any], _Pipeline], + upsert: bool, + collation: Optional[_CollationIn], + array_filters: Optional[list[Mapping[str, Any]]], + hint: Optional[_IndexKeyHint], + ): + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if array_filters is not None: + validate_list("array_filters", array_filters) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) + else: + self._hint = hint + + self._filter = filter + self._doc = doc + self._upsert = upsert + self._collation = collation + self._array_filters = array_filters + + def __eq__(self, other: object) -> bool: + if isinstance(other, type(self)): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._array_filters, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + return NotImplemented + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + + +class UpdateOne(_UpdateOp): + """Represents an update_one operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Represents an update_one operation. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + False, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class UpdateMany(_UpdateOp): + """Represents an update_many operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create an UpdateMany instance. + + For use with :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _Bulk) -> None: + """Add this operation to the _Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + True, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class IndexModel: + """Represents an index to create.""" + + __slots__ = ("__document",) + + def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: + """Create an Index instance. + + For use with :meth:`~pymongo.collection.Collection.create_indexes`. + + Takes either a single key or a list containing (key, direction) pairs + or keys. If no direction is given, :data:`~pymongo.ASCENDING` will + be assumed. + The key(s) must be an instance of :class:`str`, and the direction(s) must + be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, + :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). + + Valid options include, but are not limited to: + + - `name`: custom name to use for this index - if none is + given, a name will be generated. + - `unique`: if ``True``, creates a uniqueness constraint on the index. + - `background`: if ``True``, this index should be created in the + background. + - `sparse`: if ``True``, omit from the index any documents that lack + the indexed field. + - `bucketSize`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index. + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index. + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + - `partialFilterExpression`: A document that specifies a filter for + a partial index. + - `collation`: An instance of :class:`~pymongo.collation.Collation` + that specifies the collation to use. + - `wildcardProjection`: Allows users to include or exclude specific + field paths from a `wildcard index`_ using the { "$**" : 1} key + pattern. Requires MongoDB >= 4.2. + - `hidden`: if ``True``, this index will be hidden from the query + planner and will not be evaluated as part of query plan + selection. Requires MongoDB >= 4.4. + + See the MongoDB documentation for a full list of supported options by + server version. + + :param keys: a single key or a list containing (key, direction) pairs + or keys specifying the index to create. + :param kwargs: any additional index creation + options (see the above list) should be passed as keyword + arguments. + + .. versionchanged:: 3.11 + Added the ``hidden`` option. + .. versionchanged:: 3.2 + Added the ``partialFilterExpression`` option to support partial + indexes. + + .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ + """ + keys = _index_list(keys) + if kwargs.get("name") is None: + kwargs["name"] = _gen_index_name(keys) + kwargs["key"] = _index_document(keys) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + self.__document = kwargs + if collation is not None: + self.__document["collation"] = collation + + @property + def document(self) -> dict[str, Any]: + """An index document suitable for passing to the createIndexes + command. + """ + return self.__document + + +class SearchIndexModel: + """Represents a search index to create.""" + + __slots__ = ("__document",) + + def __init__( + self, + definition: Mapping[str, Any], + name: Optional[str] = None, + type: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Create a Search Index instance. + + For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`. + + :param definition: The definition for this index. + :param name: The name for this index, if present. + :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". + :param kwargs: Keyword arguments supplying any additional options. + + .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. + .. versionadded:: 4.5 + .. versionchanged:: 4.7 + Added the type and kwargs arguments. + """ + self.__document: dict[str, Any] = {} + if name is not None: + self.__document["name"] = name + self.__document["definition"] = definition + if type is not None: + self.__document["type"] = type + self.__document.update(kwargs) + + @property + def document(self) -> Mapping[str, Any]: + """The document for this index.""" + return self.__document diff --git a/venv/Lib/site-packages/pymongo/periodic_executor.py b/venv/Lib/site-packages/pymongo/periodic_executor.py new file mode 100644 index 00000000..9e9ead61 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/periodic_executor.py @@ -0,0 +1,200 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Run a target function on a background thread.""" + +from __future__ import annotations + +import sys +import threading +import time +import weakref +from typing import Any, Callable, Optional + +from pymongo.lock import _create_lock + + +class PeriodicExecutor: + def __init__( + self, + interval: float, + min_interval: float, + target: Callable[[], bool], + name: Optional[str] = None, + ): + """ "Run a target function periodically on a background thread. + + If the target's return value is false, the executor stops. + + :param interval: Seconds between calls to `target`. + :param min_interval: Minimum seconds between calls if `wake` is + called very often. + :param target: A function. + :param name: A name to give the underlying thread. + """ + # threading.Event and its internal condition variable are expensive + # in Python 2, see PYTHON-983. Use a boolean to know when to wake. + # The executor's design is constrained by several Python issues, see + # "periodic_executor.rst" in this repository. + self._event = False + self._interval = interval + self._min_interval = min_interval + self._target = target + self._stopped = False + self._thread: Optional[threading.Thread] = None + self._name = name + self._skip_sleep = False + self._thread_will_exit = False + self._lock = _create_lock() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" + + def open(self) -> None: + """Start. Multiple calls have no effect. + + Not safe to call from multiple threads at once. + """ + with self._lock: + if self._thread_will_exit: + # If the background thread has read self._stopped as True + # there is a chance that it has not yet exited. The call to + # join should not block indefinitely because there is no + # other work done outside the while loop in self._run. + try: + assert self._thread is not None + self._thread.join() + except ReferenceError: + # Thread terminated. + pass + self._thread_will_exit = False + self._stopped = False + started: Any = False + try: + started = self._thread and self._thread.is_alive() + except ReferenceError: + # Thread terminated. + pass + + if not started: + thread = threading.Thread(target=self._run, name=self._name) + thread.daemon = True + self._thread = weakref.proxy(thread) + _register_executor(self) + # Mitigation to RuntimeError firing when thread starts on shutdown + # https://github.com/python/cpython/issues/114570 + try: + thread.start() + except RuntimeError as e: + if "interpreter shutdown" in str(e) or sys.is_finalizing(): + self._thread = None + return + raise + + def close(self, dummy: Any = None) -> None: + """Stop. To restart, call open(). + + The dummy parameter allows an executor's close method to be a weakref + callback; see monitor.py. + """ + self._stopped = True + + def join(self, timeout: Optional[int] = None) -> None: + if self._thread is not None: + try: + self._thread.join(timeout) + except (ReferenceError, RuntimeError): + # Thread already terminated, or not yet started. + pass + + def wake(self) -> None: + """Execute the target function soon.""" + self._event = True + + def update_interval(self, new_interval: int) -> None: + self._interval = new_interval + + def skip_sleep(self) -> None: + self._skip_sleep = True + + def __should_stop(self) -> bool: + with self._lock: + if self._stopped: + self._thread_will_exit = True + return True + return False + + def _run(self) -> None: + while not self.__should_stop(): + try: + if not self._target(): + self._stopped = True + break + except BaseException: + with self._lock: + self._stopped = True + self._thread_will_exit = True + + raise + + if self._skip_sleep: + self._skip_sleep = False + else: + deadline = time.monotonic() + self._interval + while not self._stopped and time.monotonic() < deadline: + time.sleep(self._min_interval) + if self._event: + break # Early wake. + + self._event = False + + +# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started, +# an executor is kept alive by a strong reference from its thread and perhaps +# from other objects. When the thread dies and all other referrers are freed, +# the executor is freed and removed from _EXECUTORS. If any threads are +# running when the interpreter begins to shut down, we try to halt and join +# them to avoid spurious errors. +_EXECUTORS = set() + + +def _register_executor(executor: PeriodicExecutor) -> None: + ref = weakref.ref(executor, _on_executor_deleted) + _EXECUTORS.add(ref) + + +def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None: + _EXECUTORS.remove(ref) + + +def _shutdown_executors() -> None: + if _EXECUTORS is None: + return + + # Copy the set. Stopping threads has the side effect of removing executors. + executors = list(_EXECUTORS) + + # First signal all executors to close... + for ref in executors: + executor = ref() + if executor: + executor.close() + + # ...then try to join them. + for ref in executors: + executor = ref() + if executor: + executor.join(1) + + executor = None diff --git a/venv/Lib/site-packages/pymongo/pool.py b/venv/Lib/site-packages/pymongo/pool.py new file mode 100644 index 00000000..6a8cb54b --- /dev/null +++ b/venv/Lib/site-packages/pymongo/pool.py @@ -0,0 +1,2105 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import annotations + +import collections +import contextlib +import copy +import logging +import os +import platform +import socket +import ssl +import sys +import threading +import time +import weakref +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Union, +) + +import bson +from bson import DEFAULT_CODEC_OPTIONS +from pymongo import __version__, _csot, auth, helpers +from pymongo.client_session import _validate_session_write_concern +from pymongo.common import ( + MAX_BSON_SIZE, + MAX_CONNECTING, + MAX_IDLE_TIME_SEC, + MAX_MESSAGE_SIZE, + MAX_POOL_SIZE, + MAX_WIRE_VERSION, + MAX_WRITE_BATCH_SIZE, + MIN_POOL_SIZE, + ORDERED_TYPES, + WAIT_QUEUE_TIMEOUT, +) +from pymongo.errors import ( # type:ignore[attr-defined] + AutoReconnect, + ConfigurationError, + ConnectionFailure, + DocumentTooLarge, + ExecutionTimeout, + InvalidOperation, + NetworkTimeout, + NotPrimaryError, + OperationFailure, + PyMongoError, + WaitQueueTimeoutError, + _CertificateError, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers import _handle_reauth +from pymongo.lock import _create_lock +from pymongo.logger import ( + _CONNECTION_LOGGER, + _ConnectionStatusMessage, + _debug_log, + _verbose_connection_error_reason, +) +from pymongo.monitoring import ( + ConnectionCheckOutFailedReason, + ConnectionClosedReason, + _EventListeners, +) +from pymongo.network import command, receive_message +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import _add_to_command +from pymongo.server_type import SERVER_TYPE +from pymongo.socket_checker import SocketChecker +from pymongo.ssl_support import HAS_SNI, SSLError + +if TYPE_CHECKING: + from bson import CodecOptions + from bson.objectid import ObjectId + from pymongo.auth import MongoCredential, _AuthContext + from pymongo.client_session import ClientSession + from pymongo.compression_support import ( + CompressionSettings, + SnappyContext, + ZlibContext, + ZstdContext, + ) + from pymongo.driver_info import DriverInfo + from pymongo.message import _OpMsg, _OpReply + from pymongo.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.pyopenssl_context import SSLContext, _sslConn + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.server_api import ServerApi + from pymongo.typings import ClusterTime, _Address, _CollationIn + from pymongo.write_concern import WriteConcern + +try: + from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl + + def _set_non_inheritable_non_atomic(fd: int) -> None: + """Set the close-on-exec flag on the given file descriptor.""" + flags = fcntl(fd, F_GETFD) + fcntl(fd, F_SETFD, flags | FD_CLOEXEC) + +except ImportError: + # Windows, various platforms we don't claim to support + # (Jython, IronPython, ..), systems that don't provide + # everything we need from fcntl, etc. + def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 + """Dummy function for platforms that don't provide fcntl.""" + + +_MAX_TCP_KEEPIDLE = 120 +_MAX_TCP_KEEPINTVL = 10 +_MAX_TCP_KEEPCNT = 9 + +if sys.platform == "win32": + try: + import _winreg as winreg + except ImportError: + import winreg + + def _query(key, name, default): + try: + value, _ = winreg.QueryValueEx(key, name) + # Ensure the value is a number or raise ValueError. + return int(value) + except (OSError, ValueError): + # QueryValueEx raises OSError when the key does not exist (i.e. + # the system is using the Windows default value). + return default + + try: + with winreg.OpenKey( + winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" + ) as key: + _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) + _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) + except OSError: + # We could not check the default values because winreg.OpenKey failed. + # Assume the system is using the default values. + _WINDOWS_TCP_IDLE_MS = 7200000 + _WINDOWS_TCP_INTERVAL_MS = 1000 + + def _set_keepalive_times(sock): + idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) + interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) + if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: + sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) + +else: + + def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: + if hasattr(socket, tcp_option): + sockopt = getattr(socket, tcp_option) + try: + # PYTHON-1350 - NetBSD doesn't implement getsockopt for + # TCP_KEEPIDLE and friends. Don't attempt to set the + # values there. + default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) + if default > max_value: + sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) + except OSError: + pass + + def _set_keepalive_times(sock: socket.socket) -> None: + _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) + _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) + _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) + + +_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} + +if sys.platform.startswith("linux"): + # platform.linux_distribution was deprecated in Python 3.5 + # and removed in Python 3.8. Starting in Python 3.5 it + # raises DeprecationWarning + # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 + _name = platform.system() + _METADATA["os"] = { + "type": _name, + "name": _name, + "architecture": platform.machine(), + # Kernel version (e.g. 4.4.0-17-generic). + "version": platform.release(), + } +elif sys.platform == "darwin": + _METADATA["os"] = { + "type": platform.system(), + "name": platform.system(), + "architecture": platform.machine(), + # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin + # kernel version. + "version": platform.mac_ver()[0], + } +elif sys.platform == "win32": + _METADATA["os"] = { + "type": platform.system(), + # "Windows XP", "Windows 7", "Windows 10", etc. + "name": " ".join((platform.system(), platform.release())), + "architecture": platform.machine(), + # Windows patch level (e.g. 5.1.2600-SP3) + "version": "-".join(platform.win32_ver()[1:3]), + } +elif sys.platform.startswith("java"): + _name, _ver, _arch = platform.java_ver()[-1] + _METADATA["os"] = { + # Linux, Windows 7, Mac OS X, etc. + "type": _name, + "name": _name, + # x86, x86_64, AMD64, etc. + "architecture": _arch, + # Linux kernel version, OSX version, etc. + "version": _ver, + } +else: + # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) + _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) + _METADATA["os"] = { + "type": platform.system(), + "name": " ".join([part for part in _aliased[:2] if part]), + "architecture": platform.machine(), + "version": _aliased[2], + } + +if platform.python_implementation().startswith("PyPy"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.pypy_version_info)), # type: ignore + "(Python %s)" % ".".join(map(str, sys.version_info)), + ) + ) +elif sys.platform.startswith("java"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.version_info)), + "(%s)" % " ".join((platform.system(), platform.release())), + ) + ) +else: + _METADATA["platform"] = " ".join( + (platform.python_implementation(), ".".join(map(str, sys.version_info))) + ) + +DOCKER_ENV_PATH = "/.dockerenv" +ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" + +RUNTIME_NAME_DOCKER = "docker" +ORCHESTRATOR_NAME_K8S = "kubernetes" + + +def get_container_env_info() -> dict[str, str]: + """Returns the runtime and orchestrator of a container. + If neither value is present, the metadata client.env.container field will be omitted.""" + container = {} + + if Path(DOCKER_ENV_PATH).exists(): + container["runtime"] = RUNTIME_NAME_DOCKER + if os.getenv(ENV_VAR_K8S): + container["orchestrator"] = ORCHESTRATOR_NAME_K8S + + return container + + +def _is_lambda() -> bool: + if os.getenv("AWS_LAMBDA_RUNTIME_API"): + return True + env = os.getenv("AWS_EXECUTION_ENV") + if env: + return env.startswith("AWS_Lambda_") + return False + + +def _is_azure_func() -> bool: + return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) + + +def _is_gcp_func() -> bool: + return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) + + +def _is_vercel() -> bool: + return bool(os.getenv("VERCEL")) + + +def _is_faas() -> bool: + return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() + + +def _getenv_int(key: str) -> Optional[int]: + """Like os.getenv but returns an int, or None if the value is missing/malformed.""" + val = os.getenv(key) + if not val: + return None + try: + return int(val) + except ValueError: + return None + + +def _metadata_env() -> dict[str, Any]: + env: dict[str, Any] = {} + container = get_container_env_info() + if container: + env["container"] = container + # Skip if multiple (or no) envs are matched. + if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: + return env + if _is_lambda(): + env["name"] = "aws.lambda" + region = os.getenv("AWS_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") + if memory_mb is not None: + env["memory_mb"] = memory_mb + elif _is_azure_func(): + env["name"] = "azure.func" + elif _is_gcp_func(): + env["name"] = "gcp.func" + region = os.getenv("FUNCTION_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("FUNCTION_MEMORY_MB") + if memory_mb is not None: + env["memory_mb"] = memory_mb + timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") + if timeout_sec is not None: + env["timeout_sec"] = timeout_sec + elif _is_vercel(): + env["name"] = "vercel" + region = os.getenv("VERCEL_REGION") + if region: + env["region"] = region + return env + + +_MAX_METADATA_SIZE = 512 + + +# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations +def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: + """Perform metadata truncation.""" + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 1. Omit fields from env except env.name. + env_name = metadata.get("env", {}).get("name") + if env_name: + metadata["env"] = {"name": env_name} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 2. Omit fields from os except os.type. + os_type = metadata.get("os", {}).get("type") + if os_type: + metadata["os"] = {"type": os_type} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 3. Omit the env document entirely. + metadata.pop("env", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 4. Truncate platform. + overflow = encoded_size - _MAX_METADATA_SIZE + plat = metadata.get("platform", "") + if plat: + plat = plat[:-overflow] + if plat: + metadata["platform"] = plat + else: + metadata.pop("platform", None) + + +# If the first getaddrinfo call of this interpreter's life is on a thread, +# while the main thread holds the import lock, getaddrinfo deadlocks trying +# to import the IDNA codec. Import it here, where presumably we're on the +# main thread, to avoid the deadlock. See PYTHON-607. +"foo".encode("idna") + + +def _raise_connection_failure( + address: Any, + error: Exception, + msg_prefix: Optional[str] = None, + timeout_details: Optional[dict[str, float]] = None, +) -> NoReturn: + """Convert a socket.error to ConnectionFailure and raise it.""" + host, port = address + # If connecting to a Unix socket, port will be None. + if port is not None: + msg = "%s:%d: %s" % (host, port, error) + else: + msg = f"{host}: {error}" + if msg_prefix: + msg = msg_prefix + msg + if "configured timeouts" not in msg: + msg += format_timeout_details(timeout_details) + if isinstance(error, socket.timeout): + raise NetworkTimeout(msg) from error + elif isinstance(error, SSLError) and "timed out" in str(error): + # Eventlet does not distinguish TLS network timeouts from other + # SSLErrors (https://github.com/eventlet/eventlet/issues/692). + # Luckily, we can work around this limitation because the phrase + # 'timed out' appears in all the timeout related SSLErrors raised. + raise NetworkTimeout(msg) from error + else: + raise AutoReconnect(msg) from error + + +def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: + timeout = deadline - time.monotonic() if deadline else None + return condition.wait(timeout) + + +def _get_timeout_details(options: PoolOptions) -> dict[str, float]: + details = {} + timeout = _csot.get_timeout() + socket_timeout = options.socket_timeout + connect_timeout = options.connect_timeout + if timeout: + details["timeoutMS"] = timeout * 1000 + if socket_timeout and not timeout: + details["socketTimeoutMS"] = socket_timeout * 1000 + if connect_timeout: + details["connectTimeoutMS"] = connect_timeout * 1000 + return details + + +def format_timeout_details(details: Optional[dict[str, float]]) -> str: + result = "" + if details: + result += " (configured timeouts:" + for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]: + if timeout in details: + result += f" {timeout}: {details[timeout]}ms," + result = result[:-1] + result += ")" + return result + + +class PoolOptions: + """Read only connection pool options for a MongoClient. + + Should not be instantiated directly by application developers. Access + a client's pool options via + :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: + + pool_opts = client.options.pool_options + pool_opts.max_pool_size + pool_opts.min_pool_size + + """ + + __slots__ = ( + "__max_pool_size", + "__min_pool_size", + "__max_idle_time_seconds", + "__connect_timeout", + "__socket_timeout", + "__wait_queue_timeout", + "__ssl_context", + "__tls_allow_invalid_hostnames", + "__event_listeners", + "__appname", + "__driver", + "__metadata", + "__compression_settings", + "__max_connecting", + "__pause_enabled", + "__server_api", + "__load_balanced", + "__credentials", + ) + + def __init__( + self, + max_pool_size: int = MAX_POOL_SIZE, + min_pool_size: int = MIN_POOL_SIZE, + max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, + connect_timeout: Optional[float] = None, + socket_timeout: Optional[float] = None, + wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, + ssl_context: Optional[SSLContext] = None, + tls_allow_invalid_hostnames: bool = False, + event_listeners: Optional[_EventListeners] = None, + appname: Optional[str] = None, + driver: Optional[DriverInfo] = None, + compression_settings: Optional[CompressionSettings] = None, + max_connecting: int = MAX_CONNECTING, + pause_enabled: bool = True, + server_api: Optional[ServerApi] = None, + load_balanced: Optional[bool] = None, + credentials: Optional[MongoCredential] = None, + ): + self.__max_pool_size = max_pool_size + self.__min_pool_size = min_pool_size + self.__max_idle_time_seconds = max_idle_time_seconds + self.__connect_timeout = connect_timeout + self.__socket_timeout = socket_timeout + self.__wait_queue_timeout = wait_queue_timeout + self.__ssl_context = ssl_context + self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames + self.__event_listeners = event_listeners + self.__appname = appname + self.__driver = driver + self.__compression_settings = compression_settings + self.__max_connecting = max_connecting + self.__pause_enabled = pause_enabled + self.__server_api = server_api + self.__load_balanced = load_balanced + self.__credentials = credentials + self.__metadata = copy.deepcopy(_METADATA) + if appname: + self.__metadata["application"] = {"name": appname} + + # Combine the "driver" MongoClient option with PyMongo's info, like: + # { + # 'driver': { + # 'name': 'PyMongo|MyDriver', + # 'version': '4.2.0|1.2.3', + # }, + # 'platform': 'CPython 3.7.0|MyPlatform' + # } + if driver: + if driver.name: + self.__metadata["driver"]["name"] = "{}|{}".format( + _METADATA["driver"]["name"], + driver.name, + ) + if driver.version: + self.__metadata["driver"]["version"] = "{}|{}".format( + _METADATA["driver"]["version"], + driver.version, + ) + if driver.platform: + self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) + + env = _metadata_env() + if env: + self.__metadata["env"] = env + + _truncate_metadata(self.__metadata) + + @property + def _credentials(self) -> Optional[MongoCredential]: + """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" + return self.__credentials + + @property + def non_default_options(self) -> dict[str, Any]: + """The non-default options this pool was created with. + + Added for CMAP's :class:`PoolCreatedEvent`. + """ + opts = {} + if self.__max_pool_size != MAX_POOL_SIZE: + opts["maxPoolSize"] = self.__max_pool_size + if self.__min_pool_size != MIN_POOL_SIZE: + opts["minPoolSize"] = self.__min_pool_size + if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: + assert self.__max_idle_time_seconds is not None + opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 + if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: + assert self.__wait_queue_timeout is not None + opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 + if self.__max_connecting != MAX_CONNECTING: + opts["maxConnecting"] = self.__max_connecting + return opts + + @property + def max_pool_size(self) -> float: + """The maximum allowable number of concurrent connections to each + connected server. Requests to a server will block if there are + `maxPoolSize` outstanding connections to the requested server. + Defaults to 100. Cannot be 0. + + When a server's pool has reached `max_pool_size`, operations for that + server block waiting for a socket to be returned to the pool. If + ``waitQueueTimeoutMS`` is set, a blocked operation will raise + :exc:`~pymongo.errors.ConnectionFailure` after a timeout. + By default ``waitQueueTimeoutMS`` is not set. + """ + return self.__max_pool_size + + @property + def min_pool_size(self) -> int: + """The minimum required number of concurrent connections that the pool + will maintain to each connected server. Default is 0. + """ + return self.__min_pool_size + + @property + def max_connecting(self) -> int: + """The maximum number of concurrent connection creation attempts per + pool. Defaults to 2. + """ + return self.__max_connecting + + @property + def pause_enabled(self) -> bool: + return self.__pause_enabled + + @property + def max_idle_time_seconds(self) -> Optional[int]: + """The maximum number of seconds that a connection can remain + idle in the pool before being removed and replaced. Defaults to + `None` (no limit). + """ + return self.__max_idle_time_seconds + + @property + def connect_timeout(self) -> Optional[float]: + """How long a connection can take to be opened before timing out.""" + return self.__connect_timeout + + @property + def socket_timeout(self) -> Optional[float]: + """How long a send or receive on a socket can take before timing out.""" + return self.__socket_timeout + + @property + def wait_queue_timeout(self) -> Optional[int]: + """How long a thread will wait for a socket from the pool if the pool + has no free sockets. + """ + return self.__wait_queue_timeout + + @property + def _ssl_context(self) -> Optional[SSLContext]: + """An SSLContext instance or None.""" + return self.__ssl_context + + @property + def tls_allow_invalid_hostnames(self) -> bool: + """If True skip ssl.match_hostname.""" + return self.__tls_allow_invalid_hostnames + + @property + def _event_listeners(self) -> Optional[_EventListeners]: + """An instance of pymongo.monitoring._EventListeners.""" + return self.__event_listeners + + @property + def appname(self) -> Optional[str]: + """The application name, for sending with hello in server handshake.""" + return self.__appname + + @property + def driver(self) -> Optional[DriverInfo]: + """Driver name and version, for sending with hello in handshake.""" + return self.__driver + + @property + def _compression_settings(self) -> Optional[CompressionSettings]: + return self.__compression_settings + + @property + def metadata(self) -> dict[str, Any]: + """A dict of metadata about the application, driver, os, and platform.""" + return self.__metadata.copy() + + @property + def server_api(self) -> Optional[ServerApi]: + """A pymongo.server_api.ServerApi or None.""" + return self.__server_api + + @property + def load_balanced(self) -> Optional[bool]: + """True if this Pool is configured in load balanced mode.""" + return self.__load_balanced + + +class _CancellationContext: + def __init__(self) -> None: + self._cancelled = False + + def cancel(self) -> None: + """Cancel this context.""" + self._cancelled = True + + @property + def cancelled(self) -> bool: + """Was cancel called?""" + return self._cancelled + + +class Connection: + """Store a connection with some metadata. + + :param conn: a raw connection object + :param pool: a Pool instance + :param address: the server's (host, port) + :param id: the id of this socket in it's pool + """ + + def __init__( + self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int + ): + self.pool_ref = weakref.ref(pool) + self.conn = conn + self.address = address + self.id = id + self.closed = False + self.last_checkin_time = time.monotonic() + self.performed_handshake = False + self.is_writable: bool = False + self.max_wire_version = MAX_WIRE_VERSION + self.max_bson_size = MAX_BSON_SIZE + self.max_message_size = MAX_MESSAGE_SIZE + self.max_write_batch_size = MAX_WRITE_BATCH_SIZE + self.supports_sessions = False + self.hello_ok: bool = False + self.is_mongos = False + self.op_msg_enabled = False + self.listeners = pool.opts._event_listeners + self.enabled_for_cmap = pool.enabled_for_cmap + self.compression_settings = pool.opts._compression_settings + self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None + self.socket_checker: SocketChecker = SocketChecker() + self.oidc_token_gen_id: Optional[int] = None + # Support for mechanism negotiation on the initial handshake. + self.negotiated_mechs: Optional[list[str]] = None + self.auth_ctx: Optional[_AuthContext] = None + + # The pool's generation changes with each reset() so we can close + # sockets created before the last reset. + self.pool_gen = pool.gen + self.generation = self.pool_gen.get_overall() + self.ready = False + self.cancel_context: _CancellationContext = _CancellationContext() + self.opts = pool.opts + self.more_to_come: bool = False + # For load balancer support. + self.service_id: Optional[ObjectId] = None + self.server_connection_id: Optional[int] = None + # When executing a transaction in load balancing mode, this flag is + # set to true to indicate that the session now owns the connection. + self.pinned_txn = False + self.pinned_cursor = False + self.active = False + self.last_timeout = self.opts.socket_timeout + self.connect_rtt = 0.0 + self._client_id = pool._client_id + self.creation_time = time.monotonic() + + def set_conn_timeout(self, timeout: Optional[float]) -> None: + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" + if timeout == self.last_timeout: + return + self.last_timeout = timeout + self.conn.settimeout(timeout) + + def apply_timeout( + self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + # Reset the socket timeout unless we're performing a streaming monitor check. + if not self.more_to_come: + self.set_conn_timeout(self.opts.socket_timeout) + return None + # RTT validation. + rtt = _csot.get_rtt() + if rtt is None: + rtt = self.connect_rtt + max_time_ms = timeout - rtt + if max_time_ms < 0: + timeout_details = _get_timeout_details(self.opts) + formatted = format_timeout_details(timeout_details) + # CSOT: raise an error without running the command since we know it will time out. + errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" + raise ExecutionTimeout( + errmsg, + 50, + {"ok": 0, "errmsg": errmsg, "code": 50}, + self.max_wire_version, + ) + if cmd is not None: + cmd["maxTimeMS"] = int(max_time_ms * 1000) + self.set_conn_timeout(timeout) + return timeout + + def pin_txn(self) -> None: + self.pinned_txn = True + assert not self.pinned_cursor + + def pin_cursor(self) -> None: + self.pinned_cursor = True + assert not self.pinned_txn + + def unpin(self) -> None: + pool = self.pool_ref() + if pool: + pool.checkin(self) + else: + self.close_conn(ConnectionClosedReason.STALE) + + def hello_cmd(self) -> dict[str, Any]: + # Handshake spec requires us to use OP_MSG+hello command for the + # initial handshake in load balanced or stable API mode. + if self.opts.server_api or self.hello_ok or self.opts.load_balanced: + self.op_msg_enabled = True + return {HelloCompat.CMD: 1} + else: + return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} + + def hello(self) -> Hello[dict[str, Any]]: + return self._hello(None, None, None) + + def _hello( + self, + cluster_time: Optional[ClusterTime], + topology_version: Optional[Any], + heartbeat_frequency: Optional[int], + ) -> Hello[dict[str, Any]]: + cmd = self.hello_cmd() + performing_handshake = not self.performed_handshake + awaitable = False + if performing_handshake: + self.performed_handshake = True + cmd["client"] = self.opts.metadata + if self.compression_settings: + cmd["compression"] = self.compression_settings.compressors + if self.opts.load_balanced: + cmd["loadBalanced"] = True + elif topology_version is not None: + cmd["topologyVersion"] = topology_version + assert heartbeat_frequency is not None + cmd["maxAwaitTimeMS"] = int(heartbeat_frequency * 1000) + awaitable = True + # If connect_timeout is None there is no timeout. + if self.opts.connect_timeout: + self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) + + if not performing_handshake and cluster_time is not None: + cmd["$clusterTime"] = cluster_time + + creds = self.opts._credentials + if creds: + if creds.mechanism == "DEFAULT" and creds.username: + cmd["saslSupportedMechs"] = creds.source + "." + creds.username + auth_ctx = auth._AuthContext.from_credentials(creds, self.address) + if auth_ctx: + speculative_authenticate = auth_ctx.speculate_command() + if speculative_authenticate is not None: + cmd["speculativeAuthenticate"] = speculative_authenticate + else: + auth_ctx = None + + if performing_handshake: + start = time.monotonic() + doc = self.command("admin", cmd, publish_events=False, exhaust_allowed=awaitable) + if performing_handshake: + self.connect_rtt = time.monotonic() - start + hello = Hello(doc, awaitable=awaitable) + self.is_writable = hello.is_writable + self.max_wire_version = hello.max_wire_version + self.max_bson_size = hello.max_bson_size + self.max_message_size = hello.max_message_size + self.max_write_batch_size = hello.max_write_batch_size + self.supports_sessions = ( + hello.logical_session_timeout_minutes is not None and hello.is_readable + ) + self.logical_session_timeout_minutes: Optional[int] = hello.logical_session_timeout_minutes + self.hello_ok = hello.hello_ok + self.is_repl = hello.server_type in ( + SERVER_TYPE.RSPrimary, + SERVER_TYPE.RSSecondary, + SERVER_TYPE.RSArbiter, + SERVER_TYPE.RSOther, + SERVER_TYPE.RSGhost, + ) + self.is_standalone = hello.server_type == SERVER_TYPE.Standalone + self.is_mongos = hello.server_type == SERVER_TYPE.Mongos + if performing_handshake and self.compression_settings: + ctx = self.compression_settings.get_compression_context(hello.compressors) + self.compression_context = ctx + + self.op_msg_enabled = True + self.server_connection_id = hello.connection_id + if creds: + self.negotiated_mechs = hello.sasl_supported_mechs + if auth_ctx: + auth_ctx.parse_response(hello) # type:ignore[arg-type] + if auth_ctx.speculate_succeeded(): + self.auth_ctx = auth_ctx + if self.opts.load_balanced: + if not hello.service_id: + raise ConfigurationError( + "Driver attempted to initialize in load balancing mode," + " but the server does not support this mode" + ) + self.service_id = hello.service_id + self.generation = self.pool_gen.get(self.service_id) + return hello + + def _next_reply(self) -> dict[str, Any]: + reply = self.receive_message(None) + self.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response() + response_doc = unpacked_docs[0] + helpers._check_command_response(response_doc, self.max_wire_version) + return response_doc + + @_handle_reauth + def command( + self, + dbname: str, + spec: MutableMapping[str, Any], + read_preference: _ServerMode = ReadPreference.PRIMARY, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + session: Optional[ClientSession] = None, + client: Optional[MongoClient] = None, + retryable_write: bool = False, + publish_events: bool = True, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + ) -> dict[str, Any]: + """Execute a command or raise an error. + + :param dbname: name of the database on which to run the command + :param spec: a command document as a dict, SON, or mapping object + :param read_preference: a read preference + :param codec_options: a CodecOptions instance + :param check: raise OperationFailure if there are errors + :param allowable_errors: errors to ignore if `check` is True + :param read_concern: The read concern for this command. + :param write_concern: The write concern for this command. + :param parse_write_concern_error: Whether to parse the + ``writeConcernError`` field in the command response. + :param collation: The collation for this command. + :param session: optional ClientSession instance. + :param client: optional MongoClient for gossipping $clusterTime. + :param retryable_write: True if this command is a retryable write. + :param publish_events: Should we publish events for this command? + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + self.validate_session(client, session) + session = _validate_session_write_concern(session, write_concern) + + # Ensure command name remains in first place. + if not isinstance(spec, ORDERED_TYPES): # type:ignore[arg-type] + spec = dict(spec) + + if not (write_concern is None or write_concern.acknowledged or collation is None): + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") + + self.add_server_api(spec) + if session: + session._apply_to(spec, retryable_write, read_preference, self) + self.send_cluster_time(spec, session, client) + listeners = self.listeners if publish_events else None + unacknowledged = bool(write_concern and not write_concern.acknowledged) + if self.op_msg_enabled: + self._raise_if_not_writable(unacknowledged) + try: + return command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ) + except (OperationFailure, NotPrimaryError): + raise + # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + except BaseException as error: + self._raise_connection_failure(error) + + def send_message(self, message: bytes, max_doc_size: int) -> None: + """Send a raw BSON message or raise ConnectionFailure. + + If a network exception is raised, the socket is closed. + """ + if self.max_bson_size is not None and max_doc_size > self.max_bson_size: + raise DocumentTooLarge( + "BSON document too large (%d bytes) - the connected server " + "supports BSON document sizes up to %d bytes." % (max_doc_size, self.max_bson_size) + ) + + try: + self.conn.sendall(message) + except BaseException as error: + self._raise_connection_failure(error) + + def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise ConnectionFailure. + + If any exception is raised, the socket is closed. + """ + try: + return receive_message(self, request_id, self.max_message_size) + except BaseException as error: + self._raise_connection_failure(error) + + def _raise_if_not_writable(self, unacknowledged: bool) -> None: + """Raise NotPrimaryError on unacknowledged write if this socket is not + writable. + """ + if unacknowledged and not self.is_writable: + # Write won't succeed, bail as if we'd received a not primary error. + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) + + def unack_write(self, msg: bytes, max_doc_size: int) -> None: + """Send unack OP_MSG. + + Can raise ConnectionFailure or InvalidDocument. + + :param msg: bytes, an OP_MSG message. + :param max_doc_size: size in bytes of the largest document in `msg`. + """ + self._raise_if_not_writable(True) + self.send_message(msg, max_doc_size) + + def write_command( + self, request_id: int, msg: bytes, codec_options: CodecOptions + ) -> dict[str, Any]: + """Send "insert" etc. command, returning response as a dict. + + Can raise ConnectionFailure or OperationFailure. + + :param request_id: an int. + :param msg: bytes, the command message. + """ + self.send_message(msg, 0) + reply = self.receive_message(request_id) + result = reply.command_response(codec_options) + + # Raises NotPrimaryError or OperationFailure. + helpers._check_command_response(result, self.max_wire_version) + return result + + def authenticate(self, reauthenticate: bool = False) -> None: + """Authenticate to the server if needed. + + Can raise ConnectionFailure or OperationFailure. + """ + # CMAP spec says to publish the ready event only after authenticating + # the connection. + if reauthenticate: + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + self.ready = False + if not self.ready: + creds = self.opts._credentials + if creds: + auth.authenticate(creds, self, reauthenticate=reauthenticate) + self.ready = True + if self.enabled_for_cmap: + assert self.listeners is not None + duration = time.monotonic() - self.creation_time + self.listeners.publish_connection_ready(self.address, self.id, duration) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CONN_READY, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=self.id, + durationMS=duration, + ) + + def validate_session( + self, client: Optional[MongoClient], session: Optional[ClientSession] + ) -> None: + """Validate this session before use with client. + + Raises error if the client is not the one that created the session. + """ + if session: + if session._client is not client: + raise InvalidOperation("Can only use session with the MongoClient that started it") + + def close_conn(self, reason: Optional[str]) -> None: + """Close this connection with a reason.""" + if self.closed: + return + self._close_conn() + if reason and self.enabled_for_cmap: + assert self.listeners is not None + self.listeners.publish_connection_closed(self.address, self.id, reason) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CONN_CLOSED, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=self.id, + reason=_verbose_connection_error_reason(reason), + error=reason, + ) + + def _close_conn(self) -> None: + """Close this connection.""" + if self.closed: + return + self.closed = True + self.cancel_context.cancel() + # Note: We catch exceptions to avoid spurious errors on interpreter + # shutdown. + try: + self.conn.close() + except Exception: # noqa: S110 + pass + + def conn_closed(self) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + return self.socket_checker.socket_closed(self.conn) + + def send_cluster_time( + self, + command: MutableMapping[str, Any], + session: Optional[ClientSession], + client: Optional[MongoClient], + ) -> None: + """Add $clusterTime.""" + if client: + client._send_cluster_time(command, session) + + def add_server_api(self, command: MutableMapping[str, Any]) -> None: + """Add server_api parameters.""" + if self.opts.server_api: + _add_to_command(command, self.opts.server_api) + + def update_last_checkin_time(self) -> None: + self.last_checkin_time = time.monotonic() + + def update_is_writable(self, is_writable: bool) -> None: + self.is_writable = is_writable + + def idle_time_seconds(self) -> float: + """Seconds since this socket was last checked into its pool.""" + return time.monotonic() - self.last_checkin_time + + def _raise_connection_failure(self, error: BaseException) -> NoReturn: + # Catch *all* exceptions from socket methods and close the socket. In + # regular Python, socket operations only raise socket.error, even if + # the underlying cause was a Ctrl-C: a signal raised during socket.recv + # is expressed as an EINTR error from poll. See internal_select_ex() in + # socketmodule.c. All error codes from poll become socket.error at + # first. Eventually in PyEval_EvalFrameEx the interpreter checks for + # signals and throws KeyboardInterrupt into the current frame on the + # main thread. + # + # But in Gevent and Eventlet, the polling mechanism (epoll, kqueue, + # ..) is called in Python code, which experiences the signal as a + # KeyboardInterrupt from the start, rather than as an initial + # socket.error, so we catch that, close the socket, and reraise it. + # + # The connection closed event will be emitted later in checkin. + if self.ready: + reason = None + else: + reason = ConnectionClosedReason.ERROR + self.close_conn(reason) + # SSLError from PyOpenSSL inherits directly from Exception. + if isinstance(error, (IOError, OSError, SSLError)): + details = _get_timeout_details(self.opts) + _raise_connection_failure(self.address, error, timeout_details=details) + else: + raise + + def __eq__(self, other: Any) -> bool: + return self.conn == other.conn + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + return hash(self.conn) + + def __repr__(self) -> str: + return "Connection({}){} at {}".format( + repr(self.conn), + self.closed and " CLOSED" or "", + id(self), + ) + + +def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: + """Given (host, port) and PoolOptions, connect and return a socket object. + + Can raise socket.error. + + This is a modified version of create_connection from CPython >= 2.7. + """ + host, port = address + + # Check if dealing with a unix domain socket + if host.endswith(".sock"): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported on this system") + sock = socket.socket(socket.AF_UNIX) + # SOCK_CLOEXEC not supported for Unix sockets. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.connect(host) + return sock + except OSError: + sock.close() + raise + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != "localhost": + family = socket.AF_UNSPEC + + err = None + for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + af, socktype, proto, dummy, sa = res + # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited + # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 + # all file descriptors are created non-inheritable. See PEP 446. + try: + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) + except OSError: + # Can SOCK_CLOEXEC be defined even if the kernel doesn't support + # it? + sock = socket.socket(af, socktype, proto) + # Fallback when SOCK_CLOEXEC isn't available. + _set_non_inheritable_non_atomic(sock.fileno()) + try: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # CSOT: apply timeout to socket connect. + timeout = _csot.remaining() + if timeout is None: + timeout = options.connect_timeout + elif timeout <= 0: + raise socket.timeout("timed out") + sock.settimeout(timeout) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + _set_keepalive_times(sock) + sock.connect(sa) + return sock + except OSError as e: + err = e + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise OSError("getaddrinfo failed") + + +def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if HAS_SNI: + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) + else: + ssl_sock = ssl_context.wrap_socket(sock) + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +class _PoolClosedError(PyMongoError): + """Internal error raised when a thread tries to get a connection from a + closed pool. + """ + + +class _PoolGeneration: + def __init__(self) -> None: + # Maps service_id to generation. + self._generations: dict[ObjectId, int] = collections.defaultdict(int) + # Overall pool generation. + self._generation = 0 + + def get(self, service_id: Optional[ObjectId]) -> int: + """Get the generation for the given service_id.""" + if service_id is None: + return self._generation + return self._generations[service_id] + + def get_overall(self) -> int: + """Get the Pool's overall generation.""" + return self._generation + + def inc(self, service_id: Optional[ObjectId]) -> None: + """Increment the generation for the given service_id.""" + self._generation += 1 + if service_id is None: + for service_id in self._generations: + self._generations[service_id] += 1 + else: + self._generations[service_id] += 1 + + def stale(self, gen: int, service_id: Optional[ObjectId]) -> bool: + """Return if the given generation for a given service_id is stale.""" + return gen != self.get(service_id) + + +class PoolState: + PAUSED = 1 + READY = 2 + CLOSED = 3 + + +# Do *not* explicitly inherit from object or Jython won't call __del__ +# http://bugs.jython.org/issue1057 +class Pool: + def __init__( + self, + address: _Address, + options: PoolOptions, + handshake: bool = True, + client_id: Optional[ObjectId] = None, + ): + """ + :param address: a (hostname, port) tuple + :param options: a PoolOptions instance + :param handshake: whether to call hello for each new Connection + """ + if options.pause_enabled: + self.state = PoolState.PAUSED + else: + self.state = PoolState.READY + # Check a socket's health with socket_closed() every once in a while. + # Can override for testing: 0 to always check, None to never check. + self._check_interval_seconds = 1 + # LIFO pool. Sockets are ordered on idle time. Sockets claimed + # and returned to pool from the left side. Stale sockets removed + # from the right side. + self.conns: collections.deque = collections.deque() + self.active_contexts: set[_CancellationContext] = set() + self.lock = _create_lock() + self.active_sockets = 0 + # Monotonically increasing connection ID required for CMAP Events. + self.next_connection_id = 1 + # Track whether the sockets in this pool are writeable or not. + self.is_writable: Optional[bool] = None + + # Keep track of resets, so we notice sockets created before the most + # recent reset and close them. + # self.generation = 0 + self.gen = _PoolGeneration() + self.pid = os.getpid() + self.address = address + self.opts = options + self.handshake = handshake + # Don't publish events in Monitor pools. + self.enabled_for_cmap = ( + self.handshake + and self.opts._event_listeners is not None + and self.opts._event_listeners.enabled_for_cmap + ) + + # The first portion of the wait queue. + # Enforces: maxPoolSize + # Also used for: clearing the wait queue + self.size_cond = threading.Condition(self.lock) + self.requests = 0 + self.max_pool_size = self.opts.max_pool_size + if not self.max_pool_size: + self.max_pool_size = float("inf") + # The second portion of the wait queue. + # Enforces: maxConnecting + # Also used for: clearing the wait queue + self._max_connecting_cond = threading.Condition(self.lock) + self._max_connecting = self.opts.max_connecting + self._pending = 0 + self._client_id = client_id + if self.enabled_for_cmap: + assert self.opts._event_listeners is not None + self.opts._event_listeners.publish_pool_created( + self.address, self.opts.non_default_options + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.POOL_CREATED, + serverHost=self.address[0], + serverPort=self.address[1], + **self.opts.non_default_options, + ) + # Similar to active_sockets but includes threads in the wait queue. + self.operation_count: int = 0 + # Retain references to pinned connections to prevent the CPython GC + # from thinking that a cursor's pinned connection can be GC'd when the + # cursor is GC'd (see PYTHON-2751). + self.__pinned_sockets: set[Connection] = set() + self.ncursors = 0 + self.ntxns = 0 + + def ready(self) -> None: + # Take the lock to avoid the race condition described in PYTHON-2699. + with self.lock: + if self.state != PoolState.READY: + self.state = PoolState.READY + if self.enabled_for_cmap: + assert self.opts._event_listeners is not None + self.opts._event_listeners.publish_pool_ready(self.address) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.POOL_READY, + serverHost=self.address[0], + serverPort=self.address[1], + ) + + @property + def closed(self) -> bool: + return self.state == PoolState.CLOSED + + def _reset( + self, + close: bool, + pause: bool = True, + service_id: Optional[ObjectId] = None, + interrupt_connections: bool = False, + ) -> None: + old_state = self.state + with self.size_cond: + if self.closed: + return + if self.opts.pause_enabled and pause and not self.opts.load_balanced: + old_state, self.state = self.state, PoolState.PAUSED + self.gen.inc(service_id) + newpid = os.getpid() + if self.pid != newpid: + self.pid = newpid + self.active_sockets = 0 + self.operation_count = 0 + if service_id is None: + sockets, self.conns = self.conns, collections.deque() + else: + discard: collections.deque = collections.deque() + keep: collections.deque = collections.deque() + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) + else: + keep.append(conn) + sockets = discard + self.conns = keep + + if close: + self.state = PoolState.CLOSED + # Clear the wait queue + self._max_connecting_cond.notify_all() + self.size_cond.notify_all() + + if interrupt_connections: + for context in self.active_contexts: + context.cancel() + + listeners = self.opts._event_listeners + # CMAP spec says that close() MUST close sockets before publishing the + # PoolClosedEvent but that reset() SHOULD close sockets *after* + # publishing the PoolClearedEvent. + if close: + for conn in sockets: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_pool_closed(self.address) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.POOL_CLOSED, + serverHost=self.address[0], + serverPort=self.address[1], + ) + else: + if old_state != PoolState.PAUSED and self.enabled_for_cmap: + assert listeners is not None + listeners.publish_pool_cleared( + self.address, + service_id=service_id, + interrupt_connections=interrupt_connections, + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.POOL_CLEARED, + serverHost=self.address[0], + serverPort=self.address[1], + serviceId=service_id, + ) + for conn in sockets: + conn.close_conn(ConnectionClosedReason.STALE) + + def update_is_writable(self, is_writable: Optional[bool]) -> None: + """Updates the is_writable attribute on all sockets currently in the + Pool. + """ + self.is_writable = is_writable + with self.lock: + for _socket in self.conns: + _socket.update_is_writable(self.is_writable) + + def reset( + self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False + ) -> None: + self._reset(close=False, service_id=service_id, interrupt_connections=interrupt_connections) + + def reset_without_pause(self) -> None: + self._reset(close=False, pause=False) + + def close(self) -> None: + self._reset(close=True) + + def stale_generation(self, gen: int, service_id: Optional[ObjectId]) -> bool: + return self.gen.stale(gen, service_id) + + def remove_stale_sockets(self, reference_generation: int) -> None: + """Removes stale sockets then adds new ones if pool is too small and + has not been reset. The `reference_generation` argument specifies the + `generation` at the point in time this operation was requested on the + pool. + """ + # Take the lock to avoid the race condition described in PYTHON-2699. + with self.lock: + if self.state != PoolState.READY: + return + + if self.opts.max_idle_time_seconds is not None: + with self.lock: + while ( + self.conns + and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds + ): + conn = self.conns.pop() + conn.close_conn(ConnectionClosedReason.IDLE) + + while True: + with self.size_cond: + # There are enough sockets in the pool. + if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: + return + if self.requests >= self.opts.min_pool_size: + return + self.requests += 1 + incremented = False + try: + with self._max_connecting_cond: + # If maxConnecting connections are already being created + # by this pool then try again later instead of waiting. + if self._pending >= self._max_connecting: + return + self._pending += 1 + incremented = True + conn = self.connect() + with self.lock: + # Close connection and return if the pool was reset during + # socket creation or while acquiring the pool lock. + if self.gen.get_overall() != reference_generation: + conn.close_conn(ConnectionClosedReason.STALE) + return + self.conns.appendleft(conn) + self.active_contexts.discard(conn.cancel_context) + finally: + if incremented: + # Notify after adding the socket to the pool. + with self._max_connecting_cond: + self._pending -= 1 + self._max_connecting_cond.notify() + + with self.size_cond: + self.requests -= 1 + self.size_cond.notify() + + def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: + """Connect to Mongo and return a new Connection. + + Can raise ConnectionFailure. + + Note that the pool does not keep a reference to the socket -- you + must call checkin() when you're done with it. + """ + with self.lock: + conn_id = self.next_connection_id + self.next_connection_id += 1 + + listeners = self.opts._event_listeners + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_connection_created(self.address, conn_id) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CONN_CREATED, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=conn_id, + ) + + try: + sock = _configured_socket(self.address, self.opts) + except BaseException as error: + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_connection_closed( + self.address, conn_id, ConnectionClosedReason.ERROR + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CONN_CLOSED, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=conn_id, + reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), + error=ConnectionClosedReason.ERROR, + ) + if isinstance(error, (IOError, OSError, SSLError)): + details = _get_timeout_details(self.opts) + _raise_connection_failure(self.address, error, timeout_details=details) + + raise + + conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] + with self.lock: + self.active_contexts.add(conn.cancel_context) + try: + if self.handshake: + conn.hello() + self.is_writable = conn.is_writable + if handler: + handler.contribute_socket(conn, completed_handshake=False) + + conn.authenticate() + except BaseException: + conn.close_conn(ConnectionClosedReason.ERROR) + raise + + return conn + + @contextlib.contextmanager + def checkout(self, handler: Optional[_MongoClientErrorHandler] = None) -> Iterator[Connection]: + """Get a connection from the pool. Use with a "with" statement. + + Returns a :class:`Connection` object wrapping a connected + :class:`socket.socket`. + + This method should always be used in a with-statement:: + + with pool.get_conn() as connection: + connection.send_message(msg) + data = connection.receive_message(op_code, request_id) + + Can raise ConnectionFailure or OperationFailure. + + :param handler: A _MongoClientErrorHandler. + """ + listeners = self.opts._event_listeners + checkout_started_time = time.monotonic() + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_connection_check_out_started(self.address) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_STARTED, + serverHost=self.address[0], + serverPort=self.address[1], + ) + + conn = self._get_conn(checkout_started_time, handler=handler) + + if self.enabled_for_cmap: + assert listeners is not None + duration = time.monotonic() - checkout_started_time + listeners.publish_connection_checked_out(self.address, conn.id, duration) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=conn.id, + durationMS=duration, + ) + try: + with self.lock: + self.active_contexts.add(conn.cancel_context) + yield conn + except BaseException: + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the + # connection and it is responsible for checking the connection + # back into the pool. + pinned = conn.pinned_txn or conn.pinned_cursor + if handler: + # Perform SDAM error handling rules while the connection is + # still checked out. + exc_type, exc_val, _ = sys.exc_info() + handler.handle(exc_type, exc_val) + if not pinned and conn.active: + self.checkin(conn) + raise + if conn.pinned_txn: + with self.lock: + self.__pinned_sockets.add(conn) + self.ntxns += 1 + elif conn.pinned_cursor: + with self.lock: + self.__pinned_sockets.add(conn) + self.ncursors += 1 + elif conn.active: + self.checkin(conn) + + def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> None: + if self.state != PoolState.READY: + if self.enabled_for_cmap and emit_event: + assert self.opts._event_listeners is not None + duration = time.monotonic() - checkout_started_time + self.opts._event_listeners.publish_connection_check_out_failed( + self.address, ConnectionCheckOutFailedReason.CONN_ERROR, duration + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_FAILED, + serverHost=self.address[0], + serverPort=self.address[1], + reason="An error occurred while trying to establish a new connection", + error=ConnectionCheckOutFailedReason.CONN_ERROR, + durationMS=duration, + ) + + details = _get_timeout_details(self.opts) + _raise_connection_failure( + self.address, AutoReconnect("connection pool paused"), timeout_details=details + ) + + def _get_conn( + self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None + ) -> Connection: + """Get or create a Connection. Can raise ConnectionFailure.""" + # We use the pid here to avoid issues with fork / multiprocessing. + # See test.test_client:TestClient.test_fork for an example of + # what could go wrong otherwise + if self.pid != os.getpid(): + self.reset_without_pause() + + if self.closed: + if self.enabled_for_cmap: + assert self.opts._event_listeners is not None + duration = time.monotonic() - checkout_started_time + self.opts._event_listeners.publish_connection_check_out_failed( + self.address, ConnectionCheckOutFailedReason.POOL_CLOSED, duration + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_FAILED, + serverHost=self.address[0], + serverPort=self.address[1], + reason="Connection pool was closed", + error=ConnectionCheckOutFailedReason.POOL_CLOSED, + durationMS=duration, + ) + raise _PoolClosedError( + "Attempted to check out a connection from closed connection pool" + ) + + with self.lock: + self.operation_count += 1 + + # Get a free socket or create one. + if _csot.get_timeout(): + deadline = _csot.get_deadline() + elif self.opts.wait_queue_timeout: + deadline = time.monotonic() + self.opts.wait_queue_timeout + else: + deadline = None + + with self.size_cond: + self._raise_if_not_ready(checkout_started_time, emit_event=True) + while not (self.requests < self.max_pool_size): + if not _cond_wait(self.size_cond, deadline): + # Timed out, notify the next thread to ensure a + # timeout doesn't consume the condition. + if self.requests < self.max_pool_size: + self.size_cond.notify() + self._raise_wait_queue_timeout(checkout_started_time) + self._raise_if_not_ready(checkout_started_time, emit_event=True) + self.requests += 1 + + # We've now acquired the semaphore and must release it on error. + conn = None + incremented = False + emitted_event = False + try: + with self.lock: + self.active_sockets += 1 + incremented = True + while conn is None: + # CMAP: we MUST wait for either maxConnecting OR for a socket + # to be checked back into the pool. + with self._max_connecting_cond: + self._raise_if_not_ready(checkout_started_time, emit_event=False) + while not (self.conns or self._pending < self._max_connecting): + if not _cond_wait(self._max_connecting_cond, deadline): + # Timed out, notify the next thread to ensure a + # timeout doesn't consume the condition. + if self.conns or self._pending < self._max_connecting: + self._max_connecting_cond.notify() + emitted_event = True + self._raise_wait_queue_timeout(checkout_started_time) + self._raise_if_not_ready(checkout_started_time, emit_event=False) + + try: + conn = self.conns.popleft() + except IndexError: + self._pending += 1 + if conn: # We got a socket from the pool + if self._perished(conn): + conn = None + continue + else: # We need to create a new connection + try: + conn = self.connect(handler=handler) + finally: + with self._max_connecting_cond: + self._pending -= 1 + self._max_connecting_cond.notify() + except BaseException: + if conn: + # We checked out a socket but authentication failed. + conn.close_conn(ConnectionClosedReason.ERROR) + with self.size_cond: + self.requests -= 1 + if incremented: + self.active_sockets -= 1 + self.size_cond.notify() + + if self.enabled_for_cmap and not emitted_event: + assert self.opts._event_listeners is not None + duration = time.monotonic() - checkout_started_time + self.opts._event_listeners.publish_connection_check_out_failed( + self.address, ConnectionCheckOutFailedReason.CONN_ERROR, duration + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_FAILED, + serverHost=self.address[0], + serverPort=self.address[1], + reason="An error occurred while trying to establish a new connection", + error=ConnectionCheckOutFailedReason.CONN_ERROR, + durationMS=duration, + ) + raise + + conn.active = True + return conn + + def checkin(self, conn: Connection) -> None: + """Return the connection to the pool, or if it's closed discard it. + + :param conn: The connection to check into the pool. + """ + txn = conn.pinned_txn + cursor = conn.pinned_cursor + conn.active = False + conn.pinned_txn = False + conn.pinned_cursor = False + self.__pinned_sockets.discard(conn) + listeners = self.opts._event_listeners + with self.lock: + self.active_contexts.discard(conn.cancel_context) + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_connection_checked_in(self.address, conn.id) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKEDIN, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=conn.id, + ) + if self.pid != os.getpid(): + self.reset_without_pause() + else: + if self.closed: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + elif conn.closed: + # CMAP requires the closed event be emitted after the check in. + if self.enabled_for_cmap: + assert listeners is not None + listeners.publish_connection_closed( + self.address, conn.id, ConnectionClosedReason.ERROR + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CONN_CLOSED, + serverHost=self.address[0], + serverPort=self.address[1], + driverConnectionId=conn.id, + reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), + error=ConnectionClosedReason.ERROR, + ) + else: + with self.lock: + # Hold the lock to ensure this section does not race with + # Pool.reset(). + if self.stale_generation(conn.generation, conn.service_id): + conn.close_conn(ConnectionClosedReason.STALE) + else: + conn.update_last_checkin_time() + conn.update_is_writable(bool(self.is_writable)) + self.conns.appendleft(conn) + # Notify any threads waiting to create a connection. + self._max_connecting_cond.notify() + + with self.size_cond: + if txn: + self.ntxns -= 1 + elif cursor: + self.ncursors -= 1 + self.requests -= 1 + self.active_sockets -= 1 + self.operation_count -= 1 + self.size_cond.notify() + + def _perished(self, conn: Connection) -> bool: + """Return True and close the connection if it is "perished". + + This side-effecty function checks if this socket has been idle for + for longer than the max idle time, or if the socket has been closed by + some external network error, or if the socket's generation is outdated. + + Checking sockets lets us avoid seeing *some* + :class:`~pymongo.errors.AutoReconnect` exceptions on server + hiccups, etc. We only check if the socket was closed by an external + error if it has been > 1 second since the socket was checked into the + pool, to keep performance reasonable - we can't avoid AutoReconnects + completely anyway. + """ + idle_time_seconds = conn.idle_time_seconds() + # If socket is idle, open a new one. + if ( + self.opts.max_idle_time_seconds is not None + and idle_time_seconds > self.opts.max_idle_time_seconds + ): + conn.close_conn(ConnectionClosedReason.IDLE) + return True + + if self._check_interval_seconds is not None and ( + self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds + ): + if conn.conn_closed(): + conn.close_conn(ConnectionClosedReason.ERROR) + return True + + if self.stale_generation(conn.generation, conn.service_id): + conn.close_conn(ConnectionClosedReason.STALE) + return True + + return False + + def _raise_wait_queue_timeout(self, checkout_started_time: float) -> NoReturn: + listeners = self.opts._event_listeners + if self.enabled_for_cmap: + assert listeners is not None + duration = time.monotonic() - checkout_started_time + listeners.publish_connection_check_out_failed( + self.address, ConnectionCheckOutFailedReason.TIMEOUT, duration + ) + if _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _CONNECTION_LOGGER, + clientId=self._client_id, + message=_ConnectionStatusMessage.CHECKOUT_FAILED, + serverHost=self.address[0], + serverPort=self.address[1], + reason="Wait queue timeout elapsed without a connection becoming available", + error=ConnectionCheckOutFailedReason.TIMEOUT, + durationMS=duration, + ) + timeout = _csot.get_timeout() or self.opts.wait_queue_timeout + if self.opts.load_balanced: + other_ops = self.active_sockets - self.ncursors - self.ntxns + raise WaitQueueTimeoutError( + "Timeout waiting for connection from the connection pool. " + "maxPoolSize: {}, connections in use by cursors: {}, " + "connections in use by transactions: {}, connections in use " + "by other operations: {}, timeout: {}".format( + self.opts.max_pool_size, + self.ncursors, + self.ntxns, + other_ops, + timeout, + ) + ) + raise WaitQueueTimeoutError( + "Timed out while checking out a connection from connection pool. " + f"maxPoolSize: {self.opts.max_pool_size}, timeout: {timeout}" + ) + + def __del__(self) -> None: + # Avoid ResourceWarnings in Python 3 + # Close all sockets without calling reset() or close() because it is + # not safe to acquire a lock in __del__. + for conn in self.conns: + conn.close_conn(None) diff --git a/venv/Lib/site-packages/pymongo/py.typed b/venv/Lib/site-packages/pymongo/py.typed new file mode 100644 index 00000000..0f405706 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/py.typed @@ -0,0 +1,2 @@ +# PEP-561 Support File. +# "Package maintainers who wish to support type checking of their code MUST add a marker file named py.typed to their package supporting typing". diff --git a/venv/Lib/site-packages/pymongo/pyopenssl_context.py b/venv/Lib/site-packages/pymongo/pyopenssl_context.py new file mode 100644 index 00000000..fb007135 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/pyopenssl_context.py @@ -0,0 +1,417 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's +context. +""" +from __future__ import annotations + +import socket as _socket +import ssl as _stdlibssl +import sys as _sys +import time as _time +from errno import EINTR as _EINTR +from ipaddress import ip_address as _ip_address +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union + +from OpenSSL import SSL as _SSL +from OpenSSL import crypto as _crypto + +from pymongo._lazy_import import lazy_import +from pymongo.errors import ConfigurationError as _ConfigurationError +from pymongo.errors import _CertificateError # type:ignore[attr-defined] +from pymongo.ocsp_cache import _OCSPCache +from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback +from pymongo.socket_checker import SocketChecker as _SocketChecker +from pymongo.socket_checker import _errno_from_exception +from pymongo.write_concern import validate_boolean + +_x509 = lazy_import("cryptography.x509") +_service_identity = lazy_import("service_identity") +_service_identity_pyopenssl = lazy_import("service_identity.pyopenssl") + +if TYPE_CHECKING: + from ssl import VerifyMode + + from cryptography.x509 import Certificate + +_T = TypeVar("_T") + +try: + import certifi + + _HAVE_CERTIFI = True +except ImportError: + _HAVE_CERTIFI = False + +PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD +# Always available +OP_NO_SSLv2 = _SSL.OP_NO_SSLv2 +OP_NO_SSLv3 = _SSL.OP_NO_SSLv3 +OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION +# This isn't currently documented for PyOpenSSL +OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0) + +# Always available +HAS_SNI = True +IS_PYOPENSSL = True + +# Base Exception class +SSLError = _SSL.Error + +# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002 +_VERIFY_MAP = { + _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE, + _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER, + _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} + +_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()} + + +# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are +# not permitted for SNI hostname. +def _is_ip_address(address: Any) -> bool: + try: + _ip_address(address) + return True + except (ValueError, UnicodeError): + return False + + +# According to the docs for socket.send it can raise +# WantX509LookupError and should be retried. +BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) + + +def _ragged_eof(exc: BaseException) -> bool: + """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" + return exc.args == (-1, "Unexpected EOF") + + +# https://github.com/pyca/pyopenssl/issues/168 +# https://github.com/pyca/pyopenssl/issues/176 +# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets +class _sslConn(_SSL.Connection): + def __init__( + self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool + ): + self.socket_checker = _SocketChecker() + self.suppress_ragged_eofs = suppress_ragged_eofs + super().__init__(ctx, sock) + + def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: + timeout = self.gettimeout() + if timeout: + start = _time.monotonic() + while True: + try: + return call(*args, **kwargs) + except BLOCKING_IO_ERRORS as exc: + # Check for closed socket. + if self.fileno() == -1: + if timeout and _time.monotonic() - start > timeout: + raise _socket.timeout("timed out") from None + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, _SSL.WantReadError): + want_read = True + want_write = False + elif isinstance(exc, _SSL.WantWriteError): + want_read = False + want_write = True + else: + want_read = True + want_write = True + self.socket_checker.select(self, want_read, want_write, timeout) + if timeout and _time.monotonic() - start > timeout: + raise _socket.timeout("timed out") from None + continue + + def do_handshake(self, *args: Any, **kwargs: Any) -> None: + return self._call(super().do_handshake, *args, **kwargs) + + def recv(self, *args: Any, **kwargs: Any) -> bytes: + try: + return self._call(super().recv, *args, **kwargs) + except _SSL.SysCallError as exc: + # Suppress ragged EOFs to match the stdlib. + if self.suppress_ragged_eofs and _ragged_eof(exc): + return b"" + raise + + def recv_into(self, *args: Any, **kwargs: Any) -> int: + try: + return self._call(super().recv_into, *args, **kwargs) + except _SSL.SysCallError as exc: + # Suppress ragged EOFs to match the stdlib. + if self.suppress_ragged_eofs and _ragged_eof(exc): + return 0 + raise + + def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override] + view = memoryview(buf) + total_length = len(buf) + total_sent = 0 + while total_sent < total_length: + try: + sent = self._call(super().send, view[total_sent:], flags) + # XXX: It's not clear if this can actually happen. PyOpenSSL + # doesn't appear to have any interrupt handling, nor any interrupt + # errors for OpenSSL connections. + except OSError as exc: + if _errno_from_exception(exc) == _EINTR: + continue + raise + # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756 + # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html + if sent <= 0: + raise OSError("connection closed") + total_sent += sent + + +class _CallbackData: + """Data class which is passed to the OCSP callback.""" + + def __init__(self) -> None: + self.trusted_ca_certs: Optional[list[Certificate]] = None + self.check_ocsp_endpoint: Optional[bool] = None + self.ocsp_response_cache = _OCSPCache() + + +class SSLContext: + """A CPython compatible SSLContext implementation wrapping PyOpenSSL's + context. + """ + + __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname") + + def __init__(self, protocol: int): + self._protocol = protocol + self._ctx = _SSL.Context(self._protocol) + self._callback_data = _CallbackData() + self._check_hostname = True + # OCSP + # XXX: Find a better place to do this someday, since this is client + # side configuration and wrap_socket tries to support both client and + # server side sockets. + self._callback_data.check_ocsp_endpoint = True + self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data) + + @property + def protocol(self) -> int: + """The protocol version chosen when constructing the context. + This attribute is read-only. + """ + return self._protocol + + def __get_verify_mode(self) -> VerifyMode: + """Whether to try to verify other peers' certificates and how to + behave if verification fails. This attribute must be one of + ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. + """ + return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()] + + def __set_verify_mode(self, value: VerifyMode) -> None: + """Setter for verify_mode.""" + + def _cb( + _connobj: _SSL.Connection, + _x509obj: _crypto.X509, + _errnum: int, + _errdepth: int, + retcode: int, + ) -> bool: + # It seems we don't need to do anything here. Twisted doesn't, + # and OpenSSL's SSL_CTX_set_verify let's you pass NULL + # for the callback option. It's weird that PyOpenSSL requires + # this. + # This is optional in pyopenssl >= 20 and can be removed once minimum + # supported version is bumped + # See: pyopenssl.org/en/latest/changelog.html#id47 + return bool(retcode) + + self._ctx.set_verify(_VERIFY_MAP[value], _cb) + + verify_mode = property(__get_verify_mode, __set_verify_mode) + + def __get_check_hostname(self) -> bool: + return self._check_hostname + + def __set_check_hostname(self, value: Any) -> None: + validate_boolean("check_hostname", value) + self._check_hostname = value + + check_hostname = property(__get_check_hostname, __set_check_hostname) + + def __get_check_ocsp_endpoint(self) -> Optional[bool]: + return self._callback_data.check_ocsp_endpoint + + def __set_check_ocsp_endpoint(self, value: bool) -> None: + validate_boolean("check_ocsp", value) + self._callback_data.check_ocsp_endpoint = value + + check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) + + def __get_options(self) -> None: + # Calling set_options adds the option to the existing bitmask and + # returns the new bitmask. + # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options + return self._ctx.set_options(0) + + def __set_options(self, value: int) -> None: + # Explicitly convert to int, since newer CPython versions + # use enum.IntFlag for options. The values are the same + # regardless of implementation. + self._ctx.set_options(int(value)) + + options = property(__get_options, __set_options) + + def load_cert_chain( + self, + certfile: Union[str, bytes], + keyfile: Union[str, bytes, None] = None, + password: Optional[str] = None, + ) -> None: + """Load a private key and the corresponding certificate. The certfile + string must be the path to a single file in PEM format containing the + certificate as well as any number of CA certificates needed to + establish the certificate's authenticity. The keyfile string, if + present, must point to a file containing the private key. Otherwise + the private key will be taken from certfile as well. + """ + # Match CPython behavior + # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971 + # Password callback MUST be set first or it will be ignored. + if password: + + def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: bytes) -> bytes: + # XXX:We could check the password length against what OpenSSL + # tells us is the max, but we can't raise an exception, so... + # warn? + assert password is not None + return password.encode("utf-8") + + self._ctx.set_passwd_cb(_pwcb) + self._ctx.use_certificate_chain_file(certfile) + self._ctx.use_privatekey_file(keyfile or certfile) + self._ctx.check_privatekey() + + def load_verify_locations( + self, cafile: Optional[str] = None, capath: Optional[str] = None + ) -> None: + """Load a set of "certification authority"(CA) certificates used to + validate other peers' certificates when `~verify_mode` is other than + ssl.CERT_NONE. + """ + self._ctx.load_verify_locations(cafile, capath) + # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20). + if not hasattr(_SSL.Connection, "get_verified_chain"): + assert cafile is not None + self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile) + + def _load_certifi(self) -> None: + """Attempt to load CA certs from certifi.""" + if _HAVE_CERTIFI: + self.load_verify_locations(certifi.where()) + else: + raise _ConfigurationError( + "tlsAllowInvalidCertificates is False but no system " + "CA certificates could be loaded. Please install the " + "certifi package, or provide a path to a CA file using " + "the tlsCAFile option" + ) + + def _load_wincerts(self, store: str) -> None: + """Attempt to load CA certs from Windows trust store.""" + cert_store = self._ctx.get_cert_store() + oid = _stdlibssl.Purpose.SERVER_AUTH.oid + for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore + if encoding == "x509_asn": + if trust is True or oid in trust: + cert_store.add_cert( + _crypto.X509.from_cryptography(_x509.load_der_x509_certificate(cert)) + ) + + def load_default_certs(self) -> None: + """A PyOpenSSL version of load_default_certs from CPython.""" + # PyOpenSSL is incapable of loading CA certs from Windows, and mostly + # incapable on macOS. + # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths + if _sys.platform == "win32": + try: + for storename in ("CA", "ROOT"): + self._load_wincerts(storename) + except PermissionError: + # Fall back to certifi + self._load_certifi() + elif _sys.platform == "darwin": + self._load_certifi() + self._ctx.set_default_verify_paths() + + def set_default_verify_paths(self) -> None: + """Specify that the platform provided CA certificates are to be used + for verification purposes. + """ + # Note: See PyOpenSSL's docs for limitations, which are similar + # but not that same as CPython's. + self._ctx.set_default_verify_paths() + + def wrap_socket( + self, + sock: _socket.socket, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: Optional[str] = None, + session: Optional[_SSL.Session] = None, + ) -> _sslConn: + """Wrap an existing Python socket connection and return a TLS socket + object. + """ + ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) + if session: + ssl_conn.set_session(session) + if server_side is True: + ssl_conn.set_accept_state() + else: + # SNI + if server_hostname and not _is_ip_address(server_hostname): + # XXX: Do this in a callback registered with + # SSLContext.set_info_callback? See Twisted for an example. + ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) + if self.verify_mode != _stdlibssl.CERT_NONE: + # Request a stapled OCSP response. + ssl_conn.request_ocsp() + ssl_conn.set_connect_state() + # If this wasn't true the caller of wrap_socket would call + # do_handshake() + if do_handshake_on_connect: + # XXX: If we do hostname checking in a callback we can get rid + # of this call to do_handshake() since the handshake + # will happen automatically later. + ssl_conn.do_handshake() + # XXX: Do this in a callback registered with + # SSLContext.set_info_callback? See Twisted for an example. + if self.check_hostname and server_hostname is not None: + try: + if _is_ip_address(server_hostname): + _service_identity_pyopenssl.verify_ip_address(ssl_conn, server_hostname) + else: + _service_identity_pyopenssl.verify_hostname(ssl_conn, server_hostname) + except ( + _service_identity.SICertificateError, + _service_identity.SIVerificationError, + ) as exc: + raise _CertificateError(str(exc)) from None + return ssl_conn diff --git a/venv/Lib/site-packages/pymongo/read_concern.py b/venv/Lib/site-packages/pymongo/read_concern.py new file mode 100644 index 00000000..eda715f7 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/read_concern.py @@ -0,0 +1,76 @@ +# Copyright 2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License", +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with read concerns.""" +from __future__ import annotations + +from typing import Any, Optional + + +class ReadConcern: + """ReadConcern + + :param level: (string) The read concern level specifies the level of + isolation for read operations. For example, a read operation using a + read concern level of ``majority`` will only return data that has been + written to a majority of nodes. If the level is left unspecified, the + server default will be used. + + .. versionadded:: 3.2 + + """ + + def __init__(self, level: Optional[str] = None) -> None: + if level is None or isinstance(level, str): + self.__level = level + else: + raise TypeError("level must be a string or None.") + + @property + def level(self) -> Optional[str]: + """The read concern level.""" + return self.__level + + @property + def ok_for_legacy(self) -> bool: + """Return ``True`` if this read concern is compatible with + old wire protocol versions. + """ + return self.level is None or self.level == "local" + + @property + def document(self) -> dict[str, Any]: + """The document representation of this read concern. + + .. note:: + :class:`ReadConcern` is immutable. Mutating the value of + :attr:`document` does not mutate this :class:`ReadConcern`. + """ + doc = {} + if self.__level: + doc["level"] = self.level + return doc + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ReadConcern): + return self.document == other.document + return NotImplemented + + def __repr__(self) -> str: + if self.level: + return "ReadConcern(%s)" % self.level + return "ReadConcern()" + + +DEFAULT_READ_CONCERN = ReadConcern() diff --git a/venv/Lib/site-packages/pymongo/read_preferences.py b/venv/Lib/site-packages/pymongo/read_preferences.py new file mode 100644 index 00000000..7752750c --- /dev/null +++ b/venv/Lib/site-packages/pymongo/read_preferences.py @@ -0,0 +1,622 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License", +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for choosing which member of a replica set to read from.""" + +from __future__ import annotations + +from collections import abc +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +from pymongo import max_staleness_selectors +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import ( + member_with_tags_server_selector, + secondary_with_tags_server_selector, +) + +if TYPE_CHECKING: + from pymongo.server_selectors import Selection + from pymongo.topology_description import TopologyDescription + +_PRIMARY = 0 +_PRIMARY_PREFERRED = 1 +_SECONDARY = 2 +_SECONDARY_PREFERRED = 3 +_NEAREST = 4 + + +_MONGOS_MODES = ( + "primary", + "primaryPreferred", + "secondary", + "secondaryPreferred", + "nearest", +) + +_Hedge = Mapping[str, Any] +_TagSets = Sequence[Mapping[str, Any]] + + +def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: + """Validate tag sets for a MongoClient.""" + if tag_sets is None: + return tag_sets + + if not isinstance(tag_sets, (list, tuple)): + raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") + if len(tag_sets) == 0: + raise ValueError( + f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" + ) + + for tags in tag_sets: + if not isinstance(tags, abc.Mapping): + raise TypeError( + f"Tag set {tags!r} invalid, must be an instance of dict, " + "bson.son.SON or other type that inherits from " + "collection.Mapping" + ) + + return list(tag_sets) + + +def _invalid_max_staleness_msg(max_staleness: Any) -> str: + return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness + + +# Some duplication with common.py to avoid import cycle. +def _validate_max_staleness(max_staleness: Any) -> int: + """Validate max_staleness.""" + if max_staleness == -1: + return -1 + + if not isinstance(max_staleness, int): + raise TypeError(_invalid_max_staleness_msg(max_staleness)) + + if max_staleness <= 0: + raise ValueError(_invalid_max_staleness_msg(max_staleness)) + + return max_staleness + + +def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: + """Validate hedge.""" + if hedge is None: + return None + + if not isinstance(hedge, dict): + raise TypeError(f"hedge must be a dictionary, not {hedge!r}") + + return hedge + + +class _ServerMode: + """Base class for all read preferences.""" + + __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") + + def __init__( + self, + mode: int, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + self.__mongos_mode = _MONGOS_MODES[mode] + self.__mode = mode + self.__tag_sets = _validate_tag_sets(tag_sets) + self.__max_staleness = _validate_max_staleness(max_staleness) + self.__hedge = _validate_hedge(hedge) + + @property + def name(self) -> str: + """The name of this read preference.""" + return self.__class__.__name__ + + @property + def mongos_mode(self) -> str: + """The mongos mode of this read preference.""" + return self.__mongos_mode + + @property + def document(self) -> dict[str, Any]: + """Read preference as a document.""" + doc: dict[str, Any] = {"mode": self.__mongos_mode} + if self.__tag_sets not in (None, [{}]): + doc["tags"] = self.__tag_sets + if self.__max_staleness != -1: + doc["maxStalenessSeconds"] = self.__max_staleness + if self.__hedge not in (None, {}): + doc["hedge"] = self.__hedge + return doc + + @property + def mode(self) -> int: + """The mode of this read preference instance.""" + return self.__mode + + @property + def tag_sets(self) -> _TagSets: + """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to + read only from members whose ``dc`` tag has the value ``"ny"``. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." MongoClient tries each set of tags in turn + until it finds a set of tags with at least one matching member. + For example, to only send a query to an analytic node:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + Or using :class:`SecondaryPreferred`:: + + SecondaryPreferred(tag_sets=[{"node":"analytics"}]) + + .. seealso:: `Data-Center Awareness + `_ + """ + return list(self.__tag_sets) if self.__tag_sets else [{}] + + @property + def max_staleness(self) -> int: + """The maximum estimated length of time (in seconds) a replica set + secondary can fall behind the primary in replication before it will + no longer be selected for operations, or -1 for no maximum. + """ + return self.__max_staleness + + @property + def hedge(self) -> Optional[_Hedge]: + """The read preference ``hedge`` parameter. + + A dictionary that configures how the server will perform hedged reads. + It consists of the following keys: + + - ``enabled``: Enables or disables hedged reads in sharded clusters. + + Hedged reads are automatically enabled in MongoDB 4.4+ when using a + ``nearest`` read preference. To explicitly enable hedged reads, set + the ``enabled`` key to ``true``:: + + >>> Nearest(hedge={'enabled': True}) + + To explicitly disable hedged reads, set the ``enabled`` key to + ``False``:: + + >>> Nearest(hedge={'enabled': False}) + + .. versionadded:: 3.11 + """ + return self.__hedge + + @property + def min_wire_version(self) -> int: + """The wire protocol version the server must support. + + Some read preferences impose version requirements on all servers (e.g. + maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). + + All servers' maxWireVersion must be at least this read preference's + `min_wire_version`, or the driver raises + :exc:`~pymongo.errors.ConfigurationError`. + """ + return 0 if self.__max_staleness == -1 else 5 + + def __repr__(self) -> str: + return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( + self.name, + self.__tag_sets, + self.__max_staleness, + self.__hedge, + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return ( + self.mode == other.mode + and self.tag_sets == other.tag_sets + and self.max_staleness == other.max_staleness + and self.hedge == other.hedge + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __getstate__(self) -> dict[str, Any]: + """Return value of object for pickling. + + Needed explicitly because __slots__() defined. + """ + return { + "mode": self.__mode, + "tag_sets": self.__tag_sets, + "max_staleness": self.__max_staleness, + "hedge": self.__hedge, + } + + def __setstate__(self, value: Mapping[str, Any]) -> None: + """Restore from pickling.""" + self.__mode = value["mode"] + self.__mongos_mode = _MONGOS_MODES[self.__mode] + self.__tag_sets = _validate_tag_sets(value["tag_sets"]) + self.__max_staleness = _validate_max_staleness(value["max_staleness"]) + self.__hedge = _validate_hedge(value["hedge"]) + + def __call__(self, selection: Selection) -> Selection: + return selection + + +class Primary(_ServerMode): + """Primary read preference. + + * When directly connected to one mongod queries are allowed if the server + is standalone or a replica set primary. + * When connected to a mongos queries are sent to the primary of a shard. + * When connected to a replica set queries are sent to the primary of + the replica set. + """ + + __slots__ = () + + def __init__(self) -> None: + super().__init__(_PRIMARY) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return selection.primary_selection + + def __repr__(self) -> str: + return "Primary()" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return other.mode == _PRIMARY + return NotImplemented + + +class PrimaryPreferred(_ServerMode): + """PrimaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are sent to the primary of a shard if + available, otherwise a shard secondary. + * When connected to a replica set queries are sent to the primary if + available, otherwise a secondary. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to an available secondary until the + primary of the replica set is discovered. + + :param tag_sets: The :attr:`~tag_sets` to use if the primary is not + available. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` to use if the primary is not available. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + if selection.primary: + return selection.primary_selection + else: + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class Secondary(_ServerMode): + """Secondary read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries. An error is raised if no secondaries are available. + * When connected to a replica set queries are distributed among + secondaries. An error is raised if no secondaries are available. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class SecondaryPreferred(_ServerMode): + """SecondaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries, or the shard primary if no secondary is available. + * When connected to a replica set queries are distributed among + secondaries, or the primary if no secondary is available. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to the primary of the replica set until + an available secondary is discovered. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + secondaries = secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + if secondaries: + return secondaries + else: + return selection.primary_selection + + +class Nearest(_ServerMode): + """Nearest read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among all members of + a shard. + * When connected to a replica set queries are distributed among all + members. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_NEAREST, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return member_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class _AggWritePref: + """Agg $out/$merge write preference. + + * If there are readable servers and there is any pre-5.0 server, use + primary read preference. + * Otherwise use `pref` read preference. + + :param pref: The read preference to use on MongoDB 5.0+. + """ + + __slots__ = ("pref", "effective_pref") + + def __init__(self, pref: _ServerMode): + self.pref = pref + self.effective_pref: _ServerMode = ReadPreference.PRIMARY + + def selection_hook(self, topology_description: TopologyDescription) -> None: + common_wv = topology_description.common_wire_version + if ( + topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) + and common_wv + and common_wv < 13 + ): + self.effective_pref = ReadPreference.PRIMARY + else: + self.effective_pref = self.pref + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return self.effective_pref(selection) + + def __repr__(self) -> str: + return f"_AggWritePref(pref={self.pref!r})" + + # Proxy other calls to the effective_pref so that _AggWritePref can be + # used in place of an actual read preference. + def __getattr__(self, name: str) -> Any: + return getattr(self.effective_pref, name) + + +_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) + + +def make_read_preference( + mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 +) -> _ServerMode: + if mode == _PRIMARY: + if tag_sets not in (None, [{}]): + raise ConfigurationError("Read preference primary cannot be combined with tags") + if max_staleness != -1: + raise ConfigurationError( + "Read preference primary cannot be combined with maxStalenessSeconds" + ) + return Primary() + return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore + + +_MODES = ( + "PRIMARY", + "PRIMARY_PREFERRED", + "SECONDARY", + "SECONDARY_PREFERRED", + "NEAREST", +) + + +class ReadPreference: + """An enum that defines some commonly used read preference modes. + + Apps can also create a custom read preference, for example:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + See :doc:`/examples/high_availability` for code examples. + + A read preference is used in three cases: + + :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: + + - ``PRIMARY``: Queries are allowed if the server is standalone or a replica + set primary. + - All other modes allow queries to standalone servers, to a replica set + primary, or to replica set secondaries. + + :class:`~pymongo.mongo_client.MongoClient` initialized with the + ``replicaSet`` option: + + - ``PRIMARY``: Read from the primary. This is the default, and provides the + strongest consistency. If no primary is available, raise + :class:`~pymongo.errors.AutoReconnect`. + + - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is + none, read from a secondary. + + - ``SECONDARY``: Read from a secondary. If no secondary is available, + raise :class:`~pymongo.errors.AutoReconnect`. + + - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise + from the primary. + + - ``NEAREST``: Read from any member. + + :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a + sharded cluster of replica sets: + + - ``PRIMARY``: Read from the primary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + This is the default. + + - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is + none, read from a secondary of the shard. + + - ``SECONDARY``: Read from a secondary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + + - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, + otherwise from the shard primary. + + - ``NEAREST``: Read from any shard member. + """ + + PRIMARY = Primary() + PRIMARY_PREFERRED = PrimaryPreferred() + SECONDARY = Secondary() + SECONDARY_PREFERRED = SecondaryPreferred() + NEAREST = Nearest() + + +def read_pref_mode_from_name(name: str) -> int: + """Get the read preference mode from mongos/uri name.""" + return _MONGOS_MODES.index(name) + + +class MovingAverage: + """Tracks an exponentially-weighted moving average.""" + + average: Optional[float] + + def __init__(self) -> None: + self.average = None + + def add_sample(self, sample: float) -> None: + if sample < 0: + # Likely system time change while waiting for hello response + # and not using time.monotonic. Ignore it, the next one will + # probably be valid. + return + if self.average is None: + self.average = sample + else: + # The Server Selection Spec requires an exponentially weighted + # average with alpha = 0.2. + self.average = 0.8 * self.average + 0.2 * sample + + def get(self) -> Optional[float]: + """Get the calculated average, or None if no samples yet.""" + return self.average + + def reset(self) -> None: + self.average = None diff --git a/venv/Lib/site-packages/pymongo/response.py b/venv/Lib/site-packages/pymongo/response.py new file mode 100644 index 00000000..5cdd3e7e --- /dev/null +++ b/venv/Lib/site-packages/pymongo/response.py @@ -0,0 +1,131 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Represent a response from the server.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union + +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.message import _OpMsg, _OpReply + from pymongo.pool import Connection + from pymongo.typings import _Address, _DocumentOut + + +class Response: + __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") + + def __init__( + self, + data: Union[_OpMsg, _OpReply], + address: _Address, + request_id: int, + duration: Optional[timedelta], + from_command: bool, + docs: Sequence[Mapping[str, Any]], + ): + """Represent a response from the server. + + :param data: A network response message. + :param address: (host, port) of the source server. + :param request_id: The request id of this operation. + :param duration: The duration of the operation. + :param from_command: if the response is the result of a db command. + """ + self._data = data + self._address = address + self._request_id = request_id + self._duration = duration + self._from_command = from_command + self._docs = docs + + @property + def data(self) -> Union[_OpMsg, _OpReply]: + """Server response's raw BSON bytes.""" + return self._data + + @property + def address(self) -> _Address: + """(host, port) of the source server.""" + return self._address + + @property + def request_id(self) -> int: + """The request id of this operation.""" + return self._request_id + + @property + def duration(self) -> Optional[timedelta]: + """The duration of the operation.""" + return self._duration + + @property + def from_command(self) -> bool: + """If the response is a result from a db command.""" + return self._from_command + + @property + def docs(self) -> Sequence[Mapping[str, Any]]: + """The decoded document(s).""" + return self._docs + + +class PinnedResponse(Response): + __slots__ = ("_conn", "_more_to_come") + + def __init__( + self, + data: Union[_OpMsg, _OpReply], + address: _Address, + conn: Connection, + request_id: int, + duration: Optional[timedelta], + from_command: bool, + docs: list[_DocumentOut], + more_to_come: bool, + ): + """Represent a response to an exhaust cursor's initial query. + + :param data: A network response message. + :param address: (host, port) of the source server. + :param conn: The Connection used for the initial query. + :param request_id: The request id of this operation. + :param duration: The duration of the operation. + :param from_command: If the response is the result of a db command. + :param docs: List of documents. + :param more_to_come: Bool indicating whether cursor is ready to be + exhausted. + """ + super().__init__(data, address, request_id, duration, from_command, docs) + self._conn = conn + self._more_to_come = more_to_come + + @property + def conn(self) -> Connection: + """The Connection used for the initial query. + + The server will send batches on this socket, without waiting for + getMores from the client, until the result set is exhausted or there + is an error. + """ + return self._conn + + @property + def more_to_come(self) -> bool: + """If true, server is ready to send batches on the socket until the + result set is exhausted or there is an error. + """ + return self._more_to_come diff --git a/venv/Lib/site-packages/pymongo/results.py b/venv/Lib/site-packages/pymongo/results.py new file mode 100644 index 00000000..f5728656 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/results.py @@ -0,0 +1,242 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Result class definitions.""" +from __future__ import annotations + +from typing import Any, Mapping, Optional, cast + +from pymongo.errors import InvalidOperation + + +class _WriteResult: + """Base class for write result classes.""" + + __slots__ = ("__acknowledged",) + + def __init__(self, acknowledged: bool) -> None: + self.__acknowledged = acknowledged + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__acknowledged})" + + def _raise_if_unacknowledged(self, property_name: str) -> None: + """Raise an exception on property access if unacknowledged.""" + if not self.__acknowledged: + raise InvalidOperation( + f"A value for {property_name} is not available when " + "the write is unacknowledged. Check the " + "acknowledged attribute to avoid this " + "error." + ) + + @property + def acknowledged(self) -> bool: + """Is this the result of an acknowledged write operation? + + The :attr:`acknowledged` attribute will be ``False`` when using + ``WriteConcern(w=0)``, otherwise ``True``. + + .. note:: + If the :attr:`acknowledged` attribute is ``False`` all other + attributes of this class will raise + :class:`~pymongo.errors.InvalidOperation` when accessed. Values for + other attributes cannot be determined if the write operation was + unacknowledged. + + .. seealso:: + :class:`~pymongo.write_concern.WriteConcern` + """ + return self.__acknowledged + + +class InsertOneResult(_WriteResult): + """The return type for :meth:`~pymongo.collection.Collection.insert_one`.""" + + __slots__ = ("__inserted_id",) + + def __init__(self, inserted_id: Any, acknowledged: bool) -> None: + self.__inserted_id = inserted_id + super().__init__(acknowledged) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.__inserted_id!r}, acknowledged={self.acknowledged})" + ) + + @property + def inserted_id(self) -> Any: + """The inserted document's _id.""" + return self.__inserted_id + + +class InsertManyResult(_WriteResult): + """The return type for :meth:`~pymongo.collection.Collection.insert_many`.""" + + __slots__ = ("__inserted_ids",) + + def __init__(self, inserted_ids: list[Any], acknowledged: bool) -> None: + self.__inserted_ids = inserted_ids + super().__init__(acknowledged) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.__inserted_ids!r}, acknowledged={self.acknowledged})" + ) + + @property + def inserted_ids(self) -> list[Any]: + """A list of _ids of the inserted documents, in the order provided. + + .. note:: If ``False`` is passed for the `ordered` parameter to + :meth:`~pymongo.collection.Collection.insert_many` the server + may have inserted the documents in a different order than what + is presented here. + """ + return self.__inserted_ids + + +class UpdateResult(_WriteResult): + """The return type for :meth:`~pymongo.collection.Collection.update_one`, + :meth:`~pymongo.collection.Collection.update_many`, and + :meth:`~pymongo.collection.Collection.replace_one`. + """ + + __slots__ = ("__raw_result",) + + def __init__(self, raw_result: Optional[Mapping[str, Any]], acknowledged: bool): + self.__raw_result = raw_result + super().__init__(acknowledged) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__raw_result!r}, acknowledged={self.acknowledged})" + + @property + def raw_result(self) -> Optional[Mapping[str, Any]]: + """The raw result document returned by the server.""" + return self.__raw_result + + @property + def matched_count(self) -> int: + """The number of documents matched for this update.""" + self._raise_if_unacknowledged("matched_count") + if self.upserted_id is not None: + return 0 + assert self.__raw_result is not None + return self.__raw_result.get("n", 0) + + @property + def modified_count(self) -> int: + """The number of documents modified.""" + self._raise_if_unacknowledged("modified_count") + assert self.__raw_result is not None + return cast(int, self.__raw_result.get("nModified")) + + @property + def upserted_id(self) -> Any: + """The _id of the inserted document if an upsert took place. Otherwise + ``None``. + """ + self._raise_if_unacknowledged("upserted_id") + assert self.__raw_result is not None + return self.__raw_result.get("upserted") + + +class DeleteResult(_WriteResult): + """The return type for :meth:`~pymongo.collection.Collection.delete_one` + and :meth:`~pymongo.collection.Collection.delete_many` + """ + + __slots__ = ("__raw_result",) + + def __init__(self, raw_result: Mapping[str, Any], acknowledged: bool) -> None: + self.__raw_result = raw_result + super().__init__(acknowledged) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__raw_result!r}, acknowledged={self.acknowledged})" + + @property + def raw_result(self) -> Mapping[str, Any]: + """The raw result document returned by the server.""" + return self.__raw_result + + @property + def deleted_count(self) -> int: + """The number of documents deleted.""" + self._raise_if_unacknowledged("deleted_count") + return self.__raw_result.get("n", 0) + + +class BulkWriteResult(_WriteResult): + """An object wrapper for bulk API write results.""" + + __slots__ = ("__bulk_api_result",) + + def __init__(self, bulk_api_result: dict[str, Any], acknowledged: bool) -> None: + """Create a BulkWriteResult instance. + + :param bulk_api_result: A result dict from the bulk API + :param acknowledged: Was this write result acknowledged? If ``False`` + then all properties of this object will raise + :exc:`~pymongo.errors.InvalidOperation`. + """ + self.__bulk_api_result = bulk_api_result + super().__init__(acknowledged) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__bulk_api_result!r}, acknowledged={self.acknowledged})" + + @property + def bulk_api_result(self) -> dict[str, Any]: + """The raw bulk API result.""" + return self.__bulk_api_result + + @property + def inserted_count(self) -> int: + """The number of documents inserted.""" + self._raise_if_unacknowledged("inserted_count") + return cast(int, self.__bulk_api_result.get("nInserted")) + + @property + def matched_count(self) -> int: + """The number of documents matched for an update.""" + self._raise_if_unacknowledged("matched_count") + return cast(int, self.__bulk_api_result.get("nMatched")) + + @property + def modified_count(self) -> int: + """The number of documents modified.""" + self._raise_if_unacknowledged("modified_count") + return cast(int, self.__bulk_api_result.get("nModified")) + + @property + def deleted_count(self) -> int: + """The number of documents deleted.""" + self._raise_if_unacknowledged("deleted_count") + return cast(int, self.__bulk_api_result.get("nRemoved")) + + @property + def upserted_count(self) -> int: + """The number of documents upserted.""" + self._raise_if_unacknowledged("upserted_count") + return cast(int, self.__bulk_api_result.get("nUpserted")) + + @property + def upserted_ids(self) -> Optional[dict[int, Any]]: + """A map of operation index to the _id of the upserted document.""" + self._raise_if_unacknowledged("upserted_ids") + if self.__bulk_api_result: + return {upsert["index"]: upsert["_id"] for upsert in self.bulk_api_result["upserted"]} + return None diff --git a/venv/Lib/site-packages/pymongo/saslprep.py b/venv/Lib/site-packages/pymongo/saslprep.py new file mode 100644 index 00000000..7fb546f6 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/saslprep.py @@ -0,0 +1,116 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An implementation of RFC4013 SASLprep.""" +from __future__ import annotations + +from typing import Any, Optional + +try: + import stringprep +except ImportError: + HAVE_STRINGPREP = False + + def saslprep( + data: Any, + prohibit_unassigned_code_points: Optional[bool] = True, # noqa: ARG001 + ) -> Any: + """SASLprep dummy""" + if isinstance(data, str): + raise TypeError( + "The stringprep module is not available. Usernames and " + "passwords must be instances of bytes." + ) + return data + +else: + HAVE_STRINGPREP = True + import unicodedata + + # RFC4013 section 2.3 prohibited output. + _PROHIBITED = ( + # A strict reading of RFC 4013 requires table c12 here, but + # characters from it are mapped to SPACE in the Map step. Can + # normalization reintroduce them somehow? + stringprep.in_table_c12, + stringprep.in_table_c21_c22, + stringprep.in_table_c3, + stringprep.in_table_c4, + stringprep.in_table_c5, + stringprep.in_table_c6, + stringprep.in_table_c7, + stringprep.in_table_c8, + stringprep.in_table_c9, + ) + + def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) -> Any: + """An implementation of RFC4013 SASLprep. + + :param data: The string to SASLprep. Unicode strings + (:class:`str`) are supported. Byte strings + (:class:`bytes`) are ignored. + :param prohibit_unassigned_code_points: True / False. RFC 3454 + and RFCs for various SASL mechanisms distinguish between + `queries` (unassigned code points allowed) and + `stored strings` (unassigned code points prohibited). Defaults + to ``True`` (unassigned code points are prohibited). + + :return: The SASLprep'ed version of `data`. + """ + prohibited: Any + + if not isinstance(data, str): + return data + + if prohibit_unassigned_code_points: + prohibited = (*_PROHIBITED, stringprep.in_table_a1) + else: + prohibited = _PROHIBITED + + # RFC3454 section 2, step 1 - Map + # RFC4013 section 2.1 mappings + # Map Non-ASCII space characters to SPACE (U+0020). Map + # commonly mapped to nothing characters to, well, nothing. + in_table_c12 = stringprep.in_table_c12 + in_table_b1 = stringprep.in_table_b1 + data = "".join( + ["\u0020" if in_table_c12(elt) else elt for elt in data if not in_table_b1(elt)] + ) + + # RFC3454 section 2, step 2 - Normalize + # RFC4013 section 2.2 normalization + data = unicodedata.ucd_3_2_0.normalize("NFKC", data) + + in_table_d1 = stringprep.in_table_d1 + if in_table_d1(data[0]): + if not in_table_d1(data[-1]): + # RFC3454, Section 6, #3. If a string contains any + # RandALCat character, the first and last characters + # MUST be RandALCat characters. + raise ValueError("SASLprep: failed bidirectional check") + # RFC3454, Section 6, #2. If a string contains any RandALCat + # character, it MUST NOT contain any LCat character. + prohibited = (*prohibited, stringprep.in_table_d2) + else: + # RFC3454, Section 6, #3. Following the logic of #3, if + # the first character is not a RandALCat, no other character + # can be either. + prohibited = (*prohibited, in_table_d1) + + # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi + for char in data: + if any(in_table(char) for in_table in prohibited): + raise ValueError("SASLprep: failed prohibited character check") + + return data diff --git a/venv/Lib/site-packages/pymongo/server.py b/venv/Lib/site-packages/pymongo/server.py new file mode 100644 index 00000000..1c437a7e --- /dev/null +++ b/venv/Lib/site-packages/pymongo/server.py @@ -0,0 +1,346 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Communicate with one MongoDB server in a topology.""" +from __future__ import annotations + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union + +from bson import _decode_all_selective +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.helpers import _check_command_response, _handle_reauth +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.response import PinnedResponse, Response + +if TYPE_CHECKING: + from queue import Queue + from weakref import ReferenceType + + from bson.objectid import ObjectId + from pymongo.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.monitor import Monitor + from pymongo.monitoring import _EventListeners + from pymongo.pool import Connection, Pool + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription + from pymongo.typings import _DocumentOut + +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + + +class Server: + def __init__( + self, + server_description: ServerDescription, + pool: Pool, + monitor: Monitor, + topology_id: Optional[ObjectId] = None, + listeners: Optional[_EventListeners] = None, + events: Optional[ReferenceType[Queue]] = None, + ) -> None: + """Represent one MongoDB server.""" + self._description = server_description + self._pool = pool + self._monitor = monitor + self._topology_id = topology_id + self._publish = listeners is not None and listeners.enabled_for_server + self._listener = listeners + self._events = None + if self._publish: + self._events = events() # type: ignore[misc] + + def open(self) -> None: + """Start monitoring, or restart after a fork. + + Multiple calls have no effect. + """ + if not self._pool.opts.load_balanced: + self._monitor.open() + + def reset(self, service_id: Optional[ObjectId] = None) -> None: + """Clear the connection pool.""" + self.pool.reset(service_id) + + def close(self) -> None: + """Clear the connection pool and stop the monitor. + + Reconnect with open(). + """ + if self._publish: + assert self._listener is not None + assert self._events is not None + self._events.put( + ( + self._listener.publish_server_closed, + (self._description.address, self._topology_id), + ) + ) + self._monitor.close() + self._pool.close() + + def request_check(self) -> None: + """Check the server's state soon.""" + self._monitor.request_check() + + @_handle_reauth + def run_operation( + self, + conn: Connection, + operation: Union[_Query, _GetMore], + read_preference: _ServerMode, + listeners: Optional[_EventListeners], + unpack_res: Callable[..., list[_DocumentOut]], + client: MongoClient, + ) -> Response: + """Run a _Query or _GetMore operation and return a Response object. + + This method is used only to run _Query/_GetMore operations from + cursors. + Can raise ConnectionFailure, OperationFailure, etc. + + :param conn: A Connection instance. + :param operation: A _Query or _GetMore object. + :param read_preference: The read preference to use. + :param listeners: Instance of _EventListeners or None. + :param unpack_res: A callable that decodes the wire protocol response. + """ + duration = None + assert listeners is not None + publish = listeners.enabled_for_commands + start = datetime.now() + + use_cmd = operation.use_command(conn) + more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come + if more_to_come: + request_id = 0 + else: + message = operation.get_message(read_preference, conn, use_cmd) + request_id, data, max_doc_size = self._split_message(message) + + cmd, dbn = operation.as_command(conn) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + + if publish: + cmd, dbn = operation.as_command(conn) + if "$db" not in cmd: + cmd["$db"] = dbn + assert listeners is not None + listeners.publish_command_start( + cmd, + dbn, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = conn.receive_message(None) + else: + conn.send_message(data, max_doc_size) + reply = conn.receive_message(request_id) + + # Unpack and check for command errors. + if use_cmd: + user_fields = _CURSOR_DOC_FIELDS + legacy_response = False + else: + user_fields = None + legacy_response = True + docs = unpack_res( + reply, + operation.cursor_id, + operation.codec_options, + legacy_response=legacy_response, + user_fields=user_fields, + ) + if use_cmd: + first = docs[0] + operation.client._process_response(first, operation.session) + _check_command_response(first, conn.max_wire_version) + except Exception as exc: + duration = datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + listeners.publish_command_failure( + duration, + failure, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + raise + duration = datetime.now() - start + # Must publish in find / getMore / explain command response + # format. + if use_cmd: + res = docs[0] + elif operation.name == "explain": + res = docs[0] if docs else {} + else: + res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] + if operation.name == "find": + res["cursor"]["firstBatch"] = docs + else: + res["cursor"]["nextBatch"] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=res, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + listeners.publish_command_success( + duration, + res, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + + # Decrypt response. + client = operation.client + if client and client._encrypter: + if use_cmd: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + + response: Response + + if client._should_pin_cursor(operation.session) or operation.exhaust: + conn.pin_cursor() + if isinstance(reply, _OpMsg): + # In OP_MSG, the server keeps sending only if the + # more_to_come flag is set. + more_to_come = reply.more_to_come + else: + # In OP_REPLY, the server keeps sending until cursor_id is 0. + more_to_come = bool(operation.exhaust and reply.cursor_id) + if operation.conn_mgr: + operation.conn_mgr.update_exhaust(more_to_come) + response = PinnedResponse( + data=reply, + address=self._description.address, + conn=conn, + duration=duration, + request_id=request_id, + from_command=use_cmd, + docs=docs, + more_to_come=more_to_come, + ) + else: + response = Response( + data=reply, + address=self._description.address, + duration=duration, + request_id=request_id, + from_command=use_cmd, + docs=docs, + ) + + return response + + def checkout( + self, handler: Optional[_MongoClientErrorHandler] = None + ) -> ContextManager[Connection]: + return self.pool.checkout(handler) + + @property + def description(self) -> ServerDescription: + return self._description + + @description.setter + def description(self, server_description: ServerDescription) -> None: + assert server_description.address == self._description.address + self._description = server_description + + @property + def pool(self) -> Pool: + return self._pool + + def _split_message( + self, message: Union[tuple[int, Any], tuple[int, Any, int]] + ) -> tuple[int, Any, int]: + """Return request_id, data, max_doc_size. + + :param message: (request_id, data, max_doc_size) or (request_id, data) + """ + if len(message) == 3: + return message # type: ignore[return-value] + else: + # get_more and kill_cursors messages don't include BSON documents. + request_id, data = message # type: ignore[misc] + return request_id, data, 0 + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._description!r}>" diff --git a/venv/Lib/site-packages/pymongo/server_api.py b/venv/Lib/site-packages/pymongo/server_api.py new file mode 100644 index 00000000..4a746008 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/server_api.py @@ -0,0 +1,173 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for MongoDB Stable API. + +.. _versioned-api-ref: + +MongoDB Stable API +===================== + +Starting in MongoDB 5.0, applications can specify the server API version +to use when creating a :class:`~pymongo.mongo_client.MongoClient`. Doing so +ensures that the driver behaves in a manner compatible with that server API +version, regardless of the server's actual release version. + +Declaring an API Version +```````````````````````` + +.. attention:: Stable API requires MongoDB >=5.0. + +To configure MongoDB Stable API, pass the ``server_api`` keyword option to +:class:`~pymongo.mongo_client.MongoClient`:: + + >>> from pymongo.mongo_client import MongoClient + >>> from pymongo.server_api import ServerApi + >>> + >>> # Declare API version "1" for MongoClient "client" + >>> server_api = ServerApi('1') + >>> client = MongoClient(server_api=server_api) + +The declared API version is applied to all commands run through ``client``, +including those sent through the generic +:meth:`~pymongo.database.Database.command` helper. + +.. note:: Declaring an API version on the + :class:`~pymongo.mongo_client.MongoClient` **and** specifying stable + API options in :meth:`~pymongo.database.Database.command` command document + is not supported and will lead to undefined behaviour. + +To run any command without declaring a server API version or using a different +API version, create a separate :class:`~pymongo.mongo_client.MongoClient` +instance. + +Strict Mode +``````````` + +Configuring ``strict`` mode will cause the MongoDB server to reject all +commands that are not part of the declared :attr:`ServerApi.version`. This +includes command options and aggregation pipeline stages. + +For example:: + + >>> server_api = ServerApi('1', strict=True) + >>> client = MongoClient(server_api=server_api) + >>> client.test.command('count', 'test') + Traceback (most recent call last): + ... + pymongo.errors.OperationFailure: Provided apiStrict:true, but the command count is not in API Version 1, full error: {'ok': 0.0, 'errmsg': 'Provided apiStrict:true, but the command count is not in API Version 1', 'code': 323, 'codeName': 'APIStrictError' + +Detecting API Deprecations +`````````````````````````` + +The ``deprecationErrors`` option can be used to enable command failures +when using functionality that is deprecated from the configured +:attr:`ServerApi.version`. For example:: + + >>> server_api = ServerApi('1', deprecation_errors=True) + >>> client = MongoClient(server_api=server_api) + +Note that at the time of this writing, no deprecated APIs exist. + +Classes +======= +""" +from __future__ import annotations + +from typing import Any, MutableMapping, Optional + + +class ServerApiVersion: + """An enum that defines values for :attr:`ServerApi.version`. + + .. versionadded:: 3.12 + """ + + V1 = "1" + """Server API version "1".""" + + +class ServerApi: + """MongoDB Stable API.""" + + def __init__( + self, version: str, strict: Optional[bool] = None, deprecation_errors: Optional[bool] = None + ): + """Options to configure MongoDB Stable API. + + :param version: The API version string. Must be one of the values in + :class:`ServerApiVersion`. + :param strict: Set to ``True`` to enable API strict mode. + Defaults to ``None`` which means "use the server's default". + :param deprecation_errors: Set to ``True`` to enable + deprecation errors. Defaults to ``None`` which means "use the + server's default". + + .. versionadded:: 3.12 + """ + if version != ServerApiVersion.V1: + raise ValueError(f"Unknown ServerApi version: {version}") + if strict is not None and not isinstance(strict, bool): + raise TypeError( + "Wrong type for ServerApi strict, value must be an instance " + f"of bool, not {type(strict)}" + ) + if deprecation_errors is not None and not isinstance(deprecation_errors, bool): + raise TypeError( + "Wrong type for ServerApi deprecation_errors, value must be " + f"an instance of bool, not {type(deprecation_errors)}" + ) + self._version = version + self._strict = strict + self._deprecation_errors = deprecation_errors + + @property + def version(self) -> str: + """The API version setting. + + This value is sent to the server in the "apiVersion" field. + """ + return self._version + + @property + def strict(self) -> Optional[bool]: + """The API strict mode setting. + + When set, this value is sent to the server in the "apiStrict" field. + """ + return self._strict + + @property + def deprecation_errors(self) -> Optional[bool]: + """The API deprecation errors setting. + + When set, this value is sent to the server in the + "apiDeprecationErrors" field. + """ + return self._deprecation_errors + + +def _add_to_command(cmd: MutableMapping[str, Any], server_api: Optional[ServerApi]) -> None: + """Internal helper which adds API versioning options to a command. + + :param cmd: The command. + :param server_api: A :class:`ServerApi` or ``None``. + """ + if not server_api: + return + cmd["apiVersion"] = server_api.version + if server_api.strict is not None: + cmd["apiStrict"] = server_api.strict + if server_api.deprecation_errors is not None: + cmd["apiDeprecationErrors"] = server_api.deprecation_errors diff --git a/venv/Lib/site-packages/pymongo/server_description.py b/venv/Lib/site-packages/pymongo/server_description.py new file mode 100644 index 00000000..6393fce0 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/server_description.py @@ -0,0 +1,299 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Represent one server the driver is connected to.""" +from __future__ import annotations + +import time +import warnings +from typing import Any, Mapping, Optional + +from bson import EPOCH_NAIVE +from bson.objectid import ObjectId +from pymongo.hello import Hello +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import ClusterTime, _Address + + +class ServerDescription: + """Immutable representation of one server. + + :param address: A (host, port) pair + :param hello: Optional Hello instance + :param round_trip_time: Optional float + :param error: Optional, the last error attempting to connect to the server + :param round_trip_time: Optional float, the min latency from the most recent samples + """ + + __slots__ = ( + "_address", + "_server_type", + "_all_hosts", + "_tags", + "_replica_set_name", + "_primary", + "_max_bson_size", + "_max_message_size", + "_max_write_batch_size", + "_min_wire_version", + "_max_wire_version", + "_round_trip_time", + "_min_round_trip_time", + "_me", + "_is_writable", + "_is_readable", + "_ls_timeout_minutes", + "_error", + "_set_version", + "_election_id", + "_cluster_time", + "_last_write_date", + "_last_update_time", + "_topology_version", + ) + + def __init__( + self, + address: _Address, + hello: Optional[Hello] = None, + round_trip_time: Optional[float] = None, + error: Optional[Exception] = None, + min_round_trip_time: float = 0.0, + ) -> None: + self._address = address + if not hello: + hello = Hello({}) + + self._server_type = hello.server_type + self._all_hosts = hello.all_hosts + self._tags = hello.tags + self._replica_set_name = hello.replica_set_name + self._primary = hello.primary + self._max_bson_size = hello.max_bson_size + self._max_message_size = hello.max_message_size + self._max_write_batch_size = hello.max_write_batch_size + self._min_wire_version = hello.min_wire_version + self._max_wire_version = hello.max_wire_version + self._set_version = hello.set_version + self._election_id = hello.election_id + self._cluster_time = hello.cluster_time + self._is_writable = hello.is_writable + self._is_readable = hello.is_readable + self._ls_timeout_minutes = hello.logical_session_timeout_minutes + self._round_trip_time = round_trip_time + self._min_round_trip_time = min_round_trip_time + self._me = hello.me + self._last_update_time = time.monotonic() + self._error = error + self._topology_version = hello.topology_version + if error: + details = getattr(error, "details", None) + if isinstance(details, dict): + self._topology_version = details.get("topologyVersion") + + self._last_write_date: Optional[float] + if hello.last_write_date: + # Convert from datetime to seconds. + delta = hello.last_write_date - EPOCH_NAIVE + self._last_write_date = delta.total_seconds() + else: + self._last_write_date = None + + @property + def address(self) -> _Address: + """The address (host, port) of this server.""" + return self._address + + @property + def server_type(self) -> int: + """The type of this server.""" + return self._server_type + + @property + def server_type_name(self) -> str: + """The server type as a human readable string. + + .. versionadded:: 3.4 + """ + return SERVER_TYPE._fields[self._server_type] + + @property + def all_hosts(self) -> set[tuple[str, int]]: + """List of hosts, passives, and arbiters known to this server.""" + return self._all_hosts + + @property + def tags(self) -> Mapping[str, Any]: + return self._tags + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self._replica_set_name + + @property + def primary(self) -> Optional[tuple[str, int]]: + """This server's opinion about who the primary is, or None.""" + return self._primary + + @property + def max_bson_size(self) -> int: + return self._max_bson_size + + @property + def max_message_size(self) -> int: + return self._max_message_size + + @property + def max_write_batch_size(self) -> int: + return self._max_write_batch_size + + @property + def min_wire_version(self) -> int: + return self._min_wire_version + + @property + def max_wire_version(self) -> int: + return self._max_wire_version + + @property + def set_version(self) -> Optional[int]: + return self._set_version + + @property + def election_id(self) -> Optional[ObjectId]: + return self._election_id + + @property + def cluster_time(self) -> Optional[ClusterTime]: + return self._cluster_time + + @property + def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: + warnings.warn( + "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", + DeprecationWarning, + stacklevel=2, + ) + return self._set_version, self._election_id + + @property + def me(self) -> Optional[tuple[str, int]]: + return self._me + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + return self._ls_timeout_minutes + + @property + def last_write_date(self) -> Optional[float]: + return self._last_write_date + + @property + def last_update_time(self) -> float: + return self._last_update_time + + @property + def round_trip_time(self) -> Optional[float]: + """The current average latency or None.""" + # This override is for unittesting only! + if self._address in self._host_to_round_trip_time: + return self._host_to_round_trip_time[self._address] + + return self._round_trip_time + + @property + def min_round_trip_time(self) -> float: + """The min latency from the most recent samples.""" + return self._min_round_trip_time + + @property + def error(self) -> Optional[Exception]: + """The last error attempting to connect to the server, or None.""" + return self._error + + @property + def is_writable(self) -> bool: + return self._is_writable + + @property + def is_readable(self) -> bool: + return self._is_readable + + @property + def mongos(self) -> bool: + return self._server_type == SERVER_TYPE.Mongos + + @property + def is_server_type_known(self) -> bool: + return self.server_type != SERVER_TYPE.Unknown + + @property + def retryable_writes_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return ( + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) + ) or self._server_type == SERVER_TYPE.LoadBalancer + + @property + def retryable_reads_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return self._max_wire_version >= 6 + + @property + def topology_version(self) -> Optional[Mapping[str, Any]]: + return self._topology_version + + def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: + unknown = ServerDescription(self.address, error=error) + unknown._topology_version = self.topology_version + return unknown + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ServerDescription): + return ( + (self._address == other.address) + and (self._server_type == other.server_type) + and (self._min_wire_version == other.min_wire_version) + and (self._max_wire_version == other.max_wire_version) + and (self._me == other.me) + and (self._all_hosts == other.all_hosts) + and (self._tags == other.tags) + and (self._replica_set_name == other.replica_set_name) + and (self._set_version == other.set_version) + and (self._election_id == other.election_id) + and (self._primary == other.primary) + and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) + and (self._error == other.error) + ) + + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + errmsg = "" + if self.error: + errmsg = f", error={self.error!r}" + return "<{} {} server_type: {}, rtt: {}{}>".format( + self.__class__.__name__, + self.address, + self.server_type_name, + self.round_trip_time, + errmsg, + ) + + # For unittesting only. Use under no circumstances! + _host_to_round_trip_time: dict = {} diff --git a/venv/Lib/site-packages/pymongo/server_selectors.py b/venv/Lib/site-packages/pymongo/server_selectors.py new file mode 100644 index 00000000..c22ad599 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/server_selectors.py @@ -0,0 +1,174 @@ +# Copyright 2014-2016 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Criteria to select some ServerDescriptions from a TopologyDescription.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast + +from pymongo.server_type import SERVER_TYPE + +if TYPE_CHECKING: + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription + + +T = TypeVar("T") +TagSet = Mapping[str, Any] +TagSets = Sequence[TagSet] + + +class Selection: + """Input or output of a server selector function.""" + + @classmethod + def from_topology_description(cls, topology_description: TopologyDescription) -> Selection: + known_servers = topology_description.known_servers + primary = None + for sd in known_servers: + if sd.server_type == SERVER_TYPE.RSPrimary: + primary = sd + break + + return Selection( + topology_description, + topology_description.known_servers, + topology_description.common_wire_version, + primary, + ) + + def __init__( + self, + topology_description: TopologyDescription, + server_descriptions: list[ServerDescription], + common_wire_version: Optional[int], + primary: Optional[ServerDescription], + ): + self.topology_description = topology_description + self.server_descriptions = server_descriptions + self.primary = primary + self.common_wire_version = common_wire_version + + def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection: + return Selection( + self.topology_description, server_descriptions, self.common_wire_version, self.primary + ) + + def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]: + secondaries = secondary_server_selector(self) + if secondaries.server_descriptions: + return max( + secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date) + ) + return None + + @property + def primary_selection(self) -> Selection: + primaries = [self.primary] if self.primary else [] + return self.with_server_descriptions(primaries) + + @property + def heartbeat_frequency(self) -> int: + return self.topology_description.heartbeat_frequency + + @property + def topology_type(self) -> int: + return self.topology_description.topology_type + + def __bool__(self) -> bool: + return bool(self.server_descriptions) + + def __getitem__(self, item: int) -> ServerDescription: + return self.server_descriptions[item] + + +def any_server_selector(selection: T) -> T: + return selection + + +def readable_server_selector(selection: Selection) -> Selection: + return selection.with_server_descriptions( + [s for s in selection.server_descriptions if s.is_readable] + ) + + +def writable_server_selector(selection: Selection) -> Selection: + return selection.with_server_descriptions( + [s for s in selection.server_descriptions if s.is_writable] + ) + + +def secondary_server_selector(selection: Selection) -> Selection: + return selection.with_server_descriptions( + [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary] + ) + + +def arbiter_server_selector(selection: Selection) -> Selection: + return selection.with_server_descriptions( + [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter] + ) + + +def writable_preferred_server_selector(selection: Selection) -> Selection: + """Like PrimaryPreferred but doesn't use tags or latency.""" + return writable_server_selector(selection) or secondary_server_selector(selection) + + +def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection: + """All servers matching one tag set. + + A tag set is a dict. A server matches if its tags are a superset: + A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}. + + The empty tag set {} matches any server. + """ + + def tags_match(server_tags: Mapping[str, Any]) -> bool: + for key, value in tag_set.items(): + if key not in server_tags or server_tags[key] != value: + return False + + return True + + return selection.with_server_descriptions( + [s for s in selection.server_descriptions if tags_match(s.tags)] + ) + + +def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection: + """All servers match a list of tag sets. + + tag_sets is a list of dicts. The empty tag set {} matches any server, + and may be provided at the end of the list as a fallback. So + [{'a': 'value'}, {}] expresses a preference for servers tagged + {'a': 'value'}, but accepts any server if none matches the first + preference. + """ + for tag_set in tag_sets: + with_tag_set = apply_single_tag_set(tag_set, selection) + if with_tag_set: + return with_tag_set + + return selection.with_server_descriptions([]) + + +def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: + """All near-enough secondaries matching the tag sets.""" + return apply_tag_sets(tag_sets, secondary_server_selector(selection)) + + +def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: + """All near-enough members matching the tag sets.""" + return apply_tag_sets(tag_sets, readable_server_selector(selection)) diff --git a/venv/Lib/site-packages/pymongo/server_type.py b/venv/Lib/site-packages/pymongo/server_type.py new file mode 100644 index 00000000..937855cc --- /dev/null +++ b/venv/Lib/site-packages/pymongo/server_type.py @@ -0,0 +1,33 @@ +# Copyright 2014-2015 MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type codes for MongoDB servers.""" +from __future__ import annotations + +from typing import NamedTuple + + +class _ServerType(NamedTuple): + Unknown: int + Mongos: int + RSPrimary: int + RSSecondary: int + RSArbiter: int + RSOther: int + RSGhost: int + Standalone: int + LoadBalancer: int + + +SERVER_TYPE = _ServerType(*range(9)) diff --git a/venv/Lib/site-packages/pymongo/settings.py b/venv/Lib/site-packages/pymongo/settings.py new file mode 100644 index 00000000..4a3e7be4 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/settings.py @@ -0,0 +1,168 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Represent MongoClient's configuration.""" +from __future__ import annotations + +import threading +import traceback +from typing import Any, Collection, Optional, Type, Union + +from bson.objectid import ObjectId +from pymongo import common, monitor, pool +from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT +from pymongo.errors import ConfigurationError +from pymongo.pool import Pool, PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector + + +class TopologySettings: + def __init__( + self, + seeds: Optional[Collection[tuple[str, int]]] = None, + replica_set_name: Optional[str] = None, + pool_class: Optional[Type[Pool]] = None, + pool_options: Optional[PoolOptions] = None, + monitor_class: Optional[Type[monitor.Monitor]] = None, + condition_class: Optional[Type[threading.Condition]] = None, + local_threshold_ms: int = LOCAL_THRESHOLD_MS, + server_selection_timeout: int = SERVER_SELECTION_TIMEOUT, + heartbeat_frequency: int = common.HEARTBEAT_FREQUENCY, + server_selector: Optional[_ServerSelector] = None, + fqdn: Optional[str] = None, + direct_connection: Optional[bool] = False, + load_balanced: Optional[bool] = None, + srv_service_name: str = common.SRV_SERVICE_NAME, + srv_max_hosts: int = 0, + server_monitoring_mode: str = common.SERVER_MONITORING_MODE, + ): + """Represent MongoClient's configuration. + + Take a list of (host, port) pairs and optional replica set name. + """ + if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL: + raise ConfigurationError( + "heartbeatFrequencyMS cannot be less than %d" + % (common.MIN_HEARTBEAT_INTERVAL * 1000,) + ) + + self._seeds: Collection[tuple[str, int]] = seeds or [("localhost", 27017)] + self._replica_set_name = replica_set_name + self._pool_class: Type[Pool] = pool_class or pool.Pool + self._pool_options: PoolOptions = pool_options or PoolOptions() + self._monitor_class: Type[monitor.Monitor] = monitor_class or monitor.Monitor + self._condition_class: Type[threading.Condition] = condition_class or threading.Condition + self._local_threshold_ms = local_threshold_ms + self._server_selection_timeout = server_selection_timeout + self._server_selector = server_selector + self._fqdn = fqdn + self._heartbeat_frequency = heartbeat_frequency + self._direct = direct_connection + self._load_balanced = load_balanced + self._srv_service_name = srv_service_name + self._srv_max_hosts = srv_max_hosts or 0 + self._server_monitoring_mode = server_monitoring_mode + + self._topology_id = ObjectId() + # Store the allocation traceback to catch unclosed clients in the + # test suite. + self._stack = "".join(traceback.format_stack()) + + @property + def seeds(self) -> Collection[tuple[str, int]]: + """List of server addresses.""" + return self._seeds + + @property + def replica_set_name(self) -> Optional[str]: + return self._replica_set_name + + @property + def pool_class(self) -> Type[Pool]: + return self._pool_class + + @property + def pool_options(self) -> PoolOptions: + return self._pool_options + + @property + def monitor_class(self) -> Type[monitor.Monitor]: + return self._monitor_class + + @property + def condition_class(self) -> Type[threading.Condition]: + return self._condition_class + + @property + def local_threshold_ms(self) -> int: + return self._local_threshold_ms + + @property + def server_selection_timeout(self) -> int: + return self._server_selection_timeout + + @property + def server_selector(self) -> Optional[_ServerSelector]: + return self._server_selector + + @property + def heartbeat_frequency(self) -> int: + return self._heartbeat_frequency + + @property + def fqdn(self) -> Optional[str]: + return self._fqdn + + @property + def direct(self) -> Optional[bool]: + """Connect directly to a single server, or use a set of servers? + + True if there is one seed and no replica_set_name. + """ + return self._direct + + @property + def load_balanced(self) -> Optional[bool]: + """True if the client was configured to connect to a load balancer.""" + return self._load_balanced + + @property + def srv_service_name(self) -> str: + """The srvServiceName.""" + return self._srv_service_name + + @property + def srv_max_hosts(self) -> int: + """The srvMaxHosts.""" + return self._srv_max_hosts + + @property + def server_monitoring_mode(self) -> str: + """The serverMonitoringMode.""" + return self._server_monitoring_mode + + def get_topology_type(self) -> int: + if self.load_balanced: + return TOPOLOGY_TYPE.LoadBalanced + elif self.direct: + return TOPOLOGY_TYPE.Single + elif self.replica_set_name is not None: + return TOPOLOGY_TYPE.ReplicaSetNoPrimary + else: + return TOPOLOGY_TYPE.Unknown + + def get_server_descriptions(self) -> dict[Union[tuple[str, int], Any], ServerDescription]: + """Initial dict of (address, ServerDescription) for all seeds.""" + return {address: ServerDescription(address) for address in self.seeds} diff --git a/venv/Lib/site-packages/pymongo/socket_checker.py b/venv/Lib/site-packages/pymongo/socket_checker.py new file mode 100644 index 00000000..78861854 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/socket_checker.py @@ -0,0 +1,105 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Select / poll helper""" +from __future__ import annotations + +import errno +import select +import sys +from typing import Any, Optional, cast + +# PYTHON-2320: Jython does not fully support poll on SSL sockets, +# https://bugs.jython.org/issue2900 +_HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith("java") +_SelectError = getattr(select, "error", OSError) + + +def _errno_from_exception(exc: BaseException) -> Optional[int]: + if hasattr(exc, "errno"): + return cast(int, exc.errno) + if exc.args: + return cast(int, exc.args[0]) + return None + + +class SocketChecker: + def __init__(self) -> None: + self._poller: Optional[select.poll] + if _HAVE_POLL: + self._poller = select.poll() + else: + self._poller = None + + def select( + self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0 + ) -> bool: + """Select for reads or writes with a timeout in seconds (or None). + + Returns True if the socket is readable/writable, False on timeout. + """ + res: Any + while True: + try: + if self._poller: + mask = select.POLLERR | select.POLLHUP + if read: + mask = mask | select.POLLIN | select.POLLPRI + if write: + mask = mask | select.POLLOUT + self._poller.register(sock, mask) + try: + # poll() timeout is in milliseconds. select() + # timeout is in seconds. + timeout_ = None if timeout is None else timeout * 1000 + res = self._poller.poll(timeout_) + # poll returns a possibly-empty list containing + # (fd, event) 2-tuples for the descriptors that have + # events or errors to report. Return True if the list + # is not empty. + return bool(res) + finally: + self._poller.unregister(sock) + else: + rlist = [sock] if read else [] + wlist = [sock] if write else [] + res = select.select(rlist, wlist, [sock], timeout) + # select returns a 3-tuple of lists of objects that are + # ready: subsets of the first three arguments. Return + # True if any of the lists are not empty. + return any(res) + except (_SelectError, OSError) as exc: # type: ignore + if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): + continue + raise + + def socket_closed(self, sock: Any) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + try: + return self.select(sock, read=True) + except (RuntimeError, KeyError): + # RuntimeError is raised during a concurrent poll. KeyError + # is raised by unregister if the socket is not in the poller. + # These errors should not be possible since we protect the + # poller with a mutex. + raise + except ValueError: + # ValueError is raised by register/unregister/select if the + # socket file descriptor is negative or outside the range for + # select (> 1023). + return True + except Exception: + # Any other exceptions should be attributed to a closed + # or invalid socket. + return True diff --git a/venv/Lib/site-packages/pymongo/srv_resolver.py b/venv/Lib/site-packages/pymongo/srv_resolver.py new file mode 100644 index 00000000..4ee1b1f5 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/srv_resolver.py @@ -0,0 +1,138 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for resolving hosts and options from mongodb+srv:// URIs.""" +from __future__ import annotations + +import ipaddress +import random +from typing import Any, Optional, Union + +from pymongo.common import CONNECT_TIMEOUT +from pymongo.errors import ConfigurationError + +try: + from dns import resolver + + _HAVE_DNSPYTHON = True +except ImportError: + _HAVE_DNSPYTHON = False + + +# dnspython can return bytes or str from various parts +# of its API depending on version. We always want str. +def maybe_decode(text: Union[str, bytes]) -> str: + if isinstance(text, bytes): + return text.decode() + return text + + +# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. +def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + + +_INVALID_HOST_MSG = ( + "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " + "Did you mean to use 'mongodb://'?" +) + + +class _SrvResolver: + def __init__( + self, + fqdn: str, + connect_timeout: Optional[float], + srv_service_name: str, + srv_max_hosts: int = 0, + ): + self.__fqdn = fqdn + self.__srv = srv_service_name + self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT + self.__srv_max_hosts = srv_max_hosts or 0 + # Validate the fully qualified domain name. + try: + ipaddress.ip_address(fqdn) + raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) + except ValueError: + pass + + try: + self.__plist = self.__fqdn.split(".")[1:] + except Exception: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None + self.__slen = len(self.__plist) + if self.__slen < 2: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) + + def get_options(self) -> Optional[str]: + try: + results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) + except (resolver.NoAnswer, resolver.NXDOMAIN): + # No TXT records + return None + except Exception as exc: + raise ConfigurationError(str(exc)) from None + if len(results) > 1: + raise ConfigurationError("Only one TXT record is supported") + return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") + + def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: + try: + results = _resolve( + "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout + ) + except Exception as exc: + if not encapsulate_errors: + # Raise the original error. + raise + # Else, raise all errors as ConfigurationError. + raise ConfigurationError(str(exc)) from None + return results + + def _get_srv_response_and_hosts( + self, encapsulate_errors: bool + ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: + results = self._resolve_uri(encapsulate_errors) + + # Construct address tuples + nodes = [ + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results + ] + + # Validate hosts + for node in nodes: + try: + nlist = node[0].lower().split(".")[1:][-self.__slen :] + except Exception: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_max_hosts: + nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) + return results, nodes + + def get_hosts(self) -> list[tuple[str, Any]]: + _, nodes = self._get_srv_response_and_hosts(True) + return nodes + + def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: + results, nodes = self._get_srv_response_and_hosts(False) + rrset = results.rrset + ttl = rrset.ttl if rrset else 0 + return nodes, ttl diff --git a/venv/Lib/site-packages/pymongo/ssl_context.py b/venv/Lib/site-packages/pymongo/ssl_context.py new file mode 100644 index 00000000..1a042420 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/ssl_context.py @@ -0,0 +1,40 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""A fake SSLContext implementation.""" +from __future__ import annotations + +import ssl as _ssl + +# PROTOCOL_TLS_CLIENT is Python 3.6+ +PROTOCOL_SSLv23 = getattr(_ssl, "PROTOCOL_TLS_CLIENT", _ssl.PROTOCOL_SSLv23) +OP_NO_SSLv2 = getattr(_ssl, "OP_NO_SSLv2", 0) +OP_NO_SSLv3 = getattr(_ssl, "OP_NO_SSLv3", 0) +OP_NO_COMPRESSION = getattr(_ssl, "OP_NO_COMPRESSION", 0) +# Python 3.7+, OpenSSL 1.1.0h+ +OP_NO_RENEGOTIATION = getattr(_ssl, "OP_NO_RENEGOTIATION", 0) + +HAS_SNI = getattr(_ssl, "HAS_SNI", False) +IS_PYOPENSSL = False + +# Errors raised by SSL sockets when in non-blocking mode. +BLOCKING_IO_ERRORS = (_ssl.SSLWantReadError, _ssl.SSLWantWriteError) + +# Base Exception class +SSLError = _ssl.SSLError + +from ssl import SSLContext # noqa: F401,E402 + +if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"): + from ssl import VERIFY_CRL_CHECK_LEAF # noqa: F401 diff --git a/venv/Lib/site-packages/pymongo/ssl_support.py b/venv/Lib/site-packages/pymongo/ssl_support.py new file mode 100644 index 00000000..849fbf70 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/ssl_support.py @@ -0,0 +1,104 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for SSL in PyMongo.""" +from __future__ import annotations + +from typing import Optional + +from pymongo.errors import ConfigurationError + +HAVE_SSL = True + +try: + import pymongo.pyopenssl_context as _ssl +except ImportError: + try: + import pymongo.ssl_context as _ssl # type: ignore[no-redef] + except ImportError: + HAVE_SSL = False + + +if HAVE_SSL: + # Note: The validate* functions below deal with users passing + # CPython ssl module constants to configure certificate verification + # at a high level. This is legacy behavior, but requires us to + # import the ssl module even if we're only using it for this purpose. + import ssl as _stdlibssl # noqa: F401 + from ssl import CERT_NONE, CERT_REQUIRED + + HAS_SNI = _ssl.HAS_SNI + IPADDR_SAFE = True + SSLError = _ssl.SSLError + BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS + + def get_ssl_context( + certfile: Optional[str], + passphrase: Optional[str], + ca_certs: Optional[str], + crlfile: Optional[str], + allow_invalid_certificates: bool, + allow_invalid_hostnames: bool, + disable_ocsp_endpoint_check: bool, + ) -> _ssl.SSLContext: + """Create and return an SSLContext object.""" + verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED + ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23) + if verify_mode != CERT_NONE: + ctx.check_hostname = not allow_invalid_hostnames + else: + ctx.check_hostname = False + if hasattr(ctx, "check_ocsp_endpoint"): + ctx.check_ocsp_endpoint = not disable_ocsp_endpoint_check + if hasattr(ctx, "options"): + # Explicitly disable SSLv2, SSLv3 and TLS compression. Note that + # up to date versions of MongoDB 2.4 and above already disable + # SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7 + # and >= 3.3.4 and SSLv3 in >= 3.4.3. + ctx.options |= _ssl.OP_NO_SSLv2 + ctx.options |= _ssl.OP_NO_SSLv3 + ctx.options |= _ssl.OP_NO_COMPRESSION + ctx.options |= _ssl.OP_NO_RENEGOTIATION + if certfile is not None: + try: + ctx.load_cert_chain(certfile, None, passphrase) + except _ssl.SSLError as exc: + raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None + if crlfile is not None: + if _ssl.IS_PYOPENSSL: + raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL") + # Match the server's behavior. + ctx.verify_flags = getattr( # type:ignore[attr-defined] + _ssl, "VERIFY_CRL_CHECK_LEAF", 0 + ) + ctx.load_verify_locations(crlfile) + if ca_certs is not None: + ctx.load_verify_locations(ca_certs) + elif verify_mode != CERT_NONE: + ctx.load_default_certs() + ctx.verify_mode = verify_mode + return ctx + +else: + + class SSLError(Exception): # type: ignore + pass + + HAS_SNI = False + IPADDR_SAFE = False + BLOCKING_IO_ERRORS = () # type:ignore[assignment] + + def get_ssl_context(*dummy): # type: ignore + """No ssl module, raise ConfigurationError.""" + raise ConfigurationError("The ssl module is not available.") diff --git a/venv/Lib/site-packages/pymongo/topology.py b/venv/Lib/site-packages/pymongo/topology.py new file mode 100644 index 00000000..99adcae6 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/topology.py @@ -0,0 +1,1027 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Internal class to monitor a topology of one or more servers.""" + +from __future__ import annotations + +import logging +import os +import queue +import random +import sys +import time +import warnings +import weakref +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast + +from pymongo import _csot, common, helpers, periodic_executor +from pymongo.client_session import _ServerSession, _ServerSessionPool +from pymongo.errors import ( + ConnectionFailure, + InvalidOperation, + NetworkTimeout, + NotPrimaryError, + OperationFailure, + PyMongoError, + ServerSelectionTimeoutError, + WriteError, +) +from pymongo.hello import Hello +from pymongo.lock import _create_lock +from pymongo.logger import ( + _SERVER_SELECTION_LOGGER, + _debug_log, + _info_log, + _ServerSelectionStatusMessage, +) +from pymongo.monitor import SrvMonitor +from pymongo.pool import Pool, PoolOptions +from pymongo.server import Server +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import ( + Selection, + any_server_selector, + arbiter_server_selector, + secondary_server_selector, + writable_server_selector, +) +from pymongo.topology_description import ( + SRV_POLLING_TOPOLOGIES, + TOPOLOGY_TYPE, + TopologyDescription, + _updated_topology_description_srv_polling, + updated_topology_description, +) + +if TYPE_CHECKING: + from bson import ObjectId + from pymongo.settings import TopologySettings + from pymongo.typings import ClusterTime, _Address + + +_pymongo_dir = str(Path(__file__).parent) + + +def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: + q = queue_ref() + if not q: + return False # Cancel PeriodicExecutor. + + while True: + try: + event = q.get_nowait() + except queue.Empty: + break + else: + fn, args = event + fn(*args) + + return True # Continue PeriodicExecutor. + + +class Topology: + """Monitor a topology of one or more servers.""" + + def __init__(self, topology_settings: TopologySettings): + self._topology_id = topology_settings._topology_id + self._listeners = topology_settings._pool_options._event_listeners + self._publish_server = self._listeners is not None and self._listeners.enabled_for_server + self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology + + # Create events queue if there are publishers. + self._events = None + self.__events_executor: Any = None + + if self._publish_server or self._publish_tp: + self._events = queue.Queue(maxsize=100) + + if self._publish_tp: + assert self._events is not None + self._events.put((self._listeners.publish_topology_opened, (self._topology_id,))) + self._settings = topology_settings + topology_description = TopologyDescription( + topology_settings.get_topology_type(), + topology_settings.get_server_descriptions(), + topology_settings.replica_set_name, + None, + None, + topology_settings, + ) + + self._description = topology_description + if self._publish_tp: + assert self._events is not None + initial_td = TopologyDescription( + TOPOLOGY_TYPE.Unknown, {}, None, None, None, self._settings + ) + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (initial_td, self._description, self._topology_id), + ) + ) + + for seed in topology_settings.seeds: + if self._publish_server: + assert self._events is not None + self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id))) + + # Store the seed list to help diagnose errors in _error_message(). + self._seed_addresses = list(topology_description.server_descriptions()) + self._opened = False + self._closed = False + self._lock = _create_lock() + self._condition = self._settings.condition_class(self._lock) + self._servers: dict[_Address, Server] = {} + self._pid: Optional[int] = None + self._max_cluster_time: Optional[ClusterTime] = None + self._session_pool = _ServerSessionPool() + + if self._publish_server or self._publish_tp: + assert self._events is not None + weak: weakref.ReferenceType[queue.Queue] + + def target() -> bool: + return process_events_queue(weak) + + executor = periodic_executor.PeriodicExecutor( + interval=common.EVENTS_QUEUE_FREQUENCY, + min_interval=common.MIN_HEARTBEAT_INTERVAL, + target=target, + name="pymongo_events_thread", + ) + + # We strongly reference the executor and it weakly references + # the queue via this closure. When the topology is freed, stop + # the executor soon. + weak = weakref.ref(self._events, executor.close) + self.__events_executor = executor + executor.open() + + self._srv_monitor = None + if self._settings.fqdn is not None and not self._settings.load_balanced: + self._srv_monitor = SrvMonitor(self, self._settings) + + def open(self) -> None: + """Start monitoring, or restart after a fork. + + No effect if called multiple times. + + .. warning:: Topology is shared among multiple threads and is protected + by mutual exclusion. Using Topology from a process other than the one + that initialized it will emit a warning and may result in deadlock. To + prevent this from happening, MongoClient must be created after any + forking. + + """ + pid = os.getpid() + if self._pid is None: + self._pid = pid + elif pid != self._pid: + self._pid = pid + if sys.version_info[:2] >= (3, 12): + kwargs = {"skip_file_prefixes": (_pymongo_dir,)} + else: + kwargs = {"stacklevel": 6} + # Ignore B028 warning for missing stacklevel. + warnings.warn( # type: ignore[call-overload] # noqa: B028 + "MongoClient opened before fork. May not be entirely fork-safe, " + "proceed with caution. See PyMongo's documentation for details: " + "https://pymongo.readthedocs.io/en/stable/faq.html#" + "is-pymongo-fork-safe", + **kwargs, + ) + with self._lock: + # Close servers and clear the pools. + for server in self._servers.values(): + server.close() + # Reset the session pool to avoid duplicate sessions in + # the child process. + self._session_pool.reset() + + with self._lock: + self._ensure_opened() + + def get_server_selection_timeout(self) -> float: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + return self._settings.server_selection_timeout + return timeout + + def select_servers( + self, + selector: Callable[[Selection], Selection], + operation: str, + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + operation_id: Optional[int] = None, + ) -> list[Server]: + """Return a list of Servers matching selector, or time out. + + :param selector: function that takes a list of Servers and returns + a subset of them. + :param operation: The name of the operation that the server is being selected for. + :param server_selection_timeout: maximum seconds to wait. + If not provided, the default value common.SERVER_SELECTION_TIMEOUT + is used. + :param address: optional server address to select. + + Calls self.open() if needed. + + Raises exc:`ServerSelectionTimeoutError` after + `server_selection_timeout` if no matching servers are found. + """ + if server_selection_timeout is None: + server_timeout = self.get_server_selection_timeout() + else: + server_timeout = server_selection_timeout + + with self._lock: + server_descriptions = self._select_servers_loop( + selector, server_timeout, operation, operation_id, address + ) + + return [ + cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions + ] + + def _select_servers_loop( + self, + selector: Callable[[Selection], Selection], + timeout: float, + operation: str, + operation_id: Optional[int], + address: Optional[_Address], + ) -> list[ServerDescription]: + """select_servers() guts. Hold the lock when calling this.""" + now = time.monotonic() + end_time = now + timeout + logged_waiting = False + + if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _SERVER_SELECTION_LOGGER, + message=_ServerSelectionStatusMessage.STARTED, + selector=selector, + operation=operation, + operationId=operation_id, + topologyDescription=self.description, + clientId=self.description._topology_settings._topology_id, + ) + + server_descriptions = self._description.apply_selector( + selector, address, custom_selector=self._settings.server_selector + ) + + while not server_descriptions: + # No suitable servers. + if timeout == 0 or now > end_time: + if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _SERVER_SELECTION_LOGGER, + message=_ServerSelectionStatusMessage.FAILED, + selector=selector, + operation=operation, + operationId=operation_id, + topologyDescription=self.description, + clientId=self.description._topology_settings._topology_id, + failure=self._error_message(selector), + ) + raise ServerSelectionTimeoutError( + f"{self._error_message(selector)}, Timeout: {timeout}s, Topology Description: {self.description!r}" + ) + + if not logged_waiting: + _info_log( + _SERVER_SELECTION_LOGGER, + message=_ServerSelectionStatusMessage.WAITING, + selector=selector, + operation=operation, + operationId=operation_id, + topologyDescription=self.description, + clientId=self.description._topology_settings._topology_id, + remainingTimeMS=int(end_time - time.monotonic()), + ) + logged_waiting = True + + self._ensure_opened() + self._request_check_all() + + # Release the lock and wait for the topology description to + # change, or for a timeout. We won't miss any changes that + # came after our most recent apply_selector call, since we've + # held the lock until now. + self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + self._description.check_compatible() + now = time.monotonic() + server_descriptions = self._description.apply_selector( + selector, address, custom_selector=self._settings.server_selector + ) + + self._description.check_compatible() + return server_descriptions + + def _select_server( + self, + selector: Callable[[Selection], Selection], + operation: str, + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, + operation_id: Optional[int] = None, + ) -> Server: + servers = self.select_servers( + selector, operation, server_selection_timeout, address, operation_id + ) + servers = _filter_servers(servers, deprioritized_servers) + if len(servers) == 1: + return servers[0] + server1, server2 = random.sample(servers, 2) + if server1.pool.operation_count <= server2.pool.operation_count: + return server1 + else: + return server2 + + def select_server( + self, + selector: Callable[[Selection], Selection], + operation: str, + server_selection_timeout: Optional[float] = None, + address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, + operation_id: Optional[int] = None, + ) -> Server: + """Like select_servers, but choose a random server if several match.""" + server = self._select_server( + selector, + operation, + server_selection_timeout, + address, + deprioritized_servers, + operation_id=operation_id, + ) + if _csot.get_timeout(): + _csot.set_rtt(server.description.min_round_trip_time) + if _SERVER_SELECTION_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _SERVER_SELECTION_LOGGER, + message=_ServerSelectionStatusMessage.SUCCEEDED, + selector=selector, + operation=operation, + operationId=operation_id, + topologyDescription=self.description, + clientId=self.description._topology_settings._topology_id, + serverHost=server.description.address[0], + serverPort=server.description.address[1], + ) + return server + + def select_server_by_address( + self, + address: _Address, + operation: str, + server_selection_timeout: Optional[int] = None, + operation_id: Optional[int] = None, + ) -> Server: + """Return a Server for "address", reconnecting if necessary. + + If the server's type is not known, request an immediate check of all + servers. Time out after "server_selection_timeout" if the server + cannot be reached. + + :param address: A (host, port) pair. + :param operation: The name of the operation that the server is being selected for. + :param server_selection_timeout: maximum seconds to wait. + If not provided, the default value + common.SERVER_SELECTION_TIMEOUT is used. + :param operation_id: The unique id of the current operation being performed. Defaults to None if not provided. + + Calls self.open() if needed. + + Raises exc:`ServerSelectionTimeoutError` after + `server_selection_timeout` if no matching servers are found. + """ + return self.select_server( + any_server_selector, + operation, + server_selection_timeout, + address, + operation_id=operation_id, + ) + + def _process_change( + self, + server_description: ServerDescription, + reset_pool: bool = False, + interrupt_connections: bool = False, + ) -> None: + """Process a new ServerDescription on an opened topology. + + Hold the lock when calling this. + """ + td_old = self._description + sd_old = td_old._server_descriptions[server_description.address] + if _is_stale_server_description(sd_old, server_description): + # This is a stale hello response. Ignore it. + return + + new_td = updated_topology_description(self._description, server_description) + # CMAP: Ensure the pool is "ready" when the server is selectable. + if server_description.is_readable or ( + server_description.is_server_type_known and new_td.topology_type == TOPOLOGY_TYPE.Single + ): + server = self._servers.get(server_description.address) + if server: + server.pool.ready() + + suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description + if self._publish_server and not suppress_event: + assert self._events is not None + self._events.put( + ( + self._listeners.publish_server_description_changed, + (sd_old, server_description, server_description.address, self._topology_id), + ) + ) + + self._description = new_td + self._update_servers() + self._receive_cluster_time_no_lock(server_description.cluster_time) + + if self._publish_tp and not suppress_event: + assert self._events is not None + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (td_old, self._description, self._topology_id), + ) + ) + + # Shutdown SRV polling for unsupported cluster types. + # This is only applicable if the old topology was Unknown, and the + # new one is something other than Unknown or Sharded. + if self._srv_monitor and ( + td_old.topology_type == TOPOLOGY_TYPE.Unknown + and self._description.topology_type not in SRV_POLLING_TOPOLOGIES + ): + self._srv_monitor.close() + + # Clear the pool from a failed heartbeat. + if reset_pool: + server = self._servers.get(server_description.address) + if server: + server.pool.reset(interrupt_connections=interrupt_connections) + + # Wake waiters in select_servers(). + self._condition.notify_all() + + def on_change( + self, + server_description: ServerDescription, + reset_pool: bool = False, + interrupt_connections: bool = False, + ) -> None: + """Process a new ServerDescription after an hello call completes.""" + # We do no I/O holding the lock. + with self._lock: + # Monitors may continue working on hello calls for some time + # after a call to Topology.close, so this method may be called at + # any time. Ensure the topology is open before processing the + # change. + # Any monitored server was definitely in the topology description + # once. Check if it's still in the description or if some state- + # change removed it. E.g., we got a host list from the primary + # that didn't include this server. + if self._opened and self._description.has_server(server_description.address): + self._process_change(server_description, reset_pool, interrupt_connections) + + def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: + """Process a new seedlist on an opened topology. + Hold the lock when calling this. + """ + td_old = self._description + if td_old.topology_type not in SRV_POLLING_TOPOLOGIES: + return + self._description = _updated_topology_description_srv_polling(self._description, seedlist) + + self._update_servers() + + if self._publish_tp: + assert self._events is not None + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (td_old, self._description, self._topology_id), + ) + ) + + def on_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: + """Process a new list of nodes obtained from scanning SRV records.""" + # We do no I/O holding the lock. + with self._lock: + if self._opened: + self._process_srv_update(seedlist) + + def get_server_by_address(self, address: _Address) -> Optional[Server]: + """Get a Server or None. + + Returns the current version of the server immediately, even if it's + Unknown or absent from the topology. Only use this in unittests. + In driver code, use select_server_by_address, since then you're + assured a recent view of the server's type and wire protocol version. + """ + return self._servers.get(address) + + def has_server(self, address: _Address) -> bool: + return address in self._servers + + def get_primary(self) -> Optional[_Address]: + """Return primary's address or None.""" + # Implemented here in Topology instead of MongoClient, so it can lock. + with self._lock: + topology_type = self._description.topology_type + if topology_type != TOPOLOGY_TYPE.ReplicaSetWithPrimary: + return None + + return writable_server_selector(self._new_selection())[0].address + + def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> set[_Address]: + """Return set of replica set member addresses.""" + # Implemented here in Topology instead of MongoClient, so it can lock. + with self._lock: + topology_type = self._description.topology_type + if topology_type not in ( + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + ): + return set() + + return {sd.address for sd in iter(selector(self._new_selection()))} + + def get_secondaries(self) -> set[_Address]: + """Return set of secondary addresses.""" + return self._get_replica_set_members(secondary_server_selector) + + def get_arbiters(self) -> set[_Address]: + """Return set of arbiter addresses.""" + return self._get_replica_set_members(arbiter_server_selector) + + def max_cluster_time(self) -> Optional[ClusterTime]: + """Return a document, the highest seen $clusterTime.""" + return self._max_cluster_time + + def _receive_cluster_time_no_lock(self, cluster_time: Optional[Mapping[str, Any]]) -> None: + # Driver Sessions Spec: "Whenever a driver receives a cluster time from + # a server it MUST compare it to the current highest seen cluster time + # for the deployment. If the new cluster time is higher than the + # highest seen cluster time it MUST become the new highest seen cluster + # time. Two cluster times are compared using only the BsonTimestamp + # value of the clusterTime embedded field." + if cluster_time: + # ">" uses bson.timestamp.Timestamp's comparison operator. + if ( + not self._max_cluster_time + or cluster_time["clusterTime"] > self._max_cluster_time["clusterTime"] + ): + self._max_cluster_time = cluster_time + + def receive_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: + with self._lock: + self._receive_cluster_time_no_lock(cluster_time) + + def request_check_all(self, wait_time: int = 5) -> None: + """Wake all monitors, wait for at least one to check its server.""" + with self._lock: + self._request_check_all() + self._condition.wait(wait_time) + + def data_bearing_servers(self) -> list[ServerDescription]: + """Return a list of all data-bearing servers. + + This includes any server that might be selected for an operation. + """ + if self._description.topology_type == TOPOLOGY_TYPE.Single: + return self._description.known_servers + return self._description.readable_servers + + def update_pool(self) -> None: + # Remove any stale sockets and add new sockets if pool is too small. + servers = [] + with self._lock: + # Only update pools for data-bearing servers. + for sd in self.data_bearing_servers(): + server = self._servers[sd.address] + servers.append((server, server.pool.gen.get_overall())) + + for server, generation in servers: + try: + server.pool.remove_stale_sockets(generation) + except PyMongoError as exc: + ctx = _ErrorContext(exc, 0, generation, False, None) + self.handle_error(server.description.address, ctx) + raise + + def close(self) -> None: + """Clear pools and terminate monitors. Topology does not reopen on + demand. Any further operations will raise + :exc:`~.errors.InvalidOperation`. + """ + with self._lock: + for server in self._servers.values(): + server.close() + + # Mark all servers Unknown. + self._description = self._description.reset() + for address, sd in self._description.server_descriptions().items(): + if address in self._servers: + self._servers[address].description = sd + + # Stop SRV polling thread. + if self._srv_monitor: + self._srv_monitor.close() + + self._opened = False + self._closed = True + + # Publish only after releasing the lock. + if self._publish_tp: + assert self._events is not None + self._events.put((self._listeners.publish_topology_closed, (self._topology_id,))) + if self._publish_server or self._publish_tp: + self.__events_executor.close() + + @property + def description(self) -> TopologyDescription: + return self._description + + def pop_all_sessions(self) -> list[_ServerSession]: + """Pop all session ids from the pool.""" + with self._lock: + return self._session_pool.pop_all() + + def get_server_session(self, session_timeout_minutes: Optional[int]) -> _ServerSession: + """Start or resume a server session, or raise ConfigurationError.""" + with self._lock: + return self._session_pool.get_server_session(session_timeout_minutes) + + def return_server_session(self, server_session: _ServerSession, lock: bool) -> None: + if lock: + with self._lock: + self._session_pool.return_server_session( + server_session, self._description.logical_session_timeout_minutes + ) + else: + # Called from a __del__ method, can't use a lock. + self._session_pool.return_server_session_no_lock(server_session) + + def _new_selection(self) -> Selection: + """A Selection object, initially including all known servers. + + Hold the lock when calling this. + """ + return Selection.from_topology_description(self._description) + + def _ensure_opened(self) -> None: + """Start monitors, or restart after a fork. + + Hold the lock when calling this. + """ + if self._closed: + raise InvalidOperation("Cannot use MongoClient after close") + + if not self._opened: + self._opened = True + self._update_servers() + + # Start or restart the events publishing thread. + if self._publish_tp or self._publish_server: + self.__events_executor.open() + + # Start the SRV polling thread. + if self._srv_monitor and (self.description.topology_type in SRV_POLLING_TOPOLOGIES): + self._srv_monitor.open() + + if self._settings.load_balanced: + # Emit initial SDAM events for load balancer mode. + self._process_change( + ServerDescription( + self._seed_addresses[0], + Hello({"ok": 1, "serviceId": self._topology_id, "maxWireVersion": 13}), + ) + ) + + # Ensure that the monitors are open. + for server in self._servers.values(): + server.open() + + def _is_stale_error(self, address: _Address, err_ctx: _ErrorContext) -> bool: + server = self._servers.get(address) + if server is None: + # Another thread removed this server from the topology. + return True + + if server._pool.stale_generation(err_ctx.sock_generation, err_ctx.service_id): + # This is an outdated error from a previous pool version. + return True + + # topologyVersion check, ignore error when cur_tv >= error_tv: + cur_tv = server.description.topology_version + error = err_ctx.error + error_tv = None + if error and hasattr(error, "details"): + if isinstance(error.details, dict): + error_tv = error.details.get("topologyVersion") + + return _is_stale_error_topology_version(cur_tv, error_tv) + + def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: + if self._is_stale_error(address, err_ctx): + return + + server = self._servers[address] + error = err_ctx.error + service_id = err_ctx.service_id + + # Ignore a handshake error if the server is behind a load balancer but + # the service ID is unknown. This indicates that the error happened + # when dialing the connection or during the MongoDB handshake, so we + # don't know the service ID to use for clearing the pool. + if self._settings.load_balanced and not service_id and not err_ctx.completed_handshake: + return + + if isinstance(error, NetworkTimeout) and err_ctx.completed_handshake: + # The socket has been closed. Don't reset the server. + # Server Discovery And Monitoring Spec: "When an application + # operation fails because of any network error besides a socket + # timeout...." + return + elif isinstance(error, WriteError): + # Ignore writeErrors. + return + elif isinstance(error, (NotPrimaryError, OperationFailure)): + # As per the SDAM spec if: + # - the server sees a "not primary" error, and + # - the server is not shutting down, and + # - the server version is >= 4.2, then + # we keep the existing connection pool, but mark the server type + # as Unknown and request an immediate check of the server. + # Otherwise, we clear the connection pool, mark the server as + # Unknown and request an immediate check of the server. + if hasattr(error, "code"): + err_code = error.code + else: + # Default error code if one does not exist. + default = 10107 if isinstance(error, NotPrimaryError) else None + err_code = error.details.get("code", default) # type: ignore[union-attr] + if err_code in helpers._NOT_PRIMARY_CODES: + is_shutting_down = err_code in helpers._SHUTDOWN_CODES + # Mark server Unknown, clear the pool, and request check. + if not self._settings.load_balanced: + self._process_change(ServerDescription(address, error=error)) + if is_shutting_down or (err_ctx.max_wire_version <= 7): + # Clear the pool. + server.reset(service_id) + server.request_check() + elif not err_ctx.completed_handshake: + # Unknown command error during the connection handshake. + if not self._settings.load_balanced: + self._process_change(ServerDescription(address, error=error)) + # Clear the pool. + server.reset(service_id) + elif isinstance(error, ConnectionFailure): + # "Client MUST replace the server's description with type Unknown + # ... MUST NOT request an immediate check of the server." + if not self._settings.load_balanced: + self._process_change(ServerDescription(address, error=error)) + # Clear the pool. + server.reset(service_id) + # "When a client marks a server Unknown from `Network error when + # reading or writing`_, clients MUST cancel the hello check on + # that server and close the current monitoring connection." + server._monitor.cancel_check() + + def handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: + """Handle an application error. + + May reset the server to Unknown, clear the pool, and request an + immediate check depending on the error and the context. + """ + with self._lock: + self._handle_error(address, err_ctx) + + def _request_check_all(self) -> None: + """Wake all monitors. Hold the lock when calling this.""" + for server in self._servers.values(): + server.request_check() + + def _update_servers(self) -> None: + """Sync our Servers from TopologyDescription.server_descriptions. + + Hold the lock while calling this. + """ + for address, sd in self._description.server_descriptions().items(): + if address not in self._servers: + monitor = self._settings.monitor_class( + server_description=sd, + topology=self, + pool=self._create_pool_for_monitor(address), + topology_settings=self._settings, + ) + + weak = None + if self._publish_server and self._events is not None: + weak = weakref.ref(self._events) + server = Server( + server_description=sd, + pool=self._create_pool_for_server(address), + monitor=monitor, + topology_id=self._topology_id, + listeners=self._listeners, + events=weak, + ) + + self._servers[address] = server + server.open() + else: + # Cache old is_writable value. + was_writable = self._servers[address].description.is_writable + # Update server description. + self._servers[address].description = sd + # Update is_writable value of the pool, if it changed. + if was_writable != sd.is_writable: + self._servers[address].pool.update_is_writable(sd.is_writable) + + for address, server in list(self._servers.items()): + if not self._description.has_server(address): + server.close() + self._servers.pop(address) + + def _create_pool_for_server(self, address: _Address) -> Pool: + return self._settings.pool_class( + address, self._settings.pool_options, client_id=self._topology_id + ) + + def _create_pool_for_monitor(self, address: _Address) -> Pool: + options = self._settings.pool_options + + # According to the Server Discovery And Monitoring Spec, monitors use + # connect_timeout for both connect_timeout and socket_timeout. The + # pool only has one socket so maxPoolSize and so on aren't needed. + monitor_pool_options = PoolOptions( + connect_timeout=options.connect_timeout, + socket_timeout=options.connect_timeout, + ssl_context=options._ssl_context, + tls_allow_invalid_hostnames=options.tls_allow_invalid_hostnames, + event_listeners=options._event_listeners, + appname=options.appname, + driver=options.driver, + pause_enabled=False, + server_api=options.server_api, + ) + + return self._settings.pool_class( + address, monitor_pool_options, handshake=False, client_id=self._topology_id + ) + + def _error_message(self, selector: Callable[[Selection], Selection]) -> str: + """Format an error message if server selection fails. + + Hold the lock when calling this. + """ + is_replica_set = self._description.topology_type in ( + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + ) + + if is_replica_set: + server_plural = "replica set members" + elif self._description.topology_type == TOPOLOGY_TYPE.Sharded: + server_plural = "mongoses" + else: + server_plural = "servers" + + if self._description.known_servers: + # We've connected, but no servers match the selector. + if selector is writable_server_selector: + if is_replica_set: + return "No primary available for writes" + else: + return "No %s available for writes" % server_plural + else: + return f'No {server_plural} match selector "{selector}"' + else: + addresses = list(self._description.server_descriptions()) + servers = list(self._description.server_descriptions().values()) + if not servers: + if is_replica_set: + # We removed all servers because of the wrong setName? + return 'No {} available for replica set name "{}"'.format( + server_plural, + self._settings.replica_set_name, + ) + else: + return "No %s available" % server_plural + + # 1 or more servers, all Unknown. Are they unknown for one reason? + error = servers[0].error + same = all(server.error == error for server in servers[1:]) + if same: + if error is None: + # We're still discovering. + return "No %s found yet" % server_plural + + if is_replica_set and not set(addresses).intersection(self._seed_addresses): + # We replaced our seeds with new hosts but can't reach any. + return ( + "Could not reach any servers in %s. Replica set is" + " configured with internal hostnames or IPs?" % addresses + ) + + return str(error) + else: + return ",".join(str(server.error) for server in servers if server.error) + + def __repr__(self) -> str: + msg = "" + if not self._opened: + msg = "CLOSED " + return f"<{self.__class__.__name__} {msg}{self._description!r}>" + + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: + """The properties to use for MongoClient/Topology equality checks.""" + ts = self._settings + return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name) + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self.eq_props() == other.eq_props() + return NotImplemented + + def __hash__(self) -> int: + return hash(self.eq_props()) + + +class _ErrorContext: + """An error with context for SDAM error handling.""" + + def __init__( + self, + error: BaseException, + max_wire_version: int, + sock_generation: int, + completed_handshake: bool, + service_id: Optional[ObjectId], + ): + self.error = error + self.max_wire_version = max_wire_version + self.sock_generation = sock_generation + self.completed_handshake = completed_handshake + self.service_id = service_id + + +def _is_stale_error_topology_version( + current_tv: Optional[Mapping[str, Any]], error_tv: Optional[Mapping[str, Any]] +) -> bool: + """Return True if the error's topologyVersion is <= current.""" + if current_tv is None or error_tv is None: + return False + if current_tv["processId"] != error_tv["processId"]: + return False + return current_tv["counter"] >= error_tv["counter"] + + +def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDescription) -> bool: + """Return True if the new topologyVersion is < current.""" + current_tv, new_tv = current_sd.topology_version, new_sd.topology_version + if current_tv is None or new_tv is None: + return False + if current_tv["processId"] != new_tv["processId"]: + return False + return current_tv["counter"] > new_tv["counter"] + + +def _filter_servers( + candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None +) -> list[Server]: + """Filter out deprioritized servers from a list of server candidates.""" + if not deprioritized_servers: + return candidates + + filtered = [server for server in candidates if server not in deprioritized_servers] + + # If not possible to pick a prioritized server, return the original list + return filtered or candidates diff --git a/venv/Lib/site-packages/pymongo/topology_description.py b/venv/Lib/site-packages/pymongo/topology_description.py new file mode 100644 index 00000000..cc2330cb --- /dev/null +++ b/venv/Lib/site-packages/pymongo/topology_description.py @@ -0,0 +1,676 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Represent a deployment of MongoDB servers.""" +from __future__ import annotations + +from random import sample +from typing import ( + Any, + Callable, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + cast, +) + +from bson.min_key import MinKey +from bson.objectid import ObjectId +from pymongo import common +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _Address + + +# Enumeration for various kinds of MongoDB cluster topologies. +class _TopologyType(NamedTuple): + Single: int + ReplicaSetNoPrimary: int + ReplicaSetWithPrimary: int + Sharded: int + Unknown: int + LoadBalanced: int + + +TOPOLOGY_TYPE = _TopologyType(*range(6)) + +# Topologies compatible with SRV record polling. +SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) + + +_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] + + +class TopologyDescription: + def __init__( + self, + topology_type: int, + server_descriptions: dict[_Address, ServerDescription], + replica_set_name: Optional[str], + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], + topology_settings: Any, + ) -> None: + """Representation of a deployment of MongoDB servers. + + :param topology_type: initial type + :param server_descriptions: dict of (address, ServerDescription) for + all seeds + :param replica_set_name: replica set name or None + :param max_set_version: greatest setVersion seen from a primary, or None + :param max_election_id: greatest electionId seen from a primary, or None + :param topology_settings: a TopologySettings + """ + self._topology_type = topology_type + self._replica_set_name = replica_set_name + self._server_descriptions = server_descriptions + self._max_set_version = max_set_version + self._max_election_id = max_election_id + + # The heartbeat_frequency is used in staleness estimates. + self._topology_settings = topology_settings + + # Is PyMongo compatible with all servers' wire protocols? + self._incompatible_err = None + if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: + self._init_incompatible_err() + + # Server Discovery And Monitoring Spec: Whenever a client updates the + # TopologyDescription from an hello response, it MUST set + # TopologyDescription.logicalSessionTimeoutMinutes to the smallest + # logicalSessionTimeoutMinutes value among ServerDescriptions of all + # data-bearing server types. If any have a null + # logicalSessionTimeoutMinutes, then + # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. + readable_servers = self.readable_servers + if not readable_servers: + self._ls_timeout_minutes = None + elif any(s.logical_session_timeout_minutes is None for s in readable_servers): + self._ls_timeout_minutes = None + else: + self._ls_timeout_minutes = min( # type: ignore[type-var] + s.logical_session_timeout_minutes for s in readable_servers + ) + + def _init_incompatible_err(self) -> None: + """Internal compatibility check for non-load balanced topologies.""" + for s in self._server_descriptions.values(): + if not s.is_server_type_known: + continue + + # s.min/max_wire_version is the server's wire protocol. + # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. + server_too_new = ( + # Server too new. + s.min_wire_version is not None + and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION + ) + + server_too_old = ( + # Server too old. + s.max_wire_version is not None + and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION + ) + + if server_too_new: + self._incompatible_err = ( + "Server at %s:%d requires wire version %d, but this " # type: ignore + "version of PyMongo only supports up to %d." + % ( + s.address[0], + s.address[1] or 0, + s.min_wire_version, + common.MAX_SUPPORTED_WIRE_VERSION, + ) + ) + + elif server_too_old: + self._incompatible_err = ( + "Server at %s:%d reports wire version %d, but this " # type: ignore + "version of PyMongo requires at least %d (MongoDB %s)." + % ( + s.address[0], + s.address[1] or 0, + s.max_wire_version, + common.MIN_SUPPORTED_WIRE_VERSION, + common.MIN_SUPPORTED_SERVER_VERSION, + ) + ) + + break + + def check_compatible(self) -> None: + """Raise ConfigurationError if any server is incompatible. + + A server is incompatible if its wire protocol version range does not + overlap with PyMongo's. + """ + if self._incompatible_err: + raise ConfigurationError(self._incompatible_err) + + def has_server(self, address: _Address) -> bool: + return address in self._server_descriptions + + def reset_server(self, address: _Address) -> TopologyDescription: + """A copy of this description, with one server marked Unknown.""" + unknown_sd = self._server_descriptions[address].to_unknown() + return updated_topology_description(self, unknown_sd) + + def reset(self) -> TopologyDescription: + """A copy of this description, with all servers marked Unknown.""" + if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + else: + topology_type = self._topology_type + + # The default ServerDescription's type is Unknown. + sds = {address: ServerDescription(address) for address in self._server_descriptions} + + return TopologyDescription( + topology_type, + sds, + self._replica_set_name, + self._max_set_version, + self._max_election_id, + self._topology_settings, + ) + + def server_descriptions(self) -> dict[_Address, ServerDescription]: + """dict of (address, + :class:`~pymongo.server_description.ServerDescription`). + """ + return self._server_descriptions.copy() + + @property + def topology_type(self) -> int: + """The type of this topology.""" + return self._topology_type + + @property + def topology_type_name(self) -> str: + """The topology type as a human readable string. + + .. versionadded:: 3.4 + """ + return TOPOLOGY_TYPE._fields[self._topology_type] + + @property + def replica_set_name(self) -> Optional[str]: + """The replica set name.""" + return self._replica_set_name + + @property + def max_set_version(self) -> Optional[int]: + """Greatest setVersion seen from a primary, or None.""" + return self._max_set_version + + @property + def max_election_id(self) -> Optional[ObjectId]: + """Greatest electionId seen from a primary, or None.""" + return self._max_election_id + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + """Minimum logical session timeout, or None.""" + return self._ls_timeout_minutes + + @property + def known_servers(self) -> list[ServerDescription]: + """List of Servers of types besides Unknown.""" + return [s for s in self._server_descriptions.values() if s.is_server_type_known] + + @property + def has_known_servers(self) -> bool: + """Whether there are any Servers of types besides Unknown.""" + return any(s for s in self._server_descriptions.values() if s.is_server_type_known) + + @property + def readable_servers(self) -> list[ServerDescription]: + """List of readable Servers.""" + return [s for s in self._server_descriptions.values() if s.is_readable] + + @property + def common_wire_version(self) -> Optional[int]: + """Minimum of all servers' max wire versions, or None.""" + servers = self.known_servers + if servers: + return min(s.max_wire_version for s in self.known_servers) + + return None + + @property + def heartbeat_frequency(self) -> int: + return self._topology_settings.heartbeat_frequency + + @property + def srv_max_hosts(self) -> int: + return self._topology_settings._srv_max_hosts + + def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: + if not selection: + return [] + round_trip_times: list[float] = [] + for server in selection.server_descriptions: + if server.round_trip_time is None: + config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" + raise ConfigurationError(config_err_msg) + round_trip_times.append(server.round_trip_time) + # Round trip time in seconds. + fastest = min(round_trip_times) + threshold = self._topology_settings.local_threshold_ms / 1000.0 + return [ + s + for s in selection.server_descriptions + if (cast(float, s.round_trip_time) - fastest) <= threshold + ] + + def apply_selector( + self, + selector: Any, + address: Optional[_Address] = None, + custom_selector: Optional[_ServerSelector] = None, + ) -> list[ServerDescription]: + """List of servers matching the provided selector(s). + + :param selector: a callable that takes a Selection as input and returns + a Selection as output. For example, an instance of a read + preference from :mod:`~pymongo.read_preferences`. + :param address: A server address to select. + :param custom_selector: A callable that augments server + selection rules. Accepts a list of + :class:`~pymongo.server_description.ServerDescription` objects and + return a list of server descriptions that should be considered + suitable for the desired operation. + + .. versionadded:: 3.4 + """ + if getattr(selector, "min_wire_version", 0): + common_wv = self.common_wire_version + if common_wv and common_wv < selector.min_wire_version: + raise ConfigurationError( + "%s requires min wire version %d, but topology's min" + " wire version is %d" % (selector, selector.min_wire_version, common_wv) + ) + + if isinstance(selector, _AggWritePref): + selector.selection_hook(self) + + if self.topology_type == TOPOLOGY_TYPE.Unknown: + return [] + elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): + # Ignore selectors for standalone and load balancer mode. + return self.known_servers + if address: + # Ignore selectors when explicit address is requested. + description = self.server_descriptions().get(address) + return [description] if description else [] + + selection = Selection.from_topology_description(self) + # Ignore read preference for sharded clusters. + if self.topology_type != TOPOLOGY_TYPE.Sharded: + selection = selector(selection) + + # Apply custom selector followed by localThresholdMS. + if custom_selector is not None and selection: + selection = selection.with_server_descriptions( + custom_selector(selection.server_descriptions) + ) + return self._apply_local_threshold(selection) + + def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: + """Does this topology have any readable servers available matching the + given read preference? + + :param read_preference: an instance of a read preference from + :mod:`~pymongo.read_preferences`. Defaults to + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + common.validate_read_preference("read_preference", read_preference) + return any(self.apply_selector(read_preference)) + + def has_writable_server(self) -> bool: + """Does this topology have a writable server available? + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + return self.has_readable_server(ReadPreference.PRIMARY) + + def __repr__(self) -> str: + # Sort the servers by address. + servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) + return "<{} id: {}, topology_type: {}, servers: {!r}>".format( + self.__class__.__name__, + self._topology_settings._topology_id, + self.topology_type_name, + servers, + ) + + +# If topology type is Unknown and we receive a hello response, what should +# the new topology type be? +_SERVER_TYPE_TO_TOPOLOGY_TYPE = { + SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, + SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, + SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. +} + + +def updated_topology_description( + topology_description: TopologyDescription, server_description: ServerDescription +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param server_description: a new ServerDescription that resulted from + a hello call + + Called after attempting (successfully or not) to call hello on the + server at server_description.address. Does not modify topology_description. + """ + address = server_description.address + + # These values will be updated, if necessary, to form the new + # TopologyDescription. + topology_type = topology_description.topology_type + set_name = topology_description.replica_set_name + max_set_version = topology_description.max_set_version + max_election_id = topology_description.max_election_id + server_type = server_description.server_type + + # Don't mutate the original dict of server descriptions; copy it. + sds = topology_description.server_descriptions() + + # Replace this server's description with the new one. + sds[address] = server_description + + if topology_type == TOPOLOGY_TYPE.Single: + # Set server type to Unknown if replica set name does not match. + if set_name is not None and set_name != server_description.replica_set_name: + error = ConfigurationError( + "client is configured to connect to a replica set named " + "'{}' but this node belongs to a set named '{}'".format( + set_name, server_description.replica_set_name + ) + ) + sds[address] = server_description.to_unknown(error=error) + # Single type never changes. + return TopologyDescription( + TOPOLOGY_TYPE.Single, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + if topology_type == TOPOLOGY_TYPE.Unknown: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): + if len(topology_description._topology_settings.seeds) == 1: + topology_type = TOPOLOGY_TYPE.Single + else: + # Remove standalone from Topology when given multiple seeds. + sds.pop(address) + elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): + topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] + + if topology_type == TOPOLOGY_TYPE.Sharded: + if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): + sds.pop(address) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type, set_name = _update_rs_no_primary_from_member( + sds, set_name, server_description + ) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + topology_type = _check_has_primary(sds) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) + + else: + # Server type is Unknown or RSGhost: did we just lose the primary? + topology_type = _check_has_primary(sds) + + # Return updated copy. + return TopologyDescription( + topology_type, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + +def _updated_topology_description_srv_polling( + topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param seedlist: a list of new seeds new ServerDescription that resulted from + a hello call + """ + assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES + # Create a copy of the server descriptions. + sds = topology_description.server_descriptions() + + # If seeds haven't changed, don't do anything. + if set(sds.keys()) == set(seedlist): + return topology_description + + # Remove SDs corresponding to servers no longer part of the SRV record. + for address in list(sds.keys()): + if address not in seedlist: + sds.pop(address) + + if topology_description.srv_max_hosts != 0: + new_hosts = set(seedlist) - set(sds.keys()) + n_to_add = topology_description.srv_max_hosts - len(sds) + if n_to_add > 0: + seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) + else: + seedlist = [] + # Add SDs corresponding to servers recently added to the SRV record. + for address in seedlist: + if address not in sds: + sds[address] = ServerDescription(address) + return TopologyDescription( + topology_description.topology_type, + sds, + topology_description.replica_set_name, + topology_description.max_set_version, + topology_description.max_election_id, + topology_description._topology_settings, + ) + + +def _update_rs_from_primary( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], +) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: + """Update topology description from a primary's hello response. + + Pass in a dict of ServerDescriptions, current replica set name, the + ServerDescription we are processing, and the TopologyDescription's + max_set_version and max_election_id if any. + + Returns (new topology type, new replica_set_name, new max_set_version, + new max_election_id). + """ + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + # We found a primary but it doesn't have the replica_set_name + # provided by the user. + sds.pop(server_description.address) + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + + if server_description.max_wire_version is None or server_description.max_wire_version < 17: + new_election_tuple: tuple = (server_description.set_version, server_description.election_id) + max_election_tuple: tuple = (max_set_version, max_election_id) + if None not in new_election_tuple: + if None not in max_election_tuple and new_election_tuple < max_election_tuple: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + max_election_id = server_description.election_id + + if server_description.set_version is not None and ( + max_set_version is None or server_description.set_version > max_set_version + ): + max_set_version = server_description.set_version + else: + new_election_tuple = server_description.election_id, server_description.set_version + max_election_tuple = max_election_id, max_set_version + new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) + max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) + if new_election_safe < max_election_safe: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + else: + max_election_id = server_description.election_id + max_set_version = server_description.set_version + + # We've heard from the primary. Is it the same primary as before? + for server in sds.values(): + if ( + server.server_type is SERVER_TYPE.RSPrimary + and server.address != server_description.address + ): + # Reset old primary's type to Unknown. + sds[server.address] = server.to_unknown() + + # There can be only one prior primary. + break + + # Discover new hosts from this primary's response. + for new_address in server_description.all_hosts: + if new_address not in sds: + sds[new_address] = ServerDescription(new_address) + + # Remove hosts not in the response. + for addr in set(sds) - server_description.all_hosts: + sds.pop(addr) + + # If the host list differs from the seed list, we may not have a primary + # after all. + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) + + +def _update_rs_with_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> int: + """RS with known primary. Process a response from a non-primary. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns new topology type. + """ + assert replica_set_name is not None + + if replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + elif server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + # Had this member been the primary? + return _check_has_primary(sds) + + +def _update_rs_no_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> tuple[int, Optional[str]]: + """RS without known primary. Update from a non-primary's response. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns (new topology type, new replica_set_name). + """ + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + return topology_type, replica_set_name + + # This isn't the primary's response, so don't remove any servers + # it doesn't report. Only add new servers. + for address in server_description.all_hosts: + if address not in sds: + sds[address] = ServerDescription(address) + + if server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + return topology_type, replica_set_name + + +def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: + """Current topology type is ReplicaSetWithPrimary. Is primary still known? + + Pass in a dict of ServerDescriptions. + + Returns new topology type. + """ + for s in sds.values(): + if s.server_type == SERVER_TYPE.RSPrimary: + return TOPOLOGY_TYPE.ReplicaSetWithPrimary + else: # noqa: PLW0120 + return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/venv/Lib/site-packages/pymongo/typings.py b/venv/Lib/site-packages/pymongo/typings.py new file mode 100644 index 00000000..174a0e36 --- /dev/null +++ b/venv/Lib/site-packages/pymongo/typings.py @@ -0,0 +1,60 @@ +# Copyright 2022-Present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Type aliases used by PyMongo""" +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg + +if TYPE_CHECKING: + from pymongo.collation import Collation + + +# Common Shared Types. +_Address = Tuple[str, Optional[int]] +_CollationIn = Union[Mapping[str, Any], "Collation"] +_Pipeline = Sequence[Mapping[str, Any]] +ClusterTime = Mapping[str, Any] + +_T = TypeVar("_T") + + +def strip_optional(elem: Optional[_T]) -> _T: + """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T + while inside a list comprehension. + """ + assert elem is not None + return elem + + +__all__ = [ + "_DocumentOut", + "_DocumentType", + "_DocumentTypeArg", + "_Address", + "_CollationIn", + "_Pipeline", + "strip_optional", +] diff --git a/venv/Lib/site-packages/pymongo/uri_parser.py b/venv/Lib/site-packages/pymongo/uri_parser.py new file mode 100644 index 00000000..7f4ef57f --- /dev/null +++ b/venv/Lib/site-packages/pymongo/uri_parser.py @@ -0,0 +1,628 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" +from __future__ import annotations + +import re +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sized, + Union, + cast, +) +from urllib.parse import unquote_plus + +from pymongo.client_options import _parse_ssl_options +from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, + SRV_SERVICE_NAME, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver +from pymongo.typings import _Address + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext + +SCHEME = "mongodb://" +SCHEME_LEN = len(SCHEME) +SRV_SCHEME = "mongodb+srv://" +SRV_SCHEME_LEN = len(SRV_SCHEME) +DEFAULT_PORT = 27017 + + +def _unquoted_percent(s: str) -> bool: + """Check for unescaped percent signs. + + :param s: A string. `s` can have things like '%25', '%2525', + and '%E2%85%A8' but cannot have unquoted percent like '%foo'. + """ + for i in range(len(s)): + if s[i] == "%": + sub = s[i : i + 3] + # If unquoting yields the same string this means there was an + # unquoted %. + if unquote_plus(sub) == sub: + return True + return False + + +def parse_userinfo(userinfo: str) -> tuple[str, str]: + """Validates the format of user information in a MongoDB URI. + Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", + "]", "@") as per RFC 3986 must be escaped. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :param userinfo: A string of the form : + """ + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + + user, _, passwd = userinfo.partition(":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username.") + + return unquote_plus(user), unquote_plus(passwd) + + +def parse_ipv6_literal_host( + entity: str, default_port: Optional[int] +) -> tuple[str, Optional[Union[str, int]]]: + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :param entity: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + :param default_port: The port number to use when one wasn't + specified in entity. + """ + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." + ) + i = entity.find("]:") + if i == -1: + return entity[1:-1], default_port + return entity[1:i], entity[i + 2 :] + + +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :param entity: A host or host:port string where host could be a + hostname or IP address. + :param default_port: The port number to use when one wasn't + specified in entity. + """ + host = entity + port: Optional[Union[str, int]] = default_port + if entity[0] == "[": + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.endswith(".sock"): + return entity, default_port + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) + if isinstance(port, str): + if not port.isdigit() or int(port) > 65535 or int(port) <= 0: + raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") + port = int(port) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # http://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +# Options whose values are implicitly determined by tlsInsecure. +_IMPLICIT_TLSINSECURE_OPTS = { + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", +} + + +def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: + """Helper method for split_options which creates the options dict. + Also handles the creation of a list for the URI tag_sets/ + readpreferencetags portion, and the use of a unicode options string. + """ + options = _CaseInsensitiveDictionary() + for uriopt in opts.split(delim): + key, value = uriopt.split("=") + if key.lower() == "readpreferencetags": + options.setdefault(key, []).append(value) + else: + if key in options: + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) + if key.lower() == "authmechanismproperties": + val = value + else: + val = unquote_plus(value) + options[key] = val + + return options + + +def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Raise appropriate errors when conflicting TLS options are present in + the options dictionary. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Implicitly defined options must not be explicitly specified. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + if opt in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) + + # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") + if tlsallowinvalidcerts is not None: + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) + if tlsallowinvalidcerts is True: + options["tlsdisableocspendpointcheck"] = True + + # Handle co-occurence of CRL and OCSP-related options. + tlscrlfile = options.get("tlscrlfile") + if tlscrlfile is not None: + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): + if options.get(opt) is True: + err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." + raise InvalidURI(err_msg % (opt,)) + + if "ssl" in options and "tls" in options: + + def truth_value(val: Any) -> Any: + if val in ("true", "false"): + return val == "true" + if isinstance(val, bool): + return val + return val + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) + + return options + + +def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Issue appropriate warnings when deprecated options are present in the + options dictionary. Removes deprecated option key, value pairs if the + options dictionary is found to also have the renamed option. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + for optname in list(options): + if optname in URI_OPTIONS_DEPRECATION_MAP: + mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] + if mode == "renamed": + newoptname = message + if newoptname in options: + warn_msg = "Deprecated option '%s' ignored in favor of '%s'." + warnings.warn( + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) + options.pop(optname) + continue + warn_msg = "Option '%s' is deprecated, use '%s' instead." + warnings.warn( + warn_msg % (options.cased_key(optname), newoptname), + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": + warn_msg = "Option '%s' is deprecated. %s." + warnings.warn( + warn_msg % (options.cased_key(optname), message), + DeprecationWarning, + stacklevel=2, + ) + + return options + + +def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Normalizes option names in the options dictionary by converting them to + their internally-used names. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Expand the tlsInsecure option. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + # Implicit options are logically the same as tlsInsecure. + options[opt] = tlsinsecure + + for optname in list(options): + intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) + if intname is not None: + options[intname] = options.pop(optname) + + return options + + +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. If warn is + False then errors will be thrown for invalid options, otherwise they will + be ignored and a warning will be issued. + + :param opts: A dict of MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise invalid options will + cause errors. + """ + return get_validated_options(opts, warn) + + +def split_options( + opts: str, validate: bool = True, warn: bool = False, normalize: bool = True +) -> MutableMapping[str, Any]: + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. + + :param opt: A string representing MongoDB URI options. + :param validate: If ``True`` (the default), validate and normalize all + options. + :param warn: If ``False`` (the default), suppress all warnings raised + during validation of options. + :param normalize: If ``True`` (the default), renames all options to their + internally-used names. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators.") + elif and_idx >= 0: + options = _parse_options(opts, "&") + elif semi_idx >= 0: + options = _parse_options(opts, ";") + elif opts.find("=") != -1: + options = _parse_options(opts, None) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs.") from None + + options = _handle_security_options(options) + + options = _handle_option_deprecations(options) + + if normalize: + options = _normalize_options(options) + + if validate: + options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") + + return options + + +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :param hosts: A string of the form host1[:port],host2[:port],... + :param default_port: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(","): + if not entity: + raise ConfigurationError("Empty host (or extra comma in host list).") + port = default_port + # Unix socket entities don't have ports + if entity.endswith(".sock"): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +# Prohibited characters in database name. DB names also can't have ".", but for +# backward-compat we allow "db.collection" in URI. +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") + +_ALLOWED_TXT_OPTS = frozenset( + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) + + +def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: + # Ensure directConnection was not True if there are multiple seeds. + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") + + if options.get("loadbalanced"): + if len(nodes) > 1: + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _HAVE_DNSPYTHON: + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP.") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_part, _, path_part = scheme_free.partition("/") + if not host_part: + host_part = path_part + path_part = "" + + if path_part: + dbase, _, opts = path_part.partition("?") + else: + # There was no slash in scheme_free, check for a sole "?". + host_part, _, opts = host_part.partition("?") + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: + """Parse KMS TLS connection options.""" + if not kms_tls_options: + return {} + if not isinstance(kms_tls_options, dict): + raise TypeError("kms_tls_options must be a dict") + contexts = {} + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): + raise TypeError(f'kms_tls_options["{provider}"] must be a dict') + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) + opts = _handle_security_options(opts) + opts = _normalize_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) + ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) + if ssl_context is None: + raise ConfigurationError("TLS is required for KMS providers") + if allow_invalid_hostnames: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableCertificateRevocationCheck", + ]: + if n in opts: + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") + contexts[provider] = ssl_context + return contexts + + +if __name__ == "__main__": + import pprint + + try: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) diff --git a/venv/Lib/site-packages/pymongo/write_concern.py b/venv/Lib/site-packages/pymongo/write_concern.py new file mode 100644 index 00000000..591a126f --- /dev/null +++ b/venv/Lib/site-packages/pymongo/write_concern.py @@ -0,0 +1,141 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with write concerns.""" +from __future__ import annotations + +from typing import Any, Optional, Union + +from pymongo.errors import ConfigurationError + + +# Duplicated here to avoid a circular import. +def validate_boolean(option: str, value: Any) -> bool: + """Validates that 'value' is True or False.""" + if isinstance(value, bool): + return value + raise TypeError(f"{option} must be True or False, was: {option}={value}") + + +class WriteConcern: + """WriteConcern + + :param w: (integer or string) Used with replication, write operations + will block until they have been replicated to the specified number + or tagged set of servers. `w=` always includes the replica + set primary (e.g. w=3 means write to the primary and wait until + replicated to **two** secondaries). **w=0 disables acknowledgement + of write operations and can not be used with other write concern + options.** + :param wtimeout: (integer) **DEPRECATED** Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for write + propagation to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. + :param j: If ``True`` block until write operations have been committed + to the journal. Cannot be used in combination with `fsync`. Write + operations will fail with an exception if this option is used when + the server is running without journaling. + :param fsync: If ``True`` and the server is running without journaling, + blocks until the server has synced all data files to disk. If the + server is running with journaling, this acts the same as the `j` + option, blocking until write operations have been committed to the + journal. Cannot be used in combination with `j`. + + + .. versionchanged:: 4.7 + Deprecated parameter ``wtimeout``, use :meth:`~pymongo.timeout`. + """ + + __slots__ = ("__document", "__acknowledged", "__server_default") + + def __init__( + self, + w: Optional[Union[int, str]] = None, + wtimeout: Optional[int] = None, + j: Optional[bool] = None, + fsync: Optional[bool] = None, + ) -> None: + self.__document: dict[str, Any] = {} + self.__acknowledged = True + + if wtimeout is not None: + if not isinstance(wtimeout, int): + raise TypeError("wtimeout must be an integer") + if wtimeout < 0: + raise ValueError("wtimeout cannot be less than 0") + self.__document["wtimeout"] = wtimeout + + if j is not None: + validate_boolean("j", j) + self.__document["j"] = j + + if fsync is not None: + validate_boolean("fsync", fsync) + if j and fsync: + raise ConfigurationError("Can't set both j and fsync at the same time") + self.__document["fsync"] = fsync + + if w == 0 and j is True: + raise ConfigurationError("Cannot set w to 0 and j to True") + + if w is not None: + if isinstance(w, int): + if w < 0: + raise ValueError("w cannot be less than 0") + self.__acknowledged = w > 0 + elif not isinstance(w, str): + raise TypeError("w must be an integer or string") + self.__document["w"] = w + + self.__server_default = not self.__document + + @property + def is_server_default(self) -> bool: + """Does this WriteConcern match the server default.""" + return self.__server_default + + @property + def document(self) -> dict[str, Any]: + """The document representation of this write concern. + + .. note:: + :class:`WriteConcern` is immutable. Mutating the value of + :attr:`document` does not mutate this :class:`WriteConcern`. + """ + return self.__document.copy() + + @property + def acknowledged(self) -> bool: + """If ``True`` write operations will wait for acknowledgement before + returning. + """ + return self.__acknowledged + + def __repr__(self) -> str: + return "WriteConcern({})".format( + ", ".join("{}={}".format(*kvt) for kvt in self.__document.items()) + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, WriteConcern): + return self.__document == other.document + return NotImplemented + + def __ne__(self, other: Any) -> bool: + if isinstance(other, WriteConcern): + return self.__document != other.document + return NotImplemented + + +DEFAULT_WRITE_CONCERN = WriteConcern()