Skip to content

refactor _document_registry + log a warning when user register multip… #2861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@ Development
- make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api
- run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions
- Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions)
- BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry
- Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778
- Fix use of $geoNear or $collStats in aggregate #2493
- BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface
- BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+
@@ -21,6 +23,7 @@ Development
- BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858


Changes in 0.29.0
=================
- Fix weakref in EmbeddedDocumentListField (causing brief mem leak in certain circumstances) #2827
3 changes: 1 addition & 2 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,7 @@
__all__ = (
# common
"UPDATE_OPERATORS",
"_document_registry",
"get_document",
"_DocumentRegistry",
# datastructures
"BaseDict",
"BaseList",
77 changes: 54 additions & 23 deletions mongoengine/base/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings

from mongoengine.errors import NotRegistered

__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")


UPDATE_OPERATORS = {
@@ -25,28 +27,57 @@
_document_registry = {}


def get_document(name):
"""Get a registered Document class by name."""
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k for k in _document_registry if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc
class _DocumentRegistry:
"""Wrapper for the document registry (providing a singleton pattern).
This is part of MongoEngine's internals, not meant to be used directly by end-users
"""

@staticmethod
def get(name):
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k
for k in _document_registry
if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip()
% name
)
return doc

@staticmethod
def register(DocCls):
ExistingDocCls = _document_registry.get(DocCls._class_name)
if (
ExistingDocCls is not None
and ExistingDocCls.__module__ != DocCls.__module__
):
# A sign that a codebase may have named two different classes with the same name accidentally,
# this could cause issues with dereferencing because MongoEngine makes the assumption that a Document
# class name is unique.
warnings.warn(
f"Multiple Document classes named `{DocCls._class_name}` were registered, "
f"first from: `{ExistingDocCls.__module__}`, then from: `{DocCls.__module__}`. "
"this may lead to unexpected behavior during dereferencing.",
stacklevel=4,
)
_document_registry[DocCls._class_name] = DocCls

@staticmethod
def unregister(doc_cls_name):
_document_registry.pop(doc_cls_name)


def _get_documents_by_db(connection_alias, default_connection_alias):
6 changes: 3 additions & 3 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from bson import SON, DBRef, ObjectId, json_util

from mongoengine import signals
from mongoengine.base.common import get_document
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
@@ -500,7 +500,7 @@ def __expand_dynamic_values(self, name, value):
# If the value is a dict with '_cls' in it, turn it into a document
is_dict = isinstance(value, dict)
if is_dict and "_cls" in value:
cls = get_document(value["_cls"])
cls = _DocumentRegistry.get(value["_cls"])
return cls(**value)

if is_dict:
@@ -802,7 +802,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False):

# Return correct subclass for document type
if class_name != cls._class_name:
cls = get_document(class_name)
cls = _DocumentRegistry.get(class_name)

errors_dict = {}

4 changes: 2 additions & 2 deletions mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import warnings

