Skip to content

Commit 0049dc8

Browse files
authored
PYTHON-2390 - Retryable reads use the same implicit session (#2544)
1 parent 51f7b40 commit 0049dc8

20 files changed

+198
-130
lines changed

pymongo/asynchronous/aggregation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def __init__(
5050
cursor_class: type[AsyncCommandCursor[Any]],
5151
pipeline: _Pipeline,
5252
options: MutableMapping[str, Any],
53-
explicit_session: bool,
5453
let: Optional[Mapping[str, Any]] = None,
5554
user_fields: Optional[MutableMapping[str, Any]] = None,
5655
result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None,
@@ -92,7 +91,6 @@ def __init__(
9291
self._options["cursor"]["batchSize"] = self._batch_size
9392

9493
self._cursor_class = cursor_class
95-
self._explicit_session = explicit_session
9694
self._user_fields = user_fields
9795
self._result_processor = result_processor
9896

@@ -197,7 +195,6 @@ async def get_cursor(
197195
batch_size=self._batch_size or 0,
198196
max_await_time_ms=self._max_await_time_ms,
199197
session=session,
200-
explicit_session=self._explicit_session,
201198
comment=self._options.get("comment"),
202199
)
203200
await cmd_cursor._maybe_pin_connection(conn)

pymongo/asynchronous/change_stream.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N
236236
)
237237

238238
async def _run_aggregation_cmd(
239-
self, session: Optional[AsyncClientSession], explicit_session: bool
239+
self, session: Optional[AsyncClientSession]
240240
) -> AsyncCommandCursor: # type: ignore[type-arg]
241241
"""Run the full aggregation pipeline for this AsyncChangeStream and return
242242
the corresponding AsyncCommandCursor.
@@ -246,7 +246,6 @@ async def _run_aggregation_cmd(
246246
AsyncCommandCursor,
247247
self._aggregation_pipeline(),
248248
self._command_options(),
249-
explicit_session,
250249
result_processor=self._process_result,
251250
comment=self._comment,
252251
)
@@ -258,10 +257,8 @@ async def _run_aggregation_cmd(
258257
)
259258

260259
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
261-
async with self._client._tmp_session(self._session, close=False) as s:
262-
return await self._run_aggregation_cmd(
263-
session=s, explicit_session=self._session is not None
264-
)
260+
async with self._client._tmp_session(self._session) as s:
261+
return await self._run_aggregation_cmd(session=s)
265262

266263
async def _resume(self) -> None:
267264
"""Reestablish this change stream after a resumable error."""

pymongo/asynchronous/client_bulk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,8 @@ async def _process_results_cursor(
440440
) -> None:
441441
"""Internal helper for processing the server reply command cursor."""
442442
if result.get("cursor"):
443+
if session:
444+
session._leave_alive = True
443445
coll = AsyncCollection(
444446
database=AsyncDatabase(self.client, "admin"),
445447
name="$cmd.bulkWrite",
@@ -449,7 +451,6 @@ async def _process_results_cursor(
449451
result["cursor"],
450452
conn.address,
451453
session=session,
452-
explicit_session=session is not None,
453454
comment=self.comment,
454455
)
455456
await cmd_cursor._maybe_pin_connection(conn)

pymongo/asynchronous/client_session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,10 @@ def __init__(
513513
# Is this an implicitly created session?
514514
self._implicit = implicit
515515
self._transaction = _Transaction(None, client)
516+
# Is this session attached to a cursor?
517+
self._attached_to_cursor = False
518+
# Should we leave the session alive when the cursor is closed?
519+
self._leave_alive = False
516520

517521
async def end_session(self) -> None:
518522
"""Finish this session. If a transaction has started, abort it.
@@ -535,7 +539,7 @@ async def _end_session(self, lock: bool) -> None:
535539

536540
def _end_implicit_session(self) -> None:
537541
# Implicit sessions can't be part of transactions or pinned connections
538-
if self._server_session is not None:
542+
if not self._leave_alive and self._server_session is not None:
539543
self._client._return_server_session(self._server_session)
540544
self._server_session = None
541545

pymongo/asynchronous/collection.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,7 +2549,6 @@ async def _list_indexes(
25492549
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
25502550
)
25512551
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
2552-
explicit_session = session is not None
25532552

25542553
async def _cmd(
25552554
session: Optional[AsyncClientSession],
@@ -2576,13 +2575,12 @@ async def _cmd(
25762575
cursor,
25772576
conn.address,
25782577
session=session,
2579-
explicit_session=explicit_session,
25802578
comment=cmd.get("comment"),
25812579
)
25822580
await cmd_cursor._maybe_pin_connection(conn)
25832581
return cmd_cursor
25842582

2585-
async with self._database.client._tmp_session(session, False) as s:
2583+
async with self._database.client._tmp_session(session) as s:
25862584
return await self._database.client._retryable_read(
25872585
_cmd, read_pref, s, operation=_Op.LIST_INDEXES
25882586
)
@@ -2678,7 +2676,6 @@ async def list_search_indexes(
26782676
AsyncCommandCursor,
26792677
pipeline,
26802678
kwargs,
2681-
explicit_session=session is not None,
26822679
comment=comment,
26832680
user_fields={"cursor": {"firstBatch": 1}},
26842681
)
@@ -2900,7 +2897,6 @@ async def _aggregate(
29002897
pipeline: _Pipeline,
29012898
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
29022899
session: Optional[AsyncClientSession],
2903-
explicit_session: bool,
29042900
let: Optional[Mapping[str, Any]] = None,
29052901
comment: Optional[Any] = None,
29062902
**kwargs: Any,
@@ -2912,7 +2908,6 @@ async def _aggregate(
29122908
cursor_class,
29132909
pipeline,
29142910
kwargs,
2915-
explicit_session,
29162911
let,
29172912
user_fields={"cursor": {"firstBatch": 1}},
29182913
)
@@ -3018,13 +3013,12 @@ async def aggregate(
30183013
.. _aggregate command:
30193014
https://mongodb.com/docs/manual/reference/command/aggregate
30203015
"""
3021-
async with self._database.client._tmp_session(session, close=False) as s:
3016+
async with self._database.client._tmp_session(session) as s:
30223017
return await self._aggregate(
30233018
_CollectionAggregationCommand,
30243019
pipeline,
30253020
AsyncCommandCursor,
30263021
session=s,
3027-
explicit_session=session is not None,
30283022
let=let,
30293023
comment=comment,
30303024
**kwargs,
@@ -3065,15 +3059,14 @@ async def aggregate_raw_batches(
30653059
raise InvalidOperation("aggregate_raw_batches does not support auto encryption")
30663060
if comment is not None:
30673061
kwargs["comment"] = comment
3068-
async with self._database.client._tmp_session(session, close=False) as s:
3062+
async with self._database.client._tmp_session(session) as s:
30693063
return cast(
30703064
AsyncRawBatchCursor[_DocumentType],
30713065
await self._aggregate(
30723066
_CollectionRawAggregationCommand,
30733067
pipeline,
30743068
AsyncRawBatchCommandCursor,
30753069
session=s,
3076-
explicit_session=session is not None,
30773070
**kwargs,
30783071
),
30793072
)

pymongo/asynchronous/command_cursor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464
batch_size: int = 0,
6565
max_await_time_ms: Optional[int] = None,
6666
session: Optional[AsyncClientSession] = None,
67-
explicit_session: bool = False,
6867
comment: Any = None,
6968
) -> None:
7069
"""Create a new command cursor."""
@@ -80,7 +79,8 @@ def __init__(
8079
self._max_await_time_ms = max_await_time_ms
8180
self._timeout = self._collection.database.client.options.timeout
8281
self._session = session
83-
self._explicit_session = explicit_session
82+
if self._session is not None:
83+
self._session._attached_to_cursor = True
8484
self._killed = self._id == 0
8585
self._comment = comment
8686
if self._killed:
@@ -197,7 +197,7 @@ def session(self) -> Optional[AsyncClientSession]:
197197
198198
.. versionadded:: 3.6
199199
"""
200-
if self._explicit_session:
200+
if self._session and not self._session._implicit:
201201
return self._session
202202
return None
203203

@@ -218,9 +218,10 @@ def _die_no_lock(self) -> None:
218218
"""Closes this cursor without acquiring a lock."""
219219
cursor_id, address = self._prepare_to_die()
220220
self._collection.database.client._cleanup_cursor_no_lock(
221-
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
221+
cursor_id, address, self._sock_mgr, self._session
222222
)
223-
if not self._explicit_session:
223+
if self._session and self._session._implicit:
224+
self._session._attached_to_cursor = False
224225
self._session = None
225226
self._sock_mgr = None
226227

@@ -232,14 +233,15 @@ async def _die_lock(self) -> None:
232233
address,
233234
self._sock_mgr,
234235
self._session,
235-
self._explicit_session,
236236
)
237-
if not self._explicit_session:
237+
if self._session and self._session._implicit:
238+
self._session._attached_to_cursor = False
238239
self._session = None
239240
self._sock_mgr = None
240241

241242
def _end_session(self) -> None:
242-
if self._session and not self._explicit_session:
243+
if self._session and self._session._implicit:
244+
self._session._attached_to_cursor = False
243245
self._session._end_implicit_session()
244246
self._session = None
245247

@@ -430,7 +432,6 @@ def __init__(
430432
batch_size: int = 0,
431433
max_await_time_ms: Optional[int] = None,
432434
session: Optional[AsyncClientSession] = None,
433-
explicit_session: bool = False,
434435
comment: Any = None,
435436
) -> None:
436437
"""Create a new cursor / iterator over raw batches of BSON data.
@@ -449,7 +450,6 @@ def __init__(
449450
batch_size,
450451
max_await_time_ms,
451452
session,
452-
explicit_session,
453453
comment,
454454
)
455455

pymongo/asynchronous/cursor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,9 @@ def __init__(
138138

139139
if session:
140140
self._session = session
141-
self._explicit_session = True
141+
self._session._attached_to_cursor = True
142142
else:
143143
self._session = None
144-
self._explicit_session = False
145144

146145
spec: Mapping[str, Any] = filter or {}
147146
validate_is_mapping("filter", spec)
@@ -150,7 +149,7 @@ def __init__(
150149
if not isinstance(limit, int):
151150
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
152151
validate_boolean("no_cursor_timeout", no_cursor_timeout)
153-
if no_cursor_timeout and not self._explicit_session:
152+
if no_cursor_timeout and self._session and self._session._implicit:
154153
warnings.warn(
155154
"use an explicit session with no_cursor_timeout=True "
156155
"otherwise the cursor may still timeout after "
@@ -283,7 +282,7 @@ def clone(self) -> AsyncCursor[_DocumentType]:
283282
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
284283
"""Internal clone helper."""
285284
if not base:
286-
if self._explicit_session:
285+
if self._session and not self._session._implicit:
287286
base = self._clone_base(self._session)
288287
else:
289288
base = self._clone_base(None)
@@ -945,7 +944,7 @@ def session(self) -> Optional[AsyncClientSession]:
945944
946945
.. versionadded:: 3.6
947946
"""
948-
if self._explicit_session:
947+
if self._session and not self._session._implicit:
949948
return self._session
950949
return None
951950

@@ -1034,9 +1033,10 @@ def _die_no_lock(self) -> None:
10341033

10351034
cursor_id, address = self._prepare_to_die(already_killed)
10361035
self._collection.database.client._cleanup_cursor_no_lock(
1037-
cursor_id, address, self._sock_mgr, self._session, self._explicit_session
1036+
cursor_id, address, self._sock_mgr, self._session
10381037
)
1039-
if not self._explicit_session:
1038+
if self._session and self._session._implicit:
1039+
self._session._attached_to_cursor = False
10401040
self._session = None
10411041
self._sock_mgr = None
10421042

@@ -1054,9 +1054,9 @@ async def _die_lock(self) -> None:
10541054
address,
10551055
self._sock_mgr,
10561056
self._session,
1057-
self._explicit_session,
10581057
)
1059-
if not self._explicit_session:
1058+
if self._session and self._session._implicit:
1059+
self._session._attached_to_cursor = False
10601060
self._session = None
10611061
self._sock_mgr = None
10621062

pymongo/asynchronous/database.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ async def create_collection(
611611
common.validate_is_mapping("clusteredIndex", clustered_index)
612612

613613
async with self._client._tmp_session(session) as s:
614+
if s and not s.in_transaction:
615+
s._leave_alive = True
614616
# Skip this check in a transaction where listCollections is not
615617
# supported.
616618
if (
@@ -619,6 +621,8 @@ async def create_collection(
619621
and name in await self._list_collection_names(filter={"name": name}, session=s)
620622
):
621623
raise CollectionInvalid("collection %s already exists" % name)
624+
if s:
625+
s._leave_alive = False
622626
coll = AsyncCollection(
623627
self,
624628
name,
@@ -699,13 +703,12 @@ async def aggregate(
699703
.. _aggregate command:
700704
https://mongodb.com/docs/manual/reference/command/aggregate
701705
"""
702-
async with self.client._tmp_session(session, close=False) as s:
706+
async with self.client._tmp_session(session) as s:
703707
cmd = _DatabaseAggregationCommand(
704708
self,
705709
AsyncCommandCursor,
706710
pipeline,
707711
kwargs,
708-
session is not None,
709712
user_fields={"cursor": {"firstBatch": 1}},
710713
)
711714
return await self.client._retryable_read(
@@ -1011,7 +1014,7 @@ async def cursor_command(
10111014
else:
10121015
command_name = next(iter(command))
10131016

1014-
async with self._client._tmp_session(session, close=False) as tmp_session:
1017+
async with self._client._tmp_session(session) as tmp_session:
10151018
opts = codec_options or DEFAULT_CODEC_OPTIONS
10161019

10171020
if read_preference is None:
@@ -1043,7 +1046,6 @@ async def cursor_command(
10431046
conn.address,
10441047
max_await_time_ms=max_await_time_ms,
10451048
session=tmp_session,
1046-
explicit_session=session is not None,
10471049
comment=comment,
10481050
)
10491051
await cmd_cursor._maybe_pin_connection(conn)
@@ -1089,7 +1091,7 @@ async def _list_collections(
10891091
)
10901092
cmd = {"listCollections": 1, "cursor": {}}
10911093
cmd.update(kwargs)
1092-
async with self._client._tmp_session(session, close=False) as tmp_session:
1094+
async with self._client._tmp_session(session) as tmp_session:
10931095
cursor = (
10941096
await self._command(conn, cmd, read_preference=read_preference, session=tmp_session)
10951097
)["cursor"]
@@ -1098,7 +1100,6 @@ async def _list_collections(
10981100
cursor,
10991101
conn.address,
11001102
session=tmp_session,
1101-
explicit_session=session is not None,
11021103
comment=cmd.get("comment"),
11031104
)
11041105
await cmd_cursor._maybe_pin_connection(conn)

0 commit comments

Comments
 (0)