Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts
from pymongo.encryption_options import (
AutoEncryptionOpts,
RangeOpts,
TextOpts,
check_min_pymongocrypt,
)
from pymongo.errors import (
ConfigurationError,
EncryptedCollectionError,
Expand Down Expand Up @@ -675,6 +680,8 @@ def __init__(
"python -m pip install --upgrade 'pymongo[encryption]'"
)

check_min_pymongocrypt()

if not isinstance(codec_options, CodecOptions):
raise TypeError(
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
Expand Down
12 changes: 11 additions & 1 deletion pymongo/asynchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
from typing import TYPE_CHECKING, Any, Optional, Union

from pymongo.common import CONNECT_TIMEOUT
from pymongo.common import CONNECT_TIMEOUT, check_for_min_version
from pymongo.errors import ConfigurationError

if TYPE_CHECKING:
Expand All @@ -32,6 +32,14 @@ def _have_dnspython() -> bool:
try:
import dns # noqa: F401

dns_version, required_version, is_valid = check_for_min_version("dnspython")
if not is_valid:
raise RuntimeError(
f"pymongo requires dnspython>={required_version}, "
f"found version {dns_version}. "
"Install a compatible version with pip"
)

return True
except ImportError:
return False
Expand Down Expand Up @@ -80,6 +88,8 @@ def __init__(
srv_service_name: str,
srv_max_hosts: int = 0,
):
# Ensure the version of dnspython is compatible.
_have_dnspython()
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
Expand Down
89 changes: 89 additions & 0 deletions pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from collections import OrderedDict, abc
from difflib import get_close_matches
from importlib.metadata import requires, version
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1092,3 +1093,91 @@ def has_c() -> bool:
return True
except ImportError:
return False


class Version(tuple[int, ...]):
"""A class that can be used to compare version strings."""

def __new__(cls, *version: int) -> Version:
padded_version = cls._padded(version, 4)
return super().__new__(cls, tuple(padded_version))

@classmethod
def _padded(cls, iter: Any, length: int, padding: int = 0) -> list[int]:
as_list = list(iter)
if len(as_list) < length:
for _ in range(length - len(as_list)):
as_list.append(padding)
return as_list

@classmethod
def from_string(cls, version_string: str) -> Version:
mod = 0
bump_patch_level = False
if version_string.endswith("+"):
version_string = version_string[0:-1]
mod = 1
elif version_string.endswith("-pre-"):
version_string = version_string[0:-5]
mod = -1
elif version_string.endswith("-"):
version_string = version_string[0:-1]
mod = -1
# Deal with .devX substrings
if ".dev" in version_string:
version_string = version_string[0 : version_string.find(".dev")]
mod = -1
# Deal with '-rcX' substrings
if "-rc" in version_string:
version_string = version_string[0 : version_string.find("-rc")]
mod = -1
# Deal with git describe generated substrings
elif "-" in version_string:
version_string = version_string[0 : version_string.find("-")]
mod = -1
bump_patch_level = True

version = [int(part) for part in version_string.split(".")]
version = cls._padded(version, 3)
# Make from_string and from_version_array agree. For example:
# MongoDB Enterprise > db.runCommand('buildInfo').versionArray
# [ 3, 2, 1, -100 ]
# MongoDB Enterprise > db.runCommand('buildInfo').version
# 3.2.0-97-g1ef94fe
if bump_patch_level:
version[-1] += 1
version.append(mod)

return Version(*version)

@classmethod
def from_version_array(cls, version_array: Any) -> Version:
version = list(version_array)
if version[-1] < 0:
version[-1] = -1
version = cls._padded(version, 3)
return Version(*version)

def at_least(self, *other_version: Any) -> bool:
return self >= Version(*other_version)

def __str__(self) -> str:
return ".".join(map(str, self))


def check_for_min_version(package_name: str) -> tuple[str, str, bool]:
"""Test whether an installed package is of the desired version."""
package_version_str = version(package_name)
package_version = Version.from_string(package_version_str)
# Dependency is expected to be in one of the forms:
# "pymongocrypt<2.0.0,>=1.13.0; extra == 'encryption'"
# 'dnspython<3.0.0,>=1.16.0'
#
requirements = requires("pymongo")
assert requirements is not None
requirement = [i for i in requirements if i.startswith(package_name)][0] # noqa: RUF015
if ";" in requirement:
requirement = requirement.split(";")[0]
required_version = requirement[requirement.find(">=") + 2 :]
is_valid = package_version >= Version.from_string(required_version)
return package_version_str, required_version, is_valid
17 changes: 15 additions & 2 deletions pymongo/encryption_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymongo.uri_parser_shared import _parse_kms_tls_options

try:
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
import pymongocrypt # type:ignore[import-untyped] # noqa: F401

# Check for pymongocrypt>=1.10.
from pymongocrypt import synchronous as _ # noqa: F401
Expand All @@ -32,14 +32,26 @@
except ImportError:
_HAVE_PYMONGOCRYPT = False
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.common import check_for_min_version, validate_is_mapping
from pymongo.errors import ConfigurationError

if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
from pymongo.typings import _AgnosticMongoClient


def check_min_pymongocrypt() -> None:
"""Raise an appropriate error if the min pymongocrypt is not installed."""
pymongocrypt_version, required_version, is_valid = check_for_min_version("pymongocrypt")
if not is_valid:
raise ConfigurationError(
f"client side encryption requires pymongocrypt>={required_version}, "
f"found version {pymongocrypt_version}. "
"Install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)


class AutoEncryptionOpts:
"""Options to configure automatic client-side field level encryption."""

Expand Down Expand Up @@ -215,6 +227,7 @@ def __init__(
"install a compatible version with: "
"python -m pip install 'pymongo[encryption]'"
)
check_min_pymongocrypt()
if encrypted_fields_map:
validate_is_mapping("encrypted_fields_map", encrypted_fields_map)
self._encrypted_fields_map = encrypted_fields_map
Expand Down
9 changes: 8 additions & 1 deletion pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@
from pymongo import _csot
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts
from pymongo.encryption_options import (
AutoEncryptionOpts,
RangeOpts,
TextOpts,
check_min_pymongocrypt,
)
from pymongo.errors import (
ConfigurationError,
EncryptedCollectionError,
Expand Down Expand Up @@ -672,6 +677,8 @@ def __init__(
"python -m pip install --upgrade 'pymongo[encryption]'"
)

check_min_pymongocrypt()

if not isinstance(codec_options, CodecOptions):
raise TypeError(
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
Expand Down
12 changes: 11 additions & 1 deletion pymongo/synchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random
from typing import TYPE_CHECKING, Any, Optional, Union

from pymongo.common import CONNECT_TIMEOUT
from pymongo.common import CONNECT_TIMEOUT, check_for_min_version
from pymongo.errors import ConfigurationError

if TYPE_CHECKING:
Expand All @@ -32,6 +32,14 @@ def _have_dnspython() -> bool:
try:
import dns # noqa: F401

dns_version, required_version, is_valid = check_for_min_version("dnspython")
if not is_valid:
raise RuntimeError(
f"pymongo requires dnspython>={required_version}, "
f"found version {dns_version}. "
"Install a compatible version with pip"
)

return True
except ImportError:
return False
Expand Down Expand Up @@ -80,6 +88,8 @@ def __init__(
srv_service_name: str,
srv_max_hosts: int = 0,
):
# Ensure the version of dnspython is compatible.
_have_dnspython()
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
Expand Down
64 changes: 2 additions & 62 deletions test/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,10 @@
"""Some tools for running tests based on MongoDB server version."""
from __future__ import annotations

from pymongo.common import Version as BaseVersion

class Version(tuple):
def __new__(cls, *version):
padded_version = cls._padded(version, 4)
return super().__new__(cls, tuple(padded_version))

@classmethod
def _padded(cls, iter, length, padding=0):
l = list(iter)
if len(l) < length:
for _ in range(length - len(l)):
l.append(padding)
return l

@classmethod
def from_string(cls, version_string):
mod = 0
bump_patch_level = False
if version_string.endswith("+"):
version_string = version_string[0:-1]
mod = 1
elif version_string.endswith("-pre-"):
version_string = version_string[0:-5]
mod = -1
elif version_string.endswith("-"):
version_string = version_string[0:-1]
mod = -1
# Deal with '-rcX' substrings
if "-rc" in version_string:
version_string = version_string[0 : version_string.find("-rc")]
mod = -1
# Deal with git describe generated substrings
elif "-" in version_string:
version_string = version_string[0 : version_string.find("-")]
mod = -1
bump_patch_level = True

version = [int(part) for part in version_string.split(".")]
version = cls._padded(version, 3)
# Make from_string and from_version_array agree. For example:
# MongoDB Enterprise > db.runCommand('buildInfo').versionArray
# [ 3, 2, 1, -100 ]
# MongoDB Enterprise > db.runCommand('buildInfo').version
# 3.2.0-97-g1ef94fe
if bump_patch_level:
version[-1] += 1
version.append(mod)

return Version(*version)

@classmethod
def from_version_array(cls, version_array):
version = list(version_array)
if version[-1] < 0:
version[-1] = -1
version = cls._padded(version, 3)
return Version(*version)

class Version(BaseVersion):
@classmethod
def from_client(cls, client):
info = client.server_info()
Expand All @@ -86,9 +32,3 @@ async def async_from_client(cls, client):
if "versionArray" in info:
return cls.from_version_array(info["versionArray"])
return cls.from_string(info["version"])

def at_least(self, *other_version):
return self >= Version(*other_version)

def __str__(self):
return ".".join(map(str, self))
Loading