from mongoengine.base.common import _document_registry
from mongoengine.base.common import _DocumentRegistry
from mongoengine.base.fields import (
BaseField,
ComplexBaseField,
@@ -169,7 +169,7 @@ def __new__(mcs, name, bases, attrs):
new_class._collection = None

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class
_DocumentRegistry.register(new_class)

# Handle delete rules
for field in new_class._fields.values():
20 changes: 10 additions & 10 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
BaseList,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import _get_session, get_db
@@ -131,9 +131,9 @@ def _find_references(self, items, depth=0):
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and "_ref" in v:
reference_map.setdefault(get_document(v["_cls"]), set()).add(
v["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(v["_cls"]), set()
).add(v["_ref"].id)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(
getattr(field, "field", None), "document_type", None
@@ -151,9 +151,9 @@ def _find_references(self, items, depth=0):
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and "_ref" in item:
reference_map.setdefault(get_document(item["_cls"]), set()).add(
item["_ref"].id
)
reference_map.setdefault(
_DocumentRegistry.get(item["_cls"]), set()
).add(item["_ref"].id)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.items():
@@ -198,9 +198,9 @@ def _fetch_objects(self, doc_type=None):
)
for ref in references:
if "_cls" in ref:
doc = get_document(ref["_cls"])._from_son(ref)
doc = _DocumentRegistry.get(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
doc = _DocumentRegistry.get(
"".join(x.capitalize() for x in collection.split("_"))
)._from_son(ref)
else:
@@ -235,7 +235,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):
(items["_ref"].collection, items["_ref"].id), items
)
elif "_cls" in items:
doc = get_document(items["_cls"])._from_son(items)
doc = _DocumentRegistry.get(items["_cls"])._from_son(items)
_cls = doc._data.pop("_cls", None)
del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
6 changes: 3 additions & 3 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
DocumentMetaclass,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import NonOrderedList
from mongoengine.common import _import_class
@@ -851,12 +851,12 @@ def register_delete_rule(cls, document_cls, field_name, rule):
object.
"""
classes = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in cls._subclasses
if class_name != cls.__name__
] + [cls]
documents = [
get_document(class_name)
_DocumentRegistry.get(class_name)
for class_name in document_cls._subclasses
if class_name != document_cls.__name__
] + [document_cls]
20 changes: 10 additions & 10 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
GeoJsonBaseField,
LazyReference,
ObjectIdField,
get_document,
_DocumentRegistry,
)
from mongoengine.base.utils import LazyRegexCompiler
from mongoengine.common import _import_class
@@ -725,7 +725,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
resolved_document_type = self.owner_document
else:
resolved_document_type = get_document(self.document_type_obj)
resolved_document_type = _DocumentRegistry.get(self.document_type_obj)

if not issubclass(resolved_document_type, EmbeddedDocument):
# Due to the late resolution of the document_type
@@ -801,7 +801,7 @@ def prepare_query_value(self, op, value):

def to_python(self, value):
if isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
value = doc_cls._from_son(value)

return value
@@ -879,7 +879,7 @@ def to_mongo(self, value, use_db_field=True, fields=None):

def to_python(self, value):
if isinstance(value, dict) and "_cls" in value:
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
if "_ref" in value:
value = doc_cls._get_db().dereference(
value["_ref"], session=_get_session()
@@ -1171,7 +1171,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
@@ -1195,7 +1195,7 @@ def __get__(self, instance, owner):
if auto_dereference and isinstance(ref_value, DBRef):
if hasattr(ref_value, "cls"):
# Dereference using the class type specified in the reference
cls = get_document(ref_value.cls)
cls = _DocumentRegistry.get(ref_value.cls)
else:
cls = self.document_type

@@ -1335,7 +1335,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

@staticmethod
@@ -1498,7 +1498,7 @@ def __get__(self, instance, owner):

auto_dereference = instance._fields[self.name]._auto_dereference
if auto_dereference and isinstance(value, dict):
doc_cls = get_document(value["_cls"])
doc_cls = _DocumentRegistry.get(value["_cls"])
instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"])

return super().__get__(instance, owner)
@@ -2443,7 +2443,7 @@ def document_type(self):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
self.document_type_obj = get_document(self.document_type_obj)
self.document_type_obj = _DocumentRegistry.get(self.document_type_obj)
return self.document_type_obj

def build_lazyref(self, value):
@@ -2584,7 +2584,7 @@ def build_lazyref(self, value):
elif value is not None:
if isinstance(value, (dict, SON)):
value = LazyReference(
get_document(value["_cls"]),
_DocumentRegistry.get(value["_cls"]),
value["_ref"].id,
passthrough=self.passthrough,
)
6 changes: 4 additions & 2 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from pymongo.read_concern import ReadConcern

from mongoengine import signals
from mongoengine.base import get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.common import _import_class
from mongoengine.connection import _get_session, get_db
from mongoengine.context_managers import (
@@ -1956,7 +1956,9 @@ def _fields_to_dbfields(self, fields):
"""Translate fields' paths to their db equivalents."""
subclasses = []
if self._document._meta["allow_inheritance"]:
subclasses = [get_document(x) for x in self._document._subclasses][1:]
subclasses = [_DocumentRegistry.get(x) for x in self._document._subclasses][
1:
]

db_field_paths = []
for field in fields:
8 changes: 4 additions & 4 deletions tests/document/test_instance.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@

from mongoengine import *
from mongoengine import signals
from mongoengine.base import _document_registry, get_document
from mongoengine.base import _DocumentRegistry
from mongoengine.connection import get_db
from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import (
@@ -392,7 +392,7 @@ class NicePlace(Place):

# Mimic Place and NicePlace definitions being in a different file
# and the NicePlace model not being imported in at query time.
del _document_registry["Place.NicePlace"]
_DocumentRegistry.unregister("Place.NicePlace")

with pytest.raises(NotRegistered):
list(Place.objects.all())
@@ -407,8 +407,8 @@ class Area(Location):

Location.drop_collection()

assert Area == get_document("Area")
assert Area == get_document("Location.Area")
assert Area == _DocumentRegistry.get("Area")
assert Area == _DocumentRegistry.get("Location.Area")

def test_creation(self):
"""Ensure that document may be created using keyword arguments."""
4 changes: 2 additions & 2 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@
from mongoengine.base import (
BaseField,
EmbeddedDocumentList,
_document_registry,
_DocumentRegistry,
)
from mongoengine.base.fields import _no_dereference_for_fields
from mongoengine.errors import DeprecatedError
@@ -1678,7 +1678,7 @@ class User(Document):

# Mimic User and Link definitions being in a different file
# and the Link model not being imported in the User file.
del _document_registry["Link"]
_DocumentRegistry.unregister("Link")

user = User.objects.first()
try